Skip to content

Commit c2d6850

Browse files
Merge pull request #1255 from ChrisRackauckas-Claude/use-differentiation-interface-for-ad-tests
Refactor AD tests to use DifferentiationInterface with version-dependent backends
2 parents 2f184a6 + 1e99ac9 commit c2d6850

File tree

8 files changed

+298
-62
lines changed

8 files changed

+298
-62
lines changed

.github/workflows/Tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ jobs:
2727
- Downstream
2828
- Downstream2
2929
- Static
30+
exclude:
31+
# Downstream tests use AD backends that have issues on Julia pre-release
32+
- version: "pre"
33+
group: Downstream
34+
- version: "pre"
35+
group: Downstream2
3036
uses: "SciML/.github/.github/workflows/tests.yml@v1"
3137
with:
3238
julia-version: "${{ matrix.version }}"

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ DiffEqBaseTrackerExt = "Tracker"
6565
DiffEqBaseUnitfulExt = "Unitful"
6666

6767
[compat]
68+
ADTypes = "1"
6869
ArrayInterface = "7.8"
6970
BracketingNonlinearSolve = "1.6.2"
7071
CUDA = "5"
7172
ChainRulesCore = "1"
7273
ConcreteStructs = "0.2.3"
74+
DifferentiationInterface = "0.7"
7375
Distributions = "0.25"
7476
DocStringExtensions = "0.9"
7577
Enzyme = "0.13.100"
@@ -106,17 +108,22 @@ SymbolicIndexingInterface = "0.3.39"
106108
Tracker = "0.2"
107109
TruncatedStacktraces = "1"
108110
Unitful = "1"
111+
Zygote = "0.6, 0.7"
109112
julia = "1.10"
110113

111114
[extras]
115+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
112116
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
113117
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
118+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
114119
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
115120
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
121+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
116122
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
117123
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
118124
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
119125
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
126+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
120127
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
121128
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
122129
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -126,6 +133,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
126133
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
127134
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
128135
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
136+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
129137

130138
[targets]
131139
test = ["Distributed", "Measurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"]

test/downstream/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
44
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
66
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
7+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
8+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
79
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
810
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
911
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
1012
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
13+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1114
MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"
1215
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1316
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
@@ -24,5 +27,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2427

2528
[compat]
2629
ADTypes = "1"
30+
DifferentiationInterface = "0.7"
31+
Enzyme = "0.13"
32+
Mooncake = "0.4"
2733
MultiScaleArrays = "1.8"
2834
OrdinaryDiffEq = "6.91.0"

test/downstream/complex_number_ad.jl

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,43 @@
11
using LinearAlgebra, OrdinaryDiffEq, Test
2-
import ForwardDiff
2+
using SciMLSensitivity # Required for reverse-mode AD
3+
using DifferentiationInterface
4+
using ADTypes: AutoForwardDiff, AutoMooncake
5+
6+
# Load backends for all versions (required for DifferentiationInterface extensions)
7+
using ForwardDiff
8+
using Mooncake
9+
10+
# Version-dependent imports
11+
if VERSION <= v"1.11"
12+
using Zygote
13+
using ADTypes: AutoZygote
14+
end
15+
if VERSION <= v"1.11"
16+
using Enzyme
17+
using ADTypes: AutoEnzyme
18+
end
19+
20+
# Define backends based on Julia version
21+
# ForwardDiff: All versions
22+
# Mooncake: All versions
23+
# Zygote: Julia <= 1.11
24+
# Enzyme: Julia <= 1.11
25+
function get_test_backends()
26+
backends = Pair{String, Any}[]
27+
# ForwardDiff on all versions
28+
push!(backends, "ForwardDiff" => AutoForwardDiff())
29+
# Mooncake on all versions
30+
push!(backends, "Mooncake" => AutoMooncake(; config = nothing))
31+
# Zygote only on Julia <= 1.11
32+
if VERSION <= v"1.11"
33+
push!(backends, "Zygote" => AutoZygote())
34+
end
35+
# Enzyme only on Julia <= 1.11
36+
if VERSION <= v"1.11"
37+
push!(backends, "Enzyme" => AutoEnzyme())
38+
end
39+
return backends
40+
end
341

442
# setup
543
pd = 3
@@ -64,16 +102,26 @@ function assert_fun()
64102
end
65103
@assert all([assert_fun() for _ in 1:(2^6)])
66104

67-
# test ad with ForwardDiff
68-
function test_ad()
105+
# test ad with DifferentiationInterface using multiple backends
106+
backends = get_test_backends()
107+
# Note: Mooncake and Enzyme are excluded due to issues with complex number ODE differentiation
108+
backends_for_complex = filter(b -> b[1] ("Mooncake", "Enzyme"), backends)
109+
110+
function test_ad_with_backend(backend, name)
69111
p0 = rand(3)
70-
grad_real = ForwardDiff.gradient(loss_via_real, p0)
71-
grad_complex = ForwardDiff.gradient(loss, p0)
112+
grad_real = DifferentiationInterface.gradient(loss_via_real, backend, p0)
113+
grad_complex = DifferentiationInterface.gradient(loss, backend, p0)
72114
any(isnan.(grad_complex)) &&
73-
@warn "NaN detected in gradient using ode with complex numbers !!"
74-
any(isnan.(grad_real)) && @warn "NaN detected in gradient using realified ode !!"
115+
@warn "NaN detected in gradient using ode with complex numbers with $name !!"
116+
any(isnan.(grad_real)) && @warn "NaN detected in gradient using realified ode with $name !!"
75117
rel_err = norm(grad_complex - grad_real) / max(norm(grad_complex), norm(grad_real))
76-
return isapprox(grad_complex, grad_real; rtol = 1.0e-6) ? true : (@show rel_err; false)
118+
return isapprox(grad_complex, grad_real; rtol = 1.0e-6) ? true : (@show name, rel_err; false)
77119
end
78120

79-
@time @test all([test_ad() for _ in 1:(2^6)])
121+
@testset "Complex number AD tests" begin
122+
for (name, backend) in backends_for_complex
123+
@testset "AD via ode with complex numbers ($name)" begin
124+
@time @test all([test_ad_with_backend(backend, name) for _ in 1:(2^6)])
125+
end
126+
end
127+
end

test/downstream/ensemble_ad.jl

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,44 @@
1-
using OrdinaryDiffEq, ForwardDiff, Zygote, Test
1+
using OrdinaryDiffEq, Test
22
using SciMLSensitivity
33
using Random
4+
using DifferentiationInterface
5+
using ADTypes: AutoForwardDiff, AutoMooncake
6+
7+
# Load backends for all versions (required for DifferentiationInterface extensions)
8+
using ForwardDiff
9+
using Mooncake
10+
11+
# Version-dependent imports
12+
if VERSION <= v"1.11"
13+
using Zygote
14+
using ADTypes: AutoZygote
15+
end
16+
if VERSION <= v"1.11"
17+
using Enzyme
18+
using ADTypes: AutoEnzyme
19+
end
20+
21+
# Define backends based on Julia version
22+
# ForwardDiff: All versions
23+
# Mooncake: All versions
24+
# Zygote: Julia <= 1.11
25+
# Enzyme: Julia <= 1.11
26+
function get_test_backends()
27+
backends = Pair{String, Any}[]
28+
# ForwardDiff on all versions
29+
push!(backends, "ForwardDiff" => AutoForwardDiff())
30+
# Mooncake on all versions
31+
push!(backends, "Mooncake" => AutoMooncake(; config = nothing))
32+
# Zygote only on Julia <= 1.11
33+
if VERSION <= v"1.11"
34+
push!(backends, "Zygote" => AutoZygote())
35+
end
36+
# Enzyme only on Julia <= 1.11
37+
if VERSION <= v"1.11"
38+
push!(backends, "Enzyme" => AutoEnzyme())
39+
end
40+
return backends
41+
end
442

543
function dt!(du, u, p, t)
644
x, y = u
@@ -40,10 +78,18 @@ end
4078

4179
test_loss(p, prob_ode)
4280

43-
@time gs = Zygote.gradient(p) do p
44-
test_loss(p, prob_ode)
81+
# Test gradient computation with DifferentiationInterface for each backend
82+
# Note: Mooncake and Enzyme are excluded for ensemble tests
83+
# Mooncake: StackOverflowError in rule compilation
84+
# Enzyme: Issues with ensemble problem differentiation
85+
backends = get_test_backends()
86+
backends_for_ensemble = filter(b -> b[1] ("Mooncake", "Enzyme"), backends)
87+
for (name, backend) in backends_for_ensemble
88+
@testset "Ensemble test_loss gradient with $name" begin
89+
@time gs = DifferentiationInterface.gradient(p -> test_loss(p, prob_ode), backend, p)
90+
@test gs isa Vector
91+
end
4592
end
46-
@test gs[1] isa Vector
4793

4894
### https://github.com/SciML/DiffEqFlux.jl/issues/595
4995

@@ -61,9 +107,16 @@ function sum_of_solution(x)
61107
_prob = remake(prob, u0 = x[1:2], p = x[3:end])
62108
return sum(solve(_prob, Tsit5(), saveat = 0.1))
63109
end
64-
Zygote.gradient(sum_of_solution, [u0; p])
65110

66-
# Testing ensemble problem. Works with ForwardDiff. Does not work with Zygote.
111+
# Test sum_of_solution gradient with all backends
112+
for (name, backend) in backends
113+
@testset "sum_of_solution gradient with $name" begin
114+
gs = DifferentiationInterface.gradient(sum_of_solution, backend, [u0; p])
115+
@test gs isa Vector
116+
end
117+
end
118+
119+
# Testing ensemble problem with ForwardDiff
67120
N = 3
68121
eu0 = rand(N, 2)
69122
ep = rand(N, 4)
@@ -106,7 +159,17 @@ end
106159

107160
sum_of_e_solution(ep)
108161

109-
x = ForwardDiff.gradient(sum_of_e_solution, ep)
110-
y = Zygote.gradient(sum_of_e_solution, ep)[1] # Zygote second to test cache of forward pass
111-
@test x y
112-
@test cache[] == 0:0.1:10.0 # test prob.kwargs is forwarded
162+
# Test ensemble AD with multiple backends and compare results
163+
# Note: Mooncake and Enzyme are excluded for ensemble tests (same as above)
164+
@testset "Ensemble AD comparison across backends" begin
165+
# Use ForwardDiff as reference
166+
x_ref = DifferentiationInterface.gradient(sum_of_e_solution, AutoForwardDiff(), ep)
167+
168+
for (name, backend) in backends_for_ensemble
169+
@testset "sum_of_e_solution gradient with $name" begin
170+
x = DifferentiationInterface.gradient(sum_of_e_solution, backend, ep)
171+
@test x x_ref
172+
end
173+
end
174+
@test cache[] == 0:0.1:10.0 # test prob.kwargs is forwarded
175+
end

test/downstream/gtpsa.jl

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test
1+
using OrdinaryDiffEq, GTPSA, Test
2+
using DifferentiationInterface
3+
using ADTypes: AutoForwardDiff
4+
using ForwardDiff
5+
6+
# GTPSA is itself an AD engine - these tests compare GTPSA jacobian/hessian
7+
# results against ForwardDiff as a reference implementation
28

39
# ODEProblem 1 =======================
410

@@ -20,23 +26,28 @@ sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol = 1.0e-16, abstol = 1.0e-16)
2026

2127
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
2228

23-
# Compare Jacobian against ForwardDiff
24-
J_FD = ForwardDiff.jacobian([x..., p...]) do t
29+
# Compare Jacobian against AD backends using DifferentiationInterface
30+
function sol_end_problem1(t)
2531
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
2632
sol = solve(prob, Tsit5(), reltol = 1.0e-16, abstol = 1.0e-16)
27-
sol.u[end]
33+
return sol.u[end]
2834
end
2935

30-
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params = true)
36+
@testset "GTPSA Problem 1 Jacobian tests" begin
37+
J_AD = DifferentiationInterface.jacobian(sol_end_problem1, AutoForwardDiff(), [x..., p...])
38+
@test J_AD GTPSA.jacobian(sol_GTPSA.u[end], include_params = true)
39+
end
3140

32-
# Compare Hessians against ForwardDiff
33-
for i in 1:3
34-
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
35-
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
36-
sol = solve(prob, Tsit5(), reltol = 1.0e-16, abstol = 1.0e-16)
37-
sol.u[end][i]
41+
@testset "GTPSA Problem 1 Hessian tests" begin
42+
for i in 1:3
43+
function sol_end_i_problem1(t)
44+
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
45+
sol = solve(prob, Tsit5(), reltol = 1.0e-16, abstol = 1.0e-16)
46+
return sol.u[end][i]
47+
end
48+
Hi_AD = DifferentiationInterface.hessian(sol_end_i_problem1, AutoForwardDiff(), [x..., p...])
49+
@test Hi_AD GTPSA.hessian(sol_GTPSA.u[end][i], include_params = true)
3850
end
39-
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params = true)
4051
end
4152

4253
# ODEProblem 2 =======================
@@ -49,31 +60,36 @@ function qdot!(dp, p, q, params, t)
4960
]
5061
end
5162

52-
prob = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0, 25.0))
53-
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
63+
prob2 = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0, 25.0))
64+
sol2 = solve(prob2, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
5465

55-
desc = Descriptor(6, 2) # 6 variables to 2nd order
56-
dx = @vars(desc) # identity map
57-
prob_GTPSA = DynamicalODEProblem(pdot!, qdot!, dx[1:3], dx[4:6], (0.0, 25.0))
58-
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
66+
desc2 = Descriptor(6, 2) # 6 variables to 2nd order
67+
dx2 = @vars(desc2) # identity map
68+
prob_GTPSA2 = DynamicalODEProblem(pdot!, qdot!, dx2[1:3], dx2[4:6], (0.0, 25.0))
69+
sol_GTPSA2 = solve(prob_GTPSA2, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
5970

60-
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
71+
@test sol2.u[end] scalar.(sol_GTPSA2.u[end]) # scalar gets 0th order part
6172

62-
# Compare Jacobian against ForwardDiff
63-
J_FD = ForwardDiff.jacobian(zeros(6)) do t
73+
# Compare Jacobian against AD backends using DifferentiationInterface
74+
function sol_end_problem2(t)
6475
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
6576
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
66-
sol.u[end]
77+
return sol.u[end]
6778
end
6879

69-
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params = true)
80+
@testset "GTPSA Problem 2 Jacobian tests" begin
81+
J_AD = DifferentiationInterface.jacobian(sol_end_problem2, AutoForwardDiff(), zeros(6))
82+
@test J_AD GTPSA.jacobian(sol_GTPSA2.u[end], include_params = true)
83+
end
7084

71-
# Compare Hessians against ForwardDiff
72-
for i in 1:6
73-
Hi_FD = ForwardDiff.hessian(zeros(6)) do t
74-
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
75-
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
76-
sol.u[end][i]
85+
@testset "GTPSA Problem 2 Hessian tests" begin
86+
for i in 1:6
87+
function sol_end_i_problem2(t)
88+
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
89+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1.0e-16, abstol = 1.0e-16)
90+
return sol.u[end][i]
91+
end
92+
Hi_AD = DifferentiationInterface.hessian(sol_end_i_problem2, AutoForwardDiff(), zeros(6))
93+
@test Hi_AD GTPSA.hessian(sol_GTPSA2.u[end][i], include_params = true)
7794
end
78-
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params = true)
7995
end

0 commit comments

Comments
 (0)