Skip to content

Commit d577c6c

Browse files
authored
Merge pull request #71 from weitse-hsu/forced-swap
Add additional swapping methods options for different default coordinate exchange function.
2 parents d07f447 + 4833013 commit d577c6c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+9339
-378
lines changed

ensemble_md/analysis/analyze_free_energy.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
189189
A list of lists free energy differences between adjacent states for all replicas.
190190
state_ranges : list
191191
A list of lists of showing the state indices sampled by each replica.
192+
n_tot : int
193+
Number of lambda states
192194
df_err_adjacent : list, Optional
193195
A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
194196
:code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
@@ -247,7 +249,47 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
247249
return df, df_err, overlap_bool
248250

249251

250-
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None):
252+
def _calculate_df(estimators):
253+
"""
254+
An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent
255+
states for all replicas.
256+
257+
Parameters
258+
----------
259+
estimators : list
260+
A list of estimators fitting the input data for all replicas. With this, the user
261+
can access all the free energies and their associated uncertainties for all states and replicas.
262+
In our code, these estimators come from the function :func:`_apply_estimators`.
263+
264+
Returns
265+
-------
266+
df : float
267+
Free energy differences between for specified replica.
268+
df_err : float
269+
Uncertainties corresponding to the values in :code:`df`.
270+
271+
See also
272+
--------
273+
:func:`calculate_free_energy`
274+
"""
275+
# Compute FE estimate
276+
df = estimators[0].delta_f_
277+
lam = np.linspace(0, 1, num=len(df.index))
278+
df.index = lam
279+
df.columns = lam
280+
est = df.loc[0, 1]
281+
282+
# Compute FE extimate error
283+
df_err = estimators[0].d_delta_f_
284+
lam = np.linspace(0, 1, num=len(df_err.index))
285+
df_err.index = lam
286+
df_err.columns = lam
287+
err = df_err.loc[0, 1]
288+
289+
return est, err
290+
291+
292+
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False): # noqa: E501
251293
"""
252294
Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
253295
obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -275,6 +317,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
275317
seed : int, Optional
276318
The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
277319
The default is :code:`None`.
320+
MTREXEE : bool
321+
Whether this is a MT-REXEE simulation or not
278322
279323
Returns
280324
-------
@@ -299,10 +343,17 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
299343
>>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
300344
"""
301345
n_sim = len(data)
302-
n_tot = state_ranges[-1][-1] + 1
346+
if MTREXEE is False:
347+
n_tot = state_ranges[-1][-1] + 1
348+
else:
349+
n_tot = state_ranges[-1] + 1
303350
estimators = _apply_estimators(data, df_method)
304-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
305-
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate')
351+
print(estimators)
352+
if MTREXEE is False:
353+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
354+
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # noqa: E501
355+
else:
356+
df, df_err = _calculate_df(estimators)
306357

307358
if err_method == 'bootstrap':
308359
if seed is not None:
@@ -314,26 +365,33 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
314365
for b in range(n_bootstrap):
315366
sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)]
316367
bootstrap_estimators = _apply_estimators(sampled_data, df_method)
317-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
318-
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
368+
if MTREXEE is False:
369+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
370+
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
371+
else:
372+
df_sampled, _ = _calculate_df(bootstrap_estimators)
319373
df_bootstrap.append(df_sampled)
320374
error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1)
321375

322376
# Replace the value in df_err with value in error_bootstrap if df_err corresponds to
323377
# the df between overlapping states
324378
for i in range(n_tot - 1):
325-
if overlap_bool[i] is True:
379+
if MTREXEE is True or overlap_bool[i] is True:
326380
print(f'Replaced the propagated error with the bootstrapped error for states {i} and {i + 1}: {df_err[i]:.5f} -> {error_bootstrap[i]:.5f}.') # noqa: E501
327381
df_err[i] = error_bootstrap[i]
328382
elif err_method == 'propagate':
329383
pass
330384
else:
331385
raise ParameterError('Specified err_method not available.')
332386

333-
df.insert(0, 0)
334-
df_err.insert(0, 0)
335-
f = [sum(df[:(i + 1)]) for i in range(len(df))]
336-
f_err = [np.sqrt(sum([x**2 for x in df_err[:(i+1)]])) for i in range(len(df_err))]
387+
if MTREXEE is False:
388+
df.insert(0, 0)
389+
df_err.insert(0, 0)
390+
f = [sum(df[:(i + 1)]) for i in range(len(df))]
391+
f_err = [np.sqrt(sum([x**2 for x in df_err[:(i+1)]])) for i in range(len(df_err))]
392+
else:
393+
f = df
394+
f_err = df_err
337395

338396
return f, f_err, estimators
339397

ensemble_md/analysis/analyze_traj.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import matplotlib.pyplot as plt
1616
from itertools import chain
1717
from matplotlib.ticker import MaxNLocator
18-
1918
from alchemlyb.parsing.gmx import _get_headers as get_headers
2019
from alchemlyb.parsing.gmx import _extract_dataframe as extract_dataframe
2120
from ensemble_md.utils import utils
21+
import os
2222

2323

2424
def extract_state_traj(dhdl):
@@ -106,6 +106,7 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav
106106
# files_sorted[i] contains the dhdl/plumed output files for starting configuration i sorted
107107
# based on iteration indices
108108
files_sorted = [[] for i in range(n_configs)]
109+
print(n_iter)
109110
for i in range(n_configs):
110111
for j in range(n_iter):
111112
files_sorted[i].append(files[rep_trajs[i][j]][j])
@@ -539,7 +540,15 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
539540
hist, bins = np.histogram(traj, bins=np.arange(lower_bound, upper_bound + 1, 1))
540541
hist_data.append(hist)
541542
if save_hist is True:
542-
np.save('hist_data.npy', hist_data)
543+
if len(fig_name.split('/')) > 1:
544+
dir_list = []
545+
for i in fig_name.split('/')[:-1]:
546+
dir_list.append(i)
547+
dir_list.append('/')
548+
dir_path = ''.join(dir_list)
549+
np.save(f'{dir_path}/hist_data.npy', hist_data)
550+
else:
551+
np.save('hist_data.npy', hist_data)
543552

544553
# Use the same bins for all histograms
545554
bins = bins[:-1] # Remove the last bin edge because there are n+1 bin edges for n bins
@@ -685,6 +694,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
685694
units : str
686695
The units of the time.
687696
"""
697+
import pandas as pd
698+
688699
if dt is None:
689700
x = np.arange(len(trajs[0]))
690701
units = 'step'
@@ -824,6 +835,14 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
824835
plt.savefig(f'{folder}/hist_{fig_names[t]}', dpi=600)
825836
else:
826837
plt.savefig(f'{folder}/{fig_prefix}_hist_{fig_names[t]}', dpi=600)
838+
# Save to csv
839+
sim_list, rt_list = [], []
840+
for n in range(len(t_roundtrip_list)):
841+
for rt in t_roundtrip_list[n]:
842+
sim_list.append(n)
843+
rt_list.append(rt)
844+
df_rt = pd.DataFrame({'Sim': sim_list, 'Round Trip Time': rt_list})
845+
df_rt.to_csv(f'{folder}/roundtrip_times.csv')
827846

828847
return t_0k_list, t_k0_list, t_roundtrip_list, units
829848

@@ -1330,3 +1349,73 @@ def get_delta_w_updates(log_file, plot=False):
13301349
plt.savefig('delta_w_updates.png', dpi=600)
13311350

13321351
return t_updates, delta_w_updates, equil
1352+
1353+
1354+
def concat_sim_traj(working_dir, n_sim, n_iter, gro):
1355+
"""
1356+
Create a trajectory which is a concatenation off each iterations trajectory
1357+
1358+
Parameters
1359+
----------
1360+
working_dir : str
1361+
path for the current working directory
1362+
n_sim : int
1363+
the number of simulations run
1364+
n_iter : int
1365+
the number of iterations run
1366+
1367+
Returns
1368+
-------
1369+
None
1370+
"""
1371+
import mdtraj as md
1372+
import os
1373+
from tqdm import tqdm
1374+
1375+
# Create output directory if needed
1376+
if not os.path.exists(f'{working_dir}/analysis/traj'):
1377+
os.makedirs(f'{working_dir}/analysis/traj')
1378+
1379+
for rep in range(n_sim):
1380+
if not os.path.exists(f'{working_dir}/analysis/traj/sim{rep}_concat.xtc'):
1381+
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
1382+
name = 'confout_backup'
1383+
else:
1384+
name = 'confout'
1385+
gro_ref = md.load(f'{working_dir}/{gro[rep]}')
1386+
traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
1387+
traj.superpose(gro_ref, frame=0)
1388+
for iteration in tqdm(range(1, n_iter)):
1389+
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
1390+
traj_add.superpose(gro_ref, frame=0)
1391+
traj = md.join([traj, traj_add[1:]])
1392+
print(traj)
1393+
traj.save_xtc(f'{working_dir}/analysis/traj/sim{rep}_concat.xtc')
1394+
1395+
1396+
def concat_xvg(n_sim, n_iter, working_dir):
1397+
for s in range(n_sim):
1398+
if os.path.exists(f'{working_dir}/analysis/sim_{s}.xvg'):
1399+
continue
1400+
output_file = open(f'{working_dir}/analysis/sim_{s}.xvg', 'w')
1401+
for i in range(n_iter):
1402+
input_file = open(f'{working_dir}/sim_{s}/iteration_{i}/dhdl.xvg').readlines()
1403+
if i == 0:
1404+
for line in input_file:
1405+
output_file.write(line)
1406+
time_value = float(input_file[-1].split(' ')[0])
1407+
time_step = np.round(time_value - float(input_file[-2].split(' ')[0]), 4)
1408+
else:
1409+
skipped_first = False
1410+
for line in input_file:
1411+
if line[0] != '#' and line[0] != '@':
1412+
if skipped_first is False:
1413+
skipped_first = True
1414+
else:
1415+
time_value += time_step
1416+
time_str = f'{time_value:.4f}'
1417+
n = len(line.split(' ')[0])
1418+
new_line = time_str + line[n:]
1419+
new_line = time_str + line[n:]
1420+
output_file.write(new_line)
1421+
output_file.close()

0 commit comments

Comments
 (0)