Skip to content

Enzyme AD backend limitations on Julia v1.12+  #1258

@ChrisRackauckas-Claude

Description

@ChrisRackauckas-Claude

Summary

During the refactoring of AD tests to use DifferentiationInterface (#1255), we discovered that Enzyme is missing rules for several cases.

Affected Tests

1. ModelingToolkit (MTK) type differentiation - null_de.jl

Error: StackOverflowError in EnzymeInterpreter

Cause: Missing Enzyme rules for complex ModelingToolkit types

Workaround: Enzyme backend is filtered out for MTK-related derivative tests in null_de.jl:119

Reproducer:

using OrdinaryDiffEq, ModelingToolkit, DifferentiationInterface
using ADTypes: AutoEnzyme
using Enzyme
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters P
@variables x(t)
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [P => NaN])

x_at_1(P) = solve(remake(prob; p = [sys.P => P]), Tsit5())(1.0, idxs = x)

# This will fail with StackOverflowError - missing rules
DifferentiationInterface.derivative(x_at_1, AutoEnzyme(), 1.0)

2. Ensemble problem differentiation - ensemble_ad.jl

Error: Issues with ensemble problem differentiation

Cause: Missing Enzyme rules for EnsembleProblem

Workaround: Enzyme backend is filtered out for ensemble gradient tests in ensemble_ad.jl:86 and ensemble_ad.jl:168

Reproducer:

using OrdinaryDiffEq, DifferentiationInterface
using ADTypes: AutoEnzyme
using Enzyme
using SciMLSensitivity

function f!(du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = -p[3] * u[2] + p[4] * u[1] * u[2]
end

u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(f!, u0, (0.0, 10.0), p)

N = 3
eu0 = rand(N, 2)
ep = rand(N, 4)

function sum_of_e_solution(params)
    ensemble_prob = EnsembleProblem(
        prob,
        prob_func = (prob, i, repeat) -> remake(prob, u0 = eu0[i, :], p = params[i, :])
    )
    sol = solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories = N)
    return sum(Array(sol.u[1]))
end

# This will fail - missing rules
DifferentiationInterface.gradient(sum_of_e_solution, AutoEnzyme(), ep)

3. Complex number ODE differentiation - complex_number_ad.jl

Error: Issues with complex number ODE differentiation

Cause: Missing Enzyme rules for complex number ODE problems

Workaround: Enzyme backend is filtered out for complex number tests in complex_number_ad.jl:107

Action Items

These cases need Enzyme rules to be implemented in the appropriate packages (SciMLSensitivity, ModelingToolkit, etc.)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions