Skip to content
Snippets Groups Projects
Commit 7e822c45 authored by Brian Groenke's avatar Brian Groenke
Browse files

Revert "Fix issues with new save caching system"

This reverts commit 7f4422ff.
parent 0d8fd6b8
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,6 @@
#
# TODO: add more detail/background
using CryoGrid
CryoGrid.debug(true)
# Set up forcings and boundary conditions similarly to other examples:
forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
......@@ -14,7 +13,7 @@ soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve))
grid = CryoGrid.DefaultGrid_5cm
initT = initializer(:T, tempprofile)
tile = CryoGrid.SoilHeatTile(
:H,
:T,
TemperatureBC(Input(:Tair), NFactor(nf=Param(0.5), nt=Param(0.9))),
GeothermalHeatFlux(0.053u"W/m^2"),
soilprofile,
......@@ -22,30 +21,17 @@ tile = CryoGrid.SoilHeatTile(
initT;
grid=grid
)
tspan = (DateTime(2010,9,1),DateTime(2010,10,1))
tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
u0, du0 = @time initialcondition!(tile, tspan);
# We can retrieve the parameters of the system from `tile`:
para = CryoGrid.parameters(tile)
# Create the `CryoGridProblem`.
prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0, savevars=(:T,))
function testfunc(prob)
function(p)
newprob = remake(prob, p=p)
sol = solve(newprob, Euler(), dt=300.0)
out = CryoGridOutput(sol)
y = mean(ustrip.(Array(out.T)))
return y
end
end
f = testfunc(prob)
grad = @time ForwardDiff.gradient(f, vec(prob.p))
prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
# Solve the forward problem with default parameter settings:
sol = @time solve(prob);
sol = @time solve(prob)
out = CryoGridOutput(sol)
# Import relevant packages for automatic differentiation.
......
......@@ -99,10 +99,3 @@ function reset!(cache::SaveCache)
resize!(cache.t, 0)
resize!(cache.vals, 0)
end
struct SaveConfig
savevars::Tuple
saveat::Vector{Float64}
save_start::Bool
save_everystep::Bool
end
......@@ -81,14 +81,12 @@ function CommonSolve.init(
kwargs...
)
ode_prob = ODEProblem(prob)
ode_integrator = init(ode_prob, alg, args...; kwargs...)
integrator = CryoGridDiffEqIntegrator(prob, ode_integrator)
return integrator
integrator = init(ode_prob, alg, args...; kwargs...)
return CryoGridDiffEqIntegrator(prob, integrator)
end
function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...)
rv = step!(integrator.integrator, args...; kwargs...)
return rv
return step!(integrator.integrator, args...; kwargs...)
end
function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator)
......
......@@ -37,7 +37,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0,
)
p = isnothing(prob.p) ? prob.p : collect(prob.p)
if isnothing(saveat)
saveat = prob.savecfg.saveat
saveat = prob.saveat
elseif isa(saveat, Number)
saveat = collect(prob.tspan[1]:saveat:prob.tspan[2])
end
......
......@@ -97,8 +97,7 @@ function CommonSolve.solve!(integrator::CryoGridIntegrator)
integrator.sol.retcode = ReturnCode.Success
end
# if no save points are specified, save final state
prob = integrator.sol.prob
if isempty(prob.savecfg.saveat)
if isempty(integrator.sol.prob.saveat)
prob.savefunc(integrator.u, integrator.t, integrator)
push!(integrator.sol.u, integrator.u)
push!(integrator.sol.t, integrator.t)
......@@ -162,7 +161,7 @@ function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple
save_interval = ClosedInterval(tspan...)
tile = Tile(sol.prob.f) # Tile
grid = Grid(tile.grid.*u"m")
savecache = sol.prob.savecache
savecache = sol.prob.savefunc.cache
ts = savecache.t # use save cache time points
# check if last value is duplicated
ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts
......
......@@ -3,15 +3,14 @@
Represents a CryoGrid discretized PDE forward model configuration using the `SciMLBase`/`DiffEqBase` problem interface.
"""
struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
f::TT
u0::Tu
tspan::NTuple{2,Tt}
p::Tp
callbacks::Tcb
savecfg::Tsv
saveat::Tsv
savefunc::Tsf
savecache::Tsc
isoutofdomain::Tdf
kwargs::Tkw
CryoGridProblem{iip}(
......@@ -20,13 +19,12 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.Abs
tspan::NTuple{2,Tt},
p::Tp,
cbs::Tcb,
savecfg::Tsv,
saveat::Tsv,
savefunc::Tsf,
savecache::Tsc,
iood::Tdf,
kwargs::Tkw
) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} =
new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, savecfg, savefunc, savecache, iood, kwargs)
) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} =
new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs)
end
"""
......@@ -83,8 +81,13 @@ function CryoGridProblem(
du0 = zero(u0)
# remove units
tile = stripunits(tile)
# set up saving callback
saveat = expandtstep(saveat, tspan)
savecache = InputOutput.SaveCache(Float64[], [])
savefunc = saving_function(savecache, savevars...)
savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep)
diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!)
defaultcallbacks = (diagnostic_step_callback,)
defaultcallbacks = (savingcallback, diagnostic_step_callback)
# add step limiter to default callbacks, if defined
if !isnothing(step_limiter)
defaultcallbacks = (
......@@ -96,18 +99,13 @@ function CryoGridProblem(
layercallbacks = _makecallbacks(tile)
# add user callbacks
usercallbacks = isnothing(callback) ? () : callback
callbacks = Callbacks(defaultcallbacks, layercallbacks, usercallbacks)
callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...)
# build mass matrix
mass_matrix = Numerics.build_mass_matrix(tile.state)
# get params
p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p
func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...)
# set up saving config
saveat = expandtstep(saveat, tspan)
saveconfig = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
savecache = InputOutput.SaveCache(Float64[], [])
savefunc = saving_function(savecache, savevars...)
return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveconfig, savefunc, savecache, isoutofdomain, prob_kwargs)
return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs)
end
function SciMLBase.remake(
......@@ -116,13 +114,11 @@ function SciMLBase.remake(
u0=nothing,
tspan=prob.tspan,
p=prob.p,
savevars=prob.savecfg.savevars,
saveat=prob.savecfg.saveat,
save_start=prob.savecfg.save_start,
save_everystep=prob.savecfg.save_everystep,
callbacks=prob.callbacks,
saveat=prob.saveat,
savefunc=prob.savefunc,
isoutofdomain=prob.isoutofdomain,
kwargs=prob.kwargs,
callbacks=nothing,
) where iip
# always re-run initialcondition! with the given tspan and parameters
_u0, du0 = initialcondition!(Tile(f), tspan, p)
......@@ -133,15 +129,7 @@ function SciMLBase.remake(
else
u0 = _u0
end
# create new save cache
saveat = expandtstep(saveat, tspan)
savecfg = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
savecache = InputOutput.SaveCache(Float64[], [])
savefunc = saving_function(savecache, savevars...)
# rebuild Callbacks struct with new user callbacks if provided
callbacks = isnothing(callbacks) ? prob.callbacks : Callbacks(prob.callbacks.default, prob.callbacks.layer, callbacks)
# construct new CryoGridProblem
return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, savecfg, savefunc, savecache, isoutofdomain, kwargs)
return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs)
end
CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))")
......@@ -151,30 +139,15 @@ function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...)
return solve!(integrator)
end
function SciMLBase.ODEProblem(prob::CryoGridProblem)
savingcallback = FunctionCallingCallback(
prob.savefunc;
funcat=prob.savecfg.saveat,
func_start=prob.savecfg.save_start,
func_everystep=prob.savecfg.save_everystep
)
callbacks = CallbackSet(
savingcallback,
prob.callbacks.default...,
prob.callbacks.layer...,
prob.callbacks.user...
)
odeprob = ODEProblem(
prob.f,
prob.u0,
prob.tspan,
prob.p;
callback=callbacks,
isoutofdomain=prob.isoutofdomain,
prob.kwargs...
)
return odeprob
end
SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem(
prob.f,
prob.u0,
prob.tspan,
prob.p;
callback=prob.callbacks,
isoutofdomain=prob.isoutofdomain,
prob.kwargs...
)
DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob
......@@ -231,14 +204,7 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator)
DiffEqBase.u_modified!(integrator, u_modified)
end
# Callback utilities
struct Callbacks
default
layer
user
end
# callback building functions
function _makecallbacks(tile::Tile)
eventname(::Event{name}) where name = name
isgridevent(::GridContinuousEvent) = true
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment