Skip to content

Commit 62830be

Browse files
Initial outline for Trajectory refactoring
1 parent 7a49bc6 commit 62830be

File tree

1 file changed

+72
-64
lines changed

1 file changed

+72
-64
lines changed

algorithms.py

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Python version of the simulation algorithm.
33
"""
44
import argparse
5+
import dataclasses
56
import heapq
67
import logging
78
import math
@@ -385,55 +386,34 @@ def num_lineages(self):
385386
return sum(len(s) for s in self.segments)
386387

387388

388-
class TrajectorySimulator:
389-
"""
390-
Class to simulate an allele frequency trajectory on which to condition
391-
the coalescent simulation.
392-
"""
389+
class Trajectory:
390+
def next_frequency(self, x, dt, rand):
391+
"""
392+
Compute the next allele frequency in the trajectory given the
393+
current value x, after time dt using the specified random
394+
value betweeen 0 and 1.
395+
"""
396+
raise NotImplementedError() # pragma: no cover
393397

394-
def __init__(self, initial_freq, end_freq, alpha, time_slice):
395-
self._initial_freq = initial_freq
396-
self._end_freq = end_freq
397-
self._alpha = alpha
398-
self._time_slice = time_slice
399-
self._reset()
400398

401-
def _reset(self):
402-
self._allele_freqs = []
403-
self._times = []
399+
@dataclasses.dataclass()
400+
class GenicSelectionTrajectory:
401+
alpha: float
404402

405-
def _genic_selection_stochastic_forwards(self, dt, freq, alpha):
406-
ux = (alpha * freq * (1 - freq)) / np.tanh(alpha * freq)
407-
sign = 1 if random.random() < 0.5 else -1
408-
freq += (ux * dt) + sign * np.sqrt(freq * (1.0 - freq) * dt)
409-
return freq
403+
def next_frequency(self, x, dt, current_size, rand):
404+
alpha = current_size * self.alpha
405+
ux = (alpha * x * (1 - x)) / np.tanh(alpha * x)
406+
sign = 1 if rand < 0.5 else -1
407+
return x + (ux * dt) + sign * np.sqrt(x * (1.0 - x) * dt)
410408

411-
def _simulate(self):
412-
"""
413-
Proposes a sweep trajectory and returns the acceptance probability.
414-
"""
415-
x = self._end_freq # backward time
416-
current_size = 1
417-
t_inc = self._time_slice
418-
t = 0
419-
while x > self._initial_freq:
420-
self._allele_freqs.append(max(x, self._initial_freq))
421-
self._times.append(t)
422-
# just a note below
423-
# current_size = self._size_calculator(t)
424-
#
425-
x = 1.0 - self._genic_selection_stochastic_forwards(
426-
t_inc, 1.0 - x, self._alpha * current_size
427-
)
428-
t += self._time_slice
429-
# will want to return current_size / N_max
430-
# for prototype this always equals 1
431-
return 1
432409

433-
def run(self):
434-
while random.random() > self._simulate():
435-
self.reset()
436-
return self._allele_freqs, self._times
410+
@dataclasses.dataclass()
411+
class Sweep:
412+
position: float
413+
initial_frequency: float
414+
end_frequency: float
415+
dt: float
416+
trajectory: Trajectory
437417

438418

439419
class RateMap:
@@ -587,9 +567,8 @@ def __init__(
587567
model="hudson",
588568
max_segments=100,
589569
num_labels=1,
590-
sweep_trajectory=None,
570+
sweep=None,
591571
full_arg=False,
592-
time_slice=None,
593572
gene_conversion_rate=0.0,
594573
gene_conversion_length=1,
595574
discrete_genome=True,
@@ -615,6 +594,7 @@ def __init__(
615594
self.num_labels = num_labels
616595
self.num_populations = N
617596
self.max_segments = max_segments
597+
self.sweep = sweep
618598
self.full_arg = full_arg
619599
self.pedigree = None
620600
self.segment_stack = []
@@ -648,11 +628,6 @@ def __init__(
648628
self.num_re_events = 0
649629
self.num_gc_events = 0
650630

651-
# Sweep variables
652-
self.sweep_site = (self.L // 2) - 1 # need to add options here
653-
self.sweep_trajectory = sweep_trajectory
654-
self.time_slice = time_slice
655-
656631
self.modifier_events = [(sys.float_info.max, None, None)]
657632
for time, pop_id, new_size in population_size_changes:
658633
self.modifier_events.append(
@@ -989,13 +964,41 @@ def hudson_simulate(self, end_time):
989964
assert non_empty_pops == X
990965
return self.finalise()
991966

992-
def single_sweep_simulate(self):
967+
def propose_trajectory(self):
968+
"""
969+
Proposes a sweep trajectory and returns the acceptance probability.
993970
"""
994-
Does a structed coalescent until end_freq is reached, using
995-
information in self.weep_trajectory.
971+
allele_freqs = []
972+
times = []
973+
sweep = self.sweep
974+
x = sweep.end_frequency
975+
current_size = 1
976+
t = 0
977+
while x > sweep.initial_frequency:
978+
allele_freqs.append(max(x, sweep.initial_frequency))
979+
times.append(t)
980+
# just a note below
981+
# current_size = self._size_calculator(t)
982+
#
983+
x = 1 - sweep.trajectory.next_frequency(
984+
1 - x, sweep.dt, current_size, random.random()
985+
)
986+
t += sweep.dt
987+
# will want to return current_size / N_max
988+
# for prototype this always equals 1
989+
return allele_freqs, times, 1
990+
991+
def simulate_sweep_trajectory(self):
992+
allele_freqs, times, acceptance = self.propose_trajectory()
993+
while random.random() > acceptance:
994+
allele_freqs, times, acceptance = self.propose_trajectory()
995+
return allele_freqs, times
996996

997+
def single_sweep_simulate(self):
998+
"""
999+
Does a structed coalescent until end_freq is reached.
9971000
"""
998-
allele_freqs, times = self.sweep_trajectory
1001+
allele_freqs, times = self.simulate_sweep_trajectory()
9991002
sweep_traj_step = 0
10001003
x = allele_freqs[sweep_traj_step]
10011004

@@ -1018,7 +1021,7 @@ def single_sweep_simulate(self):
10181021
self.P[0].add(tmp, 1)
10191022

10201023
# main loop time
1021-
t_inc_orig = self.time_slice
1024+
t_inc_orig = self.sweep.dt
10221025
e_time = 0.0
10231026
while self.ancestors_remain() and sweep_traj_step < len(times) - 1:
10241027
self.verify()
@@ -1083,12 +1086,12 @@ def single_sweep_simulate(self):
10831086
if r < e_sum / sweep_pop_tot_rate:
10841087
# recomb in B
10851088
self.hudson_recombination_event_sweep_phase(
1086-
1, self.sweep_site, x
1089+
1, self.sweep.position, x
10871090
)
10881091
else:
10891092
# recomb in b
10901093
self.hudson_recombination_event_sweep_phase(
1091-
0, self.sweep_site, 1.0 - x
1094+
0, self.sweep.position, 1.0 - x
10921095
)
10931096
# clean up the labels at end
10941097
for idx, u in enumerate(self.P[0].iter_label(1)):
@@ -2163,16 +2166,22 @@ def run_simulate(args):
21632166
rates = args.recomb_rates
21642167
recombination_map = RateMap(positions, rates)
21652168
num_labels = 1
2166-
sweep_trajectory = None
2169+
sweep = None
21672170
if args.model == "single_sweep":
21682171
if num_populations > 1:
21692172
raise ValueError("Multiple populations not currently supported")
21702173
# Compute the trajectory
21712174
if args.trajectory is None:
21722175
raise ValueError("Must provide trajectory (init_freq, end_freq, alpha)")
2173-
init_freq, end_freq, alpha = args.trajectory
2174-
traj_sim = TrajectorySimulator(init_freq, end_freq, alpha, args.time_slice)
2175-
sweep_trajectory = traj_sim.run()
2176+
initial_frequency, end_frequency, alpha = args.trajectory
2177+
trajectory = GenicSelectionTrajectory(alpha)
2178+
sweep = Sweep(
2179+
position=m // 2 + 1,
2180+
initial_frequency=initial_frequency,
2181+
end_frequency=end_frequency,
2182+
dt=args.time_slice,
2183+
trajectory=trajectory,
2184+
)
21762185
num_labels = 2
21772186
random.seed(args.random_seed)
21782187
np.random.seed(args.random_seed + 1)
@@ -2204,8 +2213,7 @@ def run_simulate(args):
22042213
max_segments=100000,
22052214
num_labels=num_labels,
22062215
full_arg=args.full_arg,
2207-
sweep_trajectory=sweep_trajectory,
2208-
time_slice=args.time_slice,
2216+
sweep=sweep,
22092217
gene_conversion_rate=gc_rate,
22102218
gene_conversion_length=mean_tract_length,
22112219
discrete_genome=args.discrete,

0 commit comments

Comments
 (0)