-
-
Notifications
You must be signed in to change notification settings - Fork 122
Description
Summary
During the refactoring of AD tests to use DifferentiationInterface (#1255), we discovered that Enzyme is missing rules for several cases.
Affected Tests
1. ModelingToolkit (MTK) type differentiation - null_de.jl
Error: StackOverflowError in EnzymeInterpreter
Cause: Missing Enzyme rules for complex ModelingToolkit types
Workaround: Enzyme backend is filtered out for MTK-related derivative tests in null_de.jl:119
Reproducer:
using OrdinaryDiffEq, ModelingToolkit, DifferentiationInterface
using ADTypes: AutoEnzyme
using Enzyme
using ModelingToolkit: t_nounits as t, D_nounits as D
@parameters P
@variables x(t)
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [P => NaN])
x_at_1(P) = solve(remake(prob; p = [sys.P => P]), Tsit5())(1.0, idxs = x)
# This will fail with StackOverflowError - missing rules
DifferentiationInterface.derivative(x_at_1, AutoEnzyme(), 1.0)2. Ensemble problem differentiation - ensemble_ad.jl
Error: Issues with ensemble problem differentiation
Cause: Missing Enzyme rules for EnsembleProblem
Workaround: Enzyme backend is filtered out for ensemble gradient tests in ensemble_ad.jl:86 and ensemble_ad.jl:168
Reproducer:
using OrdinaryDiffEq, DifferentiationInterface
using ADTypes: AutoEnzyme
using Enzyme
using SciMLSensitivity
function f!(du, u, p, t)
du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
end
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(f!, u0, (0.0, 10.0), p)
N = 3
eu0 = rand(N, 2)
ep = rand(N, 4)
function sum_of_e_solution(params)
ensemble_prob = EnsembleProblem(
prob,
prob_func = (prob, i, repeat) -> remake(prob, u0 = eu0[i, :], p = params[i, :])
)
sol = solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories = N)
return sum(Array(sol.u[1]))
end
# This will fail - missing rules
DifferentiationInterface.gradient(sum_of_e_solution, AutoEnzyme(), ep)3. Complex number ODE differentiation - complex_number_ad.jl
Error: Issues with complex number ODE differentiation
Cause: Missing Enzyme rules for complex number ODE problems
Workaround: Enzyme backend is filtered out for complex number tests in complex_number_ad.jl:107
Action Items
These cases need Enzyme rules to be implemented in the appropriate packages (SciMLSensitivity, ModelingToolkit, etc.)
Related
- PR Refactor AD tests to use DifferentiationInterface with version-dependent backends #1255 - Refactor AD tests to use DifferentiationInterface
- Issue Mooncake AD backend limitations with MTK and Ensemble problems #1256 - Mooncake AD backend limitations
- Issue Zygote AD backend limitations with MTK types #1257 - Zygote AD backend limitations