Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \
compute_steps_for_sliding_window
from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.find_objects import recursive_find_trainer_class_by_name
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
Expand Down Expand Up @@ -96,11 +96,7 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
trainer_class = recursive_find_trainer_class_by_name(trainer_name)
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
Expand Down
1 change: 1 addition & 0 deletions nnunetv2/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
nnUNet_raw = os.environ.get('nnUNet_raw')
nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
nnUNet_results = os.environ.get('nnUNet_results')
nnUNet_extTrainer = os.environ.get("nnUNet_extTrainer")

if nnUNet_raw is None:
print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files "
Expand Down
12 changes: 2 additions & 10 deletions nnunetv2/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.find_objects import recursive_find_trainer_class_by_name
from torch.backends import cudnn


Expand All @@ -36,15 +36,7 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str],
plans_identifier: str = 'nnUNetPlans',
device: torch.device = torch.device('cuda')):
# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if nnunet_trainer is None:
raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
f'nnunetv2.training.nnUNetTrainer ('
f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
f'else, please move it there.')
assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
'nnUNetTrainer'
nnunet_trainer = recursive_find_trainer_class_by_name(trainer_name)

# handle dataset input. If it's an ID we need to convert to int from string
if dataset_name_or_id.startswith('Dataset'):
Expand Down
105 changes: 90 additions & 15 deletions nnunetv2/utilities/find_class_by_name.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,99 @@
import importlib
import pkgutil
import sys
from contextlib import contextmanager
from os.path import abspath, join


from batchgenerators.utilities.file_and_folder_operations import *


def recursive_find_python_class(folder: str, class_name: str, current_module: str):
tr = None
for importer, modname, ispkg in pkgutil.iter_modules([folder]):
# print(modname, ispkg)
if not ispkg:
m = importlib.import_module(current_module + "." + modname)
if hasattr(m, class_name):
tr = getattr(m, class_name)
break
@contextmanager
def temporarily_extend_syspath(path: str):
"""
Context manager to temporarily add a directory to sys.path.
If the path is not already in sys.path, it gets added and then removed on exit.
"""
path = abspath(path)
already_present = path in sys.path
if not already_present:
sys.path.insert(0, path)
try:
yield
finally:
if not already_present and path in sys.path:
sys.path.remove(path)


def recursive_find_python_class(
folder: str,
class_name: str,
current_module: str | None,
base_folder: str | None = None,
verbose: bool = False,
):
"""
Recursively searches for a class with the given name in a Python package directory tree.
Parameters
----------
folder : str
The directory path to start the search in.
class_name : str
The name of the class to search for.
current_module : str or None
The dotted Python module path corresponding to `folder`.
E.g., "my_package.subpackage". Can be None if starting from a flat folder.
base_folder : str or None, optional
The root directory that should be temporarily added to sys.path to allow imports.
If None, `folder` is used.
verbose : bool, optional
If True, print progress during the search.
Returns
-------
type or None
The found class object, or None if not found.
"""
if base_folder is None:
base_folder = folder

with temporarily_extend_syspath(base_folder):
if verbose:
print(
f"Searching for class {class_name} in folder {folder} with current module {current_module}"
)

if tr is None:
# Search modules (non-packages) in the folder
for importer, modname, ispkg in pkgutil.iter_modules([folder]):
if not ispkg:
search_module = (
modname if current_module is None else f"{current_module}.{modname}"
)
if verbose:
print(f" Inspecting module: {search_module}")
try:
m = importlib.import_module(search_module)
if hasattr(m, class_name):
if verbose:
print(f"Found class {class_name} in {search_module}")
return getattr(m, class_name)
except Exception as e:
print(f"Warning: Could not import module {search_module}: {e}")

# Recurse into subpackages
for importer, modname, ispkg in pkgutil.iter_modules([folder]):
if ispkg:
next_current_module = current_module + "." + modname
tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
if tr is not None:
break
return tr
next_folder = join(folder, modname)
next_module = (
modname if current_module is None else f"{current_module}.{modname}"
)
result = recursive_find_python_class(
next_folder,
class_name,
current_module=next_module,
base_folder=base_folder,
verbose=verbose,
)
if result is not None:
return result

return None
51 changes: 51 additions & 0 deletions nnunetv2/utilities/find_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from os.path import join

import nnunetv2
from nnunetv2.paths import nnUNet_extTrainer
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class


def recursive_find_trainer_class_by_name(trainer_name: str):
# Import here is necessary to avoid circular import
# this function is used in the training and inference scripts
# but the inference script needs to import the trainer class
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(
join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name,
"nnunetv2.training.nnUNetTrainer",
nnunetv2.__path__[0],
)

if nnunet_trainer is None:
if nnUNet_extTrainer:
ext_paths = nnUNet_extTrainer.split(os.pathsep)
print(
f"Trainer '{trainer_name}' not found in nnunetv2.training.nnUNetTrainer.\n"
f"Searching in external trainer paths from environment variable 'nnUNet_extTrainer'..."
)
for path in ext_paths:
if path.strip() and os.path.exists(path):
print(f"Searching in: {path}")
nnunet_trainer = recursive_find_python_class(
path, trainer_name, None, base_folder=path, verbose=True
)
if nnunet_trainer is not None:
print(f"Using trainer '{trainer_name}' from: {path}")
break
if nnunet_trainer is None:
raise RuntimeError(
f"Could not find requested nnunet trainer {trainer_name} in "
f"nnunetv2.training.nnUNetTrainer ("
f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}).'
f"If the trainer is located elsewhere, please move it there or specify the external path via the "
f"`nnUNet_extTrainer` environment variable."
f"nnUNet_extTrainer={os.environ.get('nnUNet_extTrainer', '')}"
)
assert issubclass(nnunet_trainer, nnUNetTrainer), (
"The requested nnunet trainer class must inherit from 'nnUNetTrainer'"
)
return nnunet_trainer
Loading