Skip to content

SimpleAdaptiveTauLeaping solver #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end

struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter
end

SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
Expand All @@ -14,6 +20,19 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
jump_prob.regular_jump !== nothing
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing,
dt = error("dt is required for SimpleTauLeaping."))
Expand Down Expand Up @@ -62,6 +81,164 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
interp = DiffEqBase.ConstantInterpolation(t, u))
end

function compute_hor(reactant_stoch, numjumps)
hor = zeros(Int, numjumps)
for j in 1:numjumps
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0)
if order > 3
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
end
return hor
end

function compute_gi(u, nu, hor, i)
max_gi = 1
for j in 1:size(nu, 2)
if nu[i, j] < 0 # Species i is a substrate
if hor[j] == 1
max_gi = max(max_gi, 1)
elseif hor[j] == 2 || hor[j] == 3
stoch = abs(nu[i, j])
if stoch >= 2
gi = 2 / stoch + 1 / (stoch - 1)
max_gi = max(max_gi, ceil(Int, gi))
elseif stoch == 1
max_gi = max(max_gi, hor[j])
end
end
end
end
return max_gi
end
Comment on lines +96 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed the full PR carefully, but this function still looks wrong to me. Look at Section IV of the paper at the text between (27) and (28). x_i there is not stoch it is the ith species (i.e. u_i).

Also, you should be precalculating if any reaction with HOR[i] == 2 has species i with substrate stoichiometry of 2 (and the related conditions when HOR[i] == 3). This is determined from reactant_stoch not the net stoichiometry matrix nu.

If I can make a suggestion, perhaps it makes sense for you to go back and work on finishing the GPU method PR. It seems like you are having some conceptual issues with the paper here, whereas that PR just involves some basic code modifications to get finished and so should be an easier task. Once that is done and merged perhaps you can more carefully read this paper, particularly Section IV, and then return to this PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if you have specific questions on the method of the paper please ask on Slack! That might help clarify your understanding / fix any confusion.


function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin)
rate(rate_cache, u, p, t)
tau = Inf
for i in 1:length(u)
mu = zero(eltype(u))
sigma2 = zero(eltype(u))
for j in 1:size(nu, 2)
mu += nu[i, j] * rate_cache[j]
sigma2 += nu[i, j]^2 * rate_cache[j]
end
gi = compute_gi(u, nu, hor, i)
bound = max(epsilon * u[i] / gi, 1.0)
mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf
sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf
tau = min(tau, mu_term, sigma_term)
end
return max(tau, dtmin)
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
seed = nothing,
dtmin = 1e-10,
saveat = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")
prob = jump_prob.prob
rng = DEFAULT_RNG
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate :
(out, u, p, t) -> begin
for j in 1:numjumps
out[j] = evalrxrate(u, j, maj)
end
end
c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

# Initialize current state and saved history
u_current = copy(u0)
t_current = tspan[1]
usave = [copy(u0)]
tsave = [tspan[1]]
rate_cache = zeros(float(eltype(u0)), numjumps)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon

# Extract stochiometry once from MassActionJump
nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
# Extract reactant stochiometry for hor
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)

# Set up saveat_times
saveat_times = nothing
if isnothing(saveat)
saveat_times = Vector{typeof(tspan[1])}()
elseif saveat isa Number
saveat_times = collect(range(tspan[1], tspan[2], step=saveat))
else
saveat_times = collect(saveat)
end

save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
du .= 0
if c !== nothing
c(du, u_current, p, t_current, counts, nothing)
else
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
end
u_new = u_current + du
if any(<(0), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
tau /= 2
continue
end
for i in eachindex(u_new)
u_new[i] = max(u_new[i], 0)
end
t_new = t_current + tau

# Save state if at a saveat time or if saveat is empty
if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, u_new)
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current = u_new
t_current = t_new
end

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error=false,
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand All @@ -74,3 +251,5 @@ end
function EnsembleGPUKernel()
EnsembleGPUKernel(nothing, 0.0)
end

export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping
Loading
Loading