11from multiprocessing import Array , Pool , cpu_count
22
33import numpy as np
4- import SimpleITK as sitk
54from rich .progress import Progress
65from scipy .optimize import curve_fit
76
87from asltk .asldata import ASLData
9- from asltk .aux_methods import _apply_smoothing_to_maps , _check_mask_values
8+ from asltk .aux_methods import (
9+ _apply_smoothing_to_maps ,
10+ _check_mask_values ,
11+ estimate_memory_usage ,
12+ get_optimal_core_count ,
13+ )
1014from asltk .logging_config import get_logger , log_processing_step
1115from asltk .models .signal_dynamic import asl_model_buxton
1216from asltk .mri_parameters import MRIParameters
@@ -180,7 +184,7 @@ def create_map(
180184 ub = [1.0 , 5000.0 ],
181185 lb = [0.0 , 0.0 ],
182186 par0 = [1e-5 , 1000 ],
183- cores : int = int ( cpu_count () / 2 ) ,
187+ cores = 'auto' ,
184188 smoothing = None ,
185189 smoothing_params = None ,
186190 ):
@@ -280,12 +284,28 @@ def create_map(
280284 logger = get_logger ('cbf_mapping' )
281285 logger .info ('Starting CBF map creation' )
282286
283- if (cores < 0 ) or (cores > cpu_count ()) or not isinstance (cores , int ):
284- error_msg = 'Number of proecess must be at least 1 and less than maximum cores availble.'
285- logger .error (
286- f'{ error_msg } Requested: { cores } , Available: { cpu_count ()} '
287+ if not isinstance (cores , str ):
288+ if (
289+ (cores < 0 )
290+ or (cores > cpu_count ())
291+ or not isinstance (cores , int )
292+ ):
293+ error_msg = 'Number of CPU cores must be at least 1 and less than maximum cores available.'
294+ logger .error (
295+ f'{ error_msg } Requested: { cores } , Available: { cpu_count ()} '
296+ )
297+ raise ValueError (error_msg )
298+ elif isinstance (cores , str ):
299+ if cores not in ['auto' ]:
300+ error_msg = (
301+ 'Cores parameter must be either "auto" or a integer.'
302+ )
303+ logger .error (error_msg )
304+ raise ValueError (error_msg )
305+ else :
306+ raise ValueError (
307+ 'Cores parameter must be either "auto" or a integer.'
287308 )
288- raise ValueError (error_msg )
289309
290310 if (
291311 len (self ._asl_data .get_ld ()) == 0
@@ -314,14 +334,31 @@ def create_map(
314334 f'Processing volume dimensions: { z_axis } x{ y_axis } x{ x_axis } '
315335 )
316336
317- cbf_map_shared = Array ('d' , z_axis * y_axis * x_axis , lock = False )
318- att_map_shared = Array ('d' , z_axis * y_axis * x_axis , lock = False )
337+ cbf_map_shared = Array ('f' , z_axis * y_axis * x_axis , lock = False )
338+ att_map_shared = Array ('f' , z_axis * y_axis * x_axis , lock = False )
339+
340+ # Estimate all the memory usage needed for each core processing
341+ asldata_memory = estimate_memory_usage (
342+ self ._asl_data ('pcasl' ).get_as_numpy ()
343+ )
344+ brain_mask_memory = estimate_memory_usage (self ._brain_mask )
345+ cbf_memory = estimate_memory_usage (self ._cbf_map )
346+ att_memory = estimate_memory_usage (self ._att_map )
347+
348+ actual_cores = get_optimal_core_count (
349+ cores ,
350+ sum ([asldata_memory , brain_mask_memory , cbf_memory , att_memory ]),
351+ )
352+
353+ # Make a copy of base information
354+ m0_array = asl_data ('m0' ).get_as_numpy ()
355+ pcasl_array = asl_data ('pcasl' ).get_as_numpy ()
319356
320357 log_processing_step (
321358 'Running voxel-wise CBF fitting' , 'this may take several minutes'
322359 )
323360 with Pool (
324- processes = cores ,
361+ processes = actual_cores ,
325362 initializer = _cbf_init_globals ,
326363 initargs = (cbf_map_shared , att_map_shared , brain_mask , asl_data ),
327364 ) as pool :
@@ -339,6 +376,8 @@ def create_map(
339376 par0 ,
340377 lb ,
341378 ub ,
379+ m0_array ,
380+ pcasl_array ,
342381 ),
343382 callback = lambda _ : progress .update (task , advance = 1 ),
344383 )
@@ -347,12 +386,12 @@ def create_map(
347386 for result in results :
348387 result .wait ()
349388
350- self ._cbf_map = np .frombuffer (cbf_map_shared ). reshape (
351- z_axis , y_axis , x_axis
352- )
353- self ._att_map = np .frombuffer (att_map_shared ). reshape (
354- z_axis , y_axis , x_axis
355- )
389+ self ._cbf_map = np .frombuffer (
390+ cbf_map_shared , dtype = np . float32
391+ ). reshape ( z_axis , y_axis , x_axis )
392+ self ._att_map = np .frombuffer (
393+ att_map_shared , dtype = np . float32
394+ ). reshape ( z_axis , y_axis , x_axis )
356395
357396 # Log completion statistics
358397 cbf_values = self ._cbf_map [brain_mask > 0 ]
@@ -400,20 +439,20 @@ def _cbf_init_globals(
400439
401440
402441def _cbf_process_slice (
403- i , x_axis , y_axis , z_axis , BuxtonX , par0 , lb , ub
442+ i , x_axis , y_axis , z_axis , BuxtonX , par0 , lb , ub , m0 , pcasl
404443): # pragma: no cover
405444 # indirect call method by CBFMapping().create_map()
406445 for j in range (y_axis ):
407446 for k in range (z_axis ):
408447 if brain_mask [k , j , i ] != 0 :
409- m0_px = asl_data ( 'm0' ). get_as_numpy () [k , j , i ]
448+ m0_px = m0 [k , j , i ]
410449
411450 def mod_buxton (Xdata , par1 , par2 ):
412451 return asl_model_buxton (
413452 Xdata [0 ], Xdata [1 ], m0_px , par1 , par2
414453 )
415454
416- Ydata = asl_data ( ' pcasl' ). get_as_numpy () [0 , :, k , j , i ]
455+ Ydata = pcasl [0 , :, k , j , i ]
417456
418457 # Calculate the processing index for the 3D space
419458 index = k * (y_axis * x_axis ) + j * x_axis + i
@@ -422,8 +461,8 @@ def mod_buxton(Xdata, par1, par2):
422461 par_fit , _ = curve_fit (
423462 mod_buxton , BuxtonX , Ydata , p0 = par0 , bounds = (lb , ub )
424463 )
425- cbf_map [index ] = par_fit [0 ]
426- att_map [index ] = par_fit [1 ]
464+ cbf_map [index ] = np . float32 ( par_fit [0 ])
465+ att_map [index ] = np . float32 ( par_fit [1 ])
427466 except RuntimeError :
428467 cbf_map [index ] = 0.0
429468 att_map [index ] = 0.0
0 commit comments