@@ -4,28 +4,135 @@ use diffsol::{
44 NonLinearOpJacobian , OdeEquations , OdeEquationsRef , Op , Vector , VectorCommon ,
55} ;
66use nalgebra:: DVector ;
7- use std:: cell:: RefCell ;
7+ use std:: { cell:: RefCell , cmp :: Ordering } ;
88type M = NalgebraMat < f64 > ;
99type V = <M as MatrixCommon >:: V ;
1010type C = <M as MatrixCommon >:: C ;
1111type T = <M as MatrixCommon >:: T ;
1212
13+ #[ derive( Debug , Clone ) ]
14+ struct InfusionChannel {
15+ input : usize ,
16+ event_times : Vec < f64 > ,
17+ cumulative_rates : Vec < f64 > ,
18+ }
19+
20+ impl InfusionChannel {
21+ fn new ( input : usize , mut events : Vec < ( f64 , f64 ) > ) -> Self {
22+ events. sort_by ( |a, b| a. 0 . partial_cmp ( & b. 0 ) . unwrap_or ( Ordering :: Equal ) ) ;
23+
24+ let mut event_times = Vec :: with_capacity ( events. len ( ) ) ;
25+ let mut cumulative_rates = Vec :: with_capacity ( events. len ( ) ) ;
26+ let mut current_rate = 0.0 ;
27+
28+ for ( time, delta) in events {
29+ current_rate += delta;
30+ event_times. push ( time) ;
31+ cumulative_rates. push ( current_rate) ;
32+ }
33+
34+ Self {
35+ input,
36+ event_times,
37+ cumulative_rates,
38+ }
39+ }
40+
41+ fn rate_at ( & self , time : f64 ) -> f64 {
42+ if self . event_times . is_empty ( ) {
43+ return 0.0 ;
44+ }
45+
46+ match self
47+ . event_times
48+ . binary_search_by ( |probe| probe. partial_cmp ( & time) . unwrap_or ( Ordering :: Less ) )
49+ {
50+ Ok ( mut idx) => {
51+ while idx + 1 < self . event_times . len ( )
52+ && self . event_times [ idx + 1 ] == self . event_times [ idx]
53+ {
54+ idx += 1 ;
55+ }
56+ self . cumulative_rates [ idx]
57+ }
58+ Err ( 0 ) => 0.0 ,
59+ Err ( idx) => self . cumulative_rates [ idx - 1 ] ,
60+ }
61+ }
62+ }
63+
64+ #[ derive( Debug , Clone , Default ) ]
65+ struct InfusionSchedule {
66+ channels : Vec < InfusionChannel > ,
67+ }
68+
69+ impl InfusionSchedule {
70+ fn new ( nstates : usize , infusions : & [ & Infusion ] ) -> Self {
71+ if nstates == 0 || infusions. is_empty ( ) {
72+ return Self {
73+ channels : Vec :: new ( ) ,
74+ } ;
75+ }
76+
77+ let mut per_input: Vec < Vec < ( f64 , f64 ) > > = vec ! [ Vec :: new( ) ; nstates] ;
78+ for infusion in infusions {
79+ if infusion. duration ( ) <= 0.0 {
80+ continue ;
81+ }
82+
83+ let input = infusion. input ( ) ;
84+ if input >= nstates {
85+ continue ;
86+ }
87+
88+ let rate = infusion. amount ( ) / infusion. duration ( ) ;
89+ per_input[ input] . push ( ( infusion. time ( ) , rate) ) ;
90+ per_input[ input] . push ( ( infusion. time ( ) + infusion. duration ( ) , -rate) ) ;
91+ }
92+
93+ let channels = per_input
94+ . into_iter ( )
95+ . enumerate ( )
96+ . filter_map ( |( input, events) | {
97+ if events. is_empty ( ) {
98+ None
99+ } else {
100+ Some ( InfusionChannel :: new ( input, events) )
101+ }
102+ } )
103+ . collect ( ) ;
104+
105+ Self { channels }
106+ }
107+
108+ fn fill_rate_vector ( & self , time : f64 , rateiv : & mut V ) {
109+ rateiv. fill ( 0.0 ) ;
110+ for channel in & self . channels {
111+ let rate = channel. rate_at ( time) ;
112+ if rate != 0.0 {
113+ rateiv[ channel. input ] = rate;
114+ }
115+ }
116+ }
117+ }
118+
13119pub struct PmRhs < ' a , F >
14120where
15- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) ,
121+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) ,
16122{
17123 nstates : usize ,
18124 nparams : usize ,
19- infusions : & ' a [ & ' a Infusion ] , // Change from Vec to slice reference
125+ infusion_schedule : & ' a InfusionSchedule ,
20126 covariates : & ' a Covariates ,
21- p : & ' a Vec < f64 > ,
127+ p_as_v : & ' a V ,
22128 func : & ' a F ,
23129 rateiv_buffer : & ' a RefCell < V > ,
130+ zero_bolus : & ' a V ,
24131}
25132
26133impl < F > Op for PmRhs < ' _ , F >
27134where
28- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) ,
135+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) ,
29136{
30137 type T = T ;
31138 type V = V ;
@@ -154,85 +261,36 @@ impl Op for PmOut {
154261
155262impl < F > NonLinearOp for PmRhs < ' _ , F >
156263where
157- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) ,
264+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) ,
158265{
159266 fn call_inplace ( & self , x : & Self :: V , t : Self :: T , y : & mut Self :: V ) {
160- // Compute rate IV at the current time
161267 let mut rateiv_ref = self . rateiv_buffer . borrow_mut ( ) ;
162- rateiv_ref. fill ( 0.0 ) ;
163-
164- for infusion in self . infusions {
165- if t >= infusion. time ( ) && t <= infusion. duration ( ) + infusion. time ( ) {
166- rateiv_ref[ infusion. input ( ) ] += infusion. amount ( ) / infusion. duration ( ) ;
167- }
168- }
169-
170- // We need to drop the mutable borrow before calling the function
171- // to avoid potential conflicts with future borrows in the function
172- let rateiv = rateiv_ref. clone ( ) ;
173- drop ( rateiv_ref) ;
174-
175- // Avoid creating a new DVector when possible
176- let p_len = self . p . len ( ) ;
177- let mut p_dvector: DVector < f64 > ;
178- let p_ref: & DVector < f64 > ;
179-
180- // Use stack allocation for small parameter vectors
181- if p_len <= 16 {
182- let mut stack_p = [ 0.0 ; 16 ] ;
183- stack_p[ ..p_len] . copy_from_slice ( self . p ) ;
184- p_dvector = DVector :: from_row_slice ( & stack_p[ ..p_len] ) ;
185- p_ref = & p_dvector;
186- } else {
187- // For larger vectors, use the more efficient approach with unsafe
188- p_dvector = DVector :: zeros ( p_len) ;
189- unsafe {
190- std:: ptr:: copy_nonoverlapping ( self . p . as_ptr ( ) , p_dvector. as_mut_ptr ( ) , p_len) ;
191- }
192- p_ref = & p_dvector;
193- }
194-
195- let pnew = p_ref. to_owned ( ) . into ( ) ;
268+ self . infusion_schedule . fill_rate_vector ( t, & mut rateiv_ref) ;
196269
197- let bolus = V :: zeros ( self . nstates , NalgebraContext ) ;
198-
199- ( self . func ) ( x, & pnew, t, y, bolus, rateiv, self . covariates ) ;
270+ ( self . func ) (
271+ x,
272+ self . p_as_v ,
273+ t,
274+ y,
275+ self . zero_bolus ,
276+ & rateiv_ref,
277+ self . covariates ,
278+ ) ;
200279 }
201280}
202281
203282impl < F > NonLinearOpJacobian for PmRhs < ' _ , F >
204283where
205- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) ,
284+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) ,
206285{
207286 fn jac_mul_inplace ( & self , _x : & Self :: V , t : Self :: T , v : & Self :: V , y : & mut Self :: V ) {
208- let rateiv = V :: zeros ( self . nstates , NalgebraContext ) ;
209-
210- // Avoid creating a new DVector when possible
211- let p_len = self . p . len ( ) ;
212- let mut p_dvector: DVector < f64 > ;
213-
214- // Use stack allocation for small parameter vectors
215- if p_len <= 16 {
216- let mut stack_p = [ 0.0 ; 16 ] ;
217- stack_p[ ..p_len] . copy_from_slice ( self . p ) ;
218- p_dvector = DVector :: from_row_slice ( & stack_p[ ..p_len] ) ;
219- } else {
220- // For larger vectors, use the more efficient approach with unsafe
221- p_dvector = DVector :: zeros ( p_len) ;
222- unsafe {
223- std:: ptr:: copy_nonoverlapping ( self . p . as_ptr ( ) , p_dvector. as_mut_ptr ( ) , p_len) ;
224- }
225- }
226-
227- let bolus = V :: zeros ( self . nstates , NalgebraContext ) ;
228-
229287 ( self . func ) (
230288 v,
231- & p_dvector . to_owned ( ) . into ( ) ,
289+ self . p_as_v ,
232290 t,
233291 y,
234- bolus ,
235- rateiv ,
292+ self . zero_bolus ,
293+ self . zero_bolus ,
236294 self . covariates ,
237295 ) ;
238296 }
@@ -253,49 +311,59 @@ impl NonLinearOp for PmOut {
253311// Completely revised PMProblem to fix lifetime issues and improve performance
254312pub struct PMProblem < ' a , F >
255313where
256- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) + ' a ,
314+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) + ' a ,
257315{
258316 func : F ,
259317 nstates : usize ,
260318 nparams : usize ,
261319 init : V ,
262320 p : Vec < f64 > ,
321+ p_as_v : V ,
322+ zero_bolus : V ,
263323 covariates : & ' a Covariates ,
264- infusions : Vec < & ' a Infusion > ,
324+ infusion_schedule : InfusionSchedule ,
265325 rateiv_buffer : RefCell < V > ,
266326}
267327
268328impl < ' a , F > PMProblem < ' a , F >
269329where
270- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) + ' a ,
330+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) + ' a ,
271331{
272- pub fn new (
332+ /// Creates a new PMProblem with a pre-converted parameter vector.
333+ /// This avoids an allocation when the caller already has a V representation.
334+ pub fn with_params_v (
273335 func : F ,
274336 nstates : usize ,
275337 p : Vec < f64 > ,
338+ p_as_v : V ,
276339 covariates : & ' a Covariates ,
277- infusions : Vec < & ' a Infusion > ,
340+ infusions : & [ & ' a Infusion ] ,
278341 init : V ,
279342 ) -> Self {
280343 let nparams = p. len ( ) ;
281344 let rateiv_buffer = RefCell :: new ( V :: zeros ( nstates, NalgebraContext ) ) ;
345+ let infusion_schedule = InfusionSchedule :: new ( nstates, infusions) ;
346+ // Pre-allocate zero bolus vector
347+ let zero_bolus = V :: zeros ( nstates, NalgebraContext ) ;
282348
283349 Self {
284350 func,
285351 nstates,
286352 nparams,
287353 init,
288354 p,
355+ p_as_v,
356+ zero_bolus,
289357 covariates,
290- infusions ,
358+ infusion_schedule ,
291359 rateiv_buffer,
292360 }
293361 }
294362}
295363
296364impl < ' a , F > Op for PMProblem < ' a , F >
297365where
298- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) + ' a ,
366+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) + ' a ,
299367{
300368 type T = T ;
301369 type V = V ;
@@ -318,7 +386,7 @@ where
318386// Implement OdeEquationsRef for PMProblem for any lifetime 'b
319387impl < ' a , ' b , F > OdeEquationsRef < ' b > for PMProblem < ' a , F >
320388where
321- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) + ' a ,
389+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) + ' a ,
322390{
323391 type Rhs = PmRhs < ' b , F > ;
324392 type Mass = PmMass ;
@@ -330,17 +398,18 @@ where
330398// Implement OdeEquations with correct lifetime handling
331399impl < ' a , F > OdeEquations for PMProblem < ' a , F >
332400where
333- F : Fn ( & V , & V , T , & mut V , V , V , & Covariates ) + ' a ,
401+ F : Fn ( & V , & V , T , & mut V , & V , & V , & Covariates ) + ' a ,
334402{
335403 fn rhs ( & self ) -> PmRhs < ' _ , F > {
336404 PmRhs {
337405 nstates : self . nstates ,
338406 nparams : self . nparams ,
339- infusions : & self . infusions , // Use reference instead of clone
407+ infusion_schedule : & self . infusion_schedule ,
340408 covariates : self . covariates ,
341- p : & self . p ,
409+ p_as_v : & self . p_as_v ,
342410 func : & self . func ,
343411 rateiv_buffer : & self . rateiv_buffer ,
412+ zero_bolus : & self . zero_bolus ,
344413 }
345414 }
346415
@@ -380,9 +449,11 @@ where
380449 if self . p . len ( ) == p. len ( ) {
381450 for i in 0 ..p. len ( ) {
382451 self . p [ i] = p[ i] ;
452+ self . p_as_v [ i] = p[ i] ;
383453 }
384454 } else {
385455 self . p = p. inner ( ) . iter ( ) . cloned ( ) . collect ( ) ;
456+ self . p_as_v = p. clone ( ) ;
386457 }
387458 }
388459}
0 commit comments