Skip to content

Commit 4332e6d

Browse files
Merge pull request #1266 from ChrisRackauckas-Claude/add-flexunits-extension
Add FlexUnits.jl extension for ODE solver compatibility
2 parents b3342a4 + b107f0f commit 4332e6d

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3535
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3636
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3737
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
38+
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
3839
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3940
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
4041
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
@@ -51,6 +52,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
5152
DiffEqBaseCUDAExt = "CUDA"
5253
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
5354
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
55+
DiffEqBaseFlexUnitsExt = "FlexUnits"
5456
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
5557
DiffEqBaseGTPSAExt = "GTPSA"
5658
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
@@ -77,6 +79,7 @@ Enzyme = "0.13.100"
7779
FastBroadcast = "0.3.5"
7880
FastClosures = "0.3.2"
7981
FastPower = "1.1"
82+
FlexUnits = "0.4"
8083
ForwardDiff = "0.10, 1"
8184
FunctionWrappers = "1.0"
8285
FunctionWrappersWrappers = "0.1"
@@ -117,6 +120,7 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
117120
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
118121
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
119122
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
123+
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
120124
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
121125
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
122126
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -134,4 +138,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
134138
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
135139

136140
[targets]
137-
test = ["Distributed", "Measurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"]
141+
test = ["Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"]

ext/DiffEqBaseFlexUnitsExt.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module DiffEqBaseFlexUnitsExt
2+
3+
using DiffEqBase
4+
import SciMLBase: unitfulvalue, value
5+
using FlexUnits
6+
7+
# Support adaptive errors should be errorless for exponentiation
8+
value(::Type{Quantity{T, U}}) where {T, U} = T
9+
value(x::Quantity{T, U}) where {T, U} = dstrip(x)
10+
11+
unitfulvalue(::Type{T}) where {T <: Quantity} = T
12+
unitfulvalue(x::Quantity) = x
13+
14+
@inline function DiffEqBase.ODE_DEFAULT_NORM(
15+
u::AbstractArray{
16+
<:Quantity,
17+
N,
18+
},
19+
t
20+
) where {N}
21+
return sqrt(
22+
sum(
23+
x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
24+
zip((value(x) for x in u), Iterators.repeated(t))
25+
) / length(u)
26+
)
27+
end
28+
@inline function DiffEqBase.ODE_DEFAULT_NORM(
29+
u::Array{<:Quantity, N},
30+
t
31+
) where {N}
32+
return sqrt(
33+
sum(
34+
x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
35+
zip((value(x) for x in u), Iterators.repeated(t))
36+
) / length(u)
37+
)
38+
end
39+
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Quantity, t) = abs(value(u))
40+
@inline function DiffEqBase.UNITLESS_ABS2(x::Quantity)
41+
return real(abs2(dstrip(x)))
42+
end
43+
44+
DiffEqBase._rate_prototype(u, t::Quantity, onet) = u / unit(t)
45+
end

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
66
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
77
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
88
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
9+
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1011
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
1112
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"

test/downstream/flexunits.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using FlexUnits, FlexUnits.UnitRegistry, OrdinaryDiffEq, Test
2+
3+
f(du, u, p, t) = du .= 3 * u"1/s" * u
4+
prob = ODEProblem(f, [2.0u"m"], (0.0u"s", 1.0u"s"))
5+
intg = init(prob, Tsit5(), dt = 0.01u"s")
6+
@test_nowarn step!(intg, 0.02u"s", true)
7+
8+
@test SciMLBase.unitfulvalue(1.0u"1/s") == 1.0u"1/s"
9+
@test SciMLBase.value(1.0u"1/s") isa Real

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
@time @safetestset "Null DE Handling" include("downstream/null_de.jl")
4949
@time @safetestset "StaticArrays + AD" include("downstream/static_arrays_ad.jl")
5050
@time @safetestset "Unitful" include("downstream/unitful.jl")
51+
@time @safetestset "FlexUnits" include("downstream/flexunits.jl")
5152
@time @safetestset "Dual Detection Solution" include("downstream/dual_detection_solution.jl")
5253
@time @safetestset "Null Parameters" include("downstream/null_params_test.jl")
5354
@time @safetestset "Ensemble Simulations" include("downstream/ensemble.jl")

0 commit comments

Comments
 (0)