Skip to content

Commit 526bf78

Browse files
github-actions[bot]chrbrunk
authored andcommitted
release: create release-0.1.5 branch
1 parent 5543db5 commit 526bf78

File tree

117 files changed

+1723
-833
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+1723
-833
lines changed

.github/workflows/deploy_docs.yaml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@ jobs:
1414
image: python:3.12-slim-bullseye
1515
steps:
1616
- uses: actions/checkout@v4
17-
- name: Install dependencies
17+
18+
- name: Setup
1819
run: |
1920
apt-get update && apt-get install -y coreutils git
20-
POETRY_VERSION=1.8.4
21-
pip install -U pip setuptools
22-
pip install poetry==${POETRY_VERSION}
23-
poetry install
24-
poetry run pip install git+https://github.com/jax-md/jax-md.git
21+
22+
- name: Install uv
23+
uses: astral-sh/setup-uv@v6
2524

2625
- name: Sphinx build
2726
run: |
28-
poetry run sphinx-build -b html docs/source _build
27+
uv run sphinx-build -b html docs/source _build
2928
3029
- name: Deploy to GitHub Pages
3130
uses: peaceiris/actions-gh-pages@v3

.github/workflows/publish.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ jobs:
1717
with:
1818
token: ${{ secrets.GITHUB_TOKEN }}
1919

20-
- name: install poetry
21-
run: |
22-
POETRY_VERSION=1.8.4
23-
pip install -U pip setuptools
24-
pip install poetry==${POETRY_VERSION}
20+
- name: setup
21+
run: apt-get update && apt-get install -y coreutils
22+
23+
- name: install uv
24+
uses: astral-sh/setup-uv@v6
2525

2626
- name: build and publish
2727
run: |
28-
export POETRY_PYPI_TOKEN_PYPI=${{secrets.POETRY_PYPI_TOKEN}}
29-
poetry build
30-
poetry publish
28+
uv build
29+
uv publish --token ${{secrets.POETRY_PYPI_TOKEN}}

.github/workflows/tests_and_linters.yaml

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ jobs:
2222
- name: checkout code 📦
2323
uses: actions/checkout@v4
2424

25-
- name: install poetry
26-
run: |
27-
POETRY_VERSION=1.8.4
28-
pip install -U pip setuptools
29-
pip install poetry==${POETRY_VERSION}
30-
poetry install
25+
- name: install uv
26+
uses: astral-sh/setup-uv@v6
27+
28+
- name: setup env
29+
run: uv sync --only-dev
3130

3231
- name: run linters 🖌️
3332
run: |
34-
poetry run pre-commit run --all-files --verbose
33+
git init
34+
uv run --no-sync pre-commit run --all-files --verbose
3535
3636
tests:
3737
runs-on: ubuntu-latest
@@ -41,21 +41,25 @@ jobs:
4141
- name: checkout code 📂
4242
uses: actions/checkout@v4
4343

44-
- name: install poetry
44+
- name: setup
4545
run: |
46-
POETRY_VERSION=1.8.4
47-
pip install -U pip setuptools
48-
pip install poetry==${POETRY_VERSION}
49-
poetry install
50-
poetry add git+https://github.com/jax-md/jax-md.git
46+
apt-get update && apt-get install -y \
47+
coreutils \
48+
git
49+
50+
- name: install uv
51+
uses: astral-sh/setup-uv@v6
52+
53+
- name: setup env
54+
run: uv sync
5155

5256
- name: run tests 🧪
5357
run: |
54-
poetry run pytest --ignore tests/experiments --verbose \
55-
--cov-report xml:coverage.xml \
56-
--cov-report term-missing \
57-
--junitxml=pytest.xml \
58-
--cov=mlip tests/
58+
uv run --no-sync pytest --ignore tests/experiments --verbose \
59+
--cov-report xml:coverage.xml \
60+
--cov-report term-missing \
61+
--junitxml=pytest.xml \
62+
--cov=mlip tests/
5963
6064
- name: pytest coverage comment
6165
id: coverageComment

CHANGELOG

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
# Changelog
22

3+
## Release 0.1.5
4+
5+
- Adding batched simulations feature for MD simulations and energy minimizations
6+
with the JAX-MD backend.
7+
- Removing now useless `stress_virial` prediction.
8+
- Fixing correctness of `stress` and 0K `pressure` predictions. In 0.1.4,
9+
the stress computation actually involved a derivative with respect to
10+
cell but with fixed positions. Now, the strain also acts on positions within
11+
the unit cell, thus deforming the material homogeneously. This rigorously
12+
translation-invariant stress exempts from any Virial term correction of
13+
cell boundary effects. See for instance
14+
[Thompson, Plimpton and Mattson 2009, eq (2)](https://doi.org/10.1063/1.3245303).
15+
- Migrating from poetry to uv for dependency and package management.
16+
- Improving inefficient logging strategy in ASE simulation backend.
17+
- Clarifying in the documentation that we recommend a smaller value for the timestep
18+
when running energy minimizations with the JAX-MD simulation backend.
19+
- Removing need for separate install command for JAX-MD dependency.
20+
- Adding easier install method for GPU-compatible JAX.
21+
322
## Release 0.1.4
423

524
- Removing constraints on some dependencies, such as numpy, jax, and flax. The mlip

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
FROM python:3.10.12-slim-bullseye
1+
FROM python:3.12-slim-bullseye
22

33
WORKDIR /app
44

55
RUN apt-get update && apt-get install -y git wget
66

7-
RUN pip install mlip "jax[cuda12]==0.6.2" huggingface_hub git+https://github.com/jax-md/jax-md.git notebook
7+
RUN pip install mlip[cuda] huggingface_hub notebook
88

99
RUN wget https://raw.githubusercontent.com/instadeepai/mlip/refs/heads/main/tutorials/simulation_tutorial.ipynb \
1010
https://raw.githubusercontent.com/instadeepai/mlip/refs/heads/main/tutorials/model_training_tutorial.ipynb \

README.md

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# 🪩 MLIP: Machine Learning Interatomic Potentials 🚀
22

3+
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
4+
[![Python 3.11](https://img.shields.io/badge/python-3.11%20%7C%203.12%20%7C%203.13-blue)](https://www.python.org/downloads/release/python-3110/)
5+
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
6+
[![Tests and Linters 🧪](https://github.com/instadeepai/mlip/actions/workflows/tests_and_linters.yaml/badge.svg?branch=main)](https://github.com/instadeepai/mlip/actions/workflows/tests_and_linters.yaml)
37
![badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/mlipbot/b6e4bf384215e60775699a83c3c00aef/raw/pytest-coverage-comment.json)
48

59
## 👀 Overview
@@ -12,6 +16,7 @@ the following functionality:
1216
- Training and fine-tuning MLIP models
1317
- Batched inference with trained MLIP models
1418
- MD simulations with MLIP models using multiple simulation backends (for now: JAX-MD and ASE)
19+
- Batched MD simulations and energy minimizations with the JAX-MD simulation backend.
1520
- Energy minimizations with MLIP models using the same simulation backends as for MD.
1621

1722
The purpose of the library is to provide users with a toolbox
@@ -49,29 +54,18 @@ pip install mlip
4954

5055
However, this command **only installs the regular CPU version** of JAX.
5156
We recommend that the library is run on GPU.
52-
This requires also installing the necessary versions
53-
of [jaxlib](https://pypi.org/project/jaxlib/) which can also be installed via pip. See
54-
the [installation guide of JAX](https://docs.jax.dev/en/latest/installation.html) for
55-
more information.
56-
At time of release, the following install command is supported:
57+
Use this command instead to install the GPU-compatible version:
5758

5859
```bash
59-
pip install -U "jax[cuda12]"
60+
pip install mlip[cuda]
6061
```
6162

62-
Note that using the TPU version of *jaxlib* is, in principle, also supported by
63-
this library. However, it has not been thoroughly tested and should therefore be
64-
considered an experimental feature.
63+
**This command installs the CUDA 12 version of JAX.** For different versions, please
64+
install *mlip* without the `cuda` flag and install the desired JAX version via pip.
6565

66-
Also, some tasks in *mlip* will
67-
require [JAX-MD](https://github.com/jax-md/jax-md>) as a dependency. As the newest
68-
version of JAX-MD is not available on PyPI yet, this dependency will not
69-
be shipped with *mlip* automatically and instead must be installed
70-
directly from the GitHub repository, like this:
71-
72-
```bash
73-
pip install git+https://github.com/jax-md/jax-md.git
74-
```
66+
Note that using the TPU version of JAX is, in principle, also supported by
67+
this library. You need to install it separately via pip. However, it has not been
68+
thoroughly tested and should therefore be considered an experimental feature.
7569

7670
## ⚡ Examples
7771

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ in JAX. It contains the following features:
1515
* Batched inference with trained MLIP models
1616
* MD simulations with MLIP models using multiple simulation backends
1717
* Energy minimizations with MLIP models using multiple simulation backends
18+
* Batched MD simulations and energy minimizations with the JAX-MD simulation backend
1819
* Fine-tuning of pre-trained MLIP models
1920

2021
As a first step, we recommend that you check out our page on :ref:`installation`

docs/source/installation/index.rst

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,15 @@ The *mlip* library can be installed via pip:
1111
1212
However, this command **only installs the regular CPU version** of JAX.
1313
We recommend that the library is run on GPU.
14-
This requires also installing the necessary versions
15-
of `jaxlib <https://pypi.org/project/jaxlib/>`_ which can also be installed via pip. See
16-
the `installation guide of JAX <https://docs.jax.dev/en/latest/installation.html>`_ for
17-
more information.
18-
At time of release, the following install command is supported:
14+
Use this command instead to install the GPU-compatible version:
1915

2016
.. code-block:: bash
2117
22-
pip install -U "jax[cuda12]"
18+
pip install mlip[cuda]
2319
24-
Note that using the TPU version of *jaxlib* is, in principle, also supported by
25-
this library. However, it has not been thoroughly tested and should therefore be
26-
considered an experimental feature.
20+
**This command installs the CUDA 12 version of JAX.** For different versions, please
21+
install *mlip* without the `cuda` flag and install the desired JAX version via pip.
2722

28-
Also, some tasks in *mlip* will
29-
require `JAX-MD <https://github.com/jax-md/jax-md>`_ as a dependency. As the newest
30-
version of JAX-MD is not available on PyPI yet, this dependency will not
31-
be shipped with *mlip* automatically and instead must be installed
32-
directly from the GitHub repository, like this:
33-
34-
.. code-block:: bash
35-
36-
pip install git+https://github.com/jax-md/jax-md.git
23+
Note that using the TPU version of JAX is, in principle, also supported by
24+
this library. You need to install it separately via pip. However, it has not been
25+
thoroughly tested and should therefore be considered an experimental feature.

docs/source/user_guide/simulations.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,21 @@ class for more details. Most importantly, the `simulation_type` needs to be set
8282
`SimulationType.MINIMIZATION` (see
8383
:py:class:`SimulationType <mlip.simulation.enums.SimulationType>`).
8484

85+
.. note::
86+
87+
The default timestep of 1.0 fs that is common for MD simulations may not be optimal
88+
for energy minimizations. We recommend to set this value to 0.1 fs when using the
89+
`SimulationType.MINIMIZATION` mode with the JAX-MD backend.
90+
8591
**Algorithms**: For MD, the NVT-Langevin algorithm is used
8692
(see `here <https://jax-md.readthedocs.io/en/main/jax_md.simulate.html#jax_md.simulate.nvt_langevin>`_).
8793
For energy minimization, the FIRE algorithm is used
8894
(see `here <https://jax-md.readthedocs.io/en/main/jax_md.minimize.html#jax_md.minimize.fire_descent>`_).
8995
We plan to provide more options in future versions of the library.
9096

97+
Furthermore, for MD simulations, we support running them in a **batched manner**.
98+
See :ref:`this <batched_simulations>` section below for more information.
99+
91100
.. note::
92101

93102
A special feature of the JAX-MD backend is that a simulation is divided into
@@ -203,6 +212,45 @@ The logger must be attached before starting the simulation.
203212
In ASE, this logging function will be called depending on the logging interval set,
204213
and in JAX-MD, it will be called after every episode.
205214

215+
.. _batched_simulations:
216+
217+
Batched simulations with JAX-MD
218+
-------------------------------
219+
220+
For MD simulations or energy minimizations with JAX-MD, we support running them in a
221+
batched manner for multiple systems. The API for this is straightforward,
222+
instead of passing a single `ase.Atoms` object to the engine, we pass a list of them.
223+
After the simulation, the simulation state will contain lists of properties,
224+
for example, a list of position arrays (i.e., the trajectories) instead of a single
225+
position array. Note that it is also supported that the input molecules have
226+
varying sizes. See example code below:
227+
228+
.. code-block:: python
229+
230+
from ase.io import read as ase_read
231+
from mlip.simulation.jax_md import JaxMDSimulationEngine
232+
233+
systems = []
234+
for path in ["/path/to/mol_1", "/path/to/mol_2", "/path/to/mol_3"]:
235+
atoms = ase_read(path)
236+
systems.append(atoms)
237+
238+
force_field, md_config = _get_from_somewhere() # placeholder
239+
md_engine = JaxMDSimulationEngine(systems, force_field, md_config)
240+
md_engine.run()
241+
242+
# Fetch results:
243+
# Get trajectory and temperatures for "/path/to/mol_2" (indexing starts at 0)
244+
md_state = md_engine.state
245+
print(md_state.positions[1])
246+
print(md_state.temperature[1])
247+
248+
# Compute time, for example, is not a list
249+
print(md_state.compute_time_seconds)
250+
251+
The example above works for both energy minimizations and MD simulations in the same
252+
way.
253+
206254
.. _batched_inference:
207255

208256
Batched inference

0 commit comments

Comments
 (0)