From 7f4422ff5c75bbbe9cc9005a0f5ba8014f9ad356 Mon Sep 17 00:00:00 2001 From: Brian Groenke <brian.groenke@awi.de> Date: Wed, 8 Jan 2025 01:22:16 +0100 Subject: [PATCH] Fix issues with new save caching system --- examples/heat_simple_autodiff_grad.jl | 22 +++++-- src/IO/output.jl | 7 +++ src/Solvers/DiffEq/ode_solvers.jl | 8 ++- src/Solvers/basic_solvers.jl | 2 +- src/Solvers/integrator.jl | 5 +- src/problem.jl | 88 +++++++++++++++++++-------- 6 files changed, 95 insertions(+), 37 deletions(-) diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl index 3c5d8681..a7318439 100644 --- a/examples/heat_simple_autodiff_grad.jl +++ b/examples/heat_simple_autodiff_grad.jl @@ -4,6 +4,7 @@ # # TODO: add more detail/background using CryoGrid +CryoGrid.debug(true) # Set up forcings and boundary conditions similarly to other examples: forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044); @@ -13,7 +14,7 @@ soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve)) grid = CryoGrid.DefaultGrid_5cm initT = initializer(:T, tempprofile) tile = CryoGrid.SoilHeatTile( - :T, + :H, TemperatureBC(Input(:Tair), NFactor(nf=Param(0.5), nt=Param(0.9))), GeothermalHeatFlux(0.053u"W/m^2"), soilprofile, @@ -21,17 +22,30 @@ tile = CryoGrid.SoilHeatTile( initT; grid=grid ) -tspan = (DateTime(2010,9,1),DateTime(2011,10,1)) +tspan = (DateTime(2010,9,1),DateTime(2010,10,1)) u0, du0 = @time initialcondition!(tile, tspan); # We can retrieve the parameters of the system from `tile`: para = CryoGrid.parameters(tile) # Create the `CryoGridProblem`. -prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0) +prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0, savevars=(:T,)) + +function testfunc(prob) + function(p) + newprob = remake(prob, p=p) + sol = solve(newprob, Euler(), dt=300.0) + out = CryoGridOutput(sol) + y = mean(ustrip.(Array(out.T))) + return y + end +end + +f = testfunc(prob) +grad = @time ForwardDiff.gradient(f, vec(prob.p)) # Solve the forward problem with default parameter settings: -sol = @time solve(prob) +sol = @time solve(prob); out = CryoGridOutput(sol) # Import relevant packages for automatic differentiation. diff --git a/src/IO/output.jl b/src/IO/output.jl index 82d7e32f..630f91a1 100644 --- a/src/IO/output.jl +++ b/src/IO/output.jl @@ -99,3 +99,10 @@ function reset!(cache::SaveCache) resize!(cache.t, 0) resize!(cache.vals, 0) end + +struct SaveConfig + savevars::Tuple + saveat::Vector{Float64} + save_start::Bool + save_everystep::Bool +end diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl index cd7c4e9d..f86006db 100644 --- a/src/Solvers/DiffEq/ode_solvers.jl +++ b/src/Solvers/DiffEq/ode_solvers.jl @@ -81,12 +81,14 @@ function CommonSolve.init( kwargs... ) ode_prob = ODEProblem(prob) - integrator = init(ode_prob, alg, args...; kwargs...) - return CryoGridDiffEqIntegrator(prob, integrator) + ode_integrator = init(ode_prob, alg, args...; kwargs...) + integrator = CryoGridDiffEqIntegrator(prob, ode_integrator) + return integrator end function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...) - return step!(integrator.integrator, args...; kwargs...) + rv = step!(integrator.integrator, args...; kwargs...) + return rv end function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator) diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl index f93da9b3..b7c1e036 100644 --- a/src/Solvers/basic_solvers.jl +++ b/src/Solvers/basic_solvers.jl @@ -37,7 +37,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, ) p = isnothing(prob.p) ? prob.p : collect(prob.p) if isnothing(saveat) - saveat = prob.saveat + saveat = prob.savecfg.saveat elseif isa(saveat, Number) saveat = collect(prob.tspan[1]:saveat:prob.tspan[2]) end diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl index ba7b2bfe..de23bec5 100755 --- a/src/Solvers/integrator.jl +++ b/src/Solvers/integrator.jl @@ -97,7 +97,8 @@ function CommonSolve.solve!(integrator::CryoGridIntegrator) integrator.sol.retcode = ReturnCode.Success end # if no save points are specified, save final state - if isempty(integrator.sol.prob.saveat) + prob = integrator.sol.prob + if isempty(prob.savecfg.saveat) prob.savefunc(integrator.u, integrator.t, integrator) push!(integrator.sol.u, integrator.u) push!(integrator.sol.t, integrator.t) @@ -161,7 +162,7 @@ function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple save_interval = ClosedInterval(tspan...) tile = Tile(sol.prob.f) # Tile grid = Grid(tile.grid.*u"m") - savecache = sol.prob.savefunc.cache + savecache = sol.prob.savecache ts = savecache.t # use save cache time points # check if last value is duplicated ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts diff --git a/src/problem.jl b/src/problem.jl index a144ef8a..8c90f8e3 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -3,14 +3,15 @@ Represents a CryoGrid discretized PDE forward model configuration using the `SciMLBase`/`DiffEqBase` problem interface. """ -struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip} +struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip} f::TT u0::Tu tspan::NTuple{2,Tt} p::Tp callbacks::Tcb - saveat::Tsv + savecfg::Tsv savefunc::Tsf + savecache::Tsc isoutofdomain::Tdf kwargs::Tkw CryoGridProblem{iip}( @@ -19,12 +20,13 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.Abstrac tspan::NTuple{2,Tt}, p::Tp, cbs::Tcb, - saveat::Tsv, + savecfg::Tsv, savefunc::Tsf, + savecache::Tsc, iood::Tdf, kwargs::Tkw - ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = - new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs) + ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} = + new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, savecfg, savefunc, savecache, iood, kwargs) end """ @@ -81,13 +83,8 @@ function CryoGridProblem( du0 = zero(u0) # remove units tile = stripunits(tile) - # set up saving callback - saveat = expandtstep(saveat, tspan) - savecache = InputOutput.SaveCache(Float64[], []) - savefunc = saving_function(savecache, savevars...) - savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep) diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!) - defaultcallbacks = (savingcallback, diagnostic_step_callback) + defaultcallbacks = (diagnostic_step_callback,) # add step limiter to default callbacks, if defined if !isnothing(step_limiter) defaultcallbacks = ( @@ -99,13 +96,18 @@ function CryoGridProblem( layercallbacks = _makecallbacks(tile) # add user callbacks usercallbacks = isnothing(callback) ? () : callback - callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...) + callbacks = Callbacks(defaultcallbacks, layercallbacks, usercallbacks) # build mass matrix mass_matrix = Numerics.build_mass_matrix(tile.state) # get params p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...) - return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs) + # set up saving config + saveat = expandtstep(saveat, tspan) + saveconfig = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep) + savecache = InputOutput.SaveCache(Float64[], []) + savefunc = saving_function(savecache, savevars...) + return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveconfig, savefunc, savecache, isoutofdomain, prob_kwargs) end function SciMLBase.remake( @@ -114,11 +116,13 @@ function SciMLBase.remake( u0=nothing, tspan=prob.tspan, p=prob.p, - callbacks=prob.callbacks, - saveat=prob.saveat, - savefunc=prob.savefunc, + savevars=prob.savecfg.savevars, + saveat=prob.savecfg.saveat, + save_start=prob.savecfg.save_start, + save_everystep=prob.savecfg.save_everystep, isoutofdomain=prob.isoutofdomain, kwargs=prob.kwargs, + callbacks=nothing, ) where iip # always re-run initialcondition! with the given tspan and parameters _u0, du0 = initialcondition!(Tile(f), tspan, p) @@ -129,7 +133,15 @@ function SciMLBase.remake( else u0 = _u0 end - return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs) + # create new save cache + saveat = expandtstep(saveat, tspan) + savecfg = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep) + savecache = InputOutput.SaveCache(Float64[], []) + savefunc = saving_function(savecache, savevars...) + # rebuild Callbacks struct with new user callbacks if provided + callbacks = isnothing(callbacks) ? prob.callbacks : Callbacks(prob.callbacks.default, prob.callbacks.layer, callbacks) + # construct new CryoGridProblem + return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, savecfg, savefunc, savecache, isoutofdomain, kwargs) end CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))") @@ -139,15 +151,30 @@ function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...) return solve!(integrator) end -SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem( - prob.f, - prob.u0, - prob.tspan, - prob.p; - callback=prob.callbacks, - isoutofdomain=prob.isoutofdomain, - prob.kwargs... -) +function SciMLBase.ODEProblem(prob::CryoGridProblem) + savingcallback = FunctionCallingCallback( + prob.savefunc; + funcat=prob.savecfg.saveat, + func_start=prob.savecfg.save_start, + func_everystep=prob.savecfg.save_everystep + ) + callbacks = CallbackSet( + savingcallback, + prob.callbacks.default..., + prob.callbacks.layer..., + prob.callbacks.user... + ) + odeprob = ODEProblem( + prob.f, + prob.u0, + prob.tspan, + prob.p; + callback=callbacks, + isoutofdomain=prob.isoutofdomain, + prob.kwargs... + ) + return odeprob +end DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob @@ -204,7 +231,14 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator) DiffEqBase.u_modified!(integrator, u_modified) end -# callback building functions +# Callback utilities + +struct Callbacks + default + layer + user +end + function _makecallbacks(tile::Tile) eventname(::Event{name}) where name = name isgridevent(::GridContinuousEvent) = true -- GitLab