99from time import time
1010import warnings
1111from tqdm import tqdm
12+ from scipy .linalg import svd
1213
1314
1415class Grouper (BaseClass ):
@@ -97,6 +98,7 @@ def _residual_function(
9798 parameters : np .ndarray [float ],
9899 times : np .ndarray [float ],
99100 counts : np .ndarray [float ],
101+ count_err : np .ndarray [float ],
100102 fit_func : Callable ) -> float :
101103 """
102104 Calculate the residual of the current set of parameters
@@ -108,7 +110,9 @@ def _residual_function(
108110 times : np.ndarray[float]
109111 List of times
110112 counts : np.ndarray[float]
111- List of nominal times
113+ List of delayed neutron counts
114+ count_err : np.ndarray[float]
115+ List of count errors
112116 fit_func : Callable
113117 Function that takes times and parameters to return list of counts
114118
@@ -117,7 +121,7 @@ def _residual_function(
117121 residual : float
118122 Value of the residual
119123 """
120- residual = (counts - fit_func (times , parameters )) / (counts + 1e-12 )
124+ residual = (counts - fit_func (times , parameters )) / (counts )
121125 return residual
122126
123127 def _pulse_fit_function (self ,
@@ -267,20 +271,24 @@ def _nonlinear_least_squares(self,
267271 xtol = 1e-12 ,
268272 verbose = 0 ,
269273 max_nfev = 1e5 ,
270- args = (times , counts , fit_function ))
271-
274+ args = (times , counts , count_err , fit_function ))
275+ J = result .jac
276+ s = svd (J , compute_uv = False )
277+ condition_number = s [0 ] / s [- 1 ]
278+ self .logger .info (f'{ condition_number = } ' )
272279 sampled_params : list [float ] = list ()
273280 tracked_counts : list [float ] = list ()
274281 sorted_params = self ._sort_params_by_half_life (result .x )
275282 sampled_params .append (sorted_params )
276283 countrate = CountRate (self .input_path )
277284 self .logger .info (f'Currently using { self .sample_func } sampling' )
278285 for _ in tqdm (range (1 , self .MC_samples ), desc = 'Solving least-squares' ):
279- data = countrate .calculate_count_rate (
280- MC_run = True , sampler_func = self .sample_func )
281- count_sample = data ['counts' ]
282286 with warnings .catch_warnings ():
283287 warnings .simplefilter ('ignore' )
288+ data = countrate .calculate_count_rate (
289+ MC_run = True , sampler_func = self .sample_func )
290+ count_sample = data ['counts' ]
291+ count_sample_err = data ['sigma counts' ]
284292 result = least_squares (
285293 self ._residual_function ,
286294 result .x ,
@@ -294,6 +302,7 @@ def _nonlinear_least_squares(self,
294302 args = (
295303 times ,
296304 count_sample ,
305+ count_sample_err ,
297306 fit_function ))
298307 tracked_counts .append ([i for i in count_sample ])
299308 sorted_params = self ._sort_params_by_half_life (result .x )
0 commit comments