@@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
189189 A list of lists free energy differences between adjacent states for all replicas.
190190 state_ranges : list
191191 A list of lists of showing the state indices sampled by each replica.
192+ n_tot : int
193+ Number of lambda states
192194 df_err_adjacent : list, Optional
193195 A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
194196 :code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
@@ -247,7 +249,47 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
247249 return df , df_err , overlap_bool
248250
249251
250- def calculate_free_energy (data , state_ranges , df_method = "MBAR" , err_method = "propagate" , n_bootstrap = None , seed = None ):
252+ def _calculate_df (estimators ):
253+ """
254+ An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent
255+ states for all replicas.
256+
257+ Parameters
258+ ----------
259+ estimators : list
260+ A list of estimators fitting the input data for all replicas. With this, the user
261+ can access all the free energies and their associated uncertainties for all states and replicas.
262+ In our code, these estimators come from the function :func:`_apply_estimators`.
263+
264+ Returns
265+ -------
266+ df : float
267+ Free energy differences between for specified replica.
268+ df_err : float
269+ Uncertainties corresponding to the values in :code:`df`.
270+
271+ See also
272+ --------
273+ :func:`calculate_free_energy`
274+ """
275+ # Compute FE estimate
276+ df = estimators [0 ].delta_f_
277+ lam = np .linspace (0 , 1 , num = len (df .index ))
278+ df .index = lam
279+ df .columns = lam
280+ est = df .loc [0 , 1 ]
281+
282+ # Compute FE extimate error
283+ df_err = estimators [0 ].d_delta_f_
284+ lam = np .linspace (0 , 1 , num = len (df_err .index ))
285+ df_err .index = lam
286+ df_err .columns = lam
287+ err = df_err .loc [0 , 1 ]
288+
289+ return est , err
290+
291+
292+ def calculate_free_energy (data , state_ranges , df_method = "MBAR" , err_method = "propagate" , n_bootstrap = None , seed = None , MTREXEE = False ): # noqa: E501
251293 """
252294 Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
253295 obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -275,6 +317,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
275317 seed : int, Optional
276318 The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
277319 The default is :code:`None`.
320+ MTREXEE : bool
321+ Whether this is a MT-REXEE simulation or not
278322
279323 Returns
280324 -------
@@ -299,10 +343,17 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
299343 >>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
300344 """
301345 n_sim = len (data )
302- n_tot = state_ranges [- 1 ][- 1 ] + 1
346+ if MTREXEE is False :
347+ n_tot = state_ranges [- 1 ][- 1 ] + 1
348+ else :
349+ n_tot = state_ranges [- 1 ] + 1
303350 estimators = _apply_estimators (data , df_method )
304- df_adjacent , df_err_adjacent = _calculate_df_adjacent (estimators )
305- df , df_err , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' )
351+ print (estimators )
352+ if MTREXEE is False :
353+ df_adjacent , df_err_adjacent = _calculate_df_adjacent (estimators )
354+ df , df_err , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # noqa: E501
355+ else :
356+ df , df_err = _calculate_df (estimators )
306357
307358 if err_method == 'bootstrap' :
308359 if seed is not None :
@@ -314,26 +365,33 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
314365 for b in range (n_bootstrap ):
315366 sampled_data = [sampled_data_all [i ].iloc [b * len (data [i ]):(b + 1 ) * len (data [i ])] for i in range (n_sim )]
316367 bootstrap_estimators = _apply_estimators (sampled_data , df_method )
317- df_adjacent , df_err_adjacent = _calculate_df_adjacent (bootstrap_estimators )
318- df_sampled , _ , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # doesn't matter what value err_type here is # noqa: E501
368+ if MTREXEE is False :
369+ df_adjacent , df_err_adjacent = _calculate_df_adjacent (bootstrap_estimators )
370+ df_sampled , _ , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # doesn't matter what value err_type here is # noqa: E501
371+ else :
372+ df_sampled , _ = _calculate_df (bootstrap_estimators )
319373 df_bootstrap .append (df_sampled )
320374 error_bootstrap = np .std (df_bootstrap , axis = 0 , ddof = 1 )
321375
322376 # Replace the value in df_err with value in error_bootstrap if df_err corresponds to
323377 # the df between overlapping states
324378 for i in range (n_tot - 1 ):
325- if overlap_bool [i ] is True :
379+ if MTREXEE is True or overlap_bool [i ] is True :
326380 print (f'Replaced the propagated error with the bootstrapped error for states { i } and { i + 1 } : { df_err [i ]:.5f} -> { error_bootstrap [i ]:.5f} .' ) # noqa: E501
327381 df_err [i ] = error_bootstrap [i ]
328382 elif err_method == 'propagate' :
329383 pass
330384 else :
331385 raise ParameterError ('Specified err_method not available.' )
332386
333- df .insert (0 , 0 )
334- df_err .insert (0 , 0 )
335- f = [sum (df [:(i + 1 )]) for i in range (len (df ))]
336- f_err = [np .sqrt (sum ([x ** 2 for x in df_err [:(i + 1 )]])) for i in range (len (df_err ))]
387+ if MTREXEE is False :
388+ df .insert (0 , 0 )
389+ df_err .insert (0 , 0 )
390+ f = [sum (df [:(i + 1 )]) for i in range (len (df ))]
391+ f_err = [np .sqrt (sum ([x ** 2 for x in df_err [:(i + 1 )]])) for i in range (len (df_err ))]
392+ else :
393+ f = df
394+ f_err = df_err
337395
338396 return f , f_err , estimators
339397
0 commit comments