@@ -82,12 +82,21 @@ class for more details. Most importantly, the `simulation_type` needs to be set
8282`SimulationType.MINIMIZATION ` (see
8383:py:class: `SimulationType <mlip.simulation.enums.SimulationType> `).
8484
85+ .. note ::
86+
87+ The default timestep of 1.0 fs that is common for MD simulations may not be optimal
88+ for energy minimizations. We recommend to set this value to 0.1 fs when using the
89+ `SimulationType.MINIMIZATION ` mode with the JAX-MD backend.
90+
8591**Algorithms **: For MD, the NVT-Langevin algorithm is used
8692(see `here <https://jax-md.readthedocs.io/en/main/jax_md.simulate.html#jax_md.simulate.nvt_langevin >`_).
8793For energy minimization, the FIRE algorithm is used
8894(see `here <https://jax-md.readthedocs.io/en/main/jax_md.minimize.html#jax_md.minimize.fire_descent >`_).
8995We plan to provide more options in future versions of the library.
9096
97+ Furthermore, for MD simulations, we support running them in a **batched manner **.
98+ See :ref: `this <batched_simulations >` section below for more information.
99+
91100.. note ::
92101
93102 A special feature of the JAX-MD backend is that a simulation is divided into
@@ -203,6 +212,45 @@ The logger must be attached before starting the simulation.
203212In ASE, this logging function will be called depending on the logging interval set,
204213and in JAX-MD, it will be called after every episode.
205214
215+ .. _batched_simulations :
216+
217+ Batched simulations with JAX-MD
218+ -------------------------------
219+
220+ For MD simulations or energy minimizations with JAX-MD, we support running them in a
221+ batched manner for multiple systems. The API for this is straightforward,
222+ instead of passing a single `ase.Atoms ` object to the engine, we pass a list of them.
223+ After the simulation, the simulation state will contain lists of properties,
224+ for example, a list of position arrays (i.e., the trajectories) instead of a single
225+ position array. Note that it is also supported that the input molecules have
226+ varying sizes. See example code below:
227+
228+ .. code-block :: python
229+
230+ from ase.io import read as ase_read
231+ from mlip.simulation.jax_md import JaxMDSimulationEngine
232+
233+ systems = []
234+ for path in [" /path/to/mol_1" , " /path/to/mol_2" , " /path/to/mol_3" ]:
235+ atoms = ase_read(path)
236+ systems.append(atoms)
237+
238+ force_field, md_config = _get_from_somewhere() # placeholder
239+ md_engine = JaxMDSimulationEngine(systems, force_field, md_config)
240+ md_engine.run()
241+
242+ # Fetch results:
243+ # Get trajectory and temperatures for "/path/to/mol_2" (indexing starts at 0)
244+ md_state = md_engine.state
245+ print (md_state.positions[1 ])
246+ print (md_state.temperature[1 ])
247+
248+ # Compute time, for example, is not a list
249+ print (md_state.compute_time_seconds)
250+
251+ The example above works for both energy minimizations and MD simulations in the same
252+ way.
253+
206254.. _batched_inference :
207255
208256Batched inference
0 commit comments