Skip to content

Commit ea15fd3

Browse files
Merge pull request #1079 from materialsproject/more-mlffs
Add MatterSim, Allegro and OCP models (`fairchem-core`) to `ase_calculators`
2 parents e8f2670 + a1b50a7 commit ea15fd3

File tree

11 files changed

+98
-55
lines changed

11 files changed

+98
-55
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ forcefields = [
6161
"matgl>=1.2.1",
6262
"torchdata<=0.7.1", # TODO: remove when issue fixed
6363
"quippy-ase>=0.9.14",
64+
"mattersim>=1.0.1",
6465
"sevenn>=0.9.3",
6566
"deepmd-kit>=2.1.4",
6667
]

src/atomate2/common/jobs/qha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Jobs for running qha calculations."""
1+
"""Jobs for running QHA calculations."""
22

33
from __future__ import annotations
44

src/atomate2/common/schemas/qha.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Schemas for qha documents."""
1+
"""Schemas for QHA documents."""
22

33
import logging
44
from typing import Union
@@ -15,7 +15,7 @@
1515

1616

1717
class PhononQHADoc(StructureMetadata, extra="allow"): # type: ignore[call-arg]
18-
"""Collection of all data produced by the qha workflow."""
18+
"""Collection of all data produced by the QHA workflow."""
1919

2020
structure: Structure | None = Field(
2121
None, description="Structure of Materials Project."
@@ -62,7 +62,7 @@ class PhononQHADoc(StructureMetadata, extra="allow"): # type: ignore[call-arg]
6262
description="Gruneisen parameters at temperatures.Shape: (temperatures,)",
6363
)
6464
pressure: float | None = Field(
65-
None, description="Pressure in GPA at which Gibb's energy was computed"
65+
None, description="Pressure in GPa at which the Gibbs energy was computed."
6666
)
6767
t_max: float | None = Field(
6868
None,
@@ -106,7 +106,7 @@ def from_phonon_runs(
106106
eos_type: str = "vinet",
107107
**kwargs,
108108
) -> Self:
109-
"""Generate qha results.
109+
"""Generate QHA results.
110110
111111
Parameters
112112
----------
@@ -149,35 +149,28 @@ def from_phonon_runs(
149149

150150
# create some plots here
151151
# add kwargs to change the names and file types
152+
fig_ext = kwargs.get("plot_type", "pdf")
152153
qha.plot_helmholtz_volume().savefig(
153-
f"{kwargs.get('helmholtz_volume_filename', 'helmholtz_volume')}"
154-
f".{kwargs.get('plot_type', 'pdf')}"
154+
f"{kwargs.get('helmholtz_volume_filename', 'helmholtz_volume')}.{fig_ext}"
155155
)
156156
qha.plot_volume_temperature().savefig(
157-
f"{kwargs.get('volume_temperature_plot', 'volume_temperature')}"
158-
f".{kwargs.get('plot_type', 'pdf')}"
157+
f"{kwargs.get('volume_temperature_plot', 'volume_temperature')}.{fig_ext}"
159158
)
160159
qha.plot_thermal_expansion().savefig(
161-
f"{kwargs.get('thermal_expansion_plot', 'thermal_expansion')}"
162-
f".{kwargs.get('plot_type', 'pdf')}"
160+
f"{kwargs.get('thermal_expansion_plot', 'thermal_expansion')}.{fig_ext}"
163161
)
164162
qha.plot_gibbs_temperature().savefig(
165-
f"{kwargs.get('gibbs_temperature_plot', 'gibbs_temperature')}"
166-
f".{kwargs.get('plot_type', 'pdf')}"
163+
f"{kwargs.get('gibbs_temperature_plot', 'gibbs_temperature')}.{fig_ext}"
167164
)
168165
qha.plot_bulk_modulus_temperature().savefig(
169-
f"{kwargs.get('bulk_modulus_plot', 'bulk_modulus_temperature')}"
170-
f".{kwargs.get('plot_type', 'pdf')}"
166+
f"{kwargs.get('bulk_modulus_plot', 'bulk_modulus_temperature')}.{fig_ext}"
171167
)
172168
qha.plot_heat_capacity_P_numerical().savefig(
173-
f"{kwargs.get('heat_capacity_plot', 'heat_capacity_P_numerical')}"
174-
f".{kwargs.get('plot_type', 'pdf')}"
169+
f"{kwargs.get('heat_capacity_plot', 'heat_capacity_P_numerical')}.{fig_ext}"
175170
)
176171
# qha.plot_heat_capacity_P_polyfit().savefig("heat_capacity_P_polyfit.eps")
177-
qha.plot_gruneisen_temperature().savefig(
178-
f"{kwargs.get('gruneisen_temperature_plot', 'gruneisen_temperature')}"
179-
f".{kwargs.get('plot_type', 'pdf')}"
180-
)
172+
ge_temp_plot = kwargs.get("gruneisen_temperature_plot", "gruneisen_temperature")
173+
qha.plot_gruneisen_temperature().savefig(f"{ge_temp_plot}.{fig_ext}")
181174

182175
qha.write_helmholtz_volume(
183176
filename=kwargs.get("helmholtz_volume_datafile", "helmholtz_volume.dat")
@@ -197,21 +190,16 @@ def from_phonon_runs(
197190
qha.write_gibbs_temperature(
198191
filename=kwargs.get("gibbs_temperature_datafile", "gibbs_temperature.dat")
199192
)
200-
qha.write_gruneisen_temperature(
201-
filename=kwargs.get(
202-
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
203-
)
193+
ge_temp_file = kwargs.get(
194+
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
204195
)
196+
qha.write_gruneisen_temperature(filename=ge_temp_file)
205197
qha.write_heat_capacity_P_numerical(
206198
filename=kwargs.get(
207199
"heat_capacity_datafile", "heat_capacity_P_numerical.dat"
208200
)
209201
)
210-
qha.write_gruneisen_temperature(
211-
filename=kwargs.get(
212-
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
213-
)
214-
)
202+
qha.write_gruneisen_temperature(filename=ge_temp_file)
215203

216204
# write files as well - might be easier for plotting
217205

src/atomate2/forcefields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
# ensure that this is still importable for legacy jobs
44
from atomate2.forcefields.utils import MLFF, _get_formatted_ff_name
5+
6+
__all__ = ["MLFF"]

src/atomate2/forcefields/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
4242
MATPES_R2SCAN = "MatPES-r2SCAN"
4343
MATPES_PBE = "MatPES-PBE"
4444
DeepMD = "DeepMD"
45+
Allegro = "Allegro"
46+
OCP = "OCP" # for loading model checkpoint with fairchem.core.OCPCalculator
47+
MatterSim = "MatterSim"
4548

4649
@classmethod
4750
def _missing_(cls, value: Any) -> Any:
@@ -353,6 +356,23 @@ def ase_calculator(
353356

354357
calculator = DP(**kwargs)
355358

359+
case MLFF.Allegro:
360+
from allegro.ase import AllegroCalculator
361+
362+
calculator = AllegroCalculator.from_deployed_model(**kwargs)
363+
364+
case MLFF.OCP:
365+
# Not available on PyPI, needs to be installed from source
366+
# see https://github.com/FAIR-Chem/fairchem?tab=readme-ov-file#installation
367+
from fairchem.core import OCPCalculator
368+
369+
calculator = OCPCalculator(**kwargs)
370+
371+
case MLFF.MatterSim:
372+
from mattersim.forcefield import MatterSimCalculator
373+
374+
calculator = MatterSimCalculator(**kwargs)
375+
356376
elif isinstance(calculator_meta, dict):
357377
calc_cls = _load_calc_cls(calculator_meta)
358378
calculator = calc_cls(**kwargs)

src/atomate2/vasp/flows/qha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class QhaMaker(CommonQhaMaker):
2020
First relax a structure using relax_maker.
2121
Then perform a series of deformations on the relaxed structure, and
2222
then compute harmonic phonons for each deformed structure.
23-
Finally, compute Gibb's free energy.
23+
Finally, compute Gibbs free energy.
2424
2525
Parameters
2626
----------

src/atomate2/vasp/sets/eos.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class EosSetGenerator(VaspInputGenerator):
2525
force_gamma: bool = True
2626
auto_ismear: bool = False
2727
auto_kspacing: bool = False
28-
inherit_incar: bool = False
28+
inherit_incar: bool | list[str] = False
2929

3030
@property
3131
def incar_updates(self) -> dict:
@@ -60,7 +60,7 @@ class MPLegacyEosRelaxSetGenerator(VaspInputGenerator):
6060
config_dict: dict = field(default_factory=lambda: MPRelaxSet.CONFIG)
6161
auto_ismear: bool = False
6262
auto_kspacing: bool = False
63-
inherit_incar: bool = False
63+
inherit_incar: bool | list[str] = False
6464

6565
@property
6666
def incar_updates(self) -> dict:
@@ -103,7 +103,7 @@ class MPLegacyEosStaticSetGenerator(EosSetGenerator):
103103
config_dict: dict = field(default_factory=lambda: MPRelaxSet.CONFIG)
104104
auto_ismear: bool = False
105105
auto_kspacing: bool = False
106-
inherit_incar: bool = False
106+
inherit_incar: bool | list[str] = False
107107

108108
@property
109109
def incar_updates(self) -> dict:
@@ -138,7 +138,7 @@ class MPGGAEosRelaxSetGenerator(VaspInputGenerator):
138138
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
139139
auto_ismear: bool = False
140140
auto_kspacing: bool = False
141-
inherit_incar: bool = False
141+
inherit_incar: bool | list[str] = False
142142

143143
@property
144144
def incar_updates(self) -> dict:
@@ -173,7 +173,7 @@ class MPGGAEosStaticSetGenerator(EosSetGenerator):
173173
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
174174
auto_ismear: bool = False
175175
auto_kspacing: bool = False
176-
inherit_incar: bool = False
176+
inherit_incar: bool | list[str] = False
177177

178178
@property
179179
def incar_updates(self) -> dict:
@@ -207,7 +207,7 @@ class MPMetaGGAEosStaticSetGenerator(VaspInputGenerator):
207207
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
208208
auto_ismear: bool = False
209209
auto_kspacing: bool = False
210-
inherit_incar: bool = False
210+
inherit_incar: bool | list[str] = False
211211

212212
@property
213213
def incar_updates(self) -> dict:
@@ -250,7 +250,7 @@ class MPMetaGGAEosRelaxSetGenerator(VaspInputGenerator):
250250
bandgap_tol: float = 1e-4
251251
auto_ismear: bool = False
252252
auto_kspacing: bool = False
253-
inherit_incar: bool = False
253+
inherit_incar: bool | list[str] = False
254254

255255
@property
256256
def incar_updates(self) -> dict:
@@ -295,7 +295,7 @@ class MPMetaGGAEosPreRelaxSetGenerator(VaspInputGenerator):
295295
bandgap_tol: float = 1e-4
296296
auto_ismear: bool = False
297297
auto_kspacing: bool = False
298-
inherit_incar: bool = False
298+
inherit_incar: bool | list[str] = False
299299

300300
@property
301301
def incar_updates(self) -> dict:

tests/forcefields/test_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
try:
1818
import dgl
19-
except ImportError:
19+
except Exception: # noqa: BLE001
2020
dgl = None
2121

2222

tests/forcefields/test_md.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ def test_maker_initialization():
4242
) == ForceFieldMDMaker(force_field_name=mlff)
4343

4444

45-
@pytest.mark.parametrize("ff_name, use_emmet_models", product(MLFF, [True, False]))
45+
_mlffs_for_test = set(MLFF).difference(
46+
map(MLFF, ("Forcefield", "MatterSim", "Allegro", "OCP", "M3GNet", "MACE"))
47+
)
48+
_md_test_params = sorted(product(_mlffs_for_test, [True, False]), key=lambda x: str(x))
49+
50+
51+
@pytest.mark.parametrize("ff_name, use_emmet_models", _md_test_params)
4652
def test_ml_ff_md_maker(
4753
ff_name,
4854
use_emmet_models,
@@ -53,14 +59,10 @@ def test_ml_ff_md_maker(
5359
clean_dir,
5460
get_deepmd_pretrained_model_path,
5561
):
56-
if ff_name in map(MLFF, ("Forcefield", "MACE")):
57-
return # nothing to test here, MLFF.Forcefield is just a generic placeholder
5862
if ff_name == MLFF.GAP and sys.version_info >= (3, 12):
5963
pytest.skip(
6064
"GAP model not compatible with Python 3.12, waiting on https://github.com/libAtoms/QUIP/issues/645"
6165
)
62-
if ff_name == MLFF.M3GNet:
63-
pytest.skip("M3GNet requires DGL which is PyTorch 2.4 incompatible")
6466

6567
n_steps = 5
6668

tests/forcefields/test_phonon.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,19 @@ def test_phonon_maker_initialization_with_all_mlff(
6868
# DGL which is PyTorch 2.4 incompatible, raises
6969
# "FileNotFoundError: Cannot find DGL C++ libgraphbolt_pytorch_2.4.1.so"
7070
skip_mlff = set(
71-
map(MLFF, ["Forcefield", "GAP", "M3GNet", "MATPES_R2SCAN", "MATPES_PBE"])
71+
map(
72+
MLFF,
73+
[
74+
"Forcefield",
75+
"GAP",
76+
"M3GNet",
77+
"MATPES_R2SCAN",
78+
"MATPES_PBE",
79+
"Allegro",
80+
"OCP",
81+
"MatterSim",
82+
],
83+
)
7284
)
7385
for mlff in set(MLFF).difference(skip_mlff):
7486
calc_kwargs = {

0 commit comments

Comments
 (0)