Skip to content

Commit 9285c3d

Browse files
authored
feat: optimize ODE (#167)
* optimize ODE infusions with a precokputed schedule * avoid cloning * preallocation of the output vector, resue of the spp vector also avoid redundant vec->V conversion * take boluses and ivs as reference instead of owned structs * extra tests * use slices for .occasions() and .events(), precalculating the capacity of the likelihood vector
1 parent a515ee2 commit 9285c3d

File tree

6 files changed

+1314
-122
lines changed

6 files changed

+1314
-122
lines changed

src/data/structs.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,9 @@ impl Subject {
355355
///
356356
/// # Returns
357357
///
358-
/// Vector of references to all occasions
359-
pub fn occasions(&self) -> Vec<&Occasion> {
360-
self.occasions.iter().collect()
358+
/// A slice of all occasions for this subject
359+
pub fn occasions(&self) -> &[Occasion] {
360+
&self.occasions
361361
}
362362

363363
/// Get the ID of the subject
@@ -518,9 +518,9 @@ impl Occasion {
518518
///
519519
/// # Returns
520520
///
521-
/// Vector of references to all events
522-
pub fn events(&self) -> Vec<&Event> {
523-
self.events.iter().collect()
521+
/// A slice of all events in this occasion
522+
pub fn events(&self) -> &[Event] {
523+
&self.events
524524
}
525525

526526
/// Get the index of the occasion
@@ -1130,7 +1130,7 @@ mod tests {
11301130
.find(|e| matches!(e, Event::Bolus(_)))
11311131
{
11321132
let mut event_count = 0;
1133-
for event in *bolus_event {
1133+
for event in bolus_event {
11341134
assert_eq!(event.time(), 2.0); // Bolus time from sample data
11351135
assert!(matches!(event, Event::Bolus(_)));
11361136
if let Event::Bolus(bolus) = event {
@@ -1158,7 +1158,7 @@ mod tests {
11581158
.iter()
11591159
.find(|e| matches!(e, Event::Observation(_)))
11601160
{
1161-
for event in *obs_event {
1161+
for event in obs_event {
11621162
assert_eq!(event.time(), 1.0); // Observation time from sample data
11631163
if let Event::Observation(observation) = event {
11641164
assert_eq!(observation.value(), Some(10.0)); // Value from sample data

src/simulator/equation/ode/closure.rs

Lines changed: 151 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,135 @@ use diffsol::{
44
NonLinearOpJacobian, OdeEquations, OdeEquationsRef, Op, Vector, VectorCommon,
55
};
66
use nalgebra::DVector;
7-
use std::cell::RefCell;
7+
use std::{cell::RefCell, cmp::Ordering};
88
type M = NalgebraMat<f64>;
99
type V = <M as MatrixCommon>::V;
1010
type C = <M as MatrixCommon>::C;
1111
type T = <M as MatrixCommon>::T;
1212

13+
#[derive(Debug, Clone)]
14+
struct InfusionChannel {
15+
input: usize,
16+
event_times: Vec<f64>,
17+
cumulative_rates: Vec<f64>,
18+
}
19+
20+
impl InfusionChannel {
21+
fn new(input: usize, mut events: Vec<(f64, f64)>) -> Self {
22+
events.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
23+
24+
let mut event_times = Vec::with_capacity(events.len());
25+
let mut cumulative_rates = Vec::with_capacity(events.len());
26+
let mut current_rate = 0.0;
27+
28+
for (time, delta) in events {
29+
current_rate += delta;
30+
event_times.push(time);
31+
cumulative_rates.push(current_rate);
32+
}
33+
34+
Self {
35+
input,
36+
event_times,
37+
cumulative_rates,
38+
}
39+
}
40+
41+
fn rate_at(&self, time: f64) -> f64 {
42+
if self.event_times.is_empty() {
43+
return 0.0;
44+
}
45+
46+
match self
47+
.event_times
48+
.binary_search_by(|probe| probe.partial_cmp(&time).unwrap_or(Ordering::Less))
49+
{
50+
Ok(mut idx) => {
51+
while idx + 1 < self.event_times.len()
52+
&& self.event_times[idx + 1] == self.event_times[idx]
53+
{
54+
idx += 1;
55+
}
56+
self.cumulative_rates[idx]
57+
}
58+
Err(0) => 0.0,
59+
Err(idx) => self.cumulative_rates[idx - 1],
60+
}
61+
}
62+
}
63+
64+
#[derive(Debug, Clone, Default)]
65+
struct InfusionSchedule {
66+
channels: Vec<InfusionChannel>,
67+
}
68+
69+
impl InfusionSchedule {
70+
fn new(nstates: usize, infusions: &[&Infusion]) -> Self {
71+
if nstates == 0 || infusions.is_empty() {
72+
return Self {
73+
channels: Vec::new(),
74+
};
75+
}
76+
77+
let mut per_input: Vec<Vec<(f64, f64)>> = vec![Vec::new(); nstates];
78+
for infusion in infusions {
79+
if infusion.duration() <= 0.0 {
80+
continue;
81+
}
82+
83+
let input = infusion.input();
84+
if input >= nstates {
85+
continue;
86+
}
87+
88+
let rate = infusion.amount() / infusion.duration();
89+
per_input[input].push((infusion.time(), rate));
90+
per_input[input].push((infusion.time() + infusion.duration(), -rate));
91+
}
92+
93+
let channels = per_input
94+
.into_iter()
95+
.enumerate()
96+
.filter_map(|(input, events)| {
97+
if events.is_empty() {
98+
None
99+
} else {
100+
Some(InfusionChannel::new(input, events))
101+
}
102+
})
103+
.collect();
104+
105+
Self { channels }
106+
}
107+
108+
fn fill_rate_vector(&self, time: f64, rateiv: &mut V) {
109+
rateiv.fill(0.0);
110+
for channel in &self.channels {
111+
let rate = channel.rate_at(time);
112+
if rate != 0.0 {
113+
rateiv[channel.input] = rate;
114+
}
115+
}
116+
}
117+
}
118+
13119
pub struct PmRhs<'a, F>
14120
where
15-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates),
121+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates),
16122
{
17123
nstates: usize,
18124
nparams: usize,
19-
infusions: &'a [&'a Infusion], // Change from Vec to slice reference
125+
infusion_schedule: &'a InfusionSchedule,
20126
covariates: &'a Covariates,
21-
p: &'a Vec<f64>,
127+
p_as_v: &'a V,
22128
func: &'a F,
23129
rateiv_buffer: &'a RefCell<V>,
130+
zero_bolus: &'a V,
24131
}
25132

26133
impl<F> Op for PmRhs<'_, F>
27134
where
28-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates),
135+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates),
29136
{
30137
type T = T;
31138
type V = V;
@@ -154,85 +261,36 @@ impl Op for PmOut {
154261

155262
impl<F> NonLinearOp for PmRhs<'_, F>
156263
where
157-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates),
264+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates),
158265
{
159266
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
160-
// Compute rate IV at the current time
161267
let mut rateiv_ref = self.rateiv_buffer.borrow_mut();
162-
rateiv_ref.fill(0.0);
163-
164-
for infusion in self.infusions {
165-
if t >= infusion.time() && t <= infusion.duration() + infusion.time() {
166-
rateiv_ref[infusion.input()] += infusion.amount() / infusion.duration();
167-
}
168-
}
169-
170-
// We need to drop the mutable borrow before calling the function
171-
// to avoid potential conflicts with future borrows in the function
172-
let rateiv = rateiv_ref.clone();
173-
drop(rateiv_ref);
174-
175-
// Avoid creating a new DVector when possible
176-
let p_len = self.p.len();
177-
let mut p_dvector: DVector<f64>;
178-
let p_ref: &DVector<f64>;
179-
180-
// Use stack allocation for small parameter vectors
181-
if p_len <= 16 {
182-
let mut stack_p = [0.0; 16];
183-
stack_p[..p_len].copy_from_slice(self.p);
184-
p_dvector = DVector::from_row_slice(&stack_p[..p_len]);
185-
p_ref = &p_dvector;
186-
} else {
187-
// For larger vectors, use the more efficient approach with unsafe
188-
p_dvector = DVector::zeros(p_len);
189-
unsafe {
190-
std::ptr::copy_nonoverlapping(self.p.as_ptr(), p_dvector.as_mut_ptr(), p_len);
191-
}
192-
p_ref = &p_dvector;
193-
}
194-
195-
let pnew = p_ref.to_owned().into();
268+
self.infusion_schedule.fill_rate_vector(t, &mut rateiv_ref);
196269

197-
let bolus = V::zeros(self.nstates, NalgebraContext);
198-
199-
(self.func)(x, &pnew, t, y, bolus, rateiv, self.covariates);
270+
(self.func)(
271+
x,
272+
self.p_as_v,
273+
t,
274+
y,
275+
self.zero_bolus,
276+
&rateiv_ref,
277+
self.covariates,
278+
);
200279
}
201280
}
202281

203282
impl<F> NonLinearOpJacobian for PmRhs<'_, F>
204283
where
205-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates),
284+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates),
206285
{
207286
fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
208-
let rateiv = V::zeros(self.nstates, NalgebraContext);
209-
210-
// Avoid creating a new DVector when possible
211-
let p_len = self.p.len();
212-
let mut p_dvector: DVector<f64>;
213-
214-
// Use stack allocation for small parameter vectors
215-
if p_len <= 16 {
216-
let mut stack_p = [0.0; 16];
217-
stack_p[..p_len].copy_from_slice(self.p);
218-
p_dvector = DVector::from_row_slice(&stack_p[..p_len]);
219-
} else {
220-
// For larger vectors, use the more efficient approach with unsafe
221-
p_dvector = DVector::zeros(p_len);
222-
unsafe {
223-
std::ptr::copy_nonoverlapping(self.p.as_ptr(), p_dvector.as_mut_ptr(), p_len);
224-
}
225-
}
226-
227-
let bolus = V::zeros(self.nstates, NalgebraContext);
228-
229287
(self.func)(
230288
v,
231-
&p_dvector.to_owned().into(),
289+
self.p_as_v,
232290
t,
233291
y,
234-
bolus,
235-
rateiv,
292+
self.zero_bolus,
293+
self.zero_bolus,
236294
self.covariates,
237295
);
238296
}
@@ -253,49 +311,59 @@ impl NonLinearOp for PmOut {
253311
// Completely revised PMProblem to fix lifetime issues and improve performance
254312
pub struct PMProblem<'a, F>
255313
where
256-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates) + 'a,
314+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates) + 'a,
257315
{
258316
func: F,
259317
nstates: usize,
260318
nparams: usize,
261319
init: V,
262320
p: Vec<f64>,
321+
p_as_v: V,
322+
zero_bolus: V,
263323
covariates: &'a Covariates,
264-
infusions: Vec<&'a Infusion>,
324+
infusion_schedule: InfusionSchedule,
265325
rateiv_buffer: RefCell<V>,
266326
}
267327

268328
impl<'a, F> PMProblem<'a, F>
269329
where
270-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates) + 'a,
330+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates) + 'a,
271331
{
272-
pub fn new(
332+
/// Creates a new PMProblem with a pre-converted parameter vector.
333+
/// This avoids an allocation when the caller already has a V representation.
334+
pub fn with_params_v(
273335
func: F,
274336
nstates: usize,
275337
p: Vec<f64>,
338+
p_as_v: V,
276339
covariates: &'a Covariates,
277-
infusions: Vec<&'a Infusion>,
340+
infusions: &[&'a Infusion],
278341
init: V,
279342
) -> Self {
280343
let nparams = p.len();
281344
let rateiv_buffer = RefCell::new(V::zeros(nstates, NalgebraContext));
345+
let infusion_schedule = InfusionSchedule::new(nstates, infusions);
346+
// Pre-allocate zero bolus vector
347+
let zero_bolus = V::zeros(nstates, NalgebraContext);
282348

283349
Self {
284350
func,
285351
nstates,
286352
nparams,
287353
init,
288354
p,
355+
p_as_v,
356+
zero_bolus,
289357
covariates,
290-
infusions,
358+
infusion_schedule,
291359
rateiv_buffer,
292360
}
293361
}
294362
}
295363

296364
impl<'a, F> Op for PMProblem<'a, F>
297365
where
298-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates) + 'a,
366+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates) + 'a,
299367
{
300368
type T = T;
301369
type V = V;
@@ -318,7 +386,7 @@ where
318386
// Implement OdeEquationsRef for PMProblem for any lifetime 'b
319387
impl<'a, 'b, F> OdeEquationsRef<'b> for PMProblem<'a, F>
320388
where
321-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates) + 'a,
389+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates) + 'a,
322390
{
323391
type Rhs = PmRhs<'b, F>;
324392
type Mass = PmMass;
@@ -330,17 +398,18 @@ where
330398
// Implement OdeEquations with correct lifetime handling
331399
impl<'a, F> OdeEquations for PMProblem<'a, F>
332400
where
333-
F: Fn(&V, &V, T, &mut V, V, V, &Covariates) + 'a,
401+
F: Fn(&V, &V, T, &mut V, &V, &V, &Covariates) + 'a,
334402
{
335403
fn rhs(&self) -> PmRhs<'_, F> {
336404
PmRhs {
337405
nstates: self.nstates,
338406
nparams: self.nparams,
339-
infusions: &self.infusions, // Use reference instead of clone
407+
infusion_schedule: &self.infusion_schedule,
340408
covariates: self.covariates,
341-
p: &self.p,
409+
p_as_v: &self.p_as_v,
342410
func: &self.func,
343411
rateiv_buffer: &self.rateiv_buffer,
412+
zero_bolus: &self.zero_bolus,
344413
}
345414
}
346415

@@ -380,9 +449,11 @@ where
380449
if self.p.len() == p.len() {
381450
for i in 0..p.len() {
382451
self.p[i] = p[i];
452+
self.p_as_v[i] = p[i];
383453
}
384454
} else {
385455
self.p = p.inner().iter().cloned().collect();
456+
self.p_as_v = p.clone();
386457
}
387458
}
388459
}

0 commit comments

Comments
 (0)