From 7b384ca92a64b8e97a9ed3858d95bd763b2ef0d5 Mon Sep 17 00:00:00 2001 From: Carsten Lueth Date: Wed, 13 Aug 2025 13:15:59 +0200 Subject: [PATCH] add possibility for external trainers --- nnunetv2/inference/predict_from_raw_data.py | 8 +- nnunetv2/paths.py | 1 + nnunetv2/run/run_training.py | 12 +-- nnunetv2/utilities/find_class_by_name.py | 105 +++++++++++++++++--- nnunetv2/utilities/find_objects.py | 51 ++++++++++ 5 files changed, 146 insertions(+), 31 deletions(-) create mode 100644 nnunetv2/utilities/find_objects.py diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index b14d86ec5..eea4503d0 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -28,7 +28,7 @@ from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ compute_steps_for_sliding_window from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy -from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.find_objects import recursive_find_trainer_class_by_name from nnunetv2.utilities.helpers import empty_cache, dummy_context from nnunetv2.utilities.json_export import recursive_fix_for_json_export from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels @@ -96,11 +96,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, configuration_manager = plans_manager.get_configuration(configuration_name) # restore network num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) - trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), - trainer_name, 'nnunetv2.training.nnUNetTrainer') - if trainer_class is None: - raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' - f'Please place it there (in any .py file)!') + trainer_class = recursive_find_trainer_class_by_name(trainer_name) network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, diff --git a/nnunetv2/paths.py b/nnunetv2/paths.py index 8dc466802..c661e0ef5 100644 --- a/nnunetv2/paths.py +++ b/nnunetv2/paths.py @@ -21,6 +21,7 @@ nnUNet_raw = os.environ.get('nnUNet_raw') nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed') nnUNet_results = os.environ.get('nnUNet_results') +nnUNet_extTrainer = os.environ.get("nnUNet_extTrainer") if nnUNet_raw is None: print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files " diff --git a/nnunetv2/run/run_training.py b/nnunetv2/run/run_training.py index 36127110c..5689fd77c 100644 --- a/nnunetv2/run/run_training.py +++ b/nnunetv2/run/run_training.py @@ -12,7 +12,7 @@ from nnunetv2.run.load_pretrained_weights import load_pretrained_weights from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name -from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.find_objects import recursive_find_trainer_class_by_name from torch.backends import cudnn @@ -36,15 +36,7 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str], plans_identifier: str = 'nnUNetPlans', device: torch.device = torch.device('cuda')): # load nnunet class and do sanity checks - nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), - trainer_name, 'nnunetv2.training.nnUNetTrainer') - if nnunet_trainer is None: - raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in ' - f'nnunetv2.training.nnUNetTrainer (' - f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere ' - f'else, please move it there.') - assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \ - 'nnUNetTrainer' + nnunet_trainer = recursive_find_trainer_class_by_name(trainer_name) # handle dataset input. If it's an ID we need to convert to int from string if dataset_name_or_id.startswith('Dataset'): diff --git a/nnunetv2/utilities/find_class_by_name.py b/nnunetv2/utilities/find_class_by_name.py index 223b3acc3..ccda616f2 100644 --- a/nnunetv2/utilities/find_class_by_name.py +++ b/nnunetv2/utilities/find_class_by_name.py @@ -1,24 +1,99 @@ import importlib import pkgutil +import sys +from contextlib import contextmanager +from os.path import abspath, join + from batchgenerators.utilities.file_and_folder_operations import * -def recursive_find_python_class(folder: str, class_name: str, current_module: str): - tr = None - for importer, modname, ispkg in pkgutil.iter_modules([folder]): - # print(modname, ispkg) - if not ispkg: - m = importlib.import_module(current_module + "." + modname) - if hasattr(m, class_name): - tr = getattr(m, class_name) - break +@contextmanager +def temporarily_extend_syspath(path: str): + """ + Context manager to temporarily add a directory to sys.path. + If the path is not already in sys.path, it gets added and then removed on exit. + """ + path = abspath(path) + already_present = path in sys.path + if not already_present: + sys.path.insert(0, path) + try: + yield + finally: + if not already_present and path in sys.path: + sys.path.remove(path) + + +def recursive_find_python_class( + folder: str, + class_name: str, + current_module: str | None, + base_folder: str | None = None, + verbose: bool = False, +): + """ + Recursively searches for a class with the given name in a Python package directory tree. + Parameters + ---------- + folder : str + The directory path to start the search in. + class_name : str + The name of the class to search for. + current_module : str or None + The dotted Python module path corresponding to `folder`. + E.g., "my_package.subpackage". Can be None if starting from a flat folder. + base_folder : str or None, optional + The root directory that should be temporarily added to sys.path to allow imports. + If None, `folder` is used. + verbose : bool, optional + If True, print progress during the search. + Returns + ------- + type or None + The found class object, or None if not found. + """ + if base_folder is None: + base_folder = folder + + with temporarily_extend_syspath(base_folder): + if verbose: + print( + f"Searching for class {class_name} in folder {folder} with current module {current_module}" + ) - if tr is None: + # Search modules (non-packages) in the folder + for importer, modname, ispkg in pkgutil.iter_modules([folder]): + if not ispkg: + search_module = ( + modname if current_module is None else f"{current_module}.{modname}" + ) + if verbose: + print(f" Inspecting module: {search_module}") + try: + m = importlib.import_module(search_module) + if hasattr(m, class_name): + if verbose: + print(f"Found class {class_name} in {search_module}") + return getattr(m, class_name) + except Exception as e: + print(f"Warning: Could not import module {search_module}: {e}") + + # Recurse into subpackages for importer, modname, ispkg in pkgutil.iter_modules([folder]): if ispkg: - next_current_module = current_module + "." + modname - tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) - if tr is not None: - break - return tr + next_folder = join(folder, modname) + next_module = ( + modname if current_module is None else f"{current_module}.{modname}" + ) + result = recursive_find_python_class( + next_folder, + class_name, + current_module=next_module, + base_folder=base_folder, + verbose=verbose, + ) + if result is not None: + return result + + return None diff --git a/nnunetv2/utilities/find_objects.py b/nnunetv2/utilities/find_objects.py new file mode 100644 index 000000000..92e93d462 --- /dev/null +++ b/nnunetv2/utilities/find_objects.py @@ -0,0 +1,51 @@ +import os +from os.path import join + +import nnunetv2 +from nnunetv2.paths import nnUNet_extTrainer +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + + +def recursive_find_trainer_class_by_name(trainer_name: str): + # Import here is necessary to avoid circular import + # this function is used in the training and inference scripts + # but the inference script needs to import the trainer class + from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + # load nnunet class and do sanity checks + nnunet_trainer = recursive_find_python_class( + join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, + "nnunetv2.training.nnUNetTrainer", + nnunetv2.__path__[0], + ) + + if nnunet_trainer is None: + if nnUNet_extTrainer: + ext_paths = nnUNet_extTrainer.split(os.pathsep) + print( + f"Trainer '{trainer_name}' not found in nnunetv2.training.nnUNetTrainer.\n" + f"Searching in external trainer paths from environment variable 'nnUNet_extTrainer'..." + ) + for path in ext_paths: + if path.strip() and os.path.exists(path): + print(f"Searching in: {path}") + nnunet_trainer = recursive_find_python_class( + path, trainer_name, None, base_folder=path, verbose=True + ) + if nnunet_trainer is not None: + print(f"Using trainer '{trainer_name}' from: {path}") + break + if nnunet_trainer is None: + raise RuntimeError( + f"Could not find requested nnunet trainer {trainer_name} in " + f"nnunetv2.training.nnUNetTrainer (" + f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}).' + f"If the trainer is located elsewhere, please move it there or specify the external path via the " + f"`nnUNet_extTrainer` environment variable." + f"nnUNet_extTrainer={os.environ.get('nnUNet_extTrainer', '')}" + ) + assert issubclass(nnunet_trainer, nnUNetTrainer), ( + "The requested nnunet trainer class must inherit from 'nnUNetTrainer'" + ) + return nnunet_trainer