Skip to content

Commit a84c071

Browse files
Ig-dolcipbrubeck
andauthored
Jax ml operator fix (#4041)
--------- Co-authored-by: Pablo Brubeck <brubeck@protonmail.com>
1 parent 9e21c8b commit a84c071

File tree

11 files changed

+101
-80
lines changed

11 files changed

+101
-80
lines changed

firedrake/external_operators/abstract_external_operators.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,22 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
4747
Parameters
4848
----------
4949
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
50-
Operands of the external operator.
50+
Operands of the external operator.
5151
function_space : firedrake.functionspaceimpl.WithGeometryBase
52-
The function space the external operator is mapping to.
52+
The function space the external operator is mapping to.
5353
derivatives : tuple
54-
Tuple specifiying the derivative multiindex.
54+
Tuple specifiying the derivative multiindex.
5555
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
56-
Tuple containing the arguments of the linear form associated with the external operator,
57-
i.e. the arguments with respect to which the external operator is linear. Those arguments
58-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
59-
as a result of taking the action on a given function.
56+
Tuple containing the arguments of the linear form associated with the external operator,
57+
i.e. the arguments with respect to which the external operator is linear. Those arguments can
58+
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
59+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
60+
of taking the action on a given function.
6061
operator_data : dict
61-
Dictionary containing the data of the external operator, i.e. the external data
62-
specific to the external operator subclass considered. This dictionary will be passed on
63-
over the UFL symbolic reconstructions making the operator data accessible to the external operators
64-
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
62+
Dictionary containing the data of the external operator, i.e. the external data
63+
specific to the external operator subclass considered. This dictionary will be passed on
64+
over the UFL symbolic reconstructions making the operator data accessible to the external operators
65+
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
6566
"""
6667
from firedrake_citations import Citations
6768
Citations().register("Bouziani2021")

firedrake/external_operators/ml_operator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,25 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
1515
Parameters
1616
----------
1717
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
18-
Operands of the ML operator.
18+
Operands of the ML operator.
1919
function_space : firedrake.functionspaceimpl.WithGeometryBase
20-
The function space the ML operator is mapping to.
20+
The function space the ML operator is mapping to.
2121
derivatives : tuple
22-
Tuple specifiying the derivative multiindex.
22+
Tuple specifiying the derivative multiindex.
2323
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
24-
Tuple containing the arguments of the linear form associated with the ML operator,
25-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
26-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
27-
as a result of taking the action on a given function.
24+
Tuple containing the arguments of the linear form associated with the ML operator,
25+
i.e. the arguments with respect to which the ML operator is linear. Those arguments can
26+
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
27+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
28+
of taking the action on a given function. If argument slots are not provided, then they will
29+
be generated in the :class:`.AbstractExternalOperator` constructor.
2830
operator_data : dict
29-
Dictionary to stash external data specific to the ML operator. This dictionary must
30-
at least contain the following:
31-
(i) 'model': The machine learning model implemented in the ML framework considered.
32-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
33-
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
31+
Dictionary to stash external data specific to the ML operator. This dictionary must
32+
at least contain the following:
33+
(i) 'model': The machine learning model implemented in the ML framework considered.
34+
(ii) 'inputs_format': The format of the inputs to the ML model: ``0`` for models acting globally
35+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
36+
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
3437
"""
3538
AbstractExternalOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
3639
argument_slots=argument_slots, operator_data=operator_data)

firedrake/ml/jax/fem_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def bwd(self, _, grad_output: "jax.Array") -> "jax.Array":
9292
adj_input = float(adj_input)
9393

9494
# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
95-
adj_output = self.F.derivative(adj_input=adj_input)
95+
adj_output = self.F.derivative(adj_input=adj_input, options={'riesz_representation': None})
9696

9797
# Tuplify adjoint output
9898
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output

firedrake/ml/jax/ml_operator.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,40 @@ def __init__(
3939
*operands: Union[ufl.core.expr.Expr, ufl.form.BaseForm],
4040
function_space: WithGeometryBase,
4141
derivatives: Optional[tuple] = None,
42-
argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]],
42+
argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] = (),
4343
operator_data: Optional[dict] = {}
4444
):
45-
"""External operator class representing machine learning models implemented in JAX.
45+
"""
46+
External operator class representing machine learning models implemented in JAX.
4647
4748
The :class:`.JaxOperator` allows users to embed machine learning models implemented in JAX
48-
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` is
49-
delegated to the specified JAX model. Similarly, differentiation through the :class:`.JaxOperator`
50-
class is achieved using JAX differentiation on the JAX model associated with the :class:`.JaxOperator` object.
49+
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator`
50+
is delegated to the specified JAX model. Similarly, differentiation through the
51+
:class:`.JaxOperator` is achieved using JAX differentiation on the associated JAX model.
5152
5253
Parameters
5354
----------
5455
*operands
55-
Operands of the :class:`.JaxOperator`.
56+
Operands of the :class:`.JaxOperator`.
5657
function_space
57-
The function space the ML operator is mapping to.
58+
The function space the ML operator is mapping to.
5859
derivatives
59-
Tuple specifiying the derivative multiindex.
60-
*argument_slots
61-
Tuple containing the arguments of the linear form associated with the ML operator,
62-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
63-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
64-
as a result of taking the action on a given function.
60+
Tuple specifying the derivative multi-index.
61+
argument_slots
62+
Tuple containing the arguments of the linear form associated with the ML operator,
63+
i.e., the arguments with respect to which the ML operator is linear. These arguments
64+
can be ``ufl.argument.BaseArgument`` objects, as a result of differentiation,
65+
or both ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` objects,
66+
as a result of taking the action on a given function.
6567
operator_data
66-
Dictionary to stash external data specific to the ML operator. This dictionary must
67-
at least contain the following:
68-
(i) 'model': The machine learning model implemented in JaX
69-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
70-
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
68+
Dictionary to stash external data specific to the ML operator. This dictionary must
69+
contain the following:
70+
(i) ``'model'`` : The machine learning model implemented in JaX.
71+
(ii) ``'model'`` : The format of the inputs to the ML model: ``0`` for models acting
72+
globally on the inputs. ``1`` for models acting locally/pointwise on the inputs.
73+
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
7174
"""
75+
7276
MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
7377
argument_slots=argument_slots, operator_data=operator_data)
7478

@@ -90,8 +94,7 @@ def _pre_forward_callback(self, *operands: Union[Function, Cofunction], unsqueez
9094

9195
def _post_forward_callback(self, y_P: "jax.Array") -> Union[Function, Cofunction]:
9296
"""Callback function to convert the JAX output of the ML model to a Firedrake function."""
93-
space = self.ufl_function_space()
94-
return from_jax(y_P, space)
97+
return from_jax(y_P, self.ufl_function_space())
9598

9699
# -- JAX routines for computing AD-based quantities -- #
97100

firedrake/ml/pytorch/fem_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def backward(ctx, grad_output):
8383
adj_input = float(adj_input)
8484

8585
# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
86-
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"})
86+
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None})
8787

8888
# Tuplify adjoint output
8989
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output

firedrake/ml/pytorch/ml_operator.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,24 @@ class is achieved via the `torch.autograd` module, which provides automatic diff
4040
Parameters
4141
----------
4242
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
43-
Operands of the :class:`.PytorchOperator`.
43+
Operands of the :class:`.PytorchOperator`.
4444
function_space : firedrake.functionspaceimpl.WithGeometryBase
45-
The function space the ML operator is mapping to.
45+
The function space the ML operator is mapping to.
4646
derivatives : tuple
47-
Tuple specifiying the derivative multiindex.
47+
Tuple specifiying the derivative multiindex.
4848
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
49-
Tuple containing the arguments of the linear form associated with the ML operator,
50-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
51-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
52-
as a result of taking the action on a given function.
49+
Tuple containing the arguments of the linear form associated with the ML operator, i.e. the
50+
arguments with respect to which the ML operator is linear. Those arguments can be
51+
``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
52+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
53+
of taking the action on a given function.
5354
operator_data : dict
54-
Dictionary to stash external data specific to the ML operator. This dictionary must
55-
at least contain the following:
56-
(i) 'model': The machine learning model implemented in PyTorch.
57-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
58-
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
55+
Dictionary to stash external data specific to the ML operator. This dictionary must
56+
at least contain the following:
57+
(i) ``'model'``: The machine learning model implemented in PyTorch.
58+
(ii) ``'inputs_format'``: The format of the inputs to the ML model: ``0`` for models acting globally
59+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
60+
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
5961
"""
6062
MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
6163
argument_slots=argument_slots, operator_data=operator_data)
@@ -98,8 +100,7 @@ def _pre_forward_callback(self, *operands, unsqueeze=False):
98100

99101
def _post_forward_callback(self, y_P):
100102
"""Callback function to convert the PyTorch output of the ML model to a Firedrake function."""
101-
space = self.ufl_function_space()
102-
return from_torch(y_P, space)
103+
return from_torch(y_P, self.ufl_function_space())
103104

104105
# -- PyTorch routines for computing AD based quantities via `torch.autograd.functional` -- #
105106

scripts/firedrake-install

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ def build_and_install_jax():
13031303
"""Install JAX for a CPU or CUDA backend."""
13041304
log.info("Installing JAX (backend: %s)" % args.jax)
13051305
version_name = "jax" if args.jax == "cpu" else "jax[cuda12]"
1306-
run_pip_install([version_name])
1306+
run_pip_install([version_name] + ["jaxlib"] + ["ml_dtypes"] + ["opt_einsum"])
13071307

13081308

13091309
def build_and_install_slepc():

tests/firedrake/conftest.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,25 +98,27 @@ def pytest_collection_modifyitems(session, config, items):
9898
if item.get_closest_marker("skipmumps") is not None:
9999
item.add_marker(pytest.mark.skip("MUMPS not installed with PETSc"))
100100

101-
if not torch_backend:
102-
if item.get_closest_marker("skiptorch") is not None:
103-
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))
104-
105-
if not jax_backend:
106-
if item.get_closest_marker("skipjax") is not None:
107-
item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed"))
108-
109-
if not matplotlib_installed:
110-
if item.get_closest_marker("skipplot") is not None:
111-
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed"))
112-
113-
if not netgen_installed:
114-
if item.get_closest_marker("skipnetgen") is not None:
115-
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed"))
116-
117-
if not vtk_installed:
118-
if item.get_closest_marker("skipvtk") is not None:
119-
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed"))
101+
import os
102+
if os.getenv("FIREDRAKE_CI_TESTS") != "1":
103+
if not torch_backend:
104+
if item.get_closest_marker("skiptorch") is not None:
105+
item.add_marker(pytest.mark.skip(reason="Test makes no sense if PyTorch is not installed"))
106+
107+
if not jax_backend:
108+
if item.get_closest_marker("skipjax") is not None:
109+
item.add_marker(pytest.mark.skip(reason="Test makes no sense if JAX is not installed"))
110+
111+
if not matplotlib_installed:
112+
if item.get_closest_marker("skipplot") is not None:
113+
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Matplotlib is installed"))
114+
115+
if not netgen_installed:
116+
if item.get_closest_marker("skipnetgen") is not None:
117+
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless Netgen and ngsPETSc are installed"))
118+
119+
if not vtk_installed:
120+
if item.get_closest_marker("skipvtk") is not None:
121+
item.add_marker(pytest.mark.skip(reason="Test cannot be run unless VTK is installed"))
120122

121123

122124
@pytest.fixture(scope="module", autouse=True)

tests/firedrake/external_operators/test_external_operators_adjoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@ def J(f):
7777
c = Control(f)
7878
Jhat = ReducedFunctional(J(f), c)
7979

80-
f_opt = minimize(Jhat, tol=1e-6, method="BFGS")
80+
f_opt = minimize(Jhat, tol=1e-4, method="BFGS")
8181

8282
assert assemble((f_exact - f_opt)**2 * dx) / assemble(f_exact**2 * dx) < 1e-5

tests/firedrake/external_operators/test_jax_operator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def test_forward(u, nn):
9696

9797
# Assemble NeuralNet operator
9898
assembled_N = assemble(N)
99+
assert isinstance(assembled_N, Function)
99100

100101
# Convert from Firedrake to JAX
101102
x_P = to_jax(u)
@@ -125,6 +126,8 @@ def test_jvp(u, nn):
125126
# Assemble
126127
dN = assemble(dN)
127128

129+
assert isinstance(dN, Function)
130+
128131
# Convert from Firedrake to JAX
129132
delta_u_P = to_jax(delta_u)
130133
u_P = to_jax(u)
@@ -153,6 +156,8 @@ def test_vjp(u, nn):
153156
# Assemble
154157
dN_adj = assemble(dNdu)
155158

159+
assert isinstance(dN_adj, Cofunction)
160+
156161
# Convert from Firedrake to JAX
157162
delta_N_P = to_jax(delta_N)
158163
u_P = to_jax(u)

0 commit comments

Comments
 (0)