1414from seml .utils import (
1515 assert_package_installed ,
1616 is_local_file ,
17+ recursively_list_files ,
1718 s_if ,
1819 working_directory ,
1920)
@@ -75,7 +76,12 @@ def import_exe(executable: str, conda_env: str | None, working_dir: str):
7576
7677
7778def get_imported_sources (
78- executable , root_dir , conda_env , working_dir , stash_all_py_files : bool
79+ executable ,
80+ root_dir ,
81+ conda_env ,
82+ working_dir ,
83+ stash_all_py_files : bool ,
84+ additional_artifacts : list [str ] | None = None ,
7985) -> set [str ]:
8086 """Get the sources imported by the given executable.
8187
@@ -85,6 +91,7 @@ def get_imported_sources(
8591 conda_env (_type_): The experiment's Anaconda environment.
8692 working_dir (_type_): The working directory of the experiment.
8793 stash_all_py_files (_type_): Whether to stash all .py files in the working directory.
94+ additional_artifacts: list[str] | None: Additional artifacts to put into the source code files.
8895
8996 Returns:
9097 List[str]: The sources imported by the given executable.
@@ -114,6 +121,18 @@ def get_imported_sources(
114121 if is_local_file (file , root_path ):
115122 sources .add (str (file ))
116123
124+ for artifact in set ().union (
125+ * (recursively_list_files (path ) for path in additional_artifacts or [])
126+ ):
127+ artifact = artifact .expanduser ().resolve ()
128+ # Check that the artifact is in `working_dir`
129+ if artifact .is_file () and is_local_file (str (artifact ), root_path ):
130+ sources .add (str (artifact ))
131+ else :
132+ logging .warning (
133+ f'Additional artifact { artifact } is not a subpath of the root directory '
134+ f'{ root_path } and will be ignored.'
135+ )
117136 return sources
118137
119138
@@ -124,13 +143,13 @@ def upload_sources(
124143
125144 with working_directory (seml_config ['working_dir' ]):
126145 root_dir = str (Path (seml_config ['working_dir' ]).expanduser ().resolve ())
127-
128146 sources = get_imported_sources (
129147 seml_config ['executable' ],
130148 root_dir = root_dir ,
131149 conda_env = seml_config ['conda_environment' ],
132150 working_dir = seml_config ['working_dir' ],
133151 stash_all_py_files = seml_config .get ('stash_all_py_files' , False ),
152+ additional_artifacts = seml_config .get ('additional_artifacts' , []),
134153 )
135154 executable_abs = str (Path (seml_config ['executable' ]).expanduser ().resolve ())
136155
0 commit comments