Skip to content

Commit 83d8440

Browse files
Merge pull request #136 from SciML/myb/more
Add optimal_parameter_intervention_for_reach
2 parents 027d9bb + 0594f50 commit 83d8440

File tree

3 files changed

+113
-1
lines changed

3 files changed

+113
-1
lines changed

src/EasyModelAnalysis.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export get_sensitivity, create_sensitivity_plot, get_sensitivity_of_maximum
2525
export stop_at_threshold, get_threshold
2626
export model_forecast_score
2727
export optimal_threshold_intervention, prob_violating_threshold,
28-
optimal_parameter_intervention_for_threshold, optimal_parameter_threshold
28+
optimal_parameter_intervention_for_threshold, optimal_parameter_threshold,
29+
optimal_parameter_intervention_for_reach
2930

3031
end

src/intervention.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,107 @@ function optimal_parameter_intervention_for_threshold(prob, obs, threshold,
165165
ss = duration_constraint(optx, [], Val(true))
166166
Dict(ps .=> optx), ss, ret
167167
end
168+
169+
"""
170+
optimal_parameter_intervention_for_reach(prob, obs, reach, cost, ps,
171+
lb, ub, intervention_tspan, duration; ineq_cons = nothing, maxtime=60)
172+
173+
## Arguments
174+
175+
- `prob`: An ODEProblem.
176+
- `obs`: The observation symbolic expression.
177+
- `reach`: The reach for the observation, i.e., the constraint enforces that `obs` reaches `reach`.
178+
- `cost`: the cost function for minimization, e.g. `α + 20 * β`.
179+
- `ps`: the parameters that appear in the cost, e.g. `[α, β]`.
180+
- `lb`: the lower bounds of the parameters e.g. `[-10, -5]`.
181+
- `ub`: the uppwer bounds of the parameters e.g. `[5, 10]`.
182+
- `intervention_tspan`: intervention time span, e.g. `(20.0, 30.0)`. Defaults to `prob.tspan`.
183+
- `duration`: Duration for the evaluation of intervention. Defaults to `prob.tspan[2] - prob.tspan[1]`.
184+
185+
186+
## Keyword Arguments
187+
188+
- `maxtime`: Maximum optimzation time. Defaults to `60`.
189+
- `ineq_cons`: a vector of symbolic expressions in terms of symbolic
190+
parameters. The optimizer will enforce `ineq_cons .< 0`.
191+
192+
# Returns
193+
194+
- `opt_p`: Optimal intervention parameters.
195+
- `(s1, s2, s3)`: Pre-intervention, intervention, post-intervention solutions.
196+
- `ret`: Return code from the optimization.
197+
"""
198+
function optimal_parameter_intervention_for_reach(prob, obs, reach,
199+
symbolic_cost, ps, lb, ub,
200+
intervention_tspan = prob.tspan,
201+
duration = abs(-(prob.tspan...));
202+
maxtime = 60, ineq_cons = nothing,
203+
kw...)
204+
t0 = prob.tspan[1]
205+
ti_start, ti_end = intervention_tspan
206+
symbolic_cost = Symbolics.unwrap(symbolic_cost)
207+
#ps = collect(ModelingToolkit.vars(symbolic_cost))
208+
_cost = Symbolics.build_function(symbolic_cost, ps, expression = Val{false})
209+
_cost(prob.p) # just throw when something is wrong during the setup.
210+
211+
cost = let _cost = _cost
212+
(x, grad) -> _cost(x)
213+
end
214+
215+
function duration_constraint(x::Vector, grad::Vector, ::Val{p} = Val(false)) where {p}
216+
tf = t0 + duration
217+
prob_preintervention = remake(prob, tspan = (t0, ti_start))
218+
if p
219+
sol_preintervention = solve(prob_preintervention; kw...)
220+
else
221+
sol_preintervention = stop_at_threshold(prob_preintervention, obs, reach; kw...)
222+
reach_time = ti_start - sol_preintervention.t[end]
223+
reach_time > 0 && return sol_preintervention.t[end] - tf
224+
end
225+
226+
prob_intervention = remake(prob, u0 = sol_preintervention.u[end], p = ps .=> x,
227+
tspan = (ti_start, ti_end))
228+
if p
229+
sol_intervention = solve(prob_intervention; kw...)
230+
else
231+
sol_intervention = stop_at_threshold(prob_intervention, obs, reach; kw...)
232+
reach_time = ti_end - sol_intervention.t[end]
233+
reach_time > 0 && return sol_intervention.t[end] - tf
234+
end
235+
236+
prob_postintervention = remake(prob, u0 = sol_intervention.u[end],
237+
tspan = (ti_end, t0 + duration))
238+
if p
239+
sol_postintervention = solve(prob_postintervention; kw...)
240+
sol_preintervention, sol_intervention, sol_postintervention
241+
else
242+
sol_postintervention = stop_at_threshold(prob_postintervention, obs, reach;
243+
kw...)
244+
reach_time = tf - sol_postintervention.t[end]
245+
10.0
246+
end
247+
end
248+
249+
opt = Opt(:GN_ISRES, length(ps))
250+
opt.lower_bounds = lb
251+
opt.upper_bounds = ub
252+
opt.xtol_rel = 1e-4
253+
254+
opt.min_objective = cost
255+
init_x = @. (lb + ub) / 2
256+
duration_constraint(init_x, [])
257+
inequality_constraint!(opt, duration_constraint, 1e-16)
258+
if ineq_cons !== nothing
259+
for con in ineq_cons
260+
_con = Symbolics.build_function(Symbolics.unwrap(con), ps,
261+
expression = Val{false})
262+
_con(init_x)
263+
ineq_con = (x, _) -> _con(x)
264+
inequality_constraint!(opt, ineq_con, 1e-16)
265+
end
266+
end
267+
opt.maxtime = maxtime
268+
(optf, optx, ret) = NLopt.optimize(opt, init_x)
269+
ss = duration_constraint(optx, [], Val(true))
270+
Dict(ps .=> optx), ss, ret
271+
end

test/threshold.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ opt_ps, s2, ret = optimal_parameter_threshold(prob, x, 3,
5959
opt_ps, s2, ret = optimal_parameter_threshold(prob, x, 3,
6060
-p, [p],
6161
[-1.0], [1.0]);
62+
opt_ps, (s1, s2, s3), ret = optimal_parameter_intervention_for_reach(remake(prob,
63+
tspan = (0,
64+
1.0)),
65+
x, 300,
66+
p, [p],
67+
[0.0], [30.0]);
68+
@test 10 < opt_ps[p] < 11
6269
@variables t x(t) y(t)
6370
@parameters p1 p2
6471
D = Differential(t)

0 commit comments

Comments
 (0)