Skip to content

Commit f1c60f0

Browse files
committed
Add explicitly listed artifacts to sources
1 parent 91bcf64 commit f1c60f0

File tree

7 files changed

+45
-2
lines changed

7 files changed

+45
-2
lines changed

examples/advanced_example_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ seml:
6161
description: "An advanced example configuration. We can also use variable interpolation here: ${config.model.model_type}"
6262
reschedule_timeout: 300 # The time (in seconds) that are left on the job before SEML will try to reschedule unfinished experiments.
6363
# Note that you have to implement a `reschedule_hook` to use this feature.
64+
additional_artifacts:
65+
- artifacts/something
6466

6567
slurm:
6668
- experiments_per_job: 1

examples/advanced_example_experiment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
are parsed by a specific method. This avoids having one large "main" function which takes all parameters as input.
66
"""
77

8+
import logging
9+
810
import numpy as np
911
from seml import Experiment
1012

@@ -132,6 +134,11 @@ def init_preprocessing(self, mean: float, std: float):
132134
def init_augmentation(self, flip: bool):
133135
self.augmentation_parameters = (flip,)
134136

137+
def init_artifacts(self):
138+
# Load token from artifact specified in `seml.additional_artifacts
139+
with open("artifacts/something") as f:
140+
logging.info(f"Loaded artifact {f.read().strip()}")
141+
135142
def init_all(self):
136143
"""
137144
Sequentially run the sub-initializers of the experiment.
@@ -141,6 +148,7 @@ def init_all(self):
141148
self.init_optimizer()
142149
self.init_preprocessing()
143150
self.init_augmentation()
151+
self.init_artifacts()
144152

145153
@ex.capture(prefix="training")
146154
def train(self, patience, num_epochs):

examples/artifacts/something

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
<file content of the artifact>

src/seml/document.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class SemlDocBase(TypedDict, total=False):
4141
name: str
4242
stash_all_py_files: bool
4343
reschedule_timeout: int | None
44+
additional_artifacts: list[str]
4445

4546

4647
class SemlFileConfig(SemlDocBase, total=False):

src/seml/experiment/sources.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from seml.utils import (
1515
assert_package_installed,
1616
is_local_file,
17+
recursively_list_files,
1718
s_if,
1819
working_directory,
1920
)
@@ -75,7 +76,12 @@ def import_exe(executable: str, conda_env: str | None, working_dir: str):
7576

7677

7778
def get_imported_sources(
78-
executable, root_dir, conda_env, working_dir, stash_all_py_files: bool
79+
executable,
80+
root_dir,
81+
conda_env,
82+
working_dir,
83+
stash_all_py_files: bool,
84+
additional_artifacts: list[str] | None = None,
7985
) -> set[str]:
8086
"""Get the sources imported by the given executable.
8187
@@ -85,6 +91,7 @@ def get_imported_sources(
8591
conda_env (_type_): The experiment's Anaconda environment.
8692
working_dir (_type_): The working directory of the experiment.
8793
stash_all_py_files (_type_): Whether to stash all .py files in the working directory.
94+
additional_artifacts: list[str] | None: Additional artifacts to put into the source code files.
8895
8996
Returns:
9097
List[str]: The sources imported by the given executable.
@@ -114,6 +121,18 @@ def get_imported_sources(
114121
if is_local_file(file, root_path):
115122
sources.add(str(file))
116123

124+
for artifact in set().union(
125+
*(recursively_list_files(path) for path in additional_artifacts or [])
126+
):
127+
artifact = artifact.expanduser().resolve()
128+
# Check that the artifact is in `working_dir`
129+
if artifact.is_file() and is_local_file(str(artifact), root_path):
130+
sources.add(str(artifact))
131+
else:
132+
logging.warning(
133+
f'Additional artifact {artifact} is not a subpath of the root directory '
134+
f'{root_path} and will be ignored.'
135+
)
117136
return sources
118137

119138

@@ -124,13 +143,13 @@ def upload_sources(
124143

125144
with working_directory(seml_config['working_dir']):
126145
root_dir = str(Path(seml_config['working_dir']).expanduser().resolve())
127-
128146
sources = get_imported_sources(
129147
seml_config['executable'],
130148
root_dir=root_dir,
131149
conda_env=seml_config['conda_environment'],
132150
working_dir=seml_config['working_dir'],
133151
stash_all_py_files=seml_config.get('stash_all_py_files', False),
152+
additional_artifacts=seml_config.get('additional_artifacts', []),
134153
)
135154
executable_abs = str(Path(seml_config['executable']).expanduser().resolve())
136155

src/seml/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ class Settings(SettingsDict):
237237
'description',
238238
'stash_all_py_files',
239239
'reschedule_timeout',
240+
'additional_artifacts',
240241
],
241242
'SEML_CONFIG_VALUE_VERSION': 'version',
242243
'VALID_SLURM_CONFIG_VALUES': [

src/seml/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,14 @@ def drop_typeddict_difference(obj: TD1, cls: type[TD1], cls2: type[TD2]) -> TD2:
788788
if k in result:
789789
del result[k]
790790
return result # type: ignore
791+
792+
793+
def recursively_list_files(path: Path | str) -> set[Path]:
794+
"""Recursively lists all (resolved) files in the directory."""
795+
path = Path(path)
796+
if path.expanduser().resolve().is_file():
797+
return {path.expanduser().resolve()}
798+
elif path.expanduser().resolve().is_dir():
799+
return {p.expanduser().resolve() for p in path.rglob('*') if p.is_file()}
800+
else:
801+
raise ValueError(f'Path {path} is neither a file nor a directory.')

0 commit comments

Comments
 (0)