diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 88b9004..46d7ffb 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -12,8 +12,9 @@ - Use supported image formats: `.nii`, `.nii.gz`, `.mha`, `.nrrd`. - Ensure that the code is syntactically correct and adheres to the project's coding standards. - Be sure about the documentation and comments. They should be clear and concise and use the correct Python docstring format. -- Create commit messages with a detailed description of the changes made, including any bug fixes or new features. -- Be as much specific as possible in the commit messages, including the files affected and the nature of the changes. +- Create commit messages with a as much details and description as possible in order to explain all the relevant information about the changes made, including any bug fixes or new features. +- Be as much specific and complete as possible in the commit messages, including the files affected and the nature of the changes. +- Organize the commit messages in bullet points for better readability, showing all the relevant information that is needed to explain the changes made. - Uses for commit messages prefixes the following pattern: - `ENH:` for new features and code enhancements - `BUG:` for bug fixes and general corrections diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 47b9de4..82e4ea3 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -31,12 +31,12 @@ Please select the relevant option(s): - [ ] Code passes linting checks (`task lint-check`) - [ ] Code includes appropriate inline documentation -**Note:** According to our [contribution guidelines](docs/contribute.md), please target the `develop` branch for most contributions to ensure proper testing before merging to `main`. +**Note:** According to our [contribution guidelines](https://asltk.readthedocs.io/en/main/contribute/), please target the `develop` branch for most contributions to ensure proper testing before merging to `main`. **By submitting this pull request, I confirm:** -- [ ] I have read and followed the [contribution guidelines](docs/contribute.md) +- [ ] I have read and followed the [contribution guidelines](https://asltk.readthedocs.io/en/main/contribute/) - [ ] I have tested my changes thoroughly - [ ] I understand this is an open source project and my contributions may be used by others -- [ ] I agree to the project's [Code of Conduct](CODE_OF_CONDUCT.md) +- [ ] I agree to the project's [Code of Conduct](https://github.com/LOAMRI/asltk/blob/main/CODE_OF_CONDUCT.md) \ No newline at end of file diff --git a/.github/workflows/bumpversion_publish_workflow.yaml b/.github/workflows/bumpversion_publish_workflow.yaml index 914733d..0036ac2 100644 --- a/.github/workflows/bumpversion_publish_workflow.yaml +++ b/.github/workflows/bumpversion_publish_workflow.yaml @@ -28,7 +28,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - name: Install Poetry run: pip install poetry diff --git a/.github/workflows/ci_develop.yaml b/.github/workflows/ci_develop.yaml index 8355ae8..ae3b713 100644 --- a/.github/workflows/ci_develop.yaml +++ b/.github/workflows/ci_develop.yaml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: @@ -54,7 +54,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: - name: Clone repo @@ -106,7 +106,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: - name: Clone repo diff --git a/.github/workflows/ci_main.yaml b/.github/workflows/ci_main.yaml index eb0b74f..3c32046 100644 --- a/.github/workflows/ci_main.yaml +++ b/.github/workflows/ci_main.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: @@ -53,7 +53,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: - name: Clone repo @@ -105,7 +105,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python-version: ["3.9"] + python-version: ["3.10"] steps: - name: Clone repo diff --git a/asltk/asldata.py b/asltk/asldata.py index 2b83b4f..a72a56f 100644 --- a/asltk/asldata.py +++ b/asltk/asldata.py @@ -1,12 +1,13 @@ import copy import os import warnings +from typing import Union import numpy as np -from asltk.logging_config import get_logger, log_data_info, log_function_call +from asltk.logging_config import get_logger, log_data_info from asltk.utils.image_manipulation import collect_data_volumes -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO class ASLData: @@ -69,40 +70,53 @@ def __init__( if isinstance(kwargs.get('pcasl'), str): pcasl_path = kwargs.get('pcasl') logger.info(f'Loading ASL image from: {pcasl_path}') - self._asl_image = load_image(pcasl_path) + self._asl_image = ImageIO(image_path=pcasl_path) if self._asl_image is not None: log_data_info( - 'ASL image', self._asl_image.shape, pcasl_path + 'ASL image', + self._asl_image.get_as_numpy().shape, + pcasl_path, ) elif isinstance(kwargs.get('pcasl'), np.ndarray): - self._asl_image = kwargs.get('pcasl') - logger.info('ASL image loaded as numpy array') + self._asl_image = ImageIO(image_array=kwargs.get('pcasl')) + logger.info('ASL image loaded') log_data_info( - 'ASL image', self._asl_image.shape, 'numpy array' + 'ASL image', self._asl_image.get_as_numpy().shape ) if kwargs.get('m0') is not None: + average_m0 = kwargs.get('average_m0', False) + if isinstance(kwargs.get('m0'), str): m0_path = kwargs.get('m0') logger.info(f'Loading M0 image from: {m0_path}') - self._m0_image = load_image(m0_path) + self._m0_image = ImageIO( + image_path=m0_path, average_m0=average_m0 + ) # Check if M0 image is 4D and warn if so if ( self._m0_image is not None - and len(self._m0_image.shape) > 3 + and len(self._m0_image.get_as_numpy().shape) > 3 ): warnings.warn('M0 image has more than 3 dimensions.') if self._m0_image is not None: - log_data_info('M0 image', self._m0_image.shape, m0_path) + log_data_info( + 'M0 image', + self._m0_image.get_as_numpy().shape, + m0_path, + ) elif isinstance(kwargs.get('m0'), np.ndarray): - self._m0_image = kwargs.get('m0') + self._m0_image = ImageIO( + image_array=kwargs.get('m0'), average_m0=average_m0 + ) logger.info('M0 image loaded as numpy array') - log_data_info('M0 image', self._m0_image.shape, 'numpy array') - - if kwargs.get('average_m0', False): - self._m0_image = np.mean(self._m0_image, axis=0) + log_data_info( + 'M0 image', + self._m0_image.get_as_numpy().shape, + 'numpy array', + ) self._parameters['ld'] = ( [] if kwargs.get('ld_values') is None else kwargs.get('ld_values') @@ -133,8 +147,8 @@ def __init__( logger.debug('ASLData object created successfully') - def set_image(self, image, spec: str): - """Insert a image necessary to define de ASL data processing. + def set_image(self, image: Union[str, np.ndarray], spec: str, **kwargs): + """Insert an image necessary to define the ASL data processing. The `spec` parameters specifies what is the type of image to be used in ASL processing step. Choose one of the options: `m0` for the M0 volume, @@ -152,7 +166,7 @@ def set_image(self, image, spec: str): >>> data = ASLData() >>> path_m0 = './tests/files/m0.nii.gz' # M0 file with shape (5,35,35) >>> data.set_image(path_m0, spec='m0') - >>> data('m0').shape + >>> data('m0').get_as_numpy().shape (5, 35, 35) Args: @@ -161,10 +175,18 @@ def set_image(self, image, spec: str): """ if isinstance(image, str) and os.path.exists(image): if spec == 'm0': - self._m0_image = load_image(image) + self._m0_image = ImageIO(image, **kwargs) elif spec == 'pcasl': - self._asl_image = load_image(image) + self._asl_image = ImageIO(image, **kwargs) elif isinstance(image, np.ndarray): + warnings.warn( + 'Using numpy array as image input does not preserve metadata or image properties.' + ) + if spec == 'm0': + self._m0_image = ImageIO(image_array=image, **kwargs) + elif spec == 'pcasl': + self._asl_image = ImageIO(image_array=image, **kwargs) + elif isinstance(image, ImageIO): if spec == 'm0': self._m0_image = image elif spec == 'pcasl': @@ -277,9 +299,11 @@ def __call__(self, spec: str): Examples: >>> data = ASLData(pcasl='./tests/files/t1-mri.nrrd') >>> type(data('pcasl')) + + >>> type(data('pcasl').get_as_numpy()) - >>> np.min(data('pcasl')) + >>> np.min(data('pcasl').get_as_numpy()) 0 Returns: @@ -327,7 +351,7 @@ def _check_ld_pld_sizes(self, ld, pld): ) def _check_m0_dimension(self): - if len(self._m0_image.shape) > 3: + if len(self._m0_image.get_as_numpy().shape) > 3: warnings.warn( 'M0 image has more than 3 dimensions. ' 'This may cause issues in processing. ' diff --git a/asltk/aux_methods.py b/asltk/aux_methods.py index d91b95d..e8d4b07 100644 --- a/asltk/aux_methods.py +++ b/asltk/aux_methods.py @@ -4,15 +4,38 @@ import numpy as np from asltk.smooth import isotropic_gaussian, isotropic_median +from asltk.utils.io import ImageIO -def _check_mask_values(mask, label, ref_shape): - # Check wheter mask input is an numpy array - if not isinstance(mask, np.ndarray): - raise TypeError(f'mask is not an numpy array. Type {type(mask)}') +def _check_mask_values(mask: ImageIO, label, ref_shape): + """Validate mask array for brain mask processing. + + This function performs comprehensive validation of brain mask data to ensure + it meets the requirements for ASL processing. It checks data type, binary + format compliance, label presence, and dimensional compatibility. + + Args: + mask (np.ndarray): The brain mask image to validate. + label (int or float): The label value to search for in the mask. + ref_shape (tuple): The reference shape that the mask should match. + + Raises: + TypeError: If mask is not a numpy array or dimensions don't match. + ValueError: If the specified label value is not found in the mask. + + Warnings: + UserWarning: If mask contains more than 2 unique values (not strictly binary). + """ + # Check wheter mask input is an ImageIO object + if not isinstance(mask, ImageIO): + raise TypeError( + f'mask is not an ImageIO object. Type {type(mask)} is not allowed.' + ) + + mask_array = mask.get_as_numpy() # Check whether the mask provided is a binary image - unique_values = np.unique(mask) + unique_values = np.unique(mask_array) if unique_values.size > 2: warnings.warn( 'Mask image is not a binary image. Any value > 0 will be assumed as brain label.', @@ -29,7 +52,7 @@ def _check_mask_values(mask, label, ref_shape): raise ValueError('Label value is not found in the mask provided.') # Check whether the dimensions between mask and input volume matches - mask_shape = mask.shape + mask_shape = mask_array.shape if mask_shape != ref_shape: raise TypeError( f'Image mask dimension does not match with input 3D volume. Mask shape {mask_shape} not equal to {ref_shape}' @@ -37,10 +60,10 @@ def _check_mask_values(mask, label, ref_shape): def _apply_smoothing_to_maps( - maps: Dict[str, np.ndarray], + maps: Dict[str, ImageIO], smoothing: Optional[str] = None, smoothing_params: Optional[Dict[str, Any]] = None, -) -> Dict[str, np.ndarray]: +) -> Dict[str, ImageIO]: """Apply smoothing filter to all maps in the dictionary. This function applies the specified smoothing filter to all map arrays @@ -117,7 +140,7 @@ def _apply_smoothing_to_maps( # Apply smoothing to all maps smoothed_maps = {} for key, map_array in maps.items(): - if isinstance(map_array, np.ndarray): + if isinstance(map_array, ImageIO): try: smoothed_maps[key] = smooth_func(map_array, **smoothing_params) except Exception as e: diff --git a/asltk/reconstruction/cbf_mapping.py b/asltk/reconstruction/cbf_mapping.py index 8c743b8..3422191 100644 --- a/asltk/reconstruction/cbf_mapping.py +++ b/asltk/reconstruction/cbf_mapping.py @@ -10,6 +10,7 @@ from asltk.logging_config import get_logger, log_processing_step from asltk.models.signal_dynamic import asl_model_buxton from asltk.mri_parameters import MRIParameters +from asltk.utils.io import ImageIO # Global variables to assist multi cpu threading cbf_map = None @@ -55,11 +56,11 @@ def __init__(self, asl_data: ASLData) -> None: 'ASLData is incomplete. CBFMapping need pcasl and m0 images.' ) - self._brain_mask = np.ones(self._asl_data('m0').shape) - self._cbf_map = np.zeros(self._asl_data('m0').shape) - self._att_map = np.zeros(self._asl_data('m0').shape) + self._brain_mask = np.ones(self._asl_data('m0').get_as_numpy().shape) + self._cbf_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._att_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) - def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): + def set_brain_mask(self, brain_mask: ImageIO, label: int = 1): """Defines a brain mask to limit CBF mapping calculations to specific regions. A brain mask significantly improves processing speed by limiting calculations @@ -93,28 +94,30 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): ... ) >>> cbf_mapper = CBFMapping(asl_data) >>> # Create a simple brain mask (center region only) - >>> mask_shape = asl_data('m0').shape # Get M0 dimensions - >>> brain_mask = np.zeros(mask_shape) - >>> brain_mask[2:6, 10:25, 10:25] = 1 # Define brain region + >>> mask_shape = asl_data('m0').get_as_numpy().shape # Get M0 dimensions + >>> brain_mask = ImageIO(image_array=np.zeros(mask_shape)) + >>> adjusted_brain_mask = brain_mask.get_as_numpy().copy() + >>> adjusted_brain_mask[2:6, 10:25, 10:25] = 1 # Define brain region + >>> brain_mask.update_image_data(adjusted_brain_mask) >>> cbf_mapper.set_brain_mask(brain_mask) Load and use an existing brain mask: >>> # Load pre-computed brain mask - >>> from asltk.utils.io import load_image - >>> brain_mask = load_image('./tests/files/m0_brain_mask.nii.gz') + >>> from asltk.utils.io import ImageIO + >>> brain_mask = ImageIO('./tests/files/m0_brain_mask.nii.gz') >>> cbf_mapper.set_brain_mask(brain_mask) Use multi-label mask (select specific region): >>> # Assuming a segmentation mask with different tissue labels - >>> segmentation_mask = np.random.randint(0, 4, mask_shape) # Example + >>> segmentation_mask = ImageIO(image_array=np.random.randint(0, 4, mask_shape)) # Example >>> # Use only label 2 (e.g., grey matter) >>> cbf_mapper.set_brain_mask(segmentation_mask, label=2) Automatic thresholding of M0 image as mask: >>> # Use M0 intensity to create brain mask - >>> m0_data = asl_data('m0') + >>> m0_data = asl_data('m0').get_as_numpy() >>> threshold = np.percentile(m0_data, 20) # Bottom 20% as background - >>> auto_mask = (m0_data > threshold).astype(np.uint8) + >>> auto_mask = ImageIO(image_array=(m0_data > threshold).astype(np.uint8)) >>> cbf_mapper.set_brain_mask(auto_mask) Raises: @@ -123,9 +126,18 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): logger = get_logger('cbf_mapping') logger.info(f'Setting brain mask with label {label}') - _check_mask_values(brain_mask, label, self._asl_data('m0').shape) + if not isinstance(brain_mask, ImageIO): + raise ValueError( + f'mask is not an ImageIO object. Type {type(brain_mask)}' + ) + + brain_mask_array = brain_mask.get_as_numpy() + + _check_mask_values( + brain_mask, label, self._asl_data('m0').get_as_numpy().shape + ) - binary_mask = (brain_mask == label).astype(np.uint8) * label + binary_mask = (brain_mask_array == label).astype(np.uint8) * label self._brain_mask = binary_mask mask_volume = np.sum(binary_mask > 0) @@ -154,8 +166,10 @@ def get_brain_mask(self): >>> current_mask = cbf_mapper.get_brain_mask() Verify brain mask after setting: - >>> brain_mask = np.ones(asl_data('m0').shape) - >>> brain_mask[0:4, :, :] = 0 # Remove some slices + >>> brain_mask = ImageIO(image_array=np.ones(asl_data('m0').get_as_numpy().shape)) + >>> new_brain_mask = brain_mask.get_as_numpy().copy() + >>> new_brain_mask[0:4, :, :] = 0 # Remove some slices + >>> brain_mask.update_image_data(new_brain_mask) >>> cbf_mapper.set_brain_mask(brain_mask) >>> updated_mask = cbf_mapper.get_brain_mask() """ @@ -215,6 +229,7 @@ def create_map( Examples: # doctest: +SKIP Basic CBF mapping with default parameters: >>> from asltk.asldata import ASLData + >>> from asltk.utils.io import ImageIO >>> from asltk.reconstruction import CBFMapping >>> import numpy as np >>> # Load ASL data with LD/PLD values @@ -226,7 +241,7 @@ def create_map( ... ) >>> cbf_mapper = CBFMapping(asl_data) >>> # Set brain mask (recommended for faster processing) - >>> brain_mask = np.ones((5, 35, 35)) # Example mask + >>> brain_mask = ImageIO(image_array=np.ones((5, 35, 35))) # Example mask >>> cbf_mapper.set_brain_mask(brain_mask) >>> # Generate maps >>> results = cbf_mapper.create_map() # doctest: +SKIP @@ -290,9 +305,9 @@ def create_map( BuxtonX = [self._asl_data.get_ld(), self._asl_data.get_pld()] x_axis, y_axis, z_axis = ( - self._asl_data('m0').shape[2], - self._asl_data('m0').shape[1], - self._asl_data('m0').shape[0], + self._asl_data('m0').get_as_numpy().shape[2], + self._asl_data('m0').get_as_numpy().shape[1], + self._asl_data('m0').get_as_numpy().shape[0], ) logger.info( @@ -351,10 +366,20 @@ def create_map( f'ATT statistics - Mean: {np.mean(att_values):.4f}, Std: {np.std(att_values):.4f}' ) + # Prepare output maps + cbf_map_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_image.update_image_data(self._cbf_map) + + cbf_map_norm_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_norm_image.update_image_data(self._cbf_map * (60 * 60 * 1000)) + + att_map_image = ImageIO(self._asl_data('m0').get_image_path()) + att_map_image.update_image_data(self._att_map) + output_maps = { - 'cbf': self._cbf_map, - 'cbf_norm': self._cbf_map * (60 * 60 * 1000), - 'att': self._att_map, + 'cbf': cbf_map_image, + 'cbf_norm': cbf_map_norm_image, + 'att': att_map_image, } # Apply smoothing if requested @@ -381,14 +406,14 @@ def _cbf_process_slice( for j in range(y_axis): for k in range(z_axis): if brain_mask[k, j, i] != 0: - m0_px = asl_data('m0')[k, j, i] + m0_px = asl_data('m0').get_as_numpy()[k, j, i] def mod_buxton(Xdata, par1, par2): return asl_model_buxton( Xdata[0], Xdata[1], m0_px, par1, par2 ) - Ydata = asl_data('pcasl')[0, :, k, j, i] + Ydata = asl_data('pcasl').get_as_numpy()[0, :, k, j, i] # Calculate the processing index for the 3D space index = k * (y_axis * x_axis) + j * x_axis + i diff --git a/asltk/reconstruction/multi_dw_mapping.py b/asltk/reconstruction/multi_dw_mapping.py index eb6d2b6..8a30170 100644 --- a/asltk/reconstruction/multi_dw_mapping.py +++ b/asltk/reconstruction/multi_dw_mapping.py @@ -12,6 +12,7 @@ from asltk.models.signal_dynamic import asl_model_multi_dw from asltk.mri_parameters import MRIParameters from asltk.reconstruction import CBFMapping +from asltk.utils.io import ImageIO # Global variables to assist multi cpu threading cbf_map = None @@ -88,22 +89,22 @@ def __init__(self, asl_data: ASLData): 'ASLData is incomplete. MultiDW_ASLMapping need a list of DW values.' ) - self._brain_mask = np.ones(self._asl_data('m0').shape) - self._cbf_map = np.zeros(self._asl_data('m0').shape) - self._att_map = np.zeros(self._asl_data('m0').shape) + self._brain_mask = np.ones(self._asl_data('m0').get_as_numpy().shape) + self._cbf_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._att_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) self._b_values = self._asl_data.get_dw() # self._A1 = np.zeros(tuple([len(self._b_values)]) + self._asl_data('m0').shape) - self._A1 = np.zeros(self._asl_data('m0').shape) + self._A1 = np.zeros(self._asl_data('m0').get_as_numpy().shape) # self._D1 = np.zeros(tuple([1]) +self._asl_data('m0').shape) - self._D1 = np.zeros(self._asl_data('m0').shape) - self._A2 = np.zeros(self._asl_data('m0').shape) + self._D1 = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._A2 = np.zeros(self._asl_data('m0').get_as_numpy().shape) # self._A2 = np.zeros(tuple([len(self._b_values)]) + self._asl_data('m0').shape) # self._D2 = np.zeros(tuple([1]) +self._asl_data('m0').shape) - self._D2 = np.zeros(self._asl_data('m0').shape) - self._kw = np.zeros(self._asl_data('m0').shape) + self._D2 = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._kw = np.zeros(self._asl_data('m0').get_as_numpy().shape) - def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): + def set_brain_mask(self, brain_mask: ImageIO, label: int = 1): """Set brain mask for MultiDW-ASL processing (strongly recommended). A brain mask is especially important for multi-diffusion-weighted ASL @@ -132,9 +133,11 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): ... ) >>> mdw_mapper = MultiDW_ASLMapping(asl_data) >>> # Create conservative brain mask (center region only) - >>> mask_shape = asl_data('m0').shape - >>> brain_mask = np.zeros(mask_shape) - >>> brain_mask[1:4, 5:30, 5:30] = 1 # Conservative brain region + >>> mask_shape = asl_data('m0').get_as_numpy().shape + >>> brain_mask = ImageIO(image_array=np.zeros(mask_shape)) + >>> adjusted_brain_mask = brain_mask.get_as_numpy() + >>> adjusted_brain_mask[1:4, 5:30, 5:30] = 1 # Conservative brain region + >>> brain_mask.update_image_data(adjusted_brain_mask) >>> mdw_mapper.set_brain_mask(brain_mask) Note: @@ -142,9 +145,19 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): mask initially to test parameters and processing time, then expand to full brain analysis once satisfied with results. """ - _check_mask_values(brain_mask, label, self._asl_data('m0').shape) + if not isinstance(brain_mask, ImageIO): + raise TypeError( + 'Brain mask must be an instance of ImageIO. ' + 'Use ImageIO to load or create the mask.' + ) + + _check_mask_values( + brain_mask, label, self._asl_data('m0').get_as_numpy().shape + ) - binary_mask = (brain_mask == label).astype(np.uint8) * label + binary_mask = (brain_mask.get_as_numpy() == label).astype( + np.uint8 + ) * label self._brain_mask = binary_mask def get_brain_mask(self): @@ -155,7 +168,7 @@ def get_brain_mask(self): """ return self._brain_mask - def set_cbf_map(self, cbf_map: np.ndarray): + def set_cbf_map(self, cbf_map: ImageIO): """Set the CBF map to the MultiDW_ASLMapping object. Note: @@ -166,9 +179,9 @@ def set_cbf_map(self, cbf_map: np.ndarray): Args: cbf_map (np.ndarray): The CBF map that is set in the MultiDW_ASLMapping object """ - self._cbf_map = cbf_map + self._cbf_map = cbf_map.get_as_numpy() - def get_cbf_map(self) -> np.ndarray: + def get_cbf_map(self) -> ImageIO: """Get the CBF map storaged at the MultiDW_ASLMapping object Returns: @@ -177,13 +190,13 @@ def get_cbf_map(self) -> np.ndarray: """ return self._cbf_map - def set_att_map(self, att_map: np.ndarray): + def set_att_map(self, att_map: ImageIO): """Set the ATT map to the MultiDW_ASLMapping object. Args: att_map (np.ndarray): The ATT map that is set in the MultiDW_ASLMapping object """ - self._att_map = att_map + self._att_map = att_map.get_as_numpy() def get_att_map(self): """Get the ATT map storaged at the MultiDW_ASLMapping object @@ -267,8 +280,10 @@ def create_map( ... ) >>> mdw_mapper = MultiDW_ASLMapping(asl_data) >>> # Set brain mask for faster processing (recommended) - >>> brain_mask = np.ones(asl_data('m0').shape) - >>> brain_mask[0:2, :, :] = 0 # Remove some background slices + >>> brain_mask = ImageIO(image_array=np.ones(asl_data('m0').get_as_numpy().shape)) + >>> adjusted_brain_mask = brain_mask.get_as_numpy().copy() + >>> adjusted_brain_mask[0:2, :, :] = 0 # Remove some background slices + >>> brain_mask.update_image_data(adjusted_brain_mask) >>> mdw_mapper.set_brain_mask(brain_mask) >>> # Generate all maps (may take several minutes) >>> results = mdw_mapper.create_map() # doctest: +SKIP @@ -291,7 +306,7 @@ def create_map( set_att_map(): Provide pre-computed ATT map CBFMapping: For basic CBF/ATT mapping """ - self._basic_maps.set_brain_mask(self._brain_mask) + self._basic_maps.set_brain_mask(ImageIO(image_array=self._brain_mask)) basic_maps = {'cbf': self._cbf_map, 'att': self._att_map} if np.mean(self._cbf_map) == 0 or np.mean(self._att_map) == 0: @@ -300,12 +315,16 @@ def create_map( '[blue][INFO] The CBF/ATT map were not provided. Creating these maps before next step...' ) # pragma: no cover basic_maps = self._basic_maps.create_map() # pragma: no cover - self._cbf_map = basic_maps['cbf'] # pragma: no cover - self._att_map = basic_maps['att'] # pragma: no cover + self._cbf_map = basic_maps[ + 'cbf' + ].get_as_numpy() # pragma: no cover + self._att_map = basic_maps[ + 'att' + ].get_as_numpy() # pragma: no cover - x_axis = self._asl_data('m0').shape[2] # height - y_axis = self._asl_data('m0').shape[1] # width - z_axis = self._asl_data('m0').shape[0] # depth + x_axis = self._asl_data('m0').get_as_numpy().shape[2] # height + y_axis = self._asl_data('m0').get_as_numpy().shape[1] # width + z_axis = self._asl_data('m0').get_as_numpy().shape[0] # depth # TODO Fix print('multiDW-ASL processing...') @@ -325,7 +344,8 @@ def mod_diff(Xdata, par1, par2, par3, par4): # M(t,b)/M(t,0) Ydata = ( - self._asl_data('pcasl')[:, :, k, j, i] + self._asl_data('pcasl') + .get_as_numpy()[:, :, k, j, i] .reshape( ( len(self._asl_data.get_ld()) @@ -334,7 +354,7 @@ def mod_diff(Xdata, par1, par2, par3, par4): ) ) .flatten() - / self._asl_data('m0')[k, j, i] + / self._asl_data('m0').get_as_numpy()[k, j, i] ) try: @@ -363,7 +383,7 @@ def mod_diff(Xdata, par1, par2, par3, par4): self._D2[k, j, i] = 0 # Calculates the Mc fitting to alpha = kw + T1blood - m0_px = self._asl_data('m0')[k, j, i] + m0_px = self._asl_data('m0').get_as_numpy()[k, j, i] # def mod_2comp(Xdata, par1): # ... @@ -411,16 +431,43 @@ def mod_diff(Xdata, par1, par2, par3, par4): # # Adjusting output image boundaries # self._kw = self._adjust_image_limits(self._kw, par0[0]) + # Prepare output maps + cbf_map_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_image.update_image_data(self._cbf_map) + + cbf_map_norm_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_norm_image.update_image_data( + self._cbf_map * (60 * 60 * 1000) + ) # Convert to mL/100g/min + + att_map_image = ImageIO(self._asl_data('m0').get_image_path()) + att_map_image.update_image_data(self._att_map) + + a1_map_image = ImageIO(self._asl_data('m0').get_image_path()) + a1_map_image.update_image_data(self._A1) + + d1_map_image = ImageIO(self._asl_data('m0').get_image_path()) + d1_map_image.update_image_data(self._D1) + + a2_map_image = ImageIO(self._asl_data('m0').get_image_path()) + a2_map_image.update_image_data(self._A2) + + d2_map_image = ImageIO(self._asl_data('m0').get_image_path()) + d2_map_image.update_image_data(self._D2) + + kw_map_image = ImageIO(self._asl_data('m0').get_image_path()) + kw_map_image.update_image_data(self._kw) + # Create output maps dictionary output_maps = { - 'cbf': self._cbf_map, - 'cbf_norm': self._cbf_map * (60 * 60 * 1000), - 'att': self._att_map, - 'a1': self._A1, - 'd1': self._D1, - 'a2': self._A2, - 'd2': self._D2, - 'kw': self._kw, + 'cbf': cbf_map_image, + 'cbf_norm': cbf_map_norm_image, + 'att': att_map_image, + 'a1': a1_map_image, + 'd1': d1_map_image, + 'a2': a2_map_image, + 'd2': d2_map_image, + 'kw': kw_map_image, } # Apply smoothing if requested diff --git a/asltk/reconstruction/multi_te_mapping.py b/asltk/reconstruction/multi_te_mapping.py index 69f1a75..58c4191 100644 --- a/asltk/reconstruction/multi_te_mapping.py +++ b/asltk/reconstruction/multi_te_mapping.py @@ -11,6 +11,7 @@ from asltk.models.signal_dynamic import asl_model_multi_te from asltk.mri_parameters import MRIParameters from asltk.reconstruction import CBFMapping +from asltk.utils.io import ImageIO # Global variables to assist multi cpu threading cbf_map = None @@ -91,12 +92,12 @@ def __init__(self, asl_data: ASLData) -> None: 'ASLData is incomplete. MultiTE_ASLMapping need a list of TE values.' ) - self._brain_mask = np.ones(self._asl_data('m0').shape) - self._cbf_map = np.zeros(self._asl_data('m0').shape) - self._att_map = np.zeros(self._asl_data('m0').shape) - self._t1blgm_map = np.zeros(self._asl_data('m0').shape) + self._brain_mask = np.ones(self._asl_data('m0').get_as_numpy().shape) + self._cbf_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._att_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) + self._t1blgm_map = np.zeros(self._asl_data('m0').get_as_numpy().shape) - def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): + def set_brain_mask(self, brain_mask: ImageIO, label: int = 1): """Defines whether a brain a mask is applied to the CBFMapping calculation @@ -112,9 +113,18 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): Args: brain_mask (np.ndarray): The image representing the brain mask label (int, optional): The label value used to define the foreground tissue (brain). Defaults to 1. """ - _check_mask_values(brain_mask, label, self._asl_data('m0').shape) + if not isinstance(brain_mask, ImageIO): + raise TypeError( + 'The brain_mask parameter must be an instance of ImageIO.' + ) + + _check_mask_values( + brain_mask, label, self._asl_data('m0').get_as_numpy().shape + ) - binary_mask = (brain_mask == label).astype(np.uint8) * label + binary_mask = (brain_mask.get_as_numpy() == label).astype( + np.uint8 + ) * label self._brain_mask = binary_mask def get_brain_mask(self): @@ -125,7 +135,7 @@ def get_brain_mask(self): """ return self._brain_mask - def set_cbf_map(self, cbf_map: np.ndarray): + def set_cbf_map(self, cbf_map: ImageIO): """Set the CBF map to the MultiTE_ASLMapping object. Note: @@ -136,7 +146,7 @@ def set_cbf_map(self, cbf_map: np.ndarray): Args: cbf_map (np.ndarray): The CBF map that is set in the MultiTE_ASLMapping object """ - self._cbf_map = cbf_map + self._cbf_map = cbf_map.get_as_numpy() def get_cbf_map(self) -> np.ndarray: """Get the CBF map storaged at the MultiTE_ASLMapping object @@ -147,13 +157,13 @@ def get_cbf_map(self) -> np.ndarray: """ return self._cbf_map - def set_att_map(self, att_map: np.ndarray): + def set_att_map(self, att_map: ImageIO): """Set the ATT map to the MultiTE_ASLMapping object. Args: att_map (np.ndarray): The ATT map that is set in the MultiTE_ASLMapping object """ - self._att_map = att_map + self._att_map = att_map.get_as_numpy() def get_att_map(self): """Get the ATT map storaged at the MultiTE_ASLMapping object @@ -239,6 +249,7 @@ def create_map( Basic multi-TE ASL analysis: >>> from asltk.asldata import ASLData >>> from asltk.reconstruction import MultiTE_ASLMapping + >>> from asltk.utils.io import ImageIO >>> import numpy as np >>> # Load multi-TE ASL data >>> asl_data = ASLData( @@ -250,7 +261,7 @@ def create_map( ... ) >>> mte_mapper = MultiTE_ASLMapping(asl_data) >>> # Set brain mask for faster processing - >>> brain_mask = np.ones(asl_data('m0').shape) + >>> brain_mask = ImageIO(image_array=np.ones(asl_data('m0').get_as_numpy().shape)) >>> mte_mapper.set_brain_mask(brain_mask) >>> # Generate all maps >>> results = mte_mapper.create_map() # doctest: +SKIP @@ -288,8 +299,7 @@ def create_map( set_att_map(): Provide pre-computed ATT map CBFMapping: For basic CBF/ATT mapping """ - # # TODO As entradas ub, lb e par0 não são aplicadas para CBF. Pensar se precisa ter essa flexibilidade para acertar o CBF interno à chamada - self._basic_maps.set_brain_mask(self._brain_mask) + self._basic_maps.set_brain_mask(ImageIO(image_array=self._brain_mask)) basic_maps = {'cbf': self._cbf_map, 'att': self._att_map} if np.mean(self._cbf_map) == 0 or np.mean(self._att_map) == 0: @@ -298,8 +308,8 @@ def create_map( '[blue][INFO] The CBF/ATT map were not provided. Creating these maps before next step...' ) basic_maps = self._basic_maps.create_map() - self._cbf_map = basic_maps['cbf'] - self._att_map = basic_maps['att'] + self._cbf_map = basic_maps['cbf'].get_as_numpy() + self._att_map = basic_maps['att'].get_as_numpy() global asl_data, brain_mask, cbf_map, att_map, t2bl, t2gm asl_data = self._asl_data @@ -312,9 +322,9 @@ def create_map( t2bl = self.T2bl t2gm = self.T2gm - x_axis = self._asl_data('m0').shape[2] # height - y_axis = self._asl_data('m0').shape[1] # width - z_axis = self._asl_data('m0').shape[0] # depth + x_axis = self._asl_data('m0').get_as_numpy().shape[2] # height + y_axis = self._asl_data('m0').get_as_numpy().shape[1] # width + z_axis = self._asl_data('m0').get_as_numpy().shape[0] # depth tblgm_map_shared = Array('d', z_axis * y_axis * x_axis, lock=False) @@ -356,12 +366,25 @@ def create_map( # Adjusting output image boundaries self._t1blgm_map = self._adjust_image_limits(self._t1blgm_map, par0[0]) + # Prepare output maps + cbf_map_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_image.update_image_data(self._cbf_map) + + cbf_map_norm_image = ImageIO(self._asl_data('m0').get_image_path()) + cbf_map_norm_image.update_image_data(self._cbf_map * (60 * 60 * 1000)) + + att_map_image = ImageIO(self._asl_data('m0').get_image_path()) + att_map_image.update_image_data(self._att_map) + + t1blgm_map_image = ImageIO(self._asl_data('m0').get_image_path()) + t1blgm_map_image.update_image_data(self._t1blgm_map) + # Create output maps dictionary output_maps = { - 'cbf': self._cbf_map, - 'cbf_norm': self._cbf_map * (60 * 60 * 1000), - 'att': self._att_map, - 't1blgm': self._t1blgm_map, + 'cbf': cbf_map_image, + 'cbf_norm': cbf_map_norm_image, + 'att': att_map_image, + 't1blgm': t1blgm_map_image, } # Apply smoothing if requested @@ -414,7 +437,7 @@ def _tblgm_multite_process_slice( for j in range(y_axis): for k in range(z_axis): if brain_mask[k, j, i] != 0: - m0_px = asl_data('m0')[k, j, i] + m0_px = asl_data('m0').get_as_numpy()[k, j, i] def mod_2comp(Xdata, par1): return asl_model_multi_te( @@ -430,7 +453,8 @@ def mod_2comp(Xdata, par1): ) Ydata = ( - asl_data('pcasl')[:, :, k, j, i] + asl_data('pcasl') + .get_as_numpy()[:, :, k, j, i] .reshape( ( len(ld_arr) * len(te_arr), diff --git a/asltk/reconstruction/t2_mapping.py b/asltk/reconstruction/t2_mapping.py index bfdbe05..7a9e2f6 100644 --- a/asltk/reconstruction/t2_mapping.py +++ b/asltk/reconstruction/t2_mapping.py @@ -1,3 +1,4 @@ +import warnings from multiprocessing import Array, Pool, cpu_count import numpy as np @@ -9,6 +10,7 @@ from asltk.aux_methods import _apply_smoothing_to_maps, _check_mask_values from asltk.logging_config import get_logger, log_processing_step from asltk.mri_parameters import MRIParameters +from asltk.utils.io import ImageIO # Global variables for multiprocessing t2_map_shared = None @@ -45,11 +47,13 @@ def __init__(self, asl_data: ASLData) -> None: if self._asl_data.get_dw() is not None: raise ValueError('ASLData must not include DW values.') - self._brain_mask = np.ones(self._asl_data('m0').shape) + self._brain_mask = ImageIO( + image_array=np.ones(self._asl_data('m0').get_as_numpy().shape) + ) self._t2_maps = None # Will be 4D: (N_PLDS, Z, Y, X) self._mean_t2s = None - def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): + def set_brain_mask(self, brain_mask: ImageIO, label: int = 1): """ Set a brain mask to restrict T2 fitting to specific voxels. @@ -59,8 +63,14 @@ def set_brain_mask(self, brain_mask: np.ndarray, label: int = 1): The mask should be a 3D numpy array matching the spatial dimensions of the ASL data. """ - _check_mask_values(brain_mask, label, self._asl_data('m0').shape) - binary_mask = (brain_mask == label).astype(np.uint8) * label + _check_mask_values( + brain_mask, label, self._asl_data('m0').get_as_numpy().shape + ) + + binary_mask = ImageIO( + image_array=(brain_mask.get_as_numpy() == label).astype(np.uint8) + * label + ) self._brain_mask = binary_mask def get_t2_maps(self): @@ -81,7 +91,11 @@ def get_mean_t2s(self): return self._mean_t2s def create_map( - self, cores=cpu_count(), smoothing=None, smoothing_params=None + self, + cores=cpu_count(), + smoothing=None, + smoothing_params=None, + suppress_warnings=False, ): """ Compute T2 maps using multi-echo ASL data and a brain mask, with multiprocessing. @@ -102,69 +116,100 @@ def create_map( logger = get_logger('t2_mapping') logger.info('Starting T2 map creation') - data = self._asl_data('pcasl') - mask = self._brain_mask - TEs = np.array(self._te_values) - PLDs = np.array(self._pld_values) - n_tes, n_plds, z_axis, y_axis, x_axis = data.shape - - t2_maps_all = [] - mean_t2s = [] - - for pld_idx in range(n_plds): - logger.info(f'Processing PLD index {pld_idx} ({PLDs[pld_idx]} ms)') - t2_map_shared = Array('d', z_axis * y_axis * x_axis, lock=False) - log_processing_step( - 'Running voxel-wise T2 fitting', - 'this may take several minutes', - ) - with Pool( - processes=cores, - initializer=_t2_init_globals, - initargs=(t2_map_shared, mask, data, TEs), - ) as pool: - with Progress() as progress: - task = progress.add_task( - f'T2 fitting (PLD {PLDs[pld_idx]} ms)...', total=x_axis - ) - results = [ - pool.apply_async( - _t2_process_slice, - args=(i, x_axis, y_axis, z_axis, pld_idx), - callback=lambda _: progress.update( - task, advance=1 - ), + # Optionally suppress warnings + if suppress_warnings: + warnings_context = warnings.catch_warnings() + warnings_context.__enter__() + warnings.filterwarnings('ignore', category=RuntimeWarning) + warnings.filterwarnings('ignore', category=UserWarning) + logger.info('Warnings suppressed during T2 mapping') + + try: + data = self._asl_data('pcasl').get_as_numpy() + mask = self._brain_mask.get_as_numpy() + TEs = np.array(self._te_values) + PLDs = np.array(self._pld_values) + n_tes, n_plds, z_axis, y_axis, x_axis = data.shape + + t2_maps_all = [] + mean_t2s = [] + + for pld_idx in range(n_plds): + logger.info( + f'Processing PLD index {pld_idx} ({PLDs[pld_idx]} ms)' + ) + t2_map_shared = Array( + 'd', z_axis * y_axis * x_axis, lock=False + ) + log_processing_step( + 'Running voxel-wise T2 fitting', + 'this may take several minutes', + ) + with Pool( + processes=cores, + initializer=_t2_init_globals, + initargs=(t2_map_shared, mask, data, TEs), + ) as pool: + with Progress() as progress: + task = progress.add_task( + f'T2 fitting (PLD {PLDs[pld_idx]} ms)...', + total=x_axis, ) - for i in range(x_axis) - ] - for result in results: - result.wait() + results = [ + pool.apply_async( + _t2_process_slice, + args=(i, x_axis, y_axis, z_axis, pld_idx), + callback=lambda _: progress.update( + task, advance=1 + ), + ) + for i in range(x_axis) + ] + for result in results: + result.wait() + + t2_map = np.frombuffer(t2_map_shared).reshape( + z_axis, y_axis, x_axis + ) + t2_maps_all.append(t2_map) + mean_t2s.append(np.nanmean(t2_map)) + + t2_maps_stacked = np.array(t2_maps_all) # shape: (N_PLDS, Z, Y, X) + self._t2_maps = t2_maps_stacked + self._mean_t2s = mean_t2s + + logger.info('T2 mapping completed successfully') + logger.info( + f'T2 statistics - Mean: {np.mean(self._t2_maps):.4f}, Std: {np.std(self._t2_maps):.4f}' + ) - t2_map = np.frombuffer(t2_map_shared).reshape( - z_axis, y_axis, x_axis + # Prepare output maps + # TODO At the moment, the T2 maps and mean T2 maps are as ImageIO object, however, the Spacing, Dimension are not given as a 4D array. The m0 image is 3D... check if this is a problem for the T2 image properties + t2_maps_image = ImageIO( + image_array=np.array( + [ + self._asl_data('m0').get_as_numpy() + for _ in range(len(t2_maps_all)) + ] + ) ) - t2_maps_all.append(t2_map) - mean_t2s.append(np.nanmean(t2_map)) - - t2_maps_stacked = np.stack( - t2_maps_all, axis=0 - ) # shape: (N_PLDS, Z, Y, X) - self._t2_maps = t2_maps_stacked - self._mean_t2s = mean_t2s - - logger.info('T2 mapping completed successfully') - logger.info( - f'T2 statistics - Mean: {np.mean(self._t2_maps):.4f}, Std: {np.std(self._t2_maps):.4f}' - ) + t2_maps_image.update_image_data(self._t2_maps) - output_maps = { - 't2': self._t2_maps, - 'mean_t2': self._mean_t2s, - } + # Update the _t2_maps attribute to be an ImageIO object + self._t2_maps = t2_maps_image - return _apply_smoothing_to_maps( - output_maps, smoothing, smoothing_params - ) + output_maps = { + 't2': t2_maps_image, + 'mean_t2': self._mean_t2s, + } + + return _apply_smoothing_to_maps( + output_maps, smoothing, smoothing_params + ) + finally: + # Ensure warnings are restored if suppressed + if suppress_warnings: + warnings_context.__exit__(None, None, None) def _fit_voxel(signal, TEs): # pragma: no cover diff --git a/asltk/registration/__init__.py b/asltk/registration/__init__.py index 057328c..3ae191d 100644 --- a/asltk/registration/__init__.py +++ b/asltk/registration/__init__.py @@ -5,30 +5,16 @@ from asltk.asldata import ASLData from asltk.data.brain_atlas import BrainAtlas from asltk.logging_config import get_logger -from asltk.utils.image_manipulation import check_and_fix_orientation -from asltk.utils.io import load_image - -# TODO Montar classe para fazer o coregistro de ASL -class ASLRegistration: - - # Pipeline - # inputs: ASLData (com m0 e pcasl), BrainAtlas, resolution (1 or 2 mm) - # Tomar m0 e comparar orientação com o template - # Se necessÔrio, corrigir orientação do template para estar coerente com o m0 (salvar a transformação e aplicar para os labels) - # Realizar o registro do m0 no template - # com a transformação do m0, deixar salvo como parametro do objeto da classe - # Ter metodos para aplicar transformação para o pcasl, ou mapas gerados pelo CBFMapping, MultiTE, etc. - - def __init__(self): - pass +# from asltk.utils.image_manipulation import check_and_fix_orientation +from asltk.utils.io import ImageIO, clone_image def space_normalization( - moving_image: np.ndarray, + moving_image: ImageIO, template_image: BrainAtlas, - moving_mask: np.ndarray = None, - template_mask: np.ndarray = None, + moving_mask: ImageIO = None, + template_mask: ImageIO = None, transform_type: str = 'SyNBoldAff', **kwargs, ): @@ -47,7 +33,7 @@ def space_normalization( provided in the correct format. Note: - For more specfiic cases, such as ASL data normalization, one can + For more specific cases, such as ASL data normalization, one can use other methods, such as in `asl_normalization` module. Note: @@ -87,9 +73,6 @@ def space_normalization( no mask is used. transform_type : str, optional Type of transformation ('SyN', 'BSpline', etc.). Default is 'SyNBoldAff'. - check_orientation : bool, optional - Whether to automatically check and fix orientation mismatches between - moving and template images. Default is True. verbose : bool, optional Whether to print detailed orientation analysis. Default is False. @@ -100,73 +83,49 @@ def space_normalization( transform : list A list of transformation mapping from moving to template space. """ - if not isinstance(moving_image, np.ndarray) or not isinstance( - template_image, (BrainAtlas, str, np.ndarray) + if not isinstance(moving_image, ImageIO) or not isinstance( + template_image, (BrainAtlas, str, ImageIO) ): raise TypeError( - 'moving_image must be a numpy array and template_image must be a BrainAtlas object, a string with the atlas name, or a numpy array.' + 'moving_image must be an ImageIO object and template_image must be a BrainAtlas object, a string with the atlas name, or an ImageIO object.' ) - # Take optional parameters - check_orientation = kwargs.get('check_orientation', True) - verbose = kwargs.get('verbose', False) - - logger = get_logger('registration') - logger.info('Starting space normalization') - # Load template image first - # TODO PROBLEMA PRINCIPAL: A leitura de imagens para numpy faz a perda da origen e spacing, para fazer o corregistro é preciso acertar a orientação da imagem com relação a origem (flip pela origem) para que ambas estejam na mesma orientação visual - # TODO Pensar em como serÔ a utilização do corregistro para o ASLTK (assume que jÔ estÔ alinhado? ou tenta alinhar imagens check_orientation?) template_array = None if isinstance(template_image, BrainAtlas): template_file = template_image.get_atlas()['t1_data'] - template_array = load_image(template_file) + template_array = ImageIO(template_file) elif isinstance(template_image, str): template_file = BrainAtlas(template_image).get_atlas()['t1_data'] - template_array = load_image(template_file) + template_array = ImageIO(template_file) # template_array = ants.image_read('/home/antonio/Imagens/loamri-samples/20240909/mni_2mm.nii.gz') - elif isinstance(template_image, np.ndarray): + elif isinstance(template_image, ImageIO): template_array = template_image else: raise TypeError( - 'template_image must be a BrainAtlas object, a string with the atlas name, or a numpy array.' + 'template_image must be a BrainAtlas object, a string with the atlas name, or an ImageIO object.' ) - if moving_image.ndim != 3 or template_array.ndim != 3: + if ( + moving_image.get_as_numpy().ndim != 3 + or template_array.get_as_numpy().ndim != 3 + ): raise ValueError( 'Both moving_image and template_image must be 3D arrays.' ) - corrected_moving_image = moving_image - orientation_transform = None - - # TODO VERIICAR SE CHECK_ORIENTATION ESTA CERTO... USAR sitk.FlipImageFilter usando a Origen da image (Slicer da certo assim) - if check_orientation: - ( - corrected_moving_image, - orientation_transform, - ) = check_and_fix_orientation( - moving_image, template_array, verbose=verbose - ) - if verbose and orientation_transform: - print(f'Applied orientation correction: {orientation_transform}') - - # Convert to ANTs images - - moving = ants.from_numpy(corrected_moving_image) - template = ants.from_numpy(template_array) + corrected_moving_image = clone_image(moving_image) # Load masks if provided - if isinstance(moving_mask, np.ndarray): - moving_mask = ants.from_numpy(moving_mask) - if isinstance(template_mask, np.ndarray): - template_mask = ants.from_numpy(template_mask) + if isinstance(moving_mask, ImageIO): + moving_mask = moving_mask.get_as_ants() + if isinstance(template_mask, ImageIO): + template_mask = template_mask.get_as_ants() - # TODO Vericicar se ants.registration consegue colocar o TransformInit como Centro de Massa!' # Perform registration registration = ants.registration( - fixed=template, - moving=moving, + fixed=template_array.get_as_ants(), + moving=corrected_moving_image.get_as_ants(), type_of_transform=transform_type, mask=moving_mask, mask_fixed=template_mask, @@ -174,14 +133,18 @@ def space_normalization( ) # Passing the warped image and forward transforms - return registration['warpedmovout'].numpy(), registration['fwdtransforms'] + out_warped = clone_image(template_array) + ants_numpy = registration['warpedmovout'].numpy() + out_warped.update_image_data(np.transpose(ants_numpy, (2, 1, 0))) + + return out_warped, registration['fwdtransforms'] def rigid_body_registration( - fixed_image: np.ndarray, - moving_image: np.ndarray, - moving_mask: np.ndarray = None, - template_mask: np.ndarray = None, + fixed_image: ImageIO, + moving_image: ImageIO, + moving_mask: ImageIO = None, + template_mask: ImageIO = None, ): """ Register two images using a rigid body transformation. This methods applies @@ -219,15 +182,17 @@ def rigid_body_registration( transforms : list A list of transformation mapping from moving to template space. """ - if not isinstance(fixed_image, np.ndarray) or not isinstance( - moving_image, np.ndarray + if not isinstance(fixed_image, ImageIO) or not isinstance( + moving_image, ImageIO ): - raise Exception('fixed_image and moving_image must be a numpy array.') + raise Exception( + 'fixed_image and moving_image must be an ImageIO object.' + ) - if moving_mask is not None and not isinstance(moving_mask, np.ndarray): - raise Exception('moving_mask must be a numpy array.') - if template_mask is not None and not isinstance(template_mask, np.ndarray): - raise Exception('template_mask must be a numpy array.') + if moving_mask is not None and not isinstance(moving_mask, ImageIO): + raise Exception('moving_mask must be an ImageIO object.') + if template_mask is not None and not isinstance(template_mask, ImageIO): + raise Exception('template_mask must be an ImageIO object.') normalized_image, trans_maps = space_normalization( moving_image, @@ -241,10 +206,10 @@ def rigid_body_registration( def affine_registration( - fixed_image: np.ndarray, - moving_image: np.ndarray, - moving_mask: np.ndarray = None, - template_mask: np.ndarray = None, + fixed_image: ImageIO, + moving_image: ImageIO, + moving_mask: ImageIO = None, + template_mask: ImageIO = None, fast_method: bool = True, ): """ @@ -274,14 +239,16 @@ def affine_registration( transformation_matrix : np.ndarray The transformation matrix mapping from moving to template space. """ - if not isinstance(fixed_image, np.ndarray) or not isinstance( - moving_image, np.ndarray + if not isinstance(fixed_image, ImageIO) or not isinstance( + moving_image, ImageIO ): - raise Exception('fixed_image and moving_image must be a numpy array.') - if moving_mask is not None and not isinstance(moving_mask, np.ndarray): - raise Exception('moving_mask must be a numpy array.') - if template_mask is not None and not isinstance(template_mask, np.ndarray): - raise Exception('template_mask must be a numpy array.') + raise Exception( + 'fixed_image and moving_image must be an ImageIO object.' + ) + if moving_mask is not None and not isinstance(moving_mask, ImageIO): + raise Exception('moving_mask must be an ImageIO object.') + if template_mask is not None and not isinstance(template_mask, ImageIO): + raise Exception('template_mask must be an ImageIO object.') affine_type = 'AffineFast' if fast_method else 'Affine' warped_image, transformation_matrix = space_normalization( @@ -296,8 +263,8 @@ def affine_registration( def apply_transformation( - moving_image: np.ndarray, - reference_image: np.ndarray, + moving_image: ImageIO, + reference_image: ImageIO, transforms: list, **kwargs, ): @@ -314,6 +281,13 @@ def apply_transformation( obtained from a registration process. The transformations are applied in the order they are provided in the list. + Tip: + Additional parameters can be passed to the `ants.apply_transforms` + function using the `kwargs` parameter. This allows for customization of + the transformation process, such as specifying interpolation methods, + handling of missing data, etc. See more in the ANTsPy documentation: + https://antspy.readthedocs.io/en/latest/registration.html#ants.apply_transforms + Args: image: np.ndarray The image to be transformed. @@ -327,16 +301,16 @@ def apply_transformation( transformed_image: np.ndarray The transformed image. """ - # TODO handle kwargs for additional parameters based on ants.apply_transforms - if not isinstance(moving_image, np.ndarray): - raise TypeError('moving image must be numpy array.') + if not isinstance(moving_image, ImageIO): + raise TypeError('moving image must be an ImageIO object.') - if not isinstance(reference_image, (np.ndarray, BrainAtlas)): + if not isinstance(reference_image, (ImageIO, BrainAtlas)): raise TypeError( - 'reference_image must be a numpy array or a BrainAtlas object.' + 'reference_image must be an ImageIO object or a BrainAtlas object.' ) - elif isinstance(reference_image, BrainAtlas): - reference_image = load_image(reference_image.get_atlas()['t1_data']) + + if isinstance(reference_image, BrainAtlas): + reference_image = ImageIO(reference_image.get_atlas()['t1_data']) if not isinstance(transforms, list): raise TypeError( @@ -344,9 +318,13 @@ def apply_transformation( ) corr_image = ants.apply_transforms( - fixed=ants.from_numpy(reference_image), - moving=ants.from_numpy(moving_image), + fixed=reference_image.get_as_ants(), + moving=moving_image.get_as_ants(), transformlist=transforms, + **kwargs, # Additional parameters for ants.apply_transforms ) - return corr_image.numpy() + out_image = clone_image(reference_image) + out_image.update_image_data(np.transpose(corr_image.numpy(), (2, 1, 0))) + + return out_image diff --git a/asltk/registration/asl_normalization.py b/asltk/registration/asl_normalization.py index cefaa41..1663d6c 100644 --- a/asltk/registration/asl_normalization.py +++ b/asltk/registration/asl_normalization.py @@ -1,3 +1,5 @@ +from typing import List, Union + import ants import numpy as np from rich.progress import Progress @@ -17,13 +19,14 @@ calculate_mean_intensity, calculate_snr, ) -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO, clone_image def asl_template_registration( asl_data: ASLData, - asl_data_mask: np.ndarray = None, - atlas_name: str = 'MNI2009', + atlas_reference: Union[str, BrainAtlas] = 'MNI2009', + additional_maps: List[ImageIO] = None, + asl_data_mask: ImageIO = None, verbose: bool = False, ): """ @@ -66,130 +69,48 @@ def asl_template_registration( if not isinstance(asl_data, ASLData): raise TypeError('Input must be an ASLData object.') - # if not isinstance(ref_vol, int) or ref_vol < 0: - # raise ValueError('ref_vol must be a non-negative integer.') - - total_vols, orig_shape = collect_data_volumes(asl_data('pcasl')) - # if ref_vol >= len(total_vols): - # raise ValueError( - # 'ref_vol must be a valid index based on the total ASL data volumes.' - # ) - if asl_data('m0') is None: raise ValueError( 'M0 image is required for normalization. Please provide an ASLData with a valid M0 image.' ) - atlas = BrainAtlas(atlas_name) - # atlas_img = ants.image_read(atlas.get_atlas()['t1_data']).numpy() - atlas_img = load_image(atlas.get_atlas()['t1_data']) - - def norm_function(vol, _): - return space_normalization( - moving_image=vol, - template_image=atlas, - moving_mask=asl_data_mask, - template_mask=None, - transform_type='Affine', - check_orientation=True, + if not ( + isinstance(atlas_reference, BrainAtlas) + or isinstance(atlas_reference, str) + ): + raise TypeError( + 'atlas_reference must be a BrainAtlas object or a string.' ) - - # Create a new ASLData to allocate the normalized image - new_asl = asl_data.copy() - - tmp_vol_list = [asl_data('m0')] - orig_shape = asl_data('m0').shape - - m0_vol_corrected, trans_m0_mtx = __apply_array_normalization( - tmp_vol_list, 0, norm_function - ) - new_asl.set_image(m0_vol_corrected[0], 'm0') - - # Apply the normalization transformation to all pcasl volumes - pcasl_vols, _ = collect_data_volumes(asl_data('pcasl')) - normalized_pcasl_vols = [] - with Progress() as progress: - task = progress.add_task( - '[green]Applying normalization to pcasl volumes...', - total=len(pcasl_vols), + if ( + isinstance(atlas_reference, str) + and atlas_reference not in BrainAtlas('MNI2009').list_atlas() + ): + raise ValueError( + f"atlas_reference '{atlas_reference}' is not a valid atlas name. " + f"Available atlases: {BrainAtlas('MNI2009').list_atlas()}" ) - for vol in pcasl_vols: - norm_vol = apply_transformation( - moving_image=vol, - reference_image=atlas_img, - transforms=trans_m0_mtx, - ) - normalized_pcasl_vols.append(norm_vol) - progress.update(task, advance=1) - - new_asl.set_image(normalized_pcasl_vols, 'pcasl') - - return new_asl, trans_m0_mtx - - -def asl_template_registration( - asl_data: ASLData, - asl_data_mask: np.ndarray = None, - atlas_name: str = 'MNI2009', - verbose: bool = False, -): - """ - Register ASL data to common atlas space. - - This function applies a elastic normalization to fit the subject head - space into the atlas template space. - - - Note: - This method takes in consideration the ASLData object, which contains - the pcasl and/or m0 image. The registration is performed using primarily - the `m0`image if available, otherwise it uses the `pcasl` image. - Therefore, choose wisely the `ref_vol` parameter, which should be a valid index - for the best `pcasl`volume reference to be registered to the atlas. - - Args: - asl_data: ASLData - The ASLData object containing the pcasl and/or m0 image to be corrected. - ref_vol: (int, optional) - The index of the reference volume to which all other volumes will be registered. - Defaults to 0. - asl_data_mask: np.ndarray - A single volume image mask. This can assist the normalization method to converge - into the atlas space. If not provided, the full image is adopted. - atlas_name: str - The atlas type to be considered. The BrainAtlas class is applied, then choose - the `atlas_name` based on the ASLtk brain atlas list. - verbose: (bool, optional) - If True, prints progress messages. Defaults to False. - Raises: - TypeError: If the input is not an ASLData object. - ValueError: If ref_vol is not a valid index. - RuntimeError: If an error occurs during registration. - - Returns: - tuple: ASLData object with corrected volumes and a list of transformation matrices. - """ - if not isinstance(asl_data, ASLData): - raise TypeError('Input must be an ASLData object.') - - # if not isinstance(ref_vol, int) or ref_vol < 0: - # raise ValueError('ref_vol must be a non-negative integer.') - - total_vols, orig_shape = collect_data_volumes(asl_data('pcasl')) - # if ref_vol >= len(total_vols): - # raise ValueError( - # 'ref_vol must be a valid index based on the total ASL data volumes.' - # ) + if additional_maps is not None: + if not all( + [ + isinstance(additional_map, ImageIO) + and additional_map.get_as_numpy().shape + == asl_data('m0').get_as_numpy().shape + for additional_map in additional_maps + ] + ): + raise TypeError( + 'All additional_maps must be ImageIO objects and have the same shape as the M0 image.' + ) + else: + additional_maps = [] - if asl_data('m0') is None: - raise ValueError( - 'M0 image is required for normalization. Please provide an ASLData with a valid M0 image.' - ) + if isinstance(atlas_reference, BrainAtlas): + atlas = atlas_reference + else: + atlas = BrainAtlas(atlas_reference) - atlas = BrainAtlas(atlas_name) - # atlas_img = ants.image_read(atlas.get_atlas()['t1_data']).numpy() - atlas_img = load_image(atlas.get_atlas()['t1_data']) + atlas_img = ImageIO(atlas.get_atlas()['t1_data']) def norm_function(vol, _): return space_normalization( @@ -197,49 +118,64 @@ def norm_function(vol, _): template_image=atlas, moving_mask=asl_data_mask, template_mask=None, - transform_type='Affine', + transform_type='SyN', check_orientation=True, - orientation_verbose=verbose, + verbose=verbose, ) # Create a new ASLData to allocate the normalized image new_asl = asl_data.copy() tmp_vol_list = [asl_data('m0')] - orig_shape = asl_data('m0').shape + # Apply the normalization transformation to the M0 volume and update the new ASLData m0_vol_corrected, trans_m0_mtx = __apply_array_normalization( - tmp_vol_list, 0, orig_shape, norm_function, verbose + tmp_vol_list, 0, norm_function, None ) new_asl.set_image(m0_vol_corrected[0], 'm0') - # Apply the normalization transformation to all pcasl volumes - pcasl_vols, _ = collect_data_volumes(asl_data('pcasl')) - normalized_pcasl_vols = [] + # Apply the normalization transformation to all chosen volumes + raw_volumes, _ = collect_data_volumes(asl_data('pcasl')) + + additional_maps_normalized = [] + raw_volumes_normalized = [] with Progress() as progress: task = progress.add_task( - '[green]Applying normalization to pcasl volumes...', - total=len(pcasl_vols), + '[green]Applying normalization to chosen volumes...', + total=len(raw_volumes) + len(additional_maps), ) - for vol in pcasl_vols: + for raw in raw_volumes: norm_vol = apply_transformation( - moving_image=vol, + moving_image=raw, reference_image=atlas_img, transforms=trans_m0_mtx, ) - normalized_pcasl_vols.append(norm_vol) + raw_volumes_normalized.append(norm_vol) progress.update(task, advance=1) - new_asl.set_image(normalized_pcasl_vols, 'pcasl') + for additional_map in additional_maps: + norm_additional_map = apply_transformation( + moving_image=additional_map, + reference_image=atlas_img, + transforms=trans_m0_mtx, + ) + additional_maps_normalized.append(norm_additional_map) + progress.update(task, advance=1) + + # Update the new ASLData with the normalized volumes + norm_array = np.array( + [vol.get_as_numpy() for vol in raw_volumes_normalized] + ) + new_asl.set_image(norm_array, 'pcasl') - return new_asl, trans_m0_mtx + return new_asl, trans_m0_mtx, additional_maps_normalized def head_movement_correction( asl_data: ASLData, - ref_vol: np.ndarray = None, + ref_vol: ImageIO = None, method: str = 'snr', - roi: np.ndarray = None, + roi: ImageIO = None, verbose: bool = False, ): """ @@ -307,8 +243,9 @@ def head_movement_correction( # Check if the reference volume is a valid volume. if ( - not isinstance(ref_volume, np.ndarray) - or ref_volume.shape != total_vols[0].shape + not isinstance(ref_volume, ImageIO) + or ref_volume.get_as_numpy().shape + != total_vols[0].get_as_numpy().shape ): raise ValueError( 'ref_vol must be a valid volume from the total asl data volumes.' @@ -323,10 +260,13 @@ def norm_function(vol, ref_volume): new_asl_data = asl_data.copy() # Create the new ASLData object with the corrected volumes - corrected_vols_array = np.array(corrected_vols).reshape( - asl_data('pcasl').shape - ) - new_asl_data.set_image(corrected_vols_array, 'pcasl') + corrected_vols_array = np.array( + [vol.get_as_numpy() for vol in corrected_vols] + ).reshape(asl_data('pcasl').get_as_numpy().shape) + + adjusted_pcasl = clone_image(asl_data('pcasl')) + adjusted_pcasl.update_image_data(corrected_vols_array) + new_asl_data.set_image(adjusted_pcasl, 'pcasl') return new_asl_data, trans_mtx @@ -343,16 +283,33 @@ def __apply_array_normalization( ) for idx, vol in enumerate(total_vols): try: - _, trans_m = normalization_function(vol, ref_vol) + single_correction_vol, trans_m = normalization_function( + vol, ref_vol + ) # Adjust the transformation matrix - trans_path = trans_m[0] + # if len(trans_m) > 1: + # # Non-linear transformation is being applied + + trans_path = trans_m[-1] t_matrix = ants.read_transform(trans_path) - params = t_matrix.parameters * trans_proportions[idx] + if trans_proportions is None: + params = t_matrix.parameters + else: + params = t_matrix.parameters * trans_proportions[idx] + t_matrix.set_parameters(params) - ants.write_transform(t_matrix, trans_m[0]) + ants.write_transform(t_matrix, trans_m[-1]) + + if isinstance(ref_vol, ImageIO): + # Then the normalization is doing by rigid body registration + corrected_vol = apply_transformation(vol, ref_vol, trans_m) + else: + # Then the normalization is doing by asl_template_normalization + corrected_vol = apply_transformation( + vol, single_correction_vol, trans_m + ) - corrected_vol = apply_transformation(vol, ref_vol, trans_m) except Exception as e: raise RuntimeError( f'[red on white]Error during registration of volume {idx}: {e}[/]' @@ -366,6 +323,10 @@ def __apply_array_normalization( # orig_shape = orig_shape[1:4] # corrected_vols = np.stack(corrected_vols).reshape(orig_shape) + if isinstance(trans_mtx[0], list): + # If the transformation list has a inner list, then take the first one + trans_mtx = trans_mtx[0] + return corrected_vols, trans_mtx @@ -381,12 +342,16 @@ def _collect_transformation_proportions(total_vols, method, roi): Returns: list: List of calculated values based on the method. """ + if roi is None: + # Making a full mask if no ROI is provided + roi = np.ones_like(total_vols[0].get_as_numpy()) + method_values = [] for vol in total_vols: if method == 'snr': - value = calculate_snr(vol, roi=roi) + value = calculate_snr(vol, roi=ImageIO(image_array=roi)) elif method == 'mean': - value = calculate_mean_intensity(vol, roi=roi) + value = calculate_mean_intensity(vol, roi=ImageIO(image_array=roi)) else: raise ValueError(f'Unknown method: {method}') method_values.append(value) diff --git a/asltk/scripts/cbf.py b/asltk/scripts/cbf.py index a21b5e7..0f5a62a 100644 --- a/asltk/scripts/cbf.py +++ b/asltk/scripts/cbf.py @@ -3,14 +3,11 @@ from functools import * import numpy as np -import SimpleITK as sitk from rich import print -from rich.progress import track -from scipy.optimize import curve_fit from asltk.asldata import ASLData from asltk.reconstruction import CBFMapping -from asltk.utils.io import load_image, save_image +from asltk.utils.io import ImageIO parser = argparse.ArgumentParser( prog='CBF/ATT Mapping', @@ -71,6 +68,12 @@ default='nii', help='The file format that will be used to save the output images. It is not allowed image compression (ex: .gz, .zip, etc). Default is nii, but it can be choosen: mha, nrrd.', ) +optional.add_argument( + '--average_m0', + action='store_true', + default=False, + help='Whether to average the M0 images across the time series. Default is False.', +) args = parser.parse_args() @@ -107,12 +110,14 @@ def checkUpParameters(): return is_ok -asl_img = load_image(args.pcasl) -m0_img = load_image(args.m0) +asl_img = ImageIO(args.pcasl) +m0_img = ImageIO(args.m0) -mask_img = np.ones(asl_img[0, 0, :, :, :].shape) +average_m0 = args.average_m0 + +mask_img = ImageIO(image_array=np.ones(asl_img.get_as_numpy().shape[-3:])) if args.mask != '': - mask_img = load_image(args.mask) + mask_img = ImageIO(args.mask) try: @@ -132,16 +137,24 @@ def checkUpParameters(): if args.verbose: print(' --- Script Input Data ---') print('ASL file path: ' + args.pcasl) - print('ASL image dimension: ' + str(asl_img.shape)) + print('ASL image dimension: ' + str(asl_img.get_as_numpy().shape)) print('Mask file path: ' + args.mask) - print('Mask image dimension: ' + str(mask_img.shape)) + print('Mask image dimension: ' + str(mask_img.get_as_numpy().shape)) print('M0 file path: ' + args.m0) - print('M0 image dimension: ' + str(m0_img.shape)) + print('M0 image dimension: ' + str(m0_img.get_as_numpy().shape)) print('PLD: ' + str(pld)) print('LD: ' + str(ld)) print('Output file format: ' + str(args.file_fmt)) -data = ASLData(pcasl=args.pcasl, m0=args.m0, ld_values=ld, pld_values=pld) +print(average_m0) +data = ASLData( + pcasl=args.pcasl, + m0=args.m0, + ld_values=ld, + pld_values=pld, + average_m0=average_m0, +) + recon = CBFMapping(data) recon.set_brain_mask(mask_img) maps = recon.create_map() @@ -150,19 +163,34 @@ def checkUpParameters(): save_path = args.out_folder + os.path.sep + 'cbf_map.' + args.file_fmt if args.verbose: print('Saving CBF map - Path: ' + save_path) -save_image(maps['cbf'], save_path) +maps['cbf'].save_image(save_path) save_path = ( args.out_folder + os.path.sep + 'cbf_map_normalized.' + args.file_fmt ) if args.verbose: print('Saving normalized CBF map - Path: ' + save_path) -save_image(maps['cbf_norm'], save_path) +maps['cbf_norm'].save_image(save_path) save_path = args.out_folder + os.path.sep + 'att_map.' + args.file_fmt if args.verbose: print('Saving ATT map - Path: ' + save_path) -save_image(maps['att'], save_path) +maps['att'].save_image(save_path) if args.verbose: print('Execution: ' + parser.prog + ' finished successfully!') + + +def main(): + """ + Entry point function for the CBF Scalar ASL mapping command-line tool. + + This function is called when the `asltk_cbf` command is run. + All script logic is already defined at the module level. + """ + # Script logic is already defined at the module level + pass + + +if __name__ == '__main__': + main() diff --git a/asltk/scripts/t2_maps.py b/asltk/scripts/t2_maps.py index 602f1df..1eb5bab 100644 --- a/asltk/scripts/t2_maps.py +++ b/asltk/scripts/t2_maps.py @@ -12,7 +12,7 @@ log_processing_step, ) from asltk.reconstruction import T2Scalar_ASLMapping -from asltk.utils import load_image, save_image +from asltk.utils.io import ImageIO parser = argparse.ArgumentParser( prog='T2 Scalar Mapping from ASL Multi-TE ASLData', @@ -80,6 +80,12 @@ default='nii', help='The file format that will be used to save the output images. It is not allowed image compression (ex: .gz, .zip, etc). Default is nii, but it can be choosen: mha, nrrd.', ) +optional.add_argument( + '--average_m0', + action='store_true', + default=False, + help='Whether to average the M0 images across the time series. Default is False.', +) args = parser.parse_args() @@ -120,12 +126,14 @@ def checkUpParameters(): return is_ok -asl_img = load_image(args.pcasl) -m0_img = load_image(args.m0) +asl_img = ImageIO(args.pcasl) +m0_img = ImageIO(args.m0) -mask_img = np.ones(asl_img[0, 0, :, :, :].shape) +average_m0 = args.average_m0 + +mask_img = ImageIO(image_array=np.ones(asl_img.get_as_numpy().shape[-3:])) if args.mask != '': - mask_img = load_image(args.mask) + mask_img = ImageIO(args.mask) try: @@ -148,11 +156,11 @@ def checkUpParameters(): if args.verbose: print(' --- Script Input Data ---') print('ASL file path: ' + args.pcasl) - print('ASL image dimension: ' + str(asl_img.shape)) + print('ASL image dimension: ' + str(asl_img.get_as_numpy().shape)) print('Mask file path: ' + args.mask) - print('Mask image dimension: ' + str(mask_img.shape)) + print('Mask image dimension: ' + str(mask_img.get_as_numpy().shape)) print('M0 file path: ' + args.m0) - print('M0 image dimension: ' + str(m0_img.shape)) + print('M0 image dimension: ' + str(m0_img.get_as_numpy().shape)) print('PLD: ' + str(pld)) print('LD: ' + str(ld)) print('TE: ' + str(te)) @@ -166,7 +174,12 @@ def checkUpParameters(): 'Creating ASLData object', f'Multi-TE with {len(te)} echo times' ) data = ASLData( - pcasl=args.pcasl, m0=args.m0, ld_values=ld, pld_values=pld, te_values=te + pcasl=args.pcasl, + m0=args.m0, + ld_values=ld, + pld_values=pld, + te_values=te, + average_m0=average_m0, ) log_processing_step('Initializing T2 Scalar mapper') @@ -185,8 +198,23 @@ def checkUpParameters(): if args.verbose and maps['t2'] is not None: print('Saving T2 maps - Path: ' + save_path) logger.info(f'Saving T2 maps to: {save_path}') -save_image(maps['t2'], save_path) +maps['t2'].save_image(save_path) if args.verbose: print('Execution: ' + parser.prog + ' finished successfully!') logger.info('T2 Scalar ASL processing completed successfully') + + +def main(): + """ + Entry point function for the T2 Scalar ASL mapping command-line tool. + + This function is called when the `asltk_t2_asl` command is run. + All script logic is already defined at the module level. + """ + # Script logic is already defined at the module level + pass + + +if __name__ == '__main__': + main() diff --git a/asltk/scripts/te_asl.py b/asltk/scripts/te_asl.py index 0da38dd..b11a8fd 100644 --- a/asltk/scripts/te_asl.py +++ b/asltk/scripts/te_asl.py @@ -13,7 +13,7 @@ log_processing_step, ) from asltk.reconstruction import MultiTE_ASLMapping -from asltk.utils.io import load_image, save_image +from asltk.utils.io import ImageIO parser = argparse.ArgumentParser( prog='Multi-TE ASL Mapping', @@ -95,6 +95,12 @@ default='nii', help='The file format that will be used to save the output images. It is not allowed image compression (ex: .gz, .zip, etc). Default is nii, but it can be choosen: mha, nrrd.', ) +optional.add_argument( + '--average_m0', + action='store_true', + default=False, + help='Whether to average the M0 images across the time series. Default is False.', +) args = parser.parse_args() @@ -135,21 +141,23 @@ def checkUpParameters(): return is_ok -asl_img = load_image(args.pcasl) -m0_img = load_image(args.m0) +asl_img = ImageIO(args.pcasl) +m0_img = ImageIO(args.m0) -mask_img = np.ones(asl_img[0, 0, :, :, :].shape) +average_m0 = args.average_m0 + +mask_img = ImageIO(image_array=np.ones(asl_img.get_as_numpy().shape[-3:])) if args.mask != '': - mask_img = load_image(args.mask) + mask_img = ImageIO(args.mask) cbf_map = None if args.cbf is not None: - cbf_map = load_image(args.cbf) + cbf_map = ImageIO(args.cbf) att_map = None if args.att is not None: - att_map = load_image(args.att) + att_map = ImageIO(args.att) try: @@ -172,11 +180,11 @@ def checkUpParameters(): if args.verbose: print(' --- Script Input Data ---') print('ASL file path: ' + args.pcasl) - print('ASL image dimension: ' + str(asl_img.shape)) + print('ASL image dimension: ' + str(asl_img.get_as_numpy().shape)) print('Mask file path: ' + args.mask) - print('Mask image dimension: ' + str(mask_img.shape)) + print('Mask image dimension: ' + str(mask_img.get_as_numpy().shape)) print('M0 file path: ' + args.m0) - print('M0 image dimension: ' + str(m0_img.shape)) + print('M0 image dimension: ' + str(m0_img.get_as_numpy().shape)) print('PLD: ' + str(pld)) print('LD: ' + str(ld)) print('TE: ' + str(te)) @@ -194,14 +202,19 @@ def checkUpParameters(): 'Creating ASLData object', f'Multi-TE with {len(te)} echo times' ) data = ASLData( - pcasl=args.pcasl, m0=args.m0, ld_values=ld, pld_values=pld, te_values=te + pcasl=args.pcasl, + m0=args.m0, + ld_values=ld, + pld_values=pld, + te_values=te, + average_m0=average_m0, ) log_processing_step('Initializing Multi-TE ASL mapper') recon = MultiTE_ASLMapping(data) recon.set_brain_mask(mask_img) -if isinstance(cbf_map, np.ndarray) and isinstance(att_map, np.ndarray): +if isinstance(cbf_map, ImageIO) and isinstance(att_map, ImageIO): logger.info('Setting optional CBF and ATT maps') recon.set_cbf_map(cbf_map) recon.set_att_map(att_map) @@ -217,7 +230,7 @@ def checkUpParameters(): if args.verbose and cbf_map is not None: print('Saving CBF map - Path: ' + save_path) logger.info(f'Saving CBF map to: {save_path}') -save_image(maps['cbf'], save_path) +maps['cbf'].save_image(save_path) save_path = ( args.out_folder + os.path.sep + 'cbf_map_normalized.' + args.file_fmt @@ -225,20 +238,35 @@ def checkUpParameters(): if args.verbose and cbf_map is not None: print('Saving normalized CBF map - Path: ' + save_path) logger.info(f'Saving normalized CBF map to: {save_path}') -save_image(maps['cbf_norm'], save_path) +maps['cbf_norm'].save_image(save_path) save_path = args.out_folder + os.path.sep + 'att_map.' + args.file_fmt if args.verbose and att_map is not None: print('Saving ATT map - Path: ' + save_path) logger.info(f'Saving ATT map to: {save_path}') -save_image(maps['att'], save_path) +maps['att'].save_image(save_path) save_path = args.out_folder + os.path.sep + 'mte_t1blgm_map.' + args.file_fmt if args.verbose: print('Saving multiTE-ASL T1blGM map - Path: ' + save_path) logger.info(f'Saving T1blGM map to: {save_path}') -save_image(maps['t1blgm'], save_path) +maps['t1blgm'].save_image(save_path) if args.verbose: print('Execution: ' + parser.prog + ' finished successfully!') logger.info('Multi-TE ASL processing completed successfully') + + +def main(): + """ + Entry point function for the multi-TE Scalar ASL mapping command-line tool. + + This function is called when the `asltk_te_asl` command is run. + All script logic is already defined at the module level. + """ + # Script logic is already defined at the module level + pass + + +if __name__ == '__main__': + main() diff --git a/asltk/smooth/gaussian.py b/asltk/smooth/gaussian.py index c449c50..c21b57c 100644 --- a/asltk/smooth/gaussian.py +++ b/asltk/smooth/gaussian.py @@ -4,9 +4,10 @@ import SimpleITK as sitk from asltk.utils.image_manipulation import collect_data_volumes +from asltk.utils.io import ImageIO, clone_image -def isotropic_gaussian(data, sigma: float = 1.0): +def isotropic_gaussian(data: ImageIO, sigma: float = 1.0): """Smooth the data using a isotropic Gaussian kernel. This method assumes that the same kernal size will be applied over all the @@ -41,14 +42,14 @@ def isotropic_gaussian(data, sigma: float = 1.0): raise ValueError('sigma must be a positive number.') # Check if the input data is a numpy array - if not isinstance(data, np.ndarray): - raise TypeError(f'data is not a numpy array. Type {type(data)}') + if not isinstance(data, ImageIO): + raise TypeError(f'data is not an ImageIO object. Type {type(data)}') # Make the Gaussian instance using the kernel size based on sigma parameter gaussian = sitk.SmoothingRecursiveGaussianImageFilter() gaussian.SetSigma(sigma) - if data.ndim > 3: + if data.get_as_numpy().ndim > 3: warnings.warn( 'Input data is not a 3D volume. The filter will be applied for all volumes.', UserWarning, @@ -56,6 +57,13 @@ def isotropic_gaussian(data, sigma: float = 1.0): volumes, _ = collect_data_volumes(data) processed = [] for volume in volumes: - processed.append(gaussian.Execute(sitk.GetImageFromArray(volume))) + processed.append(gaussian.Execute(volume.get_as_sitk())) - return np.array(processed).reshape(data.shape) + smooth_array = np.array( + [sitk.GetArrayFromImage(vol) for vol in processed] + ).reshape(data.get_as_numpy().shape) + + out_data = clone_image(data) + out_data.update_image_data(smooth_array) + + return out_data diff --git a/asltk/smooth/median.py b/asltk/smooth/median.py index ec874a0..cd0b4a2 100644 --- a/asltk/smooth/median.py +++ b/asltk/smooth/median.py @@ -4,9 +4,10 @@ from scipy.ndimage import median_filter from asltk.utils.image_manipulation import collect_data_volumes +from asltk.utils.io import ImageIO, clone_image -def isotropic_median(data, size: int = 3): +def isotropic_median(data: ImageIO, size: int = 3): """Smooth the data using a median filter. This method applies a median filter with an isotropic kernel to reduce @@ -41,8 +42,8 @@ def isotropic_median(data, size: int = 3): raise ValueError('size must be a positive integer.') # Check if the input data is a numpy array - if not isinstance(data, np.ndarray): - raise TypeError(f'data is not a numpy array. Type {type(data)}') + if not isinstance(data, ImageIO): + raise TypeError(f'data is not an ImageIO object. Type {type(data)}') # Ensure size is odd if size % 2 == 0: @@ -52,7 +53,7 @@ def isotropic_median(data, size: int = 3): UserWarning, ) - if data.ndim > 3: + if data.get_as_numpy().ndim > 3: warnings.warn( 'Input data is not a 3D volume. The filter will be applied for all volumes.', UserWarning, @@ -61,7 +62,12 @@ def isotropic_median(data, size: int = 3): volumes, _ = collect_data_volumes(data) processed = [] for volume in volumes: - filtered_volume = median_filter(volume, size=size) + filtered_volume = median_filter(volume.get_as_numpy(), size=size) processed.append(filtered_volume) - return np.array(processed).reshape(data.shape) + smooth_array = np.array(processed).reshape(data.get_as_numpy().shape) + + out_data = clone_image(data) + out_data.update_image_data(smooth_array) + + return out_data diff --git a/asltk/utils/image_manipulation.py b/asltk/utils/image_manipulation.py index dbad775..b3b6ac7 100644 --- a/asltk/utils/image_manipulation.py +++ b/asltk/utils/image_manipulation.py @@ -12,6 +12,7 @@ calculate_mean_intensity, calculate_snr, ) +from asltk.utils.io import ImageIO, clone_image logger = get_logger(__name__) @@ -20,7 +21,9 @@ sitk.ProcessObject_SetGlobalDefaultNumberOfThreads(num_cores) -def collect_data_volumes(data: np.ndarray): +def collect_data_volumes( + data: ImageIO, +) -> Tuple[List[ImageIO], Tuple[int, ...]]: """Collect the data volumes from a higher dimension array. This method is used to collect the data volumes from a higher dimension @@ -36,297 +39,42 @@ def collect_data_volumes(data: np.ndarray): data (np.ndarray): The data to be separated. Returns: - list: A list of 3D arrays, each one representing a volume. + list: A list of ImageIO, each one representing a volume. tuple: The original shape of the data. """ - if not isinstance(data, np.ndarray): - raise TypeError('data is not a numpy array.') + if not isinstance(data, ImageIO): + raise TypeError('data is not an ImageIO object.') - if data.ndim < 3: - raise ValueError('data is a 3D volume or higher dimensions') + dimension = data.get_as_numpy().ndim + if dimension < 3: + raise ValueError('data is not a 3D volume or higher dimensions') volumes = [] # Calculate the number of volumes by multiplying all dimensions except the last three - num_volumes = int(np.prod(data.shape[:-3])) - reshaped_data = data.reshape((int(num_volumes),) + data.shape[-3:]) - for i in range(num_volumes): - volumes.append(reshaped_data[i]) - - return volumes, data.shape - - -def orientation_check( - moving_image: np.ndarray, fixed_image: np.ndarray, threshold: float = 0.1 -) -> Dict[str, any]: - """ - Quick orientation compatibility check between two images. - - This function provides a fast assessment of whether two images - have compatible orientations for registration without applying - any corrections. - - Parameters - ---------- - moving_image : np.ndarray - The moving image to be checked. - fixed_image : np.ndarray - The reference/fixed image. - threshold : float, optional - Correlation threshold to consider orientations compatible. Default is 0.1. - - Returns - ------- - dict - Dictionary containing: - - 'compatible': bool, whether orientations are compatible - - 'correlation': float, normalized correlation between images - - 'recommendation': str, action recommendation - """ - # Normalize images - moving_norm = _normalize_image_intensity(moving_image) - fixed_norm = _normalize_image_intensity(fixed_image) - - # Resize if needed for comparison - # Resize the larger image to match the smaller one to minimize memory overhead - if moving_norm.shape != fixed_norm.shape: - if np.prod(moving_norm.shape) > np.prod(fixed_norm.shape): - moving_norm = _resize_image_to_match(moving_norm, fixed_norm.shape) - else: - fixed_norm = _resize_image_to_match(fixed_norm, moving_norm.shape) - - # Compute correlation - correlation = _compute_normalized_correlation(moving_norm, fixed_norm) - - # Determine compatibility - compatible = correlation > threshold - - if compatible: - recommendation = 'Images appear to have compatible orientations. Registration should proceed normally.' - elif correlation > 0.05: - recommendation = 'Possible orientation mismatch detected. Consider using orientation correction.' - else: - recommendation = 'Strong orientation mismatch detected. Orientation correction is highly recommended.' - - return { - 'compatible': compatible, - 'correlation': correlation, - 'recommendation': recommendation, - } - - -# TODO Evaluate this method and decide if it is needed (or useful...) -# def preview_orientation_correction( -# moving_image: np.ndarray, -# fixed_image: np.ndarray, -# slice_index: Optional[int] = None -# ) -> Dict[str, np.ndarray]: -# """ -# Preview the effect of orientation correction on a specific slice. - -# This function shows the before and after effect of orientation -# correction on a 2D slice, useful for visual validation. - -# Parameters -# ---------- -# moving_image : np.ndarray -# The moving image to be corrected. -# fixed_image : np.ndarray -# The reference/fixed image. -# slice_index : int, optional -# Index of the axial slice to preview. If None, uses middle slice. - -# Returns -# ------- -# dict -# Dictionary containing: -# - 'original_slice': np.ndarray, original moving image slice -# - 'corrected_slice': np.ndarray, corrected moving image slice -# - 'fixed_slice': np.ndarray, corresponding fixed image slice -# - 'slice_index': int, the slice index used -# """ -# # Get orientation correction -# corrected_moving, _ = check_and_fix_orientation( -# moving_image, fixed_image, verbose=False -# ) - -# # Determine slice index -# if slice_index is None: -# slice_index = moving_image.shape[0] // 2 - -# # Ensure slice index is valid -# slice_index = max(0, min(slice_index, moving_image.shape[0] - 1)) -# corrected_slice_idx = max(0, min(slice_index, corrected_moving.shape[0] - 1)) -# fixed_slice_idx = max(0, min(slice_index, fixed_image.shape[0] - 1)) - -# return { -# 'original_slice': moving_image[slice_index, :, :], -# 'corrected_slice': corrected_moving[corrected_slice_idx, :, :], -# 'fixed_slice': fixed_image[fixed_slice_idx, :, :], -# 'slice_index': slice_index -# } - - -def check_and_fix_orientation( - moving_image: np.ndarray, - fixed_image: np.ndarray, - moving_spacing: tuple = None, - fixed_spacing: tuple = None, - verbose: bool = False, -): - """ - Check and fix orientation mismatches between moving and fixed images. - - This function analyzes the anatomical orientations of both images and - applies necessary transformations to align them before registration. - It handles common orientation issues like axial, sagittal, or coronal flips. - - The method uses both intensity-based and geometric approaches to determine - the best orientation alignment between images. - - Parameters - ---------- - moving_image : np.ndarray - The moving image that needs to be aligned. - fixed_image : np.ndarray - The reference/fixed image. - moving_spacing : tuple, optional - Voxel spacing for the moving image (x, y, z). If None, assumes isotropic. - fixed_spacing : tuple, optional - Voxel spacing for the fixed image (x, y, z). If None, assumes isotropic. - verbose : bool, optional - If True, prints detailed orientation analysis. Default is False. - - Returns - ------- - corrected_moving : np.ndarray - The moving image with corrected orientation. - orientation_transform : dict - Dictionary containing the applied transformations for reproducibility. - """ - if verbose: - print('Analyzing image orientations...') - - # Convert to SimpleITK images for orientation analysis - moving_sitk = sitk.GetImageFromArray(moving_image) - fixed_sitk = sitk.GetImageFromArray(fixed_image) - - # Set spacing if provided - if moving_spacing is not None: - moving_sitk.SetSpacing(moving_spacing) - if fixed_spacing is not None: - fixed_sitk.SetSpacing(fixed_spacing) - - # Get image dimensions and properties - moving_size = moving_sitk.GetSize() - fixed_size = fixed_sitk.GetSize() - - if verbose: - print(f'Moving image size: {moving_size}') - print(f'Fixed image size: {fixed_size}') - - # Analyze anatomical orientations using intensity patterns - orientation_transform = _analyze_anatomical_orientation( - moving_image, fixed_image, verbose - ) - - # Apply orientation corrections - corrected_moving = _apply_orientation_correction( - moving_image, orientation_transform, verbose + num_volumes = int(np.prod(data.get_as_numpy().shape[:-3])) + reshaped_data = data.get_as_numpy().reshape( + (int(num_volumes),) + data.get_as_numpy().shape[-3:] ) + for i in range(num_volumes): + base_data = ImageIO(image_array=reshaped_data[i]) + base_data.update_image_spacing(data._image_as_sitk.GetSpacing()[:3]) + base_data.update_image_origin(data._image_as_sitk.GetOrigin()[:3]) - # Verify the correction using cross-correlation - if verbose: - original_corr = _compute_normalized_correlation( - moving_image, fixed_image + tmp_dir_array = np.array(data._image_as_sitk.GetDirection()).reshape( + dimension, dimension ) - corrected_corr = _compute_normalized_correlation( - corrected_moving, fixed_image + base_data.update_image_direction( + tuple(tmp_dir_array[:3, :3].flatten()) ) - print(f'Original correlation: {original_corr:.4f}') - print(f'Corrected correlation: {corrected_corr:.4f}') - if corrected_corr > original_corr: - print('Orientation correction improved alignment') - else: - print('Orientation correction may not have improved alignment') - return corrected_moving, orientation_transform + volumes.append(base_data) - -def create_orientation_report( - moving_image: np.ndarray, - fixed_image: np.ndarray, - output_path: Optional[str] = None, -) -> str: - """ - Create a comprehensive orientation analysis report. - - Parameters - ---------- - moving_image : np.ndarray - The moving image to analyze. - fixed_image : np.ndarray - The reference/fixed image. - output_path : str, optional - Path to save the report. If None, returns the report as string. - - Returns - ------- - str - The orientation analysis report. - """ - # Perform analysis - quick_check = orientation_check(moving_image, fixed_image) - moving_props = analyze_image_properties(moving_image) - fixed_props = analyze_image_properties(fixed_image) - - # Get correction info - corrected_moving, orientation_transform = check_and_fix_orientation( - moving_image, fixed_image, verbose=False - ) - - # Generate report - report = f""" - ORIENTATION ANALYSIS REPORT - ============================ - - QUICK COMPATIBILITY CHECK: - - Orientation Compatible: {quick_check['compatible']} - - Correlation Score: {quick_check['correlation']:.4f} - - Recommendation: {quick_check['recommendation']} - - MOVING IMAGE PROPERTIES: - - Shape: {moving_props['shape']} - - Center of Mass: {moving_props['center_of_mass']} - - Intensity Range: {moving_props['intensity_stats']['min']:.2f} - {moving_props['intensity_stats']['max']:.2f} - - Mean Intensity: {moving_props['intensity_stats']['mean']:.2f} - - FIXED IMAGE PROPERTIES: - - Shape: {fixed_props['shape']} - - Center of Mass: {fixed_props['center_of_mass']} - - Intensity Range: {fixed_props['intensity_stats']['min']:.2f} - {fixed_props['intensity_stats']['max']:.2f} - - Mean Intensity: {fixed_props['intensity_stats']['mean']:.2f} - - ORIENTATION CORRECTION APPLIED: - - X-axis flip: {orientation_transform.get('flip_x', False)} - - Y-axis flip: {orientation_transform.get('flip_y', False)} - - Z-axis flip: {orientation_transform.get('flip_z', False)} - - Axis transpose: {orientation_transform.get('transpose_axes', 'None')} - - RECOMMENDATIONS: - {quick_check['recommendation']} - """.strip() - - if output_path: - with open(output_path, 'w') as f: - f.write(report) - print(f'Report saved to: {output_path}') - - return report + return volumes, data.get_as_numpy().shape def select_reference_volume( - asl_data: Union['ASLData', list[np.ndarray]], - roi: np.ndarray = None, + asl_data: Union['ASLData', list[ImageIO]], + roi: ImageIO = None, method: str = 'snr', ): from asltk.asldata import ASLData @@ -354,20 +102,20 @@ def select_reference_volume( raise ValueError(f'Invalid method: {method}') if roi is not None: - if not isinstance(roi, np.ndarray): - raise TypeError('ROI must be a numpy array.') - if roi.ndim != 3: + if not isinstance(roi, ImageIO): + raise TypeError('ROI must be an ImageIO object.') + if roi.get_as_numpy().ndim != 3: raise ValueError('ROI must be a 3D array.') if isinstance(asl_data, ASLData): volumes, _ = collect_data_volumes(asl_data('pcasl')) elif isinstance(asl_data, list) and all( - isinstance(vol, np.ndarray) for vol in asl_data + isinstance(vol, ImageIO) for vol in asl_data ): volumes = asl_data else: raise TypeError( - 'asl_data must be an ASLData object or a list of numpy arrays.' + 'asl_data must be an ASLData object or a list of ImageIO objects.' ) if method == 'snr': @@ -381,7 +129,7 @@ def select_reference_volume( logger.info('Estimating maximum mean from provided volumes...') ref_volume, vol_idx = _estimate_max_mean(volumes, roi=roi) logger.info( - f'Selected volume index: {vol_idx} with mean: {ref_volume.mean():.2f}' + f'Selected volume index: {vol_idx} with mean: {ref_volume.get_as_numpy().mean():.2f}' ) else: raise ValueError(f'Unknown method: {method}') @@ -390,13 +138,13 @@ def select_reference_volume( def _estimate_max_snr( - volumes: List[np.ndarray], roi: np.ndarray = None -) -> Tuple[np.ndarray, int]: # pragma: no cover + volumes: List[ImageIO], roi: ImageIO = None +) -> Tuple[ImageIO, int]: # pragma: no cover """ Estimate the maximum SNR from a list of volumes. Args: - volumes (List[np.ndarray]): A list of 3D numpy arrays representing the image volumes. + volumes (List[ImageIO]): A list of ImageIO objects representing the image volumes. Raises: TypeError: If any volume is not a numpy array. @@ -407,11 +155,15 @@ def _estimate_max_snr( max_snr_idx = 0 max_snr_value = 0 for idx, vol in enumerate(volumes): - if not isinstance(vol, np.ndarray): - logger.error(f'Volume at index {idx} is not a numpy array.') - raise TypeError('All volumes must be numpy arrays.') + if not isinstance(vol, ImageIO): + logger.error(f'Volume at index {idx} is not an ImageIO object.') + raise TypeError('All volumes must be ImageIO objects.') + + if roi is not None: + snr_value = calculate_snr(vol, roi=roi) + else: + snr_value = calculate_snr(vol) - snr_value = calculate_snr(vol, roi=roi) if snr_value > max_snr_value: max_snr_value = snr_value max_snr_idx = idx @@ -422,16 +174,16 @@ def _estimate_max_snr( def _estimate_max_mean( - volumes: List[np.ndarray], roi: np.ndarray = None -) -> Tuple[np.ndarray, int]: + volumes: List[ImageIO], roi: ImageIO = None +) -> Tuple[ImageIO, int]: """ Estimate the maximum mean from a list of volumes. Args: - volumes (List[np.ndarray]): A list of 3D numpy arrays representing the image volumes. + volumes (List[ImageIO]): A list of ImageIO objects representing the image volumes. Raises: - TypeError: If any volume is not a numpy array. + TypeError: If any volume is not an ImageIO object. Returns: Tuple[np.ndarray, int]: The reference volume and its index. @@ -439,11 +191,14 @@ def _estimate_max_mean( max_mean_idx = 0 max_mean_value = 0 for idx, vol in enumerate(volumes): - if not isinstance(vol, np.ndarray): - logger.error(f'Volume at index {idx} is not a numpy array.') - raise TypeError('All volumes must be numpy arrays.') + if not isinstance(vol, ImageIO): + logger.error(f'Volume at index {idx} is not an ImageIO object.') + raise TypeError('All volumes must be ImageIO objects.') - mean_value = calculate_mean_intensity(vol, roi=roi) + if roi is not None: + mean_value = calculate_mean_intensity(vol, roi=roi) + else: + mean_value = calculate_mean_intensity(vol) if mean_value > max_mean_value: max_mean_value = mean_value max_mean_idx = idx @@ -451,205 +206,3 @@ def _estimate_max_mean( ref_volume = volumes[max_mean_idx] return ref_volume, max_mean_idx - - -def _analyze_anatomical_orientation(moving_image, fixed_image, verbose=False): - """ - Analyze anatomical orientations by comparing intensity patterns - and geometric properties of brain images. - """ - orientation_transform = { - 'flip_x': False, - 'flip_y': False, - 'flip_z': False, - 'transpose_axes': None, - } - - # Normalize images for comparison - moving_norm = _normalize_image_intensity(moving_image) - fixed_norm = _normalize_image_intensity(fixed_image) - - # Determine the smaller shape for comparison - moving_size = np.prod(moving_norm.shape) - fixed_size = np.prod(fixed_norm.shape) - if moving_size <= fixed_size: - ref_shape = moving_norm.shape - else: - ref_shape = fixed_norm.shape - - # Test different orientation combinations - best_corr = -1 - best_transform = orientation_transform.copy() - - # Test axis flips - for flip_x in [False, True]: - for flip_y in [False, True]: - for flip_z in [False, True]: - # Apply test transformation - test_img = moving_norm.copy() - if flip_x: - test_img = np.flip(test_img, axis=2) # X axis - if flip_y: - test_img = np.flip(test_img, axis=1) # Y axis - if flip_z: - test_img = np.flip(test_img, axis=0) # Z axis - - # Resize to match reference shape if needed - if test_img.shape != ref_shape: - test_img = _resize_image_to_match(test_img, ref_shape) - - # Also resize fixed_norm if needed - ref_img = fixed_norm - if fixed_norm.shape != ref_shape: - ref_img = _resize_image_to_match(fixed_norm, ref_shape) - - # Compute correlation - corr = _compute_normalized_correlation(test_img, ref_img) - - if corr > best_corr: - best_corr = corr - best_transform = { - 'flip_x': flip_x, - 'flip_y': flip_y, - 'flip_z': flip_z, - 'transpose_axes': None, - } - - if verbose: - print( - f'Flip X:{flip_x}, Y:{flip_y}, Z:{flip_z} -> Correlation: {corr:.4f}' - ) - - # Test common axis permutations for different acquisition orientations - axis_permutations = [ - (0, 1, 2), # Original - (0, 2, 1), # Swap Y-Z - (1, 0, 2), # Swap X-Y - (1, 2, 0), # Rotate axes - (2, 0, 1), # Rotate axes - (2, 1, 0), # Swap X-Z - ] - - for axes in axis_permutations[1:]: # Skip original - try: - test_img = np.transpose(moving_norm, axes) - # Resize to match reference shape if needed - if test_img.shape != ref_shape: - test_img = _resize_image_to_match(test_img, ref_shape) - - # Also resize fixed_norm if needed - ref_img = fixed_norm - if fixed_norm.shape != ref_shape: - ref_img = _resize_image_to_match(fixed_norm, ref_shape) - - corr = _compute_normalized_correlation(test_img, ref_img) - - if corr > best_corr: - best_corr = corr - best_transform = { - 'flip_x': False, - 'flip_y': False, - 'flip_z': False, - 'transpose_axes': axes, - } - - if verbose: - print(f'Transpose {axes} -> Correlation: {corr:.4f}') - except Exception as e: - if verbose: - print(f'Failed transpose {axes}: {e}') - continue - - if verbose: - print(f'Best orientation transform: {best_transform}') - print(f'Best correlation: {best_corr:.4f}') - - return best_transform - - -def _apply_orientation_correction(image, orientation_transform, verbose=False): - """Apply the determined orientation corrections to the image.""" - corrected = image.copy() - - # Apply axis transposition first if needed - if orientation_transform['transpose_axes'] is not None: - corrected = np.transpose( - corrected, orientation_transform['transpose_axes'] - ) - if verbose: - print( - f"Applied transpose: {orientation_transform['transpose_axes']}" - ) - - # Apply axis flips - if orientation_transform['flip_x']: - corrected = np.flip(corrected, axis=2) - if verbose: - print('Applied X-axis flip') - - if orientation_transform['flip_y']: - corrected = np.flip(corrected, axis=1) - if verbose: - print('Applied Y-axis flip') - - if orientation_transform['flip_z']: - corrected = np.flip(corrected, axis=0) - if verbose: - print('Applied Z-axis flip') - - return corrected - - -def _normalize_image_intensity(image): - """Normalize image intensity to [0, 1] range for comparison.""" - img = image.astype(np.float64) - img_min, img_max = np.min(img), np.max(img) - if img_max > img_min: - img = (img - img_min) / (img_max - img_min) - return img - - -def _resize_image_to_match(source_image, resample_shape): - """Resize source image to match target shape using antsPy (ants).""" - - # Convert numpy array to ANTsImage (assume float32 for compatibility) - ants_img = ants.from_numpy(source_image.astype(np.float32)) - - # Resample to target shape - resampled_img = ants.resample_image( - ants_img, resample_shape, use_voxels=True, interp_type=0 - ) - - # Convert back to numpy array - return resampled_img.numpy() - - -def _compute_normalized_correlation(img1, img2): - """Compute normalized cross-correlation between two images.""" - # Ensure same shape - if img1.shape != img2.shape: - return -1 - - # Flatten images - img1_flat = img1.flatten() - img2_flat = img2.flatten() - - # Remove NaN and infinite values - valid_mask = np.isfinite(img1_flat) & np.isfinite(img2_flat) - if np.sum(valid_mask) == 0: - return -1 - - img1_valid = img1_flat[valid_mask] - img2_valid = img2_flat[valid_mask] - - # Compute correlation coefficient - try: - corr_matrix = np.corrcoef(img1_valid, img2_valid) - correlation = corr_matrix[0, 1] - if np.isnan(correlation): - return -1 - return abs( - correlation - ) # Use absolute value for orientation independence - except: - return -1 diff --git a/asltk/utils/image_statistics.py b/asltk/utils/image_statistics.py index 71b0585..3f78012 100644 --- a/asltk/utils/image_statistics.py +++ b/asltk/utils/image_statistics.py @@ -3,8 +3,10 @@ import numpy as np from scipy.ndimage import center_of_mass +from asltk.utils.io import ImageIO -def calculate_snr(image: np.ndarray, roi: np.ndarray = None) -> float: + +def calculate_snr(image: ImageIO, roi: ImageIO = None) -> float: """ Calculate the Signal-to-Noise Ratio (SNR) of a medical image. @@ -20,21 +22,21 @@ def calculate_snr(image: np.ndarray, roi: np.ndarray = None) -> float: float The SNR value of the image. """ - if not isinstance(image, np.ndarray): - raise ValueError('Input must be a numpy array.') + if not isinstance(image, ImageIO): + raise ValueError('Input must be an ImageIO object.') if roi is not None: - if not isinstance(roi, np.ndarray): - raise ValueError('ROI must be a numpy array.') - if roi.shape != image.shape: + if not isinstance(roi, ImageIO): + raise ValueError('ROI must be an ImageIO object.') + if roi.get_as_numpy().shape != image.get_as_numpy().shape: raise ValueError('ROI shape must match image shape.') - image_roi = image[roi > 0] + image_roi = image.get_as_numpy()[roi.get_as_numpy() > 0] mean_signal = np.mean(image_roi) noise = image_roi - mean_signal else: - mean_signal = np.mean(image) - noise = image - mean_signal + mean_signal = np.mean(image.get_as_numpy()) + noise = image.get_as_numpy() - mean_signal try: snr = mean_signal / np.std(noise) @@ -44,9 +46,7 @@ def calculate_snr(image: np.ndarray, roi: np.ndarray = None) -> float: return float(abs(snr)) if not np.isnan(snr) else 0.0 -def calculate_mean_intensity( - image: np.ndarray, roi: np.ndarray = None -) -> float: +def calculate_mean_intensity(image: ImageIO, roi: ImageIO = None) -> float: """ Calculate the mean intensity of a medical image. @@ -63,22 +63,24 @@ def calculate_mean_intensity( float The mean intensity value of the image or ROI. """ - if not isinstance(image, np.ndarray): - raise ValueError('Input must be a numpy array.') + if not isinstance(image, ImageIO): + raise ValueError('Input must be an ImageIO object.') if roi is not None: - if not isinstance(roi, np.ndarray): - raise ValueError('ROI must be a numpy array.') - if roi.shape != image.shape: + if not isinstance(roi, ImageIO): + raise ValueError('ROI must be an ImageIO object.') + if roi.get_as_numpy().shape != image.get_as_numpy().shape: raise ValueError('ROI shape must match image shape.') # Compute mean intensity if roi is not None: - return float(abs(np.mean(image[roi > 0]))) # Only consider ROI - return float(abs(np.mean(image))) + return float( + abs(np.mean(image.get_as_numpy()[roi.get_as_numpy() > 0])) + ) # Only consider ROI + return float(abs(np.mean(image.get_as_numpy()))) -def analyze_image_properties(image: np.ndarray) -> Dict[str, any]: +def analyze_image_properties(image: ImageIO) -> Dict[str, any]: """ Analyze basic properties of a medical image for orientation assessment. @@ -96,33 +98,35 @@ def analyze_image_properties(image: np.ndarray) -> Dict[str, any]: - 'intensity_stats': dict, intensity statistics - 'symmetry_axes': dict, symmetry analysis for each axis """ + image_array = image.get_as_numpy() + # Basic properties - shape = image.shape + shape = image_array.shape # Center of mass try: - com = center_of_mass(image > np.mean(image)) + com = center_of_mass(image_array > np.mean(image_array)) except ImportError: # pragma: no cover # Fallback calculation without scipy - coords = np.argwhere(image > np.mean(image)) + coords = np.argwhere(image_array > np.mean(image_array)) com = np.mean(coords, axis=0) if len(coords) > 0 else (0, 0, 0) # Intensity statistics intensity_stats = { - 'min': float(np.min(image)), - 'max': float(np.max(image)), - 'mean': float(np.mean(image)), - 'std': float(np.std(image)), - 'median': float(np.median(image)), + 'min': float(np.min(image_array)), + 'max': float(np.max(image_array)), + 'mean': float(np.mean(image_array)), + 'std': float(np.std(image_array)), + 'median': float(np.median(image_array)), } # Symmetry analysis symmetry_axes = {} for axis in range(3): # Flip along axis and compare - flipped = np.flip(image, axis=axis) - correlation = _compute_correlation_simple(image, flipped) + flipped = np.flip(image_array, axis=axis) + correlation = _compute_correlation_simple(image_array, flipped) symmetry_axes[f'axis_{axis}'] = { 'symmetry_correlation': correlation, 'likely_symmetric': correlation > 0.8, diff --git a/asltk/utils/io.py b/asltk/utils/io.py index d3f6ec9..97a601e 100644 --- a/asltk/utils/io.py +++ b/asltk/utils/io.py @@ -1,99 +1,733 @@ +import copy import fnmatch import os +import warnings +from typing import Union +import ants import dill import numpy as np import SimpleITK as sitk +from ants.utils.sitk_to_ants import from_sitk from bids import BIDSLayout +from rich import print from asltk import AVAILABLE_IMAGE_FORMATS, BIDS_IMAGE_FORMATS -def load_image( - full_path: str, - subject: str = None, - session: str = None, - modality: str = None, - suffix: str = None, - **kwargs, -): +class ImageIO: + """ImageIO is the base class in `asltk` for loading, manipulating, + and saving ASL images. + + The basic functionality includes: + - Loading images from a file path or a numpy array. + - Converting images to different representations (SimpleITK, ANTsPy, numpy). + - Saving images to a file path in various formats. """ - Load an image file from a BIDS directory or file using the SimpleITK API. - The output is always a numpy array, converted from the SimpleITK image object. + def __init__( + self, image_path: str = None, image_array: np.ndarray = None, **kwargs + ): + """The constructor initializes the ImageIO object with an image path or a numpy array. + + It is needed to provide either an image path or a numpy array to load the image. + If both are provided, an error will be raised because it is ambiguous which one to use. + + Note: + - If `image_path` is provided, the image will be loaded from the file. + - If `image_array` is provided, the image will be loaded as a numpy array. + - If both are provided, an error will be raised. + - If neither is provided, an error will be raised. + + Important: + The image path should be a valid file path to an image file or a directory containing BIDS-compliant images. + It is also recommended to provide the image path for complex image processing, as it allows to preserve the image metadata and properties, as seen for the SimpleITK and ANTsPy representations. + + Only the SimpleITK and Numpy representations are availble to manipulate higher dimensional images (4D, 5D, etc.). + The ANTsPy representation is limited up to 3D images, mainly due to the specificity to image normalization applications. + + Args: + image_path (str, optional): The file path to the image. Defaults to None. + image_array (np.ndarray, optional): The image as a numpy array. Defaults to None. + average_m0 (bool, optional): If True, averages the M0 image if it is provided. Defaults to False. + verbose (bool, optional): If True, prints additional information during loading. Defaults to False + """ + # Image parameters and objects + self._image_path = image_path + self._image_as_numpy = image_array + self._image_as_sitk = None + self._image_as_ants = None + + # BIDS standard parameters for saving/loading + self._subject = kwargs.get('subject', None) + self._session = kwargs.get('session', None) + self._modality = kwargs.get('modality', None) + self._suffix = kwargs.get('suffix', None) + + # Loading parameters + self._average_m0 = kwargs.get('average_m0', False) + self._verbose = kwargs.get('verbose', False) + + self._check_init_images() + + self.load_image() + + if kwargs.get('verbose', False): + print( + f'[bold green]ImageIO initialized with path:[/bold green] {self._image_path}' + ) + print(self) + + def __str__(self) -> str: + """Returns a string representation of the ImageIO object. + + Returns: + str: A summary of the image parameters, BIDS information, and loading parameters. + """ + # Section 1: Image parameters + image_ext = ( + os.path.splitext(self._image_path)[-1] + if self._image_path + else 'N/A' + ) + if self._image_as_sitk is not None: + img_dim = self._image_as_sitk.GetDimension() + img_spacing = self._image_as_sitk.GetSpacing() + img_origin = self._image_as_sitk.GetOrigin() + else: + img_dim = img_spacing = img_origin = 'N/A' + if self._image_as_numpy is not None: + img_max = np.max(self._image_as_numpy) + img_min = np.min(self._image_as_numpy) + img_mean = np.mean(self._image_as_numpy) + img_std = np.std(self._image_as_numpy) + else: + img_max = img_min = img_mean = img_std = 'N/A' + + image_section = [ + '[Image parameters]', + f' Path: {self._image_path}', + f' File extension: {image_ext}', + f' Dimension: {img_dim}', + f' Spacing: {img_spacing}', + f' Origin: {img_origin}', + f' Max value: {img_max}', + f' Min value: {img_min}', + f' Mean: {img_mean}', + f' Std: {img_std}', + ] + + # Section 2: BIDS information + bids_section = [ + '[BIDS information]', + f' Subject: {self._subject}', + f' Session: {self._session}', + f' Modality: {self._modality}', + f' Suffix: {self._suffix}', + ] + + # Section 3: Loading parameters + loading_section = [ + '[Loading parameters]', + f' average_m0: {self._average_m0}', + f' verbose: {self._verbose}', + ] + + return '\n'.join(image_section + bids_section + loading_section) + + def set_image_path(self, image_path: str): + """Set the image path for loading. + + Args: + image_path (str): Path to the image file. + """ + check_path(image_path) + self._image_path = image_path + + def get_image_path(self): + """Get the image path for loading. + + Returns: + str: Path to the image file. + """ + return self._image_path + + def get_as_sitk(self): + """Get the image as a SimpleITK image object. + + Important: + The methods returns a copy of the SimpleITK image object. + This is to ensure that the original image is not modified unintentionally. + + Returns: + SimpleITK.Image: The image as a SimpleITK image object. + """ + self._check_image_representation('sitk') + + return copy.deepcopy(self._image_as_sitk) + + def get_as_ants(self): + """Get the image as an ANTsPy image object. + + Important: + The methods returns a copy of the ANTsPy image object. + This is to ensure that the original image is not modified unintentionally. + + Returns: + ants.image: The image as an ANTsPy image object. + """ + self._check_image_representation('ants') + + return self._image_as_ants.clone() + + def get_as_numpy(self): + """Get the image as a NumPy array. + + Important: + The methods returns a copy of the NumPy array. + This is to ensure that the original image is not modified unintentionally. + Also, the image representation as numpy array does not preserve the image metadata, such as spacing, origin, and direction. + For a complete image representation, use the SimpleITK or ANTsPy representations. + + Returns: + numpy.ndarray: The image as a NumPy array. + """ + self._check_image_representation('numpy') + + return self._image_as_numpy.copy() + + def load_image(self): + """ + Load an image file from a BIDS directory or file using the SimpleITK and ANTsPy representation (if applicable). + + The output is allocated internaly to a ImageIO object that contains up to three image representations: a + SimpleITK image, a numpy array and (if applicable) a ANTsPy image. + + Note: + - The general image loading is done using SimpleITK, which supports a wide range of image formats. + - The image is loaded as a SimpleITK image, and then converted to a numpy array. + - If the image is 3D or lower, it is also converted to an ANTsPy image. + + Supported image formats include: .nii, .nii.gz, .nrrd, .mha, .tif, and other formats supported by SimpleITK. + + Note: + - The default values for `modality` and `suffix` are None. If not provided, the function will search for the first matching ASL image in the directory. + - If `full_path` is a file, it is loaded directly. If it is a directory, the function searches for a BIDS-compliant image using the provided parameters. + - If both a file and a BIDS directory are provided, the file takes precedence. + + Tip: + To validate your BIDS structure, use the `bids-validator` tool: https://bids-standard.github.io/bids-validator/ + For more details about ASL BIDS structure, see: https://bids-specification.readthedocs.io/en/latest + + Note: + The image file is assumed to be an ASL subtract image (control-label). If not, use helper functions in `asltk.utils` to create one. + + The information passed to the ImageIO constructor is used to load the image. + + Examples: + Load a single image file directly: + >>> data = ImageIO("./tests/files/pcasl_mte.nii.gz").get_as_numpy() + >>> type(data) + + >>> data.shape # Example: 5D ASL data + (8, 7, 5, 35, 35) + + Load M0 reference image: + >>> m0_data = ImageIO("./tests/files/m0.nii.gz").get_as_numpy() + >>> m0_data.shape # Example: 3D reference image + (5, 35, 35) + + Load from BIDS directory (automatic detection): + >>> data = ImageIO("./tests/files/bids-example/asl001").get_as_numpy() + >>> type(data) + + + Load specific BIDS data with detailed parameters: + >>> data = ImageIO("./tests/files/bids-example/asl001", subject='Sub103', suffix='asl').get_as_numpy() + >>> type(data) + + + # Load NRRD format + >>> nrrd_data = ImageIO("./tests/files/t1-mri.nrrd").get_as_numpy() + >>> type(nrrd_data) + + + Returns: + ImageIO: The loaded image as a ImageIO object. + """ + + if self._image_path is not None: + check_path(self._image_path) + + if self._image_path.endswith(AVAILABLE_IMAGE_FORMATS): + # If the full path is a file, then load the image directly + self._image_as_sitk = sitk.ReadImage(self._image_path) + self._image_as_numpy = sitk.GetArrayFromImage( + self._image_as_sitk + ) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + else: + # If the full path is a directory, then use BIDSLayout to find the file + selected_file = self._get_file_from_folder_layout() + self._image_as_sitk = sitk.ReadImage(selected_file) + self._image_as_numpy = sitk.GetArrayFromImage( + self._image_as_sitk + ) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + elif self._image_as_numpy is not None: + # If the image is already provided as a numpy array, convert it to SimpleITK + # is_vector = True + # if self._image_as_numpy.ndim > 3: + # is_vector = False + + self._image_as_sitk = sitk.GetImageFromArray( + self._image_as_numpy, isVector=False + ) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + else: + raise ValueError( + 'Either image_path or image_array must be provided to load the image.' + ) - Supported image formats include: .nii, .nii.gz, .nrrd, .mha, .tif, and other formats supported by SimpleITK. + # Check if there are additional parameters + if self._average_m0: + # If average_m0 is True, then average the M0 image + if self._image_as_numpy.ndim > 3: + avg_img = np.mean(self._image_as_numpy, axis=0) + self.update_image_data(avg_img, enforce_new_dimension=True) + + def update_image_spacing(self, new_spacing: tuple): + """ + Update the image spacing with a new tuple, preserving the original image metadata. + + Important: + - The new spacing must be a tuple of the same length as the original image dimension. + + Args: + new_spacing (tuple): The new spacing for the image. + """ + if not isinstance(new_spacing, tuple): + raise TypeError('new_spacing must be a tuple.') + + # Update spacing in SimpleITK image + self._image_as_sitk.SetSpacing(new_spacing) + + # Update internal numpy representation + self._image_as_numpy = sitk.GetArrayFromImage(self._image_as_sitk) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + + def update_image_origin(self, new_origin: tuple): + """ + Update the image origin with a new tuple, preserving the original image metadata. + + Important: + - The new origin must be a tuple of the same length as the original image dimension. + + Args: + new_origin (tuple): The new origin for the image. + """ + if not isinstance(new_origin, tuple): + raise TypeError('new_origin must be a tuple.') + + # Update origin in SimpleITK image + self._image_as_sitk.SetOrigin(new_origin) + + # Update internal numpy representation + self._image_as_numpy = sitk.GetArrayFromImage(self._image_as_sitk) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + + def update_image_direction(self, new_direction: tuple): + """ + Update the image direction with a new tuple, preserving the original image metadata. + + Important: + - The new direction must be a tuple of the same length as the original image dimension. + + Args: + new_direction (tuple): The new direction for the image. + """ + if not isinstance(new_direction, tuple): + raise TypeError('new_direction must be a tuple.') + + # Update direction in SimpleITK image + self._image_as_sitk.SetDirection(new_direction) + + # Update internal numpy representation + self._image_as_numpy = sitk.GetArrayFromImage(self._image_as_sitk) + if self._image_as_numpy.ndim <= 3: + self._image_as_ants = from_sitk(self._image_as_sitk) + + def update_image_data( + self, new_array: np.ndarray, enforce_new_dimension=False + ): + """ + Update the image data with a new numpy array, preserving the original image metadata. + + This is particularly useful for updating the image data after processing or when new data is available. + Hence, it allows to change the image data without losing the original metadata such as spacing, origin, and direction. + + Another application for this method is to create a new image using a processed numpy array and then copy the metadata from the original image that was loaded using a file path, which contains the original metadata. + + Examples: + >>> import numpy as np + >>> array = np.random.rand(5, 35, 35) + >>> image1 = ImageIO(image_array=array)# Example 3D image from a numpy array (without metadata) + >>> image2 = ImageIO(image_path="./tests/files/m0.nii.gz") # Example 3D image with metadata + >>> full_image = ImageIO(image_path="./tests/files/m0.nii.gz") # Example 3D image with metadata + + Both images has the same shape, so we can update the image data: + >>> image1.get_as_numpy().shape == image2.get_as_numpy().shape + True + + >>> image2.update_image_data(image1.get_as_numpy()) + + Now the `image2` has the same data as `image1`, but retains its original metadata. + + Important: + - The new array must match the shape of the original image unless `enforce_new_dimension` is set to True. + - If `enforce_new_dimension` is True, the new array can have a different shape than the original image, but + it will be assumed the first dimensions to get averaged. + + Args: + new_array (np.ndarray): The new image data array. Must match the shape of the original image. + enforce_new_dimension (bool): If True, allows the new array to have a different shape than the original image. + + """ + if not isinstance(new_array, np.ndarray): + raise TypeError('new_array must be a numpy array.') + if new_array.shape != self._image_as_numpy.shape: + if not enforce_new_dimension: + raise ValueError( + 'new_array must match the shape of the original image.' + ) + + # Get the dimension difference + dim_diff = self._image_as_numpy.ndim - new_array.ndim + + if dim_diff < 0 or abs(dim_diff) >= 2: + raise TypeError( + 'The new array is too much different from the original image. ' + 'The new array must have the same number of dimensions as the original image or at most one dimension less.' + ) - Note: - - The default values for `modality` and `suffix` are None. If not provided, the function will search for the first matching ASL image in the directory. - - If `full_path` is a file, it is loaded directly. If it is a directory, the function searches for a BIDS-compliant image using the provided parameters. - - If both a file and a BIDS directory are provided, the file takes precedence. + # Create new SimpleITK image from array + new_sitk_img = sitk.GetImageFromArray(new_array, isVector=False) + + if dim_diff != 0: + base_origin = self._image_as_sitk.GetOrigin()[:3] + base_spacing = self._image_as_sitk.GetSpacing()[:3] + base_direction = tuple( + np.array(self._image_as_sitk.GetDirection()) + .reshape(self._image_as_numpy.ndim, self._image_as_numpy.ndim)[ + :3, :3 + ] + .flatten() + ) + else: + base_origin = self._image_as_sitk.GetOrigin() + base_spacing = self._image_as_sitk.GetSpacing() + base_direction = self._image_as_sitk.GetDirection() + + # Copy metadata + # Copy all metadata from the original image + new_sitk_img.SetOrigin(base_origin) + new_sitk_img.SetSpacing(base_spacing) + new_sitk_img.SetDirection(base_direction) + # Copy all key-value metadata + for k in self._image_as_sitk.GetMetaDataKeys(): + new_sitk_img.SetMetaData(k, self._image_as_sitk.GetMetaData(k)) + + # Update internal representations + self._image_as_numpy = new_array + self._image_as_sitk = new_sitk_img + if new_array.ndim <= 3: + # ANTsPy does not support higher dimension images, so we skip conversion for lower than 3D arrays + self._image_as_ants = from_sitk(new_sitk_img) + + def save_image( + self, + full_path: str = None, + *, + bids_root: str = None, + subject: str = None, + session: str = None, + **kwargs, + ): + """ + Save the current image to a file path using SimpleITK. + + All available image formats provided in the SimpleITK API can be used here. Supported formats include: .nii, .nii.gz, .nrrd, .mha, .tif, and others. + + Note: + If the file extension is not recognized by SimpleITK, an error will be raised. + The image array should be 2D, 3D, or 4D. For 4D arrays, only the first volume may be saved unless handled explicitly. + + Args: + full_path (str): Full absolute path with image file name provided. + bids_root (str): Optional BIDS root directory to save in BIDS structure. + subject (str): Subject ID for BIDS saving. + session (str): Optional session ID for BIDS saving. + + Examples: + Save an image using a direct file path: + >>> import tempfile + >>> from asltk.utils.io import ImageIO + >>> import numpy as np + >>> img = np.random.rand(10, 10, 10) + >>> io = ImageIO(image_array=img) + >>> with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as f: + ... io.save_image(f.name) + + Save an image using BIDS structure: + >>> import tempfile + >>> from asltk.utils.io import ImageIO + >>> import numpy as np + >>> img = np.random.rand(10, 10, 10) + >>> io = ImageIO(image_array=img) + >>> with tempfile.TemporaryDirectory() as temp_dir: + ... io.save_image(bids_root=temp_dir, subject='001', session='01') + + Save processed ASL results: + >>> from asltk.asldata import ASLData + >>> from asltk.utils.io import ImageIO + >>> asl_data = ASLData(pcasl='./tests/files/pcasl_mte.nii.gz', m0='./tests/files/m0.nii.gz') + >>> processed_img = asl_data('pcasl').get_as_numpy()[0] # Get first volume + >>> io = ImageIO(image_array=processed_img) + >>> import tempfile + >>> with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as f: + ... io.save_image(f.name) + + Raises: + ValueError: If neither full_path nor (bids_root + subject) are provided. + RuntimeError: If the file extension is not recognized by SimpleITK. + """ + if bids_root and subject: + full_path = _make_bids_path(bids_root, subject, session) + + if not full_path: + raise ValueError( + 'Either full_path or bids_root + subject must be provided.' + ) - Tip: - To validate your BIDS structure, use the `bids-validator` tool: https://bids-standard.github.io/bids-validator/ - For more details about ASL BIDS structure, see: https://bids-specification.readthedocs.io/en/latest + if not os.path.exists(os.path.dirname(full_path)): + raise FileNotFoundError( + f'The directory of the full path {full_path} does not exist.' + ) - Note: - The image file is assumed to be an ASL subtract image (control-label). If not, use helper functions in `asltk.utils` to create one. + # sitk_img = sitk.GetImageFromArray(img) + useCompression = kwargs.get('useCompression', False) + compressionLevel = kwargs.get('compressionLevel', -1) + compressor = kwargs.get('compressor', '') + sitk.WriteImage( + self._image_as_sitk, + full_path, + useCompression=useCompression, + compressionLevel=compressionLevel, + compressor=compressor, + ) + + def _check_image_representation(self, representation): + if representation == 'sitk' and self._image_as_sitk is None: + raise ValueError( + 'Image is not loaded as SimpleITK. Please load the image first.' + ) + elif representation == 'ants' and self._image_as_ants is None: + raise ValueError( + 'Image is not loaded as ANTsPy. Please load the image first.' + ) + elif representation == 'numpy' and self._image_as_numpy is None: + raise ValueError( + 'Image is not loaded as numpy array. Please load the image first.' + ) + + def _get_file_from_folder_layout(self): + selected_file = None + layout = BIDSLayout(self._image_path) + if all( + param is None + for param in [ + self._subject, + self._session, + self._modality, + self._suffix, + ] + ): + for root, _, files in os.walk(self._image_path): + for file in files: + if '_asl' in file and file.endswith(BIDS_IMAGE_FORMATS): + selected_file = os.path.join(root, file) + else: + layout_files = layout.files.keys() + matching_files = [] + for f in layout_files: + search_pattern = '' + if self._subject: + search_pattern = f'*sub-*{self._subject}*' + if self._session: + search_pattern += search_pattern + f'*ses-*{self._session}' + if self._modality: + search_pattern += search_pattern + f'*{self._modality}*' + if self._suffix: + search_pattern += search_pattern + f'*{self._suffix}*' + + if fnmatch.fnmatch(f, search_pattern) and f.endswith( + BIDS_IMAGE_FORMATS + ): + matching_files.append(f) + + if not matching_files: + raise FileNotFoundError( + f'ASL image file is missing for subject {self._subject} in directory {self._image_path}' + ) + selected_file = matching_files[0] + + return selected_file + + def _check_init_images(self): + """ + Check if the image is initialized correctly. + If both image_path and image_array are None, raise an error. + """ + + if self._image_path is None and self._image_as_numpy is None: + raise ValueError( + 'Either image_path or image_array must be provided to initialize the ImageIO object.' + ) + if self._image_path is not None and self._image_as_numpy is not None: + raise ValueError( + 'Both image_path and image_array are provided. Please provide only one.' + ) + if self._image_path is None and self._image_as_numpy is not None: + warnings.warn( + 'image_array is provided but image_path is not set. The image will be loaded as a numpy array only and the image metadata will be set as default. For complex image processing it is better to provide the image_path instead.', + ) + + +def check_image_properties( + first_image: Union[sitk.Image, np.ndarray, ants.ANTsImage, ImageIO], + ref_image: ImageIO, +): + """Check the properties of two images to ensure they are compatible. + + The first image can be a SimpleITK image, a numpy array, an ANTsPy image, or an ImageIO object. + The reference image must be an ImageIO object. + + This function checks the size, spacing, origin, and direction of the first image against the reference image. Args: - full_path (str): Path to the image file or BIDS directory. - subject (str, optional): Subject identifier. Defaults to None. - session (str, optional): Session identifier. Defaults to None. - modality (str, optional): Modality folder name. Defaults to None. - suffix (str, optional): Suffix of the file to load. Defaults to None. + first_image (Union[sitk.Image, np.ndarray, ants.ANTsImage, ImageIO]): The first image to check. + ref_image (ImageIO): The reference image to compare against. - Examples: - Load a single image file directly: - >>> data = load_image("./tests/files/pcasl_mte.nii.gz") - >>> type(data) - - >>> data.shape # Example: 5D ASL data - (8, 7, 5, 35, 35) + Raises: + TypeError: If the reference image is not an ImageIO object. + ValueError: If the image properties (size, spacing, origin, direction) do not match. + ValueError: If the image properties (size, spacing, origin, direction) do not match. + """ + # Check the image size, dimension, spacing and all the properties to see if the first_image is equal to ref_image + if not isinstance(ref_image, ImageIO): + raise TypeError('Reference image must be a ImageIO object') + + if isinstance(first_image, sitk.Image): + # Compare with ref_image's sitk representation + ref_sitk = ref_image._image_as_sitk + + if first_image.GetSize() != ref_sitk.GetSize(): + raise ValueError('Image size mismatch.') + if first_image.GetSpacing() != ref_sitk.GetSpacing(): + raise ValueError('Image spacing mismatch.') + if first_image.GetOrigin() != ref_sitk.GetOrigin(): + raise ValueError('Image origin mismatch.') + if first_image.GetDirection() != ref_sitk.GetDirection(): + raise ValueError('Image direction mismatch.') + + elif isinstance(first_image, np.ndarray): + ref_np = ref_image._image_as_numpy + + if first_image.shape != ref_np.shape: + raise ValueError('Numpy array shape mismatch.') + if first_image.dtype != ref_np.dtype: + raise ValueError('Numpy array dtype mismatch.') + + warnings.warn( + 'Numpy arrays does not has spacing and origin image information.' + ) - Load M0 reference image: - >>> m0_data = load_image("./tests/files/m0.nii.gz") - >>> m0_data.shape # Example: 3D reference image - (5, 35, 35) + elif isinstance(first_image, ants.ANTsImage): + ref_ants = ( + ref_image._image_as_ants + if isinstance(ref_image, ImageIO) + else ref_image + ) + if not isinstance(ref_ants, ants.ANTsImage): + raise ValueError('Reference image is not an ANTsPy image.') + if first_image.shape != ref_ants.shape: + raise ValueError('ANTs image shape mismatch.') + if not np.allclose(first_image.spacing, ref_ants.spacing): + raise ValueError('ANTs image spacing mismatch.') + if not np.allclose(first_image.origin, ref_ants.origin): + raise ValueError('ANTs image origin mismatch.') + if not np.allclose(first_image.direction, ref_ants.direction): + raise ValueError('ANTs image direction mismatch.') + + elif isinstance(first_image, ImageIO): + # Recursively check using numpy representation + check_image_properties(first_image.get_as_sitk(), ref_image) + else: + raise TypeError('Unsupported image type for comparison.') + + +def clone_image(source: ImageIO, include_path: bool = False): + """Clone an image getting a deep copy. - Load from BIDS directory (automatic detection): - >>> data = load_image("./tests/files/bids-example/asl001") - >>> type(data) - + All the image properties are copied, including the image path if `include_path` is True. - Load specific BIDS data with detailed parameters: - >>> data = load_image("./tests/files/bids-example/asl001", subject='Sub103', suffix='asl') - >>> type(data) - + Tip: + This a useful method to create a copy of an image for processing without modifying the original image. + Also, after making a clone, you can modify the image properties without affecting the original image. + The image array representation can be modified, but the original image metadata will remain unchanged, + however the `update_image_data` method can be used to update the image data while preserving the original metadata. - # Load NRRD format - >>> nrrd_data = load_image("./tests/files/t1-mri.nrrd") - >>> type(nrrd_data) - + Args: + source (ImageIO): The source image to clone. + include_path (bool, optional): Whether to include the image path in the clone. Defaults to False. + + Raises: + TypeError: If the source image is not an ImageIO object. Returns: - numpy.ndarray: The loaded image array. + ImageIO: The cloned image. """ - _check_input_path(full_path) - img = None + if not isinstance(source, ImageIO): + raise TypeError('Source image must be a ImageIO object') - if full_path.endswith(AVAILABLE_IMAGE_FORMATS): - # If the full path is a file, then load the image directly - img = sitk.GetArrayFromImage(sitk.ReadImage(full_path)) - else: - # If the full path is a directory, then use BIDSLayout to find the file - selected_file = _get_file_from_folder_layout( - full_path, subject, session, modality, suffix - ) - img = sitk.GetArrayFromImage(sitk.ReadImage(selected_file)) + cloned = copy.deepcopy(source) + if not include_path: + cloned._image_path = None - # Check if there are additional parameters - if kwargs.get('average_m0', False): - # If average_m0 is True, then average the M0 image - if img.ndim > 3: - img = np.mean(img, axis=0) + return cloned - return img + +def check_path(path: str): + """Check if the image path is valid. + + Args: + path (str): The image path to check. + + Raises: + ValueError: If the image path is not set. + FileNotFoundError: If the image file does not exist. + """ + if path is None: + raise ValueError( + 'Image path is not set. Please set the image path first.' + ) + if not os.path.exists(path): + raise FileNotFoundError(f'The file {path} does not exist.') def _make_bids_path( @@ -118,68 +752,6 @@ def _make_bids_path( return os.path.join(out_dir, filename) -def save_image( - img: np.ndarray, - full_path: str = None, - *, - bids_root: str = None, - subject: str = None, - session: str = None, -): - """ - Save an image to a file path using SimpleITK. - - All available image formats provided in the SimpleITK API can be used here. Supported formats include: .nii, .nii.gz, .nrrd, .mha, .tif, and others. - - Note: - If the file extension is not recognized by SimpleITK, an error will be raised. - The input array should be 2D, 3D, or 4D. For 4D arrays, only the first volume may be saved unless handled explicitly. - - Args: - img (np.ndarray): The image array to be saved. Can be 2D, 3D, or 4D. - full_path (str): Full absolute path with image file name provided. - bids_root (str): Optional BIDS root directory to save in BIDS structure. - subject (str): Subject ID for BIDS saving. - session (str): Optional session ID for BIDS saving. - - Examples: - Save an image using a direct file path: - >>> import tempfile - >>> import numpy as np - >>> img = np.random.rand(10, 10, 10) - >>> with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as f: - ... save_image(img, f.name) - - Save an image using BIDS structure: - >>> import tempfile - >>> img = np.random.rand(10, 10, 10) - >>> with tempfile.TemporaryDirectory() as temp_dir: - ... save_image(img, bids_root=temp_dir, subject='001', session='01') - - Save processed ASL results: - >>> from asltk.asldata import ASLData - >>> asl_data = ASLData(pcasl='./tests/files/pcasl_mte.nii.gz', m0='./tests/files/m0.nii.gz') - >>> processed_img = asl_data('pcasl')[0] # Get first volume - >>> import tempfile - >>> with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as f: - ... save_image(processed_img, f.name) - - Raises: - ValueError: If neither full_path nor (bids_root + subject) are provided. - RuntimeError: If the file extension is not recognized by SimpleITK. - """ - if bids_root and subject: - full_path = _make_bids_path(bids_root, subject, session) - - if not full_path: - raise ValueError( - 'Either full_path or bids_root + subject must be provided.' - ) - - sitk_img = sitk.GetImageFromArray(img) - sitk.WriteImage(sitk_img, full_path) - - def save_asl_data( asldata, fullpath: str = None, @@ -259,55 +831,8 @@ def load_asl_data(fullpath: str): >>> loaded_asldata = load_asl_data(temp_file_path) >>> loaded_asldata.get_ld() [1.8, 1.8, 1.8] - >>> loaded_asldata('pcasl').shape + >>> loaded_asldata('pcasl').get_as_numpy().shape (8, 7, 5, 35, 35) """ - _check_input_path(fullpath) + check_path(fullpath) return dill.load(open(fullpath, 'rb')) - - -def _check_input_path(full_path: str): - if not os.path.exists(full_path): - raise FileNotFoundError(f'The file {full_path} does not exist.') - - -def _get_file_from_folder_layout( - full_path: str, - subject: str = None, - session: str = None, - modality: str = None, - suffix: str = None, -): - selected_file = None - layout = BIDSLayout(full_path) - if all(param is None for param in [subject, session, modality, suffix]): - for root, _, files in os.walk(full_path): - for file in files: - if '_asl' in file and file.endswith(BIDS_IMAGE_FORMATS): - selected_file = os.path.join(root, file) - else: - layout_files = layout.files.keys() - matching_files = [] - for f in layout_files: - search_pattern = '' - if subject: - search_pattern = f'*sub-*{subject}*' - if session: - search_pattern += search_pattern + f'*ses-*{session}' - if modality: - search_pattern += search_pattern + f'*{modality}*' - if suffix: - search_pattern += search_pattern + f'*{suffix}*' - - if fnmatch.fnmatch(f, search_pattern) and f.endswith( - BIDS_IMAGE_FORMATS - ): - matching_files.append(f) - - if not matching_files: - raise FileNotFoundError( - f'ASL image file is missing for subject {subject} in directory {full_path}' - ) - selected_file = matching_files[0] - - return selected_file diff --git a/pyproject.toml b/pyproject.toml index 4851e9e..a0192f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,10 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Image Processing", "Topic :: Scientific/Engineering :: Medical Science Apps.", @@ -24,14 +27,14 @@ classifiers = [ [tool.poetry.dependencies] -python = "^3.9" +python = "^3.10" SimpleITK = "^2.4.0" numpy = "^1.22.4" rich = "^13.8.1" scipy = "^1.13.1" dill = "^0.3.9" pybids = "^0.17.2" -antspyx = "^0.5.4" +antspyx = "^0.6.1" kagglehub = "^0.3.12" @@ -67,7 +70,7 @@ test = "pytest --ignore-glob='./asltk/scripts/*.py' -s -x --cov=asltk -vv --disa post_test = "coverage html" [build-system] -requires = ["poetry-core"] +requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] diff --git a/tests/reconstruction/test_cbf_mapping.py b/tests/reconstruction/test_cbf_mapping.py index 2400771..da69dff 100644 --- a/tests/reconstruction/test_cbf_mapping.py +++ b/tests/reconstruction/test_cbf_mapping.py @@ -6,7 +6,7 @@ from asltk.asldata import ASLData from asltk.reconstruction import CBFMapping -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep @@ -64,7 +64,7 @@ def test_cbf_object_set_mri_parameters_values(value, param): def test_cbf_add_brain_mask_success(): cbf = CBFMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) cbf.set_brain_mask(mask) assert isinstance(cbf._brain_mask, np.ndarray) @@ -79,7 +79,7 @@ def test_cbf_object_create_map_raise_error_if_ld_or_pld_are_not_provided(): def test_set_brain_mask_verify_if_input_is_a_label_mask(): cbf = CBFMapping(asldata_te) - not_mask = load_image(T1_MRI) + not_mask = ImageIO(T1_MRI) with pytest.warns(UserWarning): warnings.warn( 'Mask image is not a binary image. Any value > 0 will be assumed as brain label.', @@ -89,7 +89,7 @@ def test_set_brain_mask_verify_if_input_is_a_label_mask(): def test_set_brain_mask_set_label_value(): cbf = CBFMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) cbf.set_brain_mask(mask, label=1) assert np.unique(cbf._brain_mask).size == 2 assert np.max(cbf._brain_mask) == np.int8(1) @@ -100,7 +100,7 @@ def test_set_brain_mask_set_label_value_raise_error_value_not_found_in_mask( label, ): cbf = CBFMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) with pytest.raises(Exception) as e: cbf.set_brain_mask(mask, label=label) assert e.value.args[0] == 'Label value is not found in the mask provided.' @@ -111,7 +111,7 @@ def test_set_brain_mask_gives_binary_image_using_correct_label_value(): img = np.zeros((5, 35, 35)) img[1, 16:30, 16:30] = 250 img[1, 0:15, 0:15] = 1 - cbf.set_brain_mask(img, label=250) + cbf.set_brain_mask(ImageIO(image_array=img), label=250) assert np.unique(cbf._brain_mask).size == 2 assert np.max(cbf._brain_mask) == np.uint8(250) assert np.min(cbf._brain_mask) == np.uint8(0) @@ -119,19 +119,21 @@ def test_set_brain_mask_gives_binary_image_using_correct_label_value(): def test_set_brain_mask_raise_error_if_image_dimension_is_different_from_3d_volume(): cbf = CBFMapping(asldata_te) - pcasl_3d_vol = load_image(PCASL_MTE)[0, 0, :, :, :] + pcasl_3d_vol = ImageIO( + image_array=ImageIO(PCASL_MTE).get_as_numpy()[0, 0, :, :, :] + ) fake_mask = np.array(((1, 1, 1), (0, 1, 0))) with pytest.raises(Exception) as error: - cbf.set_brain_mask(fake_mask) + cbf.set_brain_mask(ImageIO(image_array=fake_mask)) assert ( error.value.args[0] - == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.shape} not equal to {pcasl_3d_vol.shape}' + == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.shape} not equal to {pcasl_3d_vol.get_as_numpy().shape}' ) def test_set_brain_mask_creates_3d_volume_of_ones_if_not_set_in_cbf_object(): cbf = CBFMapping(asldata_te) - vol_shape = asldata_te('m0').shape + vol_shape = asldata_te('m0').get_as_numpy().shape mask_shape = cbf._brain_mask.shape assert vol_shape == mask_shape @@ -142,7 +144,7 @@ def test_set_brain_mask_raise_error_mask_is_not_an_numpy_array(): cbf.set_brain_mask(M0_BRAIN_MASK) assert ( e.value.args[0] - == f'mask is not an numpy array. Type {type(M0_BRAIN_MASK)}' + == f'mask is not an ImageIO object. Type {type(M0_BRAIN_MASK)}' ) @@ -150,7 +152,7 @@ def test_cbf_mapping_get_brain_mask_return_adjusted_brain_mask_image_in_the_obje cbf = CBFMapping(asldata_te) assert np.mean(cbf.get_brain_mask()) == 1 - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) cbf.set_brain_mask(mask) assert np.unique(cbf.get_brain_mask()).tolist() == [0, 1] @@ -158,19 +160,19 @@ def test_cbf_mapping_get_brain_mask_return_adjusted_brain_mask_image_in_the_obje def test_cbf_object_create_map_success(): cbf = CBFMapping(asldata_te) out = cbf.create_map() - assert isinstance(out['cbf'], np.ndarray) - assert np.mean(out['cbf']) < 0.0001 - assert isinstance(out['att'], np.ndarray) - assert np.mean(out['att']) > 10 + assert isinstance(out['cbf'], ImageIO) + assert np.mean(out['cbf'].get_as_numpy()) < 0.0001 + assert isinstance(out['att'], ImageIO) + assert np.mean(out['att'].get_as_numpy()) > 10 def test_cbf_object_create_map_sucess_setting_single_core(): cbf = CBFMapping(asldata_te) out = cbf.create_map(cores=1) - assert isinstance(out['cbf'], np.ndarray) - assert np.mean(out['cbf']) < 0.0001 - assert isinstance(out['att'], np.ndarray) - assert np.mean(out['att']) > 10 + assert isinstance(out['cbf'], ImageIO) + assert np.mean(out['cbf'].get_as_numpy()) < 0.0001 + assert isinstance(out['att'], ImageIO) + assert np.mean(out['att'].get_as_numpy()) > 10 @pytest.mark.parametrize('core_value', [(100), (-1), (-10), (1.5), (-1.5)]) @@ -188,6 +190,7 @@ def test_cbf_raise_error_cores_not_valid(core_value): def test_cbf_map_normalized_flag_true_result_cbf_map_rescaled(): cbf = CBFMapping(asldata_te) out = cbf.create_map() - out['cbf_norm'][out['cbf_norm'] == 0] = np.nan - mean_px_value = np.nanmean(out['cbf_norm']) + out_norm_array = out['cbf_norm'].get_as_numpy() + out_norm_array[out_norm_array == 0] = np.nan + mean_px_value = np.nanmean(out_norm_array) assert mean_px_value < 500 and mean_px_value > 50 diff --git a/tests/reconstruction/test_multi_dw_mapping.py b/tests/reconstruction/test_multi_dw_mapping.py index f062ec1..fa7d35c 100644 --- a/tests/reconstruction/test_multi_dw_mapping.py +++ b/tests/reconstruction/test_multi_dw_mapping.py @@ -7,7 +7,7 @@ from asltk.asldata import ASLData from asltk.reconstruction import MultiDW_ASLMapping -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep @@ -49,35 +49,35 @@ def test_multi_dw_asl_object_constructor_created_sucessfully(): def test_multi_dw_asl_set_brain_mask_success(): mte = MultiDW_ASLMapping(asldata_dw) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mte.set_brain_mask(mask) assert isinstance(mte._brain_mask, np.ndarray) def test_multi_dw_asl_set_cbf_map_success(): mte = MultiDW_ASLMapping(asldata_dw) - fake_cbf = np.ones((10, 10)) * 20 + fake_cbf = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_cbf_map(fake_cbf) assert np.mean(mte._cbf_map) == 20 def test_multi_dw_asl_get_cbf_map_success(): mte = MultiDW_ASLMapping(asldata_dw) - fake_cbf = np.ones((10, 10)) * 20 + fake_cbf = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_cbf_map(fake_cbf) assert np.mean(mte.get_cbf_map()) == 20 def test_multi_dw_asl_set_att_map_success(): mte = MultiDW_ASLMapping(asldata_dw) - fake_att = np.ones((10, 10)) * 20 + fake_att = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_att_map(fake_att) assert np.mean(mte._att_map) == 20 def test_multi_dw_asl_get_att_map_success(): mte = MultiDW_ASLMapping(asldata_dw) - fake_att = np.ones((10, 10)) * 20 + fake_att = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_att_map(fake_att) assert np.mean(mte.get_att_map()) == 20 @@ -87,7 +87,7 @@ def test_multi_dw_asl_set_brain_mask_set_label_value_raise_error_value_not_found label, ): mte = MultiDW_ASLMapping(asldata_dw) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) with pytest.raises(Exception) as e: mte.set_brain_mask(mask, label=label) assert e.value.args[0] == 'Label value is not found in the mask provided.' @@ -95,9 +95,14 @@ def test_multi_dw_asl_set_brain_mask_set_label_value_raise_error_value_not_found def test_multi_dw_asl_set_brain_mask_verify_if_input_is_a_label_mask(): mte = MultiDW_ASLMapping(asldata_dw) - not_mask = load_image(M0) + not_mask = ImageIO(M0) with pytest.warns(UserWarning): - mte.set_brain_mask(not_mask / np.max(not_mask)) + mte.set_brain_mask( + ImageIO( + image_array=not_mask.get_as_numpy() + / np.max(not_mask.get_as_numpy()) + ) + ) warnings.warn( 'Mask image is not a binary image. Any value > 0 will be assumed as brain label.', UserWarning, @@ -106,13 +111,15 @@ def test_multi_dw_asl_set_brain_mask_verify_if_input_is_a_label_mask(): def test_multi_dw_asl_set_brain_mask_raise_error_if_image_dimension_is_different_from_3d_volume(): mte = MultiDW_ASLMapping(asldata_dw) - pcasl_3d_vol = load_image(PCASL_MDW)[0, 0, :, :, :] - fake_mask = np.array(((1, 1, 1), (0, 1, 0))) + pcasl_3d_vol = ImageIO( + image_array=ImageIO(PCASL_MDW).get_as_numpy()[0, 0, :, :, :] + ) + fake_mask = ImageIO(image_array=np.array(((1, 1, 1), (0, 1, 0)))) with pytest.raises(Exception) as error: mte.set_brain_mask(fake_mask) assert ( error.value.args[0] - == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.shape} not equal to {pcasl_3d_vol.shape}' + == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.get_as_numpy().shape} not equal to {pcasl_3d_vol.get_as_numpy().shape}' ) @@ -120,7 +127,7 @@ def test_multi_dw_mapping_get_brain_mask_return_adjusted_brain_mask_image_in_the mdw = MultiDW_ASLMapping(asldata_dw) assert np.mean(mdw.get_brain_mask()) == 1 - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mdw.set_brain_mask(mask) assert np.unique(mdw.get_brain_mask()).tolist() == [0, 1] @@ -166,7 +173,7 @@ def test_multi_dw_asl_object_set_cbf_and_att_maps_before_create_map(): mte = MultiDW_ASLMapping(asldata_dw) assert np.mean(mte.get_brain_mask()) == 1 - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mte.set_brain_mask(mask) assert np.mean(mte.get_brain_mask()) < 1 @@ -174,8 +181,8 @@ def test_multi_dw_asl_object_set_cbf_and_att_maps_before_create_map(): assert np.mean(mte.get_att_map()) == 0 and np.mean(mte.get_cbf_map()) == 0 # Update CBF/ATT maps and test if it changed in the obj - cbf = np.ones(mask.shape) * 100 - att = np.ones(mask.shape) * 1500 + cbf = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 100) + att = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 1500) mte.set_cbf_map(cbf) mte.set_att_map(att) assert ( @@ -186,9 +193,9 @@ def test_multi_dw_asl_object_set_cbf_and_att_maps_before_create_map(): def test_multi_dw_asl_object_create_map_using_provided_cbf_att_maps(capfd): mte = MultiDW_ASLMapping(asldata_dw) - mask = load_image(M0_BRAIN_MASK) - cbf = np.ones(mask.shape) * 100 - att = np.ones(mask.shape) * 1500 + mask = ImageIO(M0_BRAIN_MASK) + cbf = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 100) + att = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 1500) mte.set_brain_mask(mask) mte.set_cbf_map(cbf) diff --git a/tests/reconstruction/test_multi_te_mapping.py b/tests/reconstruction/test_multi_te_mapping.py index f619672..690b63b 100644 --- a/tests/reconstruction/test_multi_te_mapping.py +++ b/tests/reconstruction/test_multi_te_mapping.py @@ -7,7 +7,7 @@ from asltk.asldata import ASLData from asltk.reconstruction import CBFMapping, MultiTE_ASLMapping -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep @@ -46,35 +46,35 @@ def test_multite_asl_object_constructor_created_sucessfully(): def test_multite_asl_set_brain_mask_success(): mte = MultiTE_ASLMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mte.set_brain_mask(mask) assert isinstance(mte._brain_mask, np.ndarray) def test_multite_asl_set_cbf_map_success(): mte = MultiTE_ASLMapping(asldata_te) - fake_cbf = np.ones((10, 10)) * 20 + fake_cbf = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_cbf_map(fake_cbf) assert np.mean(mte._cbf_map) == 20 def test_multite_asl_get_cbf_map_success(): mte = MultiTE_ASLMapping(asldata_te) - fake_cbf = np.ones((10, 10)) * 20 + fake_cbf = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_cbf_map(fake_cbf) assert np.mean(mte.get_cbf_map()) == 20 def test_multite_asl_set_att_map_success(): mte = MultiTE_ASLMapping(asldata_te) - fake_att = np.ones((10, 10)) * 20 + fake_att = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_att_map(fake_att) assert np.mean(mte._att_map) == 20 def test_multite_asl_get_att_map_success(): mte = MultiTE_ASLMapping(asldata_te) - fake_att = np.ones((10, 10)) * 20 + fake_att = ImageIO(image_array=np.ones((10, 10)) * 20) mte.set_att_map(fake_att) assert np.mean(mte.get_att_map()) == 20 @@ -99,7 +99,7 @@ def test_multite_asl_set_brain_mask_set_label_value_raise_error_value_not_found_ label, ): mte = MultiTE_ASLMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) with pytest.raises(Exception) as e: mte.set_brain_mask(mask, label=label) assert e.value.args[0] == 'Label value is not found in the mask provided.' @@ -107,9 +107,13 @@ def test_multite_asl_set_brain_mask_set_label_value_raise_error_value_not_found_ def test_multite_asl_set_brain_mask_verify_if_input_is_a_label_mask(): mte = MultiTE_ASLMapping(asldata_te) - not_mask = load_image(M0) + not_mask = ImageIO(M0) with pytest.warns(UserWarning): - mte.set_brain_mask(not_mask / np.max(not_mask)) + not_mask_image = ImageIO( + image_array=not_mask.get_as_numpy() + / np.max(not_mask.get_as_numpy()) + ) + mte.set_brain_mask(not_mask_image) warnings.warn( 'Mask image is not a binary image. Any value > 0 will be assumed as brain label.', UserWarning, @@ -118,13 +122,13 @@ def test_multite_asl_set_brain_mask_verify_if_input_is_a_label_mask(): def test_multite_asl_set_brain_mask_raise_error_if_image_dimension_is_different_from_3d_volume(): mte = MultiTE_ASLMapping(asldata_te) - pcasl_3d_vol = load_image(PCASL_MTE)[0, 0, :, :, :] - fake_mask = np.array(((1, 1, 1), (0, 1, 0))) + pcasl_3d_vol = ImageIO(PCASL_MTE).get_as_numpy()[0, 0, :, :, :] + fake_mask = ImageIO(image_array=np.array(((1, 1, 1), (0, 1, 0)))) with pytest.raises(Exception) as error: mte.set_brain_mask(fake_mask) assert ( error.value.args[0] - == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.shape} not equal to {pcasl_3d_vol.shape}' + == f'Image mask dimension does not match with input 3D volume. Mask shape {fake_mask.get_as_numpy().shape} not equal to {pcasl_3d_vol.shape}' ) @@ -132,7 +136,7 @@ def test_multite_mapping_get_brain_mask_return_adjusted_brain_mask_image_in_the_ mte = MultiTE_ASLMapping(asldata_te) assert np.mean(mte.get_brain_mask()) == 1 - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mte.set_brain_mask(mask) assert np.unique(mte.get_brain_mask()).tolist() == [0, 1] @@ -140,12 +144,12 @@ def test_multite_mapping_get_brain_mask_return_adjusted_brain_mask_image_in_the_ def test_multite_asl_object_create_map_success(): mte = MultiTE_ASLMapping(asldata_te) out = mte.create_map() - assert isinstance(out['cbf'], np.ndarray) - assert np.mean(out['cbf']) < 0.0001 - assert isinstance(out['att'], np.ndarray) - assert np.mean(out['att']) > 10 - assert isinstance(out['t1blgm'], np.ndarray) - assert np.mean(out['t1blgm']) > 50 + assert isinstance(out['cbf'], ImageIO) + assert np.mean(out['cbf'].get_as_numpy()) < 0.0001 + assert isinstance(out['att'], ImageIO) + assert np.mean(out['att'].get_as_numpy()) > 10 + assert isinstance(out['t1blgm'], ImageIO) + assert np.mean(out['t1blgm'].get_as_numpy()) > 50 def test_multite_asl_object_raises_error_if_asldata_does_not_have_pcasl_or_m0_image(): @@ -178,7 +182,7 @@ def test_multite_asl_object_set_cbf_and_att_maps_before_create_map(): mte = MultiTE_ASLMapping(asldata_te) assert np.mean(mte.get_brain_mask()) == 1 - mask = load_image(M0_BRAIN_MASK) + mask = ImageIO(M0_BRAIN_MASK) mte.set_brain_mask(mask) assert np.mean(mte.get_brain_mask()) < 1 @@ -186,8 +190,8 @@ def test_multite_asl_object_set_cbf_and_att_maps_before_create_map(): assert np.mean(mte.get_att_map()) == 0 and np.mean(mte.get_cbf_map()) == 0 # Update CBF/ATT maps and test if it changed in the obj - cbf = np.ones(mask.shape) * 100 - att = np.ones(mask.shape) * 1500 + cbf = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 100) + att = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 1500) mte.set_cbf_map(cbf) mte.set_att_map(att) assert ( @@ -198,9 +202,9 @@ def test_multite_asl_object_set_cbf_and_att_maps_before_create_map(): def test_multite_asl_object_create_map_using_provided_cbf_att_maps(capfd): mte = MultiTE_ASLMapping(asldata_te) - mask = load_image(M0_BRAIN_MASK) - cbf = np.ones(mask.shape) * 100 - att = np.ones(mask.shape) * 1500 + mask = ImageIO(M0_BRAIN_MASK) + cbf = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 100) + att = ImageIO(image_array=np.ones(mask.get_as_numpy().shape) * 1500) mte.set_brain_mask(mask) mte.set_cbf_map(cbf) diff --git a/tests/reconstruction/test_t2_mapping.py b/tests/reconstruction/test_t2_mapping.py index 8881d5e..69ad44e 100644 --- a/tests/reconstruction/test_t2_mapping.py +++ b/tests/reconstruction/test_t2_mapping.py @@ -5,6 +5,7 @@ from asltk.asldata import ASLData from asltk.reconstruction.t2_mapping import T2Scalar_ASLMapping +from asltk.utils.io import ImageIO SEP = os.sep @@ -28,7 +29,7 @@ def test_t2_scalar_asl_mapping_initialization(): assert isinstance(t2_mapping, T2Scalar_ASLMapping) assert isinstance(t2_mapping._asl_data, ASLData) - assert isinstance(t2_mapping._brain_mask, np.ndarray) + assert isinstance(t2_mapping._brain_mask, ImageIO) assert t2_mapping._t2_maps is None assert t2_mapping._mean_t2s is None @@ -52,13 +53,15 @@ def test_t2_scalar_mapping_success_construction_t2_map(): out = t2_mapping.create_map() - assert isinstance(out['t2'], np.ndarray) - assert out['t2'].ndim == 4 # Expecting a 4D array + assert isinstance(out['t2'], ImageIO) + assert out['t2'].get_as_numpy().ndim == 4 # Expecting a 4D array assert out['mean_t2'] is not None assert len(out['mean_t2']) == len( asldata_te.get_pld() ) # One mean T2 per PLD - assert np.mean(out['t2']) > 0 # Ensure T2 values are positive + assert ( + np.mean(out['t2'].get_as_numpy()) > 0 + ) # Ensure T2 values are positive def test_t2_scalar_mapping_raise_error_with_dw_in_asldata(): @@ -87,39 +90,41 @@ def test_t2_scalar_mapping_get_t2_maps_and_mean_t2s_before_and_after_create_map( t2_maps = t2_mapping.get_t2_maps() mean_t2s = t2_mapping.get_mean_t2s() - assert isinstance(t2_maps, np.ndarray) - assert t2_maps.ndim == 4 # (N_PLDS, Z, Y, X) + assert isinstance(t2_maps, ImageIO) + assert t2_maps.get_as_numpy().ndim == 4 # (N_PLDS, Z, Y, X) assert isinstance(mean_t2s, list) assert len(mean_t2s) == len(asldata_te.get_pld()) assert all( isinstance(val, float) or isinstance(val, np.floating) for val in mean_t2s ) - assert np.all(t2_maps >= 0) + assert np.all(t2_maps.get_as_numpy() >= 0) def test_set_brain_mask_binary_and_label(): t2_mapping = T2Scalar_ASLMapping(asldata_te) - shape = t2_mapping._asl_data('m0').shape + shape = t2_mapping._asl_data('m0').get_as_numpy().shape # Binary mask (all ones) - binary_mask = np.ones(shape, dtype=np.uint8) + binary_mask = ImageIO(image_array=np.ones(shape, dtype=np.uint8)) t2_mapping.set_brain_mask(binary_mask) - assert np.all(t2_mapping._brain_mask == 1) - assert t2_mapping._brain_mask.shape == shape + assert np.all(t2_mapping._brain_mask.get_as_numpy() == 1) + assert t2_mapping._brain_mask.get_as_numpy().shape == shape # Mask with different label label = 2 mask_with_label = np.zeros(shape, dtype=np.uint8) mask_with_label[0, 0, 0] = label - t2_mapping.set_brain_mask(mask_with_label, label=label) - assert t2_mapping._brain_mask[0, 0, 0] == label - assert np.sum(t2_mapping._brain_mask == label) == 1 + t2_mapping.set_brain_mask( + ImageIO(image_array=mask_with_label), label=label + ) + assert t2_mapping._brain_mask.get_as_numpy()[0, 0, 0] == label + assert np.sum(t2_mapping._brain_mask.get_as_numpy() == label) == 1 def test_set_brain_mask_invalid_shape_raises(): t2_mapping = T2Scalar_ASLMapping(asldata_te) - wrong_shape_mask = np.ones((2, 2, 2), dtype=np.uint8) + wrong_shape_mask = ImageIO(image_array=np.ones((2, 2, 2), dtype=np.uint8)) with pytest.raises(Exception) as error: t2_mapping.set_brain_mask(wrong_shape_mask) @@ -130,8 +135,8 @@ def test_set_brain_mask_invalid_shape_raises(): def test_set_brain_mask_noninteger_label(): t2_mapping = T2Scalar_ASLMapping(asldata_te) - shape = t2_mapping._asl_data('m0').shape - mask = np.ones(shape, dtype=np.float32) + shape = t2_mapping._asl_data('m0').get_as_numpy().shape + mask = ImageIO(image_array=np.ones(shape, dtype=np.float32)) # Should still work, as mask == label will be True for 1.0 == 1 t2_mapping.set_brain_mask(mask, label=1) - assert np.all(t2_mapping._brain_mask == 1) + assert np.all(t2_mapping._brain_mask.get_as_numpy() == 1) diff --git a/tests/registration/test_asl_normalization.py b/tests/registration/test_asl_normalization.py index c9fa0c4..ffcb77b 100644 --- a/tests/registration/test_asl_normalization.py +++ b/tests/registration/test_asl_normalization.py @@ -1,157 +1,80 @@ -# import os - -# import numpy as np -# import pytest - -# from asltk.asldata import ASLData -# from asltk.registration.asl_normalization import ( -# asl_template_registration, -# head_movement_correction, -# ) - -# SEP = os.sep -# M0_ORIG = ( -# f'tests' + SEP + 'files' + SEP + 'registration' + SEP + 'm0_mean.nii.gz' -# ) -# M0_RIGID = ( -# f'tests' -# + SEP -# + 'files' -# + SEP -# + 'registration' -# + SEP -# + 'm0_mean-rigid-25degrees.nrrd' -# ) -# PCASL_MTE = f'tests' + SEP + 'files' + SEP + 'pcasl_mte.nii.gz' -# M0 = f'tests' + SEP + 'files' + SEP + 'm0.nii.gz' - - -# def test_head_movement_correction_build_asldata_success(): -# pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - -# asldata, _ = head_movement_correction(pcasl_orig) - -# assert asldata('pcasl').shape == pcasl_orig('pcasl').shape - - -# def test_head_movement_correction_error_input_is_not_ASLData_object(): -# with pytest.raises(TypeError) as e: -# head_movement_correction('invalid_input') - -# assert str(e.value) == 'Input must be an ASLData object.' - - -# def test_head_movement_correction_error_ref_vol_is_not_int(): -# pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - -# with pytest.raises(Exception) as e: -# head_movement_correction(pcasl_orig, ref_vol='invalid_ref_vol') - -# assert ( -# str(e.value) -# == 'ref_vol must be an positive integer based on the total asl data volumes.' -# ) - - -# def test_head_movement_correction_success(): -# pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - -# pcasl_corrected, trans_mtxs = head_movement_correction( -# pcasl_orig, verbose=True -# ) - -# assert pcasl_corrected('pcasl').shape == pcasl_orig('pcasl').shape -# # assert ( -# # np.abs( -# # np.mean(np.subtract(pcasl_corrected('pcasl'), pcasl_orig('pcasl'))) -# # ) -# # > np.abs(np.mean(pcasl_orig('pcasl')) * 0.01) -# # ) -# assert any(not np.array_equal(mtx, np.eye(4)) for mtx in trans_mtxs) - - -# def test_head_movement_correction_returns_asl_data_corrected(): -# pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - -# asl_data_corrected, _ = head_movement_correction(pcasl_orig) - -# assert isinstance(asl_data_corrected, ASLData) -# assert asl_data_corrected('pcasl').shape == pcasl_orig('pcasl').shape -# assert asl_data_corrected('pcasl').dtype == pcasl_orig('pcasl').dtype - - -# # TODO Arrumar o path do arquivo de template -# # def test_asl_template_registration_success(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) -# # # pcasl_orig = ASLData( -# # # pcasl='/home/antonio/Imagens/loamri-samples/20240909/pcasl.nii.gz', -# # # m0='/home/antonio/Imagens/loamri-samples/20240909/m0.nii.gz', -# # # ) -# # # asl_data_mask = np.ones_like(pcasl_orig('m0'), dtype=bool) - -# # asl_data_registered, trans_mtxs = asl_template_registration( -# # pcasl_orig, -# # atlas_name='MNI2009', -# # verbose=True, -# # ) - -# # assert isinstance(asl_data_registered, ASLData) -# # assert asl_data_registered('pcasl').shape == pcasl_orig('pcasl').shape -# # assert isinstance(trans_mtxs, list) -# # assert len(trans_mtxs) == pcasl_orig('pcasl').shape[0] - - -# def test_asl_template_registration_invalid_input_type(): -# with pytest.raises(TypeError) as e: -# asl_template_registration('not_asldata') -# assert str(e.value) == 'Input must be an ASLData object.' - - -# # def test_asl_template_registration_invalid_ref_vol_type(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) -# # with pytest.raises(ValueError) as e: -# # asl_template_registration(pcasl_orig, ref_vol='invalid') -# # assert str(e.value) == 'ref_vol must be a non-negative integer.' - - -# # def test_asl_template_registration_invalid_ref_vol_type_with_negative_volume(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) -# # with pytest.raises(ValueError) as e: -# # asl_template_registration(pcasl_orig, ref_vol=-1) -# # assert str(e.value) == 'ref_vol must be a non-negative integer.' - - -# # def test_asl_template_registration_invalid_ref_vol_index(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) -# # n_vols = 1000000 -# # with pytest.raises(ValueError) as e: -# # asl_template_registration(pcasl_orig, ref_vol=n_vols) -# # assert 'ref_vol must be a valid index' in str(e.value) - - -# # def test_asl_template_registration_create_another_asldata_object(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - -# # asl_data_registered, _ = asl_template_registration( -# # pcasl_orig, -# # ref_vol=0, -# # atlas_name='MNI2009', -# # verbose=True, -# # ) - -# # assert isinstance(asl_data_registered, ASLData) -# # assert asl_data_registered('pcasl').shape == pcasl_orig('pcasl').shape -# # assert asl_data_registered('m0').shape == pcasl_orig('m0').shape -# # assert asl_data_registered is not pcasl_orig - - -# # def test_asl_template_registration_returns_transforms(): -# # pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) -# # asl_data_mask = np.ones_like(pcasl_orig('pcasl')[0], dtype=bool) - -# # asl_data_registered, trans_mtxs = asl_template_registration( -# # pcasl_orig, ref_vol=0, asl_data_mask=asl_data_mask -# # ) - -# # assert isinstance(trans_mtxs, list) -# # assert all(isinstance(mtx, np.ndarray) for mtx in trans_mtxs) +import os + +import numpy as np +import pytest + +from asltk.asldata import ASLData +from asltk.data.brain_atlas import BrainAtlas +from asltk.registration.asl_normalization import asl_template_registration +from asltk.utils.io import ImageIO + +SEP = os.sep +M0_ORIG = ( + f'tests' + SEP + 'files' + SEP + 'registration' + SEP + 'm0_mean.nii.gz' +) +M0_RIGID = ( + f'tests' + + SEP + + 'files' + + SEP + + 'registration' + + SEP + + 'm0_mean-rigid-25degrees.nrrd' +) +PCASL_MTE = f'tests' + SEP + 'files' + SEP + 'pcasl_mte.nii.gz' +M0 = f'tests' + SEP + 'files' + SEP + 'm0.nii.gz' + + +def test_asl_template_registration_success(): + pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) + # Reducing pcasl size to not exceed memory limits + short_data = ImageIO( + image_array=pcasl_orig('pcasl').get_as_numpy()[:3, :, :, :] + ) + pcasl_orig.set_image(short_data, 'pcasl') + # pcasl_orig = ASLData( + # pcasl='/home/antonio/Imagens/loamri-samples/20240909/pcasl.nii.gz', + # m0='/home/antonio/Imagens/loamri-samples/20240909/m0.nii.gz', + # average_m0=True, + # ) + # asl_data_mask = np.ones_like(pcasl_orig('m0'), dtype=bool) + + ( + asl_data_registered, + trans_mtxs, + additional_maps_normalized, + ) = asl_template_registration( + pcasl_orig, + atlas_reference='MNI2009', + verbose=True, + ) + + assert isinstance(asl_data_registered, ASLData) + assert isinstance(trans_mtxs, list) + assert isinstance(additional_maps_normalized, list) + + +def test_asl_template_registration_invalid_input_type(): + with pytest.raises(TypeError) as e: + asl_template_registration('not_asldata') + assert str(e.value) == 'Input must be an ASLData object.' + + +def test_asl_template_registration_raise_error_if_m0_volume_not_present_at_input_asl_data(): + pcasl_orig = ASLData(pcasl=PCASL_MTE) + with pytest.raises(ValueError) as e: + asl_template_registration(pcasl_orig) + assert 'M0 image is required for normalization' in str(e.value) + + +@pytest.mark.parametrize( + 'atlas_reference', + [('invalid_atlas'), ('/tmp/invalid_path.nii.gz')], +) +def test_asl_template_registration_checks_atlas_reference_types( + atlas_reference, +): + pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) + with pytest.raises(Exception) as e: + asl_template_registration(pcasl_orig, atlas_reference=atlas_reference) + assert isinstance(str(e.value), str) diff --git a/tests/registration/test_registration.py b/tests/registration/test_registration.py index 9791422..0c12277 100644 --- a/tests/registration/test_registration.py +++ b/tests/registration/test_registration.py @@ -12,7 +12,7 @@ space_normalization, ) from asltk.registration.asl_normalization import head_movement_correction -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep M0_ORIG = ( @@ -33,10 +33,12 @@ def test_head_movement_correction_build_asldata_success(): pcasl_orig = ASLData(pcasl=PCASL_MTE, m0=M0) - asldata, _ = head_movement_correction(pcasl_orig) - assert asldata('pcasl').shape == pcasl_orig('pcasl').shape + assert ( + asldata('pcasl').get_as_numpy().shape + == pcasl_orig('pcasl').get_as_numpy().shape + ) def test_head_movement_correction_error_input_is_not_ASLData_object(): @@ -65,10 +67,18 @@ def test_head_movement_correction_success(): pcasl_orig, verbose=True ) - assert pcasl_corrected('pcasl').shape == pcasl_orig('pcasl').shape + assert ( + pcasl_corrected('pcasl').get_as_numpy().shape + == pcasl_orig('pcasl').get_as_numpy().shape + ) assert ( np.abs( - np.mean(np.subtract(pcasl_corrected('pcasl'), pcasl_orig('pcasl'))) + np.mean( + np.subtract( + pcasl_corrected('pcasl').get_as_numpy(), + pcasl_orig('pcasl').get_as_numpy(), + ) + ) ) != 0 ) @@ -76,31 +86,34 @@ def test_head_movement_correction_success(): def test_rigid_body_registration_run_sucess(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) resampled_image, _ = rigid_body_registration(img_orig, img_rot) - assert np.mean(np.subtract(img_orig, resampled_image)) < np.mean(img_orig) + assert np.mean( + np.subtract(img_orig.get_as_numpy(), resampled_image.get_as_numpy()) + ) < np.mean(img_orig.get_as_numpy()) @pytest.mark.parametrize( 'img_rot', [('invalid_image'), ([1, 2, 3]), (['a', 1, 5.23])] ) def test_rigid_body_registration_error_fixed_image_is_not_numpy_array(img_rot): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) with pytest.raises(Exception) as e: rigid_body_registration(img_orig, img_rot) assert ( - str(e.value) == 'fixed_image and moving_image must be a numpy array.' + str(e.value) + == 'fixed_image and moving_image must be an ImageIO object.' ) def test_rigid_body_registration_output_registration_matrix_success(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) _, trans_matrix = rigid_body_registration(img_orig, img_rot) @@ -108,25 +121,25 @@ def test_rigid_body_registration_output_registration_matrix_success(): def test_rigid_body_registration_raise_exception_if_moving_mask_not_numpy(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: rigid_body_registration(img_orig, img_rot, moving_mask='invalid_mask') - assert str(e.value) == 'moving_mask must be a numpy array.' + assert str(e.value) == 'moving_mask must be an ImageIO object.' def test_rigid_body_registration_raise_exception_if_template_mask_not_numpy(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: rigid_body_registration( img_orig, img_rot, template_mask='invalid_mask' ) - assert str(e.value) == 'template_mask must be a numpy array.' + assert str(e.value) == 'template_mask must be an ImageIO object.' def test_space_normalization_success(): @@ -139,8 +152,8 @@ def test_space_normalization_success(): verbose=True, ) - assert isinstance(normalized_image, np.ndarray) - assert normalized_image.shape == (182, 218, 182) + assert isinstance(normalized_image, ImageIO) + assert normalized_image.get_as_numpy().shape == (182, 218, 182) assert len(transform) == 1 @@ -152,8 +165,8 @@ def test_space_normalization_success_transform_type_Affine(): pcasl_orig('m0'), template_image='MNI2009', transform_type='Affine' ) - assert isinstance(normalized_image, np.ndarray) - assert normalized_image.shape == (182, 218, 182) + assert isinstance(normalized_image, ImageIO) + assert normalized_image.get_as_numpy().shape == (182, 218, 182) assert len(transform) == 1 @@ -165,8 +178,8 @@ def test_space_normalization_success_transform_type_Affine(): pcasl_orig('m0'), template_image='MNI2009', transform_type='Affine' ) - assert isinstance(normalized_image, np.ndarray) - assert normalized_image.shape == (182, 218, 182) + assert isinstance(normalized_image, ImageIO) + assert normalized_image.get_as_numpy().shape == (182, 218, 182) assert len(transform) == 1 @@ -175,13 +188,13 @@ def test_space_normalization_raise_exception_if_fixed_image_not_numpy(): space_normalization('invalid_image', template_image='MNI2009') assert ( - 'moving_image must be a numpy array and template_image must be a BrainAtlas object' + 'moving_image must be an ImageIO object and template_image must be a BrainAtlas object' in str(e.value) ) def test_space_normalization_raise_exception_if_template_image_not_a_valid_BrainAtlas_option(): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) with pytest.raises(Exception) as e: space_normalization(img_orig, template_image='invalid_image') @@ -190,20 +203,20 @@ def test_space_normalization_raise_exception_if_template_image_not_a_valid_Brain def test_space_normalization_success_passing_template_image_as_BrainAtlas_option(): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) # Use the BrainAtlas object directly normalized_image, transform = space_normalization( img_orig, template_image='MNI2009' ) - assert isinstance(normalized_image, np.ndarray) - assert normalized_image.shape == (182, 218, 182) + assert isinstance(normalized_image, ImageIO) + assert normalized_image.get_as_numpy().shape == (182, 218, 182) assert len(transform) == 2 def test_space_normalization_success_passing_template_image_as_BrainAtlas_object(): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) atlas = BrainAtlas(atlas_name='MNI2009') # Use the BrainAtlas object directly @@ -211,119 +224,129 @@ def test_space_normalization_success_passing_template_image_as_BrainAtlas_object img_orig, template_image=atlas ) - assert isinstance(normalized_image, np.ndarray) - assert normalized_image.shape == (182, 218, 182) + assert isinstance(normalized_image, ImageIO) + assert normalized_image.get_as_numpy().shape == (182, 218, 182) assert len(transform) == 2 def test_affine_registration_success(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) resampled_image, _ = affine_registration(img_orig, img_rot) - assert np.mean(np.subtract(img_orig, resampled_image)) < np.mean(img_orig) + assert np.mean( + np.subtract(img_orig.get_as_numpy(), resampled_image.get_as_numpy()) + ) < np.mean(img_orig.get_as_numpy()) def test_affine_registration_raise_exception_if_fixed_image_not_numpy(): - img_rot = load_image(M0_RIGID) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: affine_registration('invalid_image', img_rot) assert ( - str(e.value) == 'fixed_image and moving_image must be a numpy array.' + str(e.value) + == 'fixed_image and moving_image must be an ImageIO object.' ) def test_affine_registration_raise_exception_if_moving_image_not_numpy(): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) with pytest.raises(Exception) as e: affine_registration(img_orig, 'invalid_image') assert ( - str(e.value) == 'fixed_image and moving_image must be a numpy array.' + str(e.value) + == 'fixed_image and moving_image must be an ImageIO object.' ) def test_affine_registration_raise_exception_if_moving_mask_not_numpy(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: affine_registration(img_orig, img_rot, moving_mask='invalid_mask') - assert str(e.value) == 'moving_mask must be a numpy array.' + assert str(e.value) == 'moving_mask must be an ImageIO object.' def test_affine_registration_raise_exception_if_template_mask_not_numpy(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: affine_registration(img_orig, img_rot, template_mask='invalid_mask') - assert str(e.value) == 'template_mask must be a numpy array.' + assert str(e.value) == 'template_mask must be an ImageIO object.' def test_affine_registration_fast_method(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) resampled_image, _ = affine_registration( img_orig, img_rot, fast_method=True ) - assert isinstance(resampled_image, np.ndarray) - assert resampled_image.shape == img_rot.shape - assert np.mean(np.abs(img_orig - resampled_image)) < np.mean(img_orig) + assert isinstance(resampled_image, ImageIO) + assert resampled_image.get_as_numpy().shape == img_rot.get_as_numpy().shape + assert np.mean( + np.abs(img_orig.get_as_numpy() - resampled_image.get_as_numpy()) + ) < np.mean(img_orig.get_as_numpy()) def test_affine_registration_slow_method(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) resampled_image, _ = affine_registration( img_orig, img_rot, fast_method=False ) - assert isinstance(resampled_image, np.ndarray) - assert resampled_image.shape == img_rot.shape - assert np.mean(np.abs(img_orig - resampled_image)) < np.mean(img_orig) + assert isinstance(resampled_image, ImageIO) + assert resampled_image.get_as_numpy().shape == img_rot.get_as_numpy().shape + assert np.mean( + np.abs(img_orig.get_as_numpy() - resampled_image.get_as_numpy()) + ) < np.mean(img_orig.get_as_numpy()) def test_apply_transformation_success(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) # Get transformation matrix from rigid registration _, trans_matrix = rigid_body_registration(img_orig, img_rot) # Apply transformation transformed_img = apply_transformation(img_rot, img_orig, trans_matrix) - assert isinstance(transformed_img, np.ndarray) - assert transformed_img.shape == img_rot.shape - assert np.mean(np.abs(transformed_img - img_rot)) < np.mean(img_rot) + assert isinstance(transformed_img, ImageIO) + assert transformed_img.get_as_numpy().shape == img_rot.get_as_numpy().shape + assert np.mean( + np.abs(transformed_img.get_as_numpy() - img_rot.get_as_numpy()) + ) < np.mean(img_rot.get_as_numpy()) def test_apply_transformation_invalid_fixed_image(): - img_rot = load_image(M0_RIGID) + img_rot = ImageIO(M0_RIGID) _, trans_matrix = rigid_body_registration(img_rot, img_rot) with pytest.raises(Exception) as e: apply_transformation('invalid_image', img_rot, trans_matrix) - assert 'moving image must be numpy array' in str(e.value) + assert 'moving image must be an ImageIO object' in str(e.value) def test_apply_transformation_invalid_moving_image(): - img_orig = load_image(M0_ORIG) + img_orig = ImageIO(M0_ORIG) _, trans_matrix = rigid_body_registration(img_orig, img_orig) with pytest.raises(Exception) as e: apply_transformation(img_orig, 'invalid_image', trans_matrix) - assert 'reference_image must be a numpy array' in str(e.value) + assert 'reference_image must be an ImageIO object' in str(e.value) def test_apply_transformation_invalid_transformation_matrix(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) with pytest.raises(Exception) as e: apply_transformation(img_orig, img_rot, 'invalid_matrix') assert 'transforms must be a list of transformation matrices' in str( @@ -332,37 +355,37 @@ def test_apply_transformation_invalid_transformation_matrix(): def test_apply_transformation_with_mask(): - img_orig = load_image(M0_ORIG) - img_rot = load_image(M0_RIGID) + img_orig = ImageIO(M0_ORIG) + img_rot = ImageIO(M0_RIGID) mask = np.ones_like(img_orig, dtype=bool) _, trans_matrix = rigid_body_registration(img_orig, img_rot) transformed_img = apply_transformation( img_orig, img_rot, trans_matrix, mask=mask ) - assert isinstance(transformed_img, np.ndarray) - assert transformed_img.shape == img_rot.shape + assert isinstance(transformed_img, ImageIO) + assert transformed_img.get_as_numpy().shape == img_rot.get_as_numpy().shape def test_apply_transformation_with_BrainAtlas_reference_input_error(): - img_rot = load_image(M0_RIGID) - img_orig = load_image(M0_ORIG) + img_rot = ImageIO(M0_RIGID) + img_orig = ImageIO(M0_ORIG) _, trans_matrix = rigid_body_registration(img_orig, img_rot) with pytest.raises(Exception) as e: apply_transformation(img_rot, 'invalid atlas', trans_matrix) assert ( - 'reference_image must be a numpy array or a BrainAtlas object' + 'reference_image must be an ImageIO object or a BrainAtlas object' in str(e.value) ) def test_apply_transformation_with_BrainAtlas_reference_input_sucess(): - img_rot = load_image(M0_RIGID) - img_orig = load_image(M0_ORIG) + img_rot = ImageIO(M0_RIGID) + img_orig = ImageIO(M0_ORIG) _, trans_matrix = rigid_body_registration(img_orig, img_rot) atlas = BrainAtlas(atlas_name='MNI2009') - atlas_img = load_image(atlas.get_atlas()['t1_data']) + atlas_img = ImageIO(atlas.get_atlas()['t1_data']) corr_img = apply_transformation(img_rot, atlas, trans_matrix) - assert isinstance(corr_img, np.ndarray) - assert corr_img.shape == atlas_img.shape + assert isinstance(corr_img, ImageIO) + assert corr_img.get_as_numpy().shape == atlas_img.get_as_numpy().shape diff --git a/tests/test_asldata.py b/tests/test_asldata.py index 3e7d327..0eb01c3 100644 --- a/tests/test_asldata.py +++ b/tests/test_asldata.py @@ -4,7 +4,7 @@ import pytest from asltk import asldata -from asltk.utils.io import load_image, save_image +from asltk.utils.io import ImageIO SEP = os.sep T1_MRI = f'tests' + SEP + 'files' + SEP + 't1-mri.nrrd' @@ -23,8 +23,15 @@ def test_asldata_object_shows_warning_if_m0_has_more_than_3D_dimensions( ): tmp_file = tmp_path / 'temp_m0_4D.nii.gz' # Create a 4D M0 image - m0_4d = np.stack([load_image(M0), load_image(M0), load_image(M0)], axis=0) - save_image(m0_4d, str(tmp_file)) + m0_4d = np.stack( + [ + ImageIO(M0).get_as_numpy(), + ImageIO(M0).get_as_numpy(), + ImageIO(M0).get_as_numpy(), + ], + axis=0, + ) + ImageIO(image_array=m0_4d).save_image(str(tmp_file)) with pytest.warns(Warning) as record: obj = asldata.ASLData(m0=str(tmp_file)) assert len(record) == 1 @@ -91,38 +98,38 @@ def test_create_object_check_initial_parameters(): def test_create_object_with_m0_as_numpy_array(): - array = load_image(M0) + array = ImageIO(M0).get_as_numpy() obj = asldata.ASLData(m0=array) - assert obj('m0').shape == array.shape + assert obj('m0').get_as_numpy().shape == array.shape def test_create_object_with_m0_as_numpy_array(): - array = load_image(M0) + array = ImageIO(M0).get_as_numpy() obj = asldata.ASLData(m0=array) - assert obj('m0').shape == array.shape + assert obj('m0').get_as_numpy().shape == array.shape def test_create_object_with_m0_as_numpy_array(): - array = load_image(M0) + array = ImageIO(M0).get_as_numpy() obj = asldata.ASLData(m0=array) - assert obj('m0').shape == array.shape + assert obj('m0').get_as_numpy().shape == array.shape def test_create_object_with_m0_as_numpy_array(): - array = load_image(M0) + array = ImageIO(M0).get_as_numpy() obj = asldata.ASLData(m0=array) - assert obj('m0').shape == array.shape + assert obj('m0').get_as_numpy().shape == array.shape def test_create_object_with_pcasl_as_numpy_array(): - array = load_image(PCASL_MTE) + array = ImageIO(PCASL_MTE).get_as_numpy() obj = asldata.ASLData(pcasl=array) - assert obj('pcasl').shape == array.shape + assert obj('pcasl').get_as_numpy().shape == array.shape def test_get_ld_show_empty_list_for_new_object(): @@ -296,19 +303,19 @@ def test_set_dw_throw_error_input_is_list_of_negative_or_zero_numbers(input): def test_asldata_object_call_returns_image(): obj = asldata.ASLData(pcasl=T1_MRI) - assert isinstance(obj('pcasl'), np.ndarray) + assert isinstance(obj('pcasl'), ImageIO) def test_set_image_sucess_m0(): obj = asldata.ASLData(pcasl=T1_MRI) obj.set_image(M0, 'm0') - assert isinstance(obj('m0'), np.ndarray) + assert isinstance(obj('m0'), ImageIO) def test_set_image_sucess_pcasl(): obj = asldata.ASLData() obj.set_image(M0, 'pcasl') - assert isinstance(obj('pcasl'), np.ndarray) + assert isinstance(obj('pcasl'), ImageIO) @pytest.mark.parametrize( diff --git a/tests/test_smooth.py b/tests/test_smooth.py index eaadccd..c1e4bdc 100644 --- a/tests/test_smooth.py +++ b/tests/test_smooth.py @@ -5,7 +5,7 @@ from asltk.smooth.gaussian import isotropic_gaussian from asltk.smooth.median import isotropic_median -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep PCASL_MTE = f'tests' + SEP + 'files' + SEP + 'pcasl_mte.nii.gz' @@ -23,11 +23,11 @@ ], ) def test_isotropic_gaussian_smooth(sigma): - data = load_image(PCASL_MTE) + data = ImageIO(PCASL_MTE) smoothed = isotropic_gaussian(data, sigma) - assert smoothed.shape == data.shape - assert np.mean(smoothed) != np.mean(data) - assert np.std(smoothed) < np.std(data) + assert smoothed.get_as_numpy().shape == data.get_as_numpy().shape + assert np.mean(smoothed.get_as_numpy()) != np.mean(data.get_as_numpy()) + assert np.std(smoothed.get_as_numpy()) < np.std(data.get_as_numpy()) @pytest.mark.parametrize( @@ -39,7 +39,7 @@ def test_isotropic_gaussian_smooth(sigma): ], ) def test_isotropic_gaussian_smooth_wrong_sigma(sigma): - data = load_image(PCASL_MTE) + data = ImageIO(PCASL_MTE) with pytest.raises(Exception) as e: isotropic_gaussian(data, sigma) assert 'sigma must be a positive number.' in e.value.args[0] @@ -56,15 +56,15 @@ def test_isotropic_gaussian_smooth_wrong_sigma(sigma): def test_isotropic_gaussian_smooth_wrong_data(data): with pytest.raises(Exception) as e: isotropic_gaussian(data) - assert 'data is not a numpy array. Type' in e.value.args[0] + assert 'data is not an ImageIO object. Type' in e.value.args[0] def test_isotropic_gaussian_3D_volume_sucess(): - data = load_image(M0) + data = ImageIO(M0) smoothed = isotropic_gaussian(data) - assert smoothed.shape == data.shape - assert np.mean(smoothed) != np.mean(data) - assert np.std(smoothed) < np.std(data) + assert smoothed.get_as_numpy().shape == data.get_as_numpy().shape + assert np.mean(smoothed.get_as_numpy()) != np.mean(data.get_as_numpy()) + assert np.std(smoothed.get_as_numpy()) < np.std(data.get_as_numpy()) @pytest.mark.parametrize( @@ -76,11 +76,11 @@ def test_isotropic_gaussian_3D_volume_sucess(): ], ) def test_isotropic_median_smooth(size): - data = load_image(PCASL_MTE) + data = ImageIO(PCASL_MTE) smoothed = isotropic_median(data, size) - assert smoothed.shape == data.shape - assert np.mean(smoothed) != np.mean(data) - assert np.std(smoothed) < np.std(data) + assert smoothed.get_as_numpy().shape == data.get_as_numpy().shape + assert np.mean(smoothed.get_as_numpy()) != np.mean(data.get_as_numpy()) + assert np.std(smoothed.get_as_numpy()) < np.std(data.get_as_numpy()) @pytest.mark.parametrize( @@ -94,7 +94,7 @@ def test_isotropic_median_smooth(size): ], ) def test_isotropic_median_smooth_wrong_size(size): - data = load_image(PCASL_MTE) + data = ImageIO(PCASL_MTE) with pytest.raises(Exception) as e: isotropic_median(data, size) assert 'size must be a positive integer.' in e.value.args[0] @@ -111,20 +111,20 @@ def test_isotropic_median_smooth_wrong_size(size): def test_isotropic_median_smooth_wrong_data(data): with pytest.raises(Exception) as e: isotropic_median(data) - assert 'data is not a numpy array. Type' in e.value.args[0] + assert 'data is not an ImageIO object. Type' in e.value.args[0] def test_isotropic_median_3D_volume_success(): - data = load_image(M0) + data = ImageIO(M0) smoothed = isotropic_median(data) - assert smoothed.shape == data.shape - assert np.mean(smoothed) != np.mean(data) - assert np.std(smoothed) < np.std(data) + assert smoothed.get_as_numpy().shape == data.get_as_numpy().shape + assert np.mean(smoothed.get_as_numpy()) != np.mean(data.get_as_numpy()) + assert np.std(smoothed.get_as_numpy()) < np.std(data.get_as_numpy()) def test_isotropic_median_even_size_warning(): - data = load_image(M0) + data = ImageIO(M0) with pytest.warns(UserWarning) as warning: smoothed = isotropic_median(data, size=4) assert 'size was even, using 3 instead' in str(warning[0].message) - assert smoothed.shape == data.shape + assert smoothed.get_as_numpy().shape == data.get_as_numpy().shape diff --git a/tests/test_smooth_utils.py b/tests/test_smooth_utils.py index 176e00e..3266a9d 100644 --- a/tests/test_smooth_utils.py +++ b/tests/test_smooth_utils.py @@ -2,13 +2,14 @@ import pytest from asltk.aux_methods import _apply_smoothing_to_maps +from asltk.utils.io import ImageIO def test_apply_smoothing_to_maps_no_smoothing(): # Test no smoothing (default behavior) maps = { - 'cbf': np.random.random((10, 10, 10)), - 'att': np.random.random((10, 10, 10)), + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), + 'att': ImageIO(image_array=np.random.random((10, 10, 10))), } result = _apply_smoothing_to_maps(maps) @@ -21,25 +22,31 @@ def test_apply_smoothing_to_maps_no_smoothing(): def test_apply_smoothing_to_maps_gaussian(): # Test gaussian smoothing maps = { - 'cbf': np.random.random((10, 10, 10)), - 'att': np.random.random((10, 10, 10)), + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), + 'att': ImageIO(image_array=np.random.random((10, 10, 10))), } result = _apply_smoothing_to_maps(maps, smoothing='gaussian') # Should return different smoothed maps assert set(result.keys()) == set(maps.keys()) for key in maps.keys(): - assert result[key].shape == maps[key].shape - assert not np.array_equal(result[key], maps[key]) + assert ( + result[key].get_as_numpy().shape == maps[key].get_as_numpy().shape + ) + assert not np.array_equal( + result[key].get_as_numpy(), maps[key].get_as_numpy() + ) # Smoothing should reduce noise (typically lower std) - assert np.std(result[key]) <= np.std(maps[key]) + assert np.std(result[key].get_as_numpy()) <= np.std( + maps[key].get_as_numpy() + ) def test_apply_smoothing_to_maps_median(): # Test median smoothing maps = { - 'cbf': np.random.random((10, 10, 10)), - 'att': np.random.random((10, 10, 10)), + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), + 'att': ImageIO(image_array=np.random.random((10, 10, 10))), } result = _apply_smoothing_to_maps( maps, smoothing='median', smoothing_params={'size': 3} @@ -48,8 +55,12 @@ def test_apply_smoothing_to_maps_median(): # Should return different smoothed maps assert set(result.keys()) == set(maps.keys()) for key in maps.keys(): - assert result[key].shape == maps[key].shape - assert not np.array_equal(result[key], maps[key]) + assert ( + result[key].get_as_numpy().shape == maps[key].get_as_numpy().shape + ) + assert not np.array_equal( + result[key].get_as_numpy(), maps[key].get_as_numpy() + ) def test_apply_smoothing_to_maps_invalid_type(): @@ -64,7 +75,7 @@ def test_apply_smoothing_to_maps_invalid_type(): def test_apply_smoothing_to_maps_non_array_values(): # Test that non-array values are passed through unchanged maps = { - 'cbf': np.random.random((10, 10, 10)), + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), 'metadata': 'some_string', 'number': 42, } @@ -74,12 +85,14 @@ def test_apply_smoothing_to_maps_non_array_values(): assert result['metadata'] == maps['metadata'] assert result['number'] == maps['number'] # Array should be smoothed - assert not np.array_equal(result['cbf'], maps['cbf']) + assert not np.array_equal( + result['cbf'].get_as_numpy(), maps['cbf'].get_as_numpy() + ) def test_apply_smoothing_to_maps_custom_params(): # Test custom smoothing parameters - maps = {'cbf': np.random.random((10, 10, 10))} + maps = {'cbf': ImageIO(image_array=np.random.random((10, 10, 10)))} result1 = _apply_smoothing_to_maps( maps, smoothing='gaussian', smoothing_params={'sigma': 1.0} @@ -89,31 +102,39 @@ def test_apply_smoothing_to_maps_custom_params(): ) # Different parameters should produce different results - assert not np.array_equal(result1['cbf'], result2['cbf']) + assert not np.array_equal( + result1['cbf'].get_as_numpy(), result2['cbf'].get_as_numpy() + ) def test_apply_smoothing_to_maps_median_default_params(): # Test median smoothing with default parameters maps = { - 'cbf': np.random.random((10, 10, 10)), - 'att': np.random.random((10, 10, 10)), + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), + 'att': ImageIO(image_array=np.random.random((10, 10, 10))), } result = _apply_smoothing_to_maps(maps, smoothing='median') for key in maps.keys(): - assert result[key].shape == maps[key].shape - assert not np.array_equal(result[key], maps[key]) + assert ( + result[key].get_as_numpy().shape == maps[key].get_as_numpy().shape + ) + assert not np.array_equal( + result[key].get_as_numpy(), maps[key].get_as_numpy() + ) def test_apply_smoothing_to_maps_median_different_sizes(): # Test median smoothing with different kernel sizes - maps = {'cbf': np.random.random((10, 10, 10))} + maps = {'cbf': ImageIO(image_array=np.random.random((10, 10, 10)))} result1 = _apply_smoothing_to_maps( maps, smoothing='median', smoothing_params={'size': 3} ) result2 = _apply_smoothing_to_maps( maps, smoothing='median', smoothing_params={'size': 5} ) - assert not np.array_equal(result1['cbf'], result2['cbf']) + assert not np.array_equal( + result1['cbf'].get_as_numpy(), result2['cbf'].get_as_numpy() + ) def test_apply_smoothing_to_maps_median_invalid_param(): @@ -129,17 +150,26 @@ def test_apply_smoothing_to_maps_median_invalid_param(): def test_apply_smoothing_to_maps_median_non_array(): # Test median smoothing with non-array values in maps - maps = {'cbf': np.random.random((10, 10, 10)), 'meta': 'info'} + maps = { + 'cbf': ImageIO(image_array=np.random.random((10, 10, 10))), + 'meta': 'info', + } result = _apply_smoothing_to_maps(maps, smoothing='median') assert result['meta'] == maps['meta'] - assert not np.array_equal(result['cbf'], maps['cbf']) + assert not np.array_equal( + result['cbf'].get_as_numpy(), maps['cbf'].get_as_numpy() + ) def test_apply_smoothing_to_maps_median_1d_array(): # Test median smoothing with 1D array - maps = {'cbf': np.random.random((10, 10, 10))} + maps = {'cbf': ImageIO(image_array=np.random.random((10, 10, 10)))} result = _apply_smoothing_to_maps( maps, smoothing='median', smoothing_params={'size': 3} ) - assert result['cbf'].shape == maps['cbf'].shape - assert not np.array_equal(result['cbf'], maps['cbf']) + assert ( + result['cbf'].get_as_numpy().shape == maps['cbf'].get_as_numpy().shape + ) + assert not np.array_equal( + result['cbf'].get_as_numpy(), maps['cbf'].get_as_numpy() + ) diff --git a/tests/utils/test_image_manipulation.py b/tests/utils/test_image_manipulation.py index 03197b2..7c4134b 100644 --- a/tests/utils/test_image_manipulation.py +++ b/tests/utils/test_image_manipulation.py @@ -11,7 +11,7 @@ collect_data_volumes, select_reference_volume, ) -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep T1_MRI = f'tests' + SEP + 'files' + SEP + 't1-mri.nrrd' @@ -108,11 +108,12 @@ def test_collect_data_volumes_return_correct_list_of_volumes_4D_data(): data = np.ones((2, 30, 40, 15)) data[0, :, :, :] = data[0, :, :, :] * 10 data[1, :, :, :] = data[1, :, :, :] * 20 - collected_volumes, _ = collect_data_volumes(data) + image = ImageIO(image_array=data) + collected_volumes, _ = collect_data_volumes(image) assert len(collected_volumes) == 2 - assert collected_volumes[0].shape == (30, 40, 15) - assert np.mean(collected_volumes[0]) == 10 - assert np.mean(collected_volumes[1]) == 20 + assert collected_volumes[0].get_as_numpy().shape == (30, 40, 15) + assert np.mean(collected_volumes[0].get_as_numpy()) == 10 + assert np.mean(collected_volumes[1].get_as_numpy()) == 20 def test_collect_data_volumes_return_correct_list_of_volumes_5D_data(): @@ -121,27 +122,28 @@ def test_collect_data_volumes_return_correct_list_of_volumes_5D_data(): data[0, 1, :, :, :] = data[0, 1, :, :, :] * 10 data[1, 0, :, :, :] = data[1, 0, :, :, :] * 20 data[1, 1, :, :, :] = data[1, 1, :, :, :] * 20 + data = ImageIO(image_array=data) collected_volumes, _ = collect_data_volumes(data) assert len(collected_volumes) == 4 - assert collected_volumes[0].shape == (30, 40, 15) - assert np.mean(collected_volumes[0]) == 10 - assert np.mean(collected_volumes[1]) == 10 - assert np.mean(collected_volumes[2]) == 20 - assert np.mean(collected_volumes[3]) == 20 + assert collected_volumes[0].get_as_numpy().shape == (30, 40, 15) + assert np.mean(collected_volumes[0].get_as_numpy()) == 10 + assert np.mean(collected_volumes[1].get_as_numpy()) == 10 + assert np.mean(collected_volumes[2].get_as_numpy()) == 20 + assert np.mean(collected_volumes[3].get_as_numpy()) == 20 def test_collect_data_volumes_error_if_input_is_not_numpy_array(): data = [1, 2, 3] with pytest.raises(Exception) as e: collected_volumes, _ = collect_data_volumes(data) - assert 'data is not a numpy array' in e.value.args[0] + assert 'data is not an ImageIO object' in e.value.args[0] def test_collect_data_volumes_error_if_input_is_less_than_3D(): - data = np.ones((30, 40)) + data = ImageIO(image_array=np.ones((30, 40))) with pytest.raises(Exception) as e: collected_volumes, _ = collect_data_volumes(data) - assert 'data is a 3D volume or higher dimensions' in e.value.args[0] + assert 'data is not a 3D volume or higher dimensions' in e.value.args[0] @pytest.mark.parametrize('method', ['snr', 'mean']) @@ -152,7 +154,10 @@ def test_select_reference_volume_returns_correct_volume_and_index_with_sample_im ref_volume, idx = select_reference_volume(asl, method=method) - assert ref_volume.shape == asl('pcasl')[0][0].shape + assert ( + ref_volume.get_as_numpy().shape + == asl('pcasl').get_as_numpy()[0][0].shape + ) assert idx != 0 @@ -165,3 +170,40 @@ def test_select_reference_volume_raise_error_invalid_method(method): with pytest.raises(Exception) as e: select_reference_volume(asl, method=method) assert 'Invalid method' in e.value.args[0] + + +def test_select_reference_volume_raise_error_wrong_roi(): + asl = asldata.ASLData(pcasl=PCASL_MTE, m0=M0) + + with pytest.raises(Exception) as e: + select_reference_volume(asl, roi='invalid_roi') + assert 'ROI must be an ImageIO object' in e.value.args[0] + + +def test_select_reference_volume_raise_error_wrong_4D_roi(): + asl = asldata.ASLData(pcasl=PCASL_MTE, m0=M0) + roi = ImageIO( + image_array=np.array( + [asl('m0').get_as_numpy(), asl('m0').get_as_numpy()] + ) + ) + + with pytest.raises(Exception) as e: + select_reference_volume(asl, roi=roi) + assert 'ROI must be a 3D array' in e.value.args[0] + + +def test_select_reference_volume_raise_error_wrong_list_image_input_images(): + wrong_input_list = ['wrong_input1', 'wrong_input2'] + + with pytest.raises(Exception) as e: + select_reference_volume(wrong_input_list) + assert 'asl_data must be an ASLData object' in e.value.args[0] + + +def test_select_reference_volume_raise_error_wrong_method(): + asl = asldata.ASLData(pcasl=PCASL_MTE, m0=M0) + + with pytest.raises(Exception) as e: + select_reference_volume(asl, method='invalid_method') + assert 'Invalid method' in e.value.args[0] diff --git a/tests/utils/test_image_statistics.py b/tests/utils/test_image_statistics.py index 2d8f913..d254cc8 100644 --- a/tests/utils/test_image_statistics.py +++ b/tests/utils/test_image_statistics.py @@ -8,7 +8,7 @@ calculate_mean_intensity, calculate_snr, ) -from asltk.utils.io import load_image +from asltk.utils.io import ImageIO SEP = os.sep T1_MRI = f'tests{SEP}files{SEP}t1-mri.nrrd' @@ -20,7 +20,7 @@ @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_analyze_image_properties_returns_dict(image_path): """Test that analyze_image_properties returns a dictionary with expected keys.""" - img = load_image(image_path) + img = ImageIO(image_path) props = analyze_image_properties(img) assert isinstance(props, dict) assert 'shape' in props @@ -47,7 +47,7 @@ def test_analyze_image_properties_invalid_path(input): @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_snr_returns_float(image_path): """Test that calculate_snr returns a float for valid images.""" - img = load_image(image_path) + img = ImageIO(image_path) snr = calculate_snr(img) assert isinstance(snr, float) assert snr >= 0 @@ -56,8 +56,10 @@ def test_calculate_snr_returns_float(image_path): @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_snr_returns_float_using_valid_roi(image_path): """Test that calculate_snr returns a float for valid images.""" - img = load_image(image_path) - roi = np.ones(img.shape, dtype=bool) # Create a valid ROI + img = ImageIO(image_path) + roi = ImageIO( + image_array=np.ones(img.get_as_numpy().shape, dtype=np.uint8) + ) # Create a valid ROI snr = calculate_snr(img, roi=roi) assert isinstance(snr, float) assert snr >= 0 @@ -65,8 +67,10 @@ def test_calculate_snr_returns_float_using_valid_roi(image_path): def test_calculate_snr_make_zero_division_with_same_image_input(): """Test that calculate_snr handles zero division with same image input.""" - img = np.ones((10, 10, 10)) # Create a simple image - roi = np.ones(img.shape, dtype=bool) # Create a valid ROI + img = ImageIO(image_array=np.ones((10, 10, 10))) # Create a simple image + roi = ImageIO( + image_array=np.ones(img.get_as_numpy().shape, dtype=np.uint8) + ) # Create a valid ROI snr = calculate_snr(img, roi=roi) assert isinstance(snr, float) @@ -74,7 +78,12 @@ def test_calculate_snr_make_zero_division_with_same_image_input(): @pytest.mark.parametrize( - 'input', [np.zeros((10, 10)), np.ones((5, 5, 5)), np.full((3, 3), 7)] + 'input', + [ + ImageIO(image_array=np.zeros((10, 10))), + ImageIO(image_array=np.ones((5, 5, 5))), + ImageIO(image_array=np.full((3, 3), 7)), + ], ) def test_calculate_snr_known_arrays(input): """Test calculate_snr with known arrays.""" @@ -93,11 +102,11 @@ def test_calculate_snr_invalid_input(): @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_snr_raise_error_roi_different_shape(image_path): """Test that calculate_snr raises an error for ROI of different shape.""" - img = load_image(image_path) + img = ImageIO(image_path) # Add an extra dimension to img and create a mismatched ROI - img = np.expand_dims(img, axis=0) - roi = np.ones( - img.shape[1:], dtype=bool + img = ImageIO(image_array=img.get_as_numpy()[:, :]) + roi = ImageIO( + image_array=np.ones(img.get_as_numpy().shape[1:], dtype=np.uint8) ) # ROI shape does not match img shape with pytest.raises(ValueError) as error: calculate_snr(img, roi=roi) @@ -108,18 +117,18 @@ def test_calculate_snr_raise_error_roi_different_shape(image_path): @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_snr_raise_error_roi_not_numpy_array(image_path): """Test that calculate_snr raises an error for ROI not being a numpy array.""" - img = load_image(image_path) + img = ImageIO(image_path) roi = 'invalid_roi' with pytest.raises(ValueError) as error: calculate_snr(img, roi=roi) - assert 'ROI must be a numpy array' in str(error.value) + assert 'ROI must be an ImageIO object' in str(error.value) @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_mean_intensity_returns_float(image_path): """Test that calculate_mean_intensity returns a float for valid images.""" - img = load_image(image_path) + img = ImageIO(image_path) mean_intensity = calculate_mean_intensity(img) assert isinstance(mean_intensity, float) assert mean_intensity >= 0 @@ -128,32 +137,33 @@ def test_calculate_mean_intensity_returns_float(image_path): @pytest.mark.parametrize('image_path', [T1_MRI, PCASL_MTE, M0]) def test_calculate_mean_intensity_with_valid_roi(image_path): """Test that calculate_mean_intensity returns a float for valid ROI.""" - img = load_image(image_path) - roi = np.ones(img.shape, dtype=bool) + img = ImageIO(image_path) + roi = ImageIO( + image_array=np.ones(img.get_as_numpy().shape, dtype=np.uint8) + ) mean_intensity = calculate_mean_intensity(img, roi=roi) assert isinstance(mean_intensity, float) assert mean_intensity >= 0 -def test_calculate_mean_intensity_known_arrays(): +@pytest.mark.parametrize( + 'image,answer', + [ + (ImageIO(image_array=np.ones((5, 5, 5))), 1.0), + (ImageIO(image_array=np.full((3, 3), 7)), 7.0), + (ImageIO(image_array=np.array([[1, 2], [3, 4]])), 2.5), + ], +) +def test_calculate_mean_intensity_known_arrays(image, answer): """Test calculate_mean_intensity with known arrays.""" - arr = np.ones((5, 5, 5)) - mean_intensity = calculate_mean_intensity(arr) - assert mean_intensity == 1.0 - - arr = np.full((3, 3), 7) - mean_intensity = calculate_mean_intensity(arr) - assert mean_intensity == 7.0 - - arr = np.array([[1, 2], [3, 4]]) - mean_intensity = calculate_mean_intensity(arr) - assert mean_intensity == 2.5 + mean_intensity = calculate_mean_intensity(image) + assert mean_intensity == answer def test_calculate_mean_intensity_with_roi_mask(): """Test calculate_mean_intensity with ROI mask.""" - arr = np.array([[1, 2], [3, 4]]) - roi = np.array([[0, 1], [1, 0]]) + arr = ImageIO(image_array=np.array([[1, 2], [3, 4]])) + roi = ImageIO(image_array=np.array([[0, 1], [1, 0]])) mean_intensity = calculate_mean_intensity(arr, roi=roi) assert mean_intensity == 2.5 # mean of [2, 3] @@ -162,22 +172,22 @@ def test_calculate_mean_intensity_invalid_input(): """Test that calculate_mean_intensity raises an error for invalid input.""" with pytest.raises(ValueError) as error: calculate_mean_intensity('invalid_input') - assert 'Input must be a numpy array' in str(error.value) + assert 'Input must be an ImageIO object' in str(error.value) def test_calculate_mean_intensity_roi_not_numpy_array(): """Test that calculate_mean_intensity raises an error for ROI not being a numpy array.""" - arr = np.ones((5, 5)) + arr = ImageIO(image_array=np.ones((5, 5))) roi = 'invalid_roi' with pytest.raises(ValueError) as error: calculate_mean_intensity(arr, roi=roi) - assert 'ROI must be a numpy array' in str(error.value) + assert 'ROI must be an ImageIO object' in str(error.value) def test_calculate_mean_intensity_roi_shape_mismatch(): """Test that calculate_mean_intensity raises an error for ROI shape mismatch.""" - arr = np.ones((5, 5)) - roi = np.ones((4, 4), dtype=bool) + arr = ImageIO(image_array=np.ones((5, 5))) + roi = ImageIO(image_array=np.ones((4, 4), dtype=np.uint8)) with pytest.raises(ValueError) as error: calculate_mean_intensity(arr, roi=roi) assert 'ROI shape must match image shape' in str(error.value) diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 5fbf706..05ff514 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,13 +1,21 @@ import os import tempfile +import ants import numpy as np import pytest import SimpleITK as sitk from asltk import asldata from asltk.models import signal_dynamic -from asltk.utils.io import load_asl_data, load_image, save_asl_data, save_image +from asltk.utils.io import ( + ImageIO, + check_image_properties, + check_path, + clone_image, + load_asl_data, + save_asl_data, +) SEP = os.sep T1_MRI = f'tests' + SEP + 'files' + SEP + 't1-mri.nrrd' @@ -17,23 +25,27 @@ def test_load_image_pcasl_type_update_object_image_reference(): - img = load_image(PCASL_MTE) - assert isinstance(img, np.ndarray) + img = ImageIO(PCASL_MTE) + assert isinstance(img, ImageIO) def test_load_image_m0_type_update_object_image_reference(): - img = load_image(M0) - assert isinstance(img, np.ndarray) + img = ImageIO(M0) + assert isinstance(img, ImageIO) def test_load_image_m0_with_average_m0_option(tmp_path): - multi_M0 = np.stack([load_image(M0), load_image(M0)], axis=0) + img_4d = np.array( + [ImageIO(M0).get_as_numpy(), ImageIO(M0).get_as_numpy()], + dtype=np.float32, + ) + multi_M0 = ImageIO(image_array=img_4d) tmp_file = tmp_path / 'temp_m0.nii.gz' - save_image(multi_M0, str(tmp_file)) - img = load_image(str(tmp_file), average_m0=True) + multi_M0.save_image(str(tmp_file)) + img = ImageIO(str(tmp_file), average_m0=True) - assert isinstance(img, np.ndarray) - assert len(img.shape) == 3 + assert isinstance(img, ImageIO) + assert len(img.get_as_numpy().shape) == 3 @pytest.mark.parametrize( @@ -46,7 +58,7 @@ def test_load_image_m0_with_average_m0_option(tmp_path): ) def test_load_image_attest_fullpath_is_valid(input): with pytest.raises(Exception) as e: - load_image(input) + ImageIO(input) assert 'does not exist.' in e.value.args[0] @@ -54,9 +66,9 @@ def test_load_image_attest_fullpath_is_valid(input): 'input', [('out.nrrd'), ('out.nii'), ('out.mha'), ('out.tif')] ) def test_save_image_success(input, tmp_path): - img = load_image(T1_MRI) + img = ImageIO(T1_MRI) full_path = tmp_path.as_posix() + os.sep + input - save_image(img, full_path) + img.save_image(full_path) assert os.path.exists(full_path) read_file = sitk.ReadImage(full_path) assert read_file.GetSize() == sitk.ReadImage(T1_MRI).GetSize() @@ -66,10 +78,10 @@ def test_save_image_success(input, tmp_path): 'input', [('out.nrr'), ('out.n'), ('out.m'), ('out.zip')] ) def test_save_image_throw_error_invalid_formatt(input, tmp_path): - img = load_image(T1_MRI) + img = ImageIO(T1_MRI) full_path = tmp_path.as_posix() + os.sep + input with pytest.raises(Exception) as e: - save_image(img, full_path) + img.save_image(full_path) def test_asl_model_buxton_return_sucess_list_of_values(): @@ -200,7 +212,10 @@ def test_load_asl_data_sucess(input_data, filename, tmp_path): save_asl_data(obj, out_file) loaded_obj = load_asl_data(out_file) assert isinstance(loaded_obj, asldata.ASLData) - assert loaded_obj('pcasl').shape == obj('pcasl').shape + assert ( + loaded_obj('pcasl').get_as_numpy().shape + == obj('pcasl').get_as_numpy().shape + ) @pytest.mark.parametrize( @@ -213,14 +228,14 @@ def test_load_asl_data_sucess(input_data, filename, tmp_path): ], ) def test_load_image_using_BIDS_input_sucess(input_bids, sub, sess, mod, suff): - loaded_obj = load_image( - full_path=input_bids, + loaded_obj = ImageIO( + image_path=input_bids, subject=sub, session=sess, modality=mod, suffix=suff, ) - assert isinstance(loaded_obj, np.ndarray) + assert isinstance(loaded_obj, ImageIO) @pytest.mark.parametrize( @@ -229,7 +244,7 @@ def test_load_image_using_BIDS_input_sucess(input_bids, sub, sess, mod, suff): ) def test_load_image_using_not_valid_BIDS_input_raise_error(input_data): with pytest.raises(Exception) as e: - loaded_obj = load_image(input_data) + loaded_obj = ImageIO(input_data) assert 'is missing' in e.value.args[0] @@ -245,8 +260,8 @@ def test_load_image_raise_FileNotFoundError_not_matching_image_file( input_bids, sub, sess, mod, suff ): with pytest.raises(Exception) as e: - loaded_obj = load_image( - full_path=input_bids, + loaded_obj = ImageIO( + image_path=input_bids, subject=sub, session=sess, modality=mod, @@ -262,8 +277,8 @@ def test_load_image_from_bids_structure_returns_valid_array(): modality = 'asl' suffix = None # m0 is deleted, because it does not exist - img = load_image( - full_path=bids_root, + img = ImageIO( + image_path=bids_root, subject=subject, session=session, modality=modality, @@ -271,3 +286,259 @@ def test_load_image_from_bids_structure_returns_valid_array(): ) assert img is not None + + +@pytest.mark.parametrize( + 'input_data, type', + [ + (np.random.rand(10, 10, 10), 'array'), + (np.random.rand(10, 10, 10, 5), 'array'), + (np.random.rand(5, 2, 10, 10, 10), 'array'), + (T1_MRI, 'path'), + (PCASL_MTE, 'path'), + (M0, 'path'), + ], +) +def test_ImageIO_constructor_success_with_image_array(input_data, type): + """Test ImageIO constructor with an image array.""" + if type == 'array': + img_array = input_data + io = ImageIO(image_array=img_array) + assert isinstance(io, ImageIO) + assert np.array_equal(io.get_as_numpy(), img_array) + elif type == 'path': + img_path = input_data + io = ImageIO(image_path=img_path) + assert isinstance(io, ImageIO) + assert io.get_as_numpy() is not None + + +def test_ImageIO_str_representation(): + """Test the __str__ method of ImageIO.""" + img = ImageIO(T1_MRI) + representation = str(img) + assert 'Path: ' in representation + assert 'Dimension: 3' in representation + assert ( + 'Spacing: (15.000015000015, 15.000015000015, 14.884615384615385)' + in representation + ) + assert 'average_m0: False' in representation + assert 'verbose: False' in representation + assert 'Subject: None' in representation + assert 'Session: None' in representation + assert 'Modality: None' in representation + assert 'Suffix: None' in representation + + +def test_ImageIO_set_image_path_sucess(): + """Test setting a new image path.""" + img = ImageIO(T1_MRI) + new_path = PCASL_MTE + img.set_image_path(new_path) + assert img.get_image_path() == new_path + assert img.get_as_numpy() is not None + + +def test_ImageIO_set_image_path_invalid_path(): + """Test setting an invalid image path.""" + img = ImageIO(T1_MRI) + invalid_path = 'invalid/path/to/image.nii' + with pytest.raises(Exception) as e: + img.set_image_path(invalid_path) + assert 'does not exist.' in e.value.args[0] + + +def test_ImageIO_get_image_path(): + """Test getting the image path.""" + img = ImageIO(T1_MRI) + assert img.get_image_path() == T1_MRI + + +def test_ImageIO_get_as_sitk_sucess(): + """Test getting the image as a SimpleITK object.""" + img = ImageIO(T1_MRI) + sitk_img = img.get_as_sitk() + assert isinstance(sitk_img, sitk.Image) + assert sitk_img.GetSize() == sitk.ReadImage(T1_MRI).GetSize() + + +def test_ImageIO_get_as_sitk_raise_error_no_image_loaded(): + """Test getting the image as SimpleITK when no image is loaded.""" + img = ImageIO(image_array=np.ones((5, 5, 5))) + img._image_as_sitk = None # Force no image loaded + with pytest.raises(Exception) as e: + img.get_as_sitk() + assert ( + e.value.args[0] + == 'Image is not loaded as SimpleITK. Please load the image first.' + ) + + +def test_ImageIO_get_as_ants_sucess(): + """Test getting the image as an ANTs object.""" + img = ImageIO(T1_MRI) + ants_img = img.get_as_ants() + assert ants_img is not None + assert ants_img.dimension == 3 + assert isinstance(ants_img, ants.ANTsImage) + + +def test_ImageIO_get_as_ants_raise_error_no_image_loaded(): + """Test getting the image as ANTs when no image is loaded.""" + img = ImageIO(image_array=np.ones((5, 5, 5))) + img._image_as_ants = None # Force no image loaded + with pytest.raises(Exception) as e: + img.get_as_ants() + assert ( + e.value.args[0] + == 'Image is not loaded as ANTsPy. Please load the image first.' + ) + + +def test_ImageIO_get_as_numpy_sucess(): + """Test getting the image as a numpy array.""" + img = ImageIO(T1_MRI) + np_array = img.get_as_numpy() + assert isinstance(np_array, np.ndarray) + assert ( + np_array.shape == sitk.ReadImage(T1_MRI).GetSize()[::-1] + ) # Reverse for numpy shape + + +def test_ImageIO_get_as_numpy_raise_error_no_image_loaded(): + """Test getting the image as numpy when no image is loaded.""" + img = ImageIO(image_array=np.ones((5, 5, 5))) + img._image_as_numpy = None # Force no image loaded + with pytest.raises(Exception) as e: + img.get_as_numpy() + assert ( + e.value.args[0] + == 'Image is not loaded as numpy array. Please load the image first.' + ) + + +def test_ImageIO_update_image_spacing_sucess(): + """Test updating the image spacing.""" + img = ImageIO(T1_MRI) + new_spacing = (2.0, 2.5, 3.0) + img.update_image_spacing(new_spacing) + sitk_img = img.get_as_sitk() + assert sitk_img.GetSpacing() == new_spacing + + +def test_ImageIO_update_image_origin_sucess(): + """Test updating the image origin.""" + img = ImageIO(T1_MRI) + new_origin = (5.0, 10.0, 15.0) + img.update_image_origin(new_origin) + sitk_img = img.get_as_sitk() + assert sitk_img.GetOrigin() == new_origin + + +def test_ImageIO_update_image_direction_sucess(): + """Test updating the image direction.""" + img = ImageIO(T1_MRI) + new_direction = (1.0, 2.0, 1.1, 0.0, 1.0, 2.0, 4.0, 3.0, 1.0) + img.update_image_direction(new_direction) + sitk_img = img.get_as_sitk() + assert sitk_img.GetDirection() == new_direction + + +def test_ImageIO_update_image_data_sucess_with_enforce_new_dimension(): + """Test updating the image data.""" + img = ImageIO(T1_MRI) + new_data = np.random.rand(10, 10, 10) + img.update_image_data(new_data, enforce_new_dimension=True) + np_array = img.get_as_numpy() + assert np.array_equal(np_array, new_data) + + +def test_ImageIO_save_image_sucess(tmp_path): + """Test saving the image to a new path.""" + img = ImageIO(T1_MRI) + save_path = tmp_path / 'saved_image.nii.gz' + img.save_image(str(save_path)) + assert os.path.exists(save_path) + saved_img = sitk.ReadImage(str(save_path)) + assert saved_img.GetSize() == sitk.ReadImage(T1_MRI).GetSize() + + +def test_ImageIO_save_image_raise_error_no_image_loaded(): + """Test saving the image when no image is loaded.""" + img = ImageIO(image_array=np.ones((5, 5, 5))) + img._image_as_sitk = None # Force no image loaded + save_path = os.path.join('directory', 'not', 'found', 'saved_image.nii.gz') + with pytest.raises(Exception) as e: + img.save_image(str(save_path)) + assert 'The directory of the full path' in e.value.args[0] + + +@pytest.mark.parametrize( + 'input_data, ref_data', + [ + ( + np.random.rand(10, 10, 10), + ImageIO(image_array=np.random.rand(10, 10, 10)), + ), + ( + ImageIO(image_array=np.random.rand(10, 10, 10, 5)), + ImageIO(image_array=np.random.rand(10, 10, 10, 5)), + ), + ( + ImageIO(image_array=np.random.rand(10, 10, 10, 5)).get_as_sitk(), + ImageIO(image_array=np.random.rand(10, 10, 10, 5)), + ), + ( + ImageIO(image_array=np.random.rand(10, 10, 10)).get_as_ants(), + ImageIO(image_array=np.random.rand(10, 10, 10)), + ), + (ImageIO(T1_MRI), ImageIO(image_path=T1_MRI)), + (ImageIO(PCASL_MTE), ImageIO(image_path=PCASL_MTE)), + (ImageIO(M0), ImageIO(image_path=M0)), + ], +) +def test_check_image_properties_does_not_raises_errors_for_valid_image( + input_data, ref_data +): + """Test check_image_properties with a valid image.""" + check_image_properties(input_data, ref_data) + assert True # If no exception is raised, the test passes + + +def test_clone_image_sucess(): + """Test cloning an image.""" + img = ImageIO(T1_MRI) + cloned_img = clone_image(img) + assert isinstance(cloned_img, ImageIO) + assert cloned_img.get_image_path() == None + assert np.array_equal(cloned_img.get_as_numpy(), img.get_as_numpy()) + assert cloned_img.get_as_sitk().GetSize() == img.get_as_sitk().GetSize() + assert cloned_img.get_as_ants().dimension == img.get_as_ants().dimension + + +def test_clone_image_sucess_with_copied_path(): + """Test cloning an image.""" + img = ImageIO(T1_MRI) + cloned_img = clone_image(img, include_path=True) + assert isinstance(cloned_img, ImageIO) + assert cloned_img.get_image_path() == img.get_image_path() + assert np.array_equal(cloned_img.get_as_numpy(), img.get_as_numpy()) + assert cloned_img.get_as_sitk().GetSize() == img.get_as_sitk().GetSize() + assert cloned_img.get_as_ants().dimension == img.get_as_ants().dimension + + +def test_check_path_sucess(): + """Test check_path with a valid path.""" + valid_path = T1_MRI + check_path(valid_path) + assert True # If no exception is raised, the test passes + + +def test_check_path_failure(): + """Test check_path with an invalid path.""" + invalid_path = os.path.join('invalid', 'path', 'to', 'image.nii.gz') + with pytest.raises(FileNotFoundError) as e: + check_path(invalid_path) + + assert 'The file' in e.value.args[0]