From 7f4422ff5c75bbbe9cc9005a0f5ba8014f9ad356 Mon Sep 17 00:00:00 2001
From: Brian Groenke <brian.groenke@awi.de>
Date: Wed, 8 Jan 2025 01:22:16 +0100
Subject: [PATCH] Fix issues with new save caching system

---
 examples/heat_simple_autodiff_grad.jl | 22 +++++--
 src/IO/output.jl                      |  7 +++
 src/Solvers/DiffEq/ode_solvers.jl     |  8 ++-
 src/Solvers/basic_solvers.jl          |  2 +-
 src/Solvers/integrator.jl             |  5 +-
 src/problem.jl                        | 88 +++++++++++++++++++--------
 6 files changed, 95 insertions(+), 37 deletions(-)

diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl
index 3c5d8681..a7318439 100644
--- a/examples/heat_simple_autodiff_grad.jl
+++ b/examples/heat_simple_autodiff_grad.jl
@@ -4,6 +4,7 @@
 #
 # TODO: add more detail/background
 using CryoGrid
+CryoGrid.debug(true)
 
 # Set up forcings and boundary conditions similarly to other examples:
 forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
@@ -13,7 +14,7 @@ soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve))
 grid = CryoGrid.DefaultGrid_5cm
 initT = initializer(:T, tempprofile)
 tile = CryoGrid.SoilHeatTile(
-    :T,
+    :H,
     TemperatureBC(Input(:Tair), NFactor(nf=Param(0.5), nt=Param(0.9))),
     GeothermalHeatFlux(0.053u"W/m^2"),
     soilprofile,
@@ -21,17 +22,30 @@ tile = CryoGrid.SoilHeatTile(
     initT;
     grid=grid
 )
-tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
+tspan = (DateTime(2010,9,1),DateTime(2010,10,1))
 u0, du0 = @time initialcondition!(tile, tspan);
 
 # We can retrieve the parameters of the system from `tile`:
 para = CryoGrid.parameters(tile)
 
 # Create the `CryoGridProblem`.
-prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
+prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0, savevars=(:T,))
+
+function testfunc(prob)
+    function(p)
+        newprob = remake(prob, p=p)
+        sol = solve(newprob, Euler(), dt=300.0)
+        out = CryoGridOutput(sol)
+        y = mean(ustrip.(Array(out.T)))
+        return y
+    end
+end
+
+f = testfunc(prob)
+grad = @time ForwardDiff.gradient(f, vec(prob.p))
 
 # Solve the forward problem with default parameter settings:
-sol = @time solve(prob)
+sol = @time solve(prob);
 out = CryoGridOutput(sol)
 
 # Import relevant packages for automatic differentiation.
diff --git a/src/IO/output.jl b/src/IO/output.jl
index 82d7e32f..630f91a1 100644
--- a/src/IO/output.jl
+++ b/src/IO/output.jl
@@ -99,3 +99,10 @@ function reset!(cache::SaveCache)
     resize!(cache.t, 0)
     resize!(cache.vals, 0)
 end
+
+struct SaveConfig
+    savevars::Tuple
+    saveat::Vector{Float64}
+    save_start::Bool
+    save_everystep::Bool
+end
diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl
index cd7c4e9d..f86006db 100644
--- a/src/Solvers/DiffEq/ode_solvers.jl
+++ b/src/Solvers/DiffEq/ode_solvers.jl
@@ -81,12 +81,14 @@ function CommonSolve.init(
     kwargs...
 )
     ode_prob = ODEProblem(prob)
-    integrator = init(ode_prob, alg, args...; kwargs...)
-    return CryoGridDiffEqIntegrator(prob, integrator)
+    ode_integrator = init(ode_prob, alg, args...; kwargs...)
+    integrator = CryoGridDiffEqIntegrator(prob, ode_integrator)
+    return integrator
 end
 
 function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...)
-    return step!(integrator.integrator, args...; kwargs...)
+    rv = step!(integrator.integrator, args...; kwargs...)
+    return rv
 end
 
 function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator)
diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl
index f93da9b3..b7c1e036 100644
--- a/src/Solvers/basic_solvers.jl
+++ b/src/Solvers/basic_solvers.jl
@@ -37,7 +37,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0,
     )
     p = isnothing(prob.p) ? prob.p : collect(prob.p)
     if isnothing(saveat)
-        saveat = prob.saveat
+        saveat = prob.savecfg.saveat
     elseif isa(saveat, Number)
         saveat = collect(prob.tspan[1]:saveat:prob.tspan[2])
     end
diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl
index ba7b2bfe..de23bec5 100755
--- a/src/Solvers/integrator.jl
+++ b/src/Solvers/integrator.jl
@@ -97,7 +97,8 @@ function CommonSolve.solve!(integrator::CryoGridIntegrator)
         integrator.sol.retcode = ReturnCode.Success
     end
     # if no save points are specified, save final state
-    if isempty(integrator.sol.prob.saveat)
+    prob = integrator.sol.prob
+    if isempty(prob.savecfg.saveat)
         prob.savefunc(integrator.u, integrator.t, integrator)
         push!(integrator.sol.u, integrator.u)
         push!(integrator.sol.t, integrator.t)
@@ -161,7 +162,7 @@ function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple
     save_interval = ClosedInterval(tspan...)
     tile = Tile(sol.prob.f) # Tile
     grid = Grid(tile.grid.*u"m")
-    savecache = sol.prob.savefunc.cache
+    savecache = sol.prob.savecache
     ts = savecache.t # use save cache time points
     # check if last value is duplicated
     ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts
diff --git a/src/problem.jl b/src/problem.jl
index a144ef8a..8c90f8e3 100644
--- a/src/problem.jl
+++ b/src/problem.jl
@@ -3,14 +3,15 @@
 
 Represents a CryoGrid discretized PDE forward model configuration using the `SciMLBase`/`DiffEqBase` problem interface.
 """
-struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
+struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
     f::TT
     u0::Tu
     tspan::NTuple{2,Tt}
     p::Tp
     callbacks::Tcb
-    saveat::Tsv
+    savecfg::Tsv
     savefunc::Tsf
+    savecache::Tsc
     isoutofdomain::Tdf
     kwargs::Tkw
     CryoGridProblem{iip}(
@@ -19,12 +20,13 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.Abstrac
         tspan::NTuple{2,Tt},
         p::Tp,
         cbs::Tcb,
-        saveat::Tsv,
+        savecfg::Tsv,
         savefunc::Tsf,
+        savecache::Tsc,
         iood::Tdf,
         kwargs::Tkw
-    ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} =
-        new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs)
+    ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} =
+        new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, savecfg, savefunc, savecache, iood, kwargs)
 end
 
 """
@@ -81,13 +83,8 @@ function CryoGridProblem(
     du0 = zero(u0)
     # remove units
     tile = stripunits(tile)
-    # set up saving callback
-    saveat = expandtstep(saveat, tspan)
-    savecache = InputOutput.SaveCache(Float64[], [])
-    savefunc = saving_function(savecache, savevars...)
-    savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep)
     diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!)
-    defaultcallbacks = (savingcallback, diagnostic_step_callback)
+    defaultcallbacks = (diagnostic_step_callback,)
     # add step limiter to default callbacks, if defined
     if !isnothing(step_limiter)
         defaultcallbacks = (
@@ -99,13 +96,18 @@ function CryoGridProblem(
     layercallbacks = _makecallbacks(tile)
     # add user callbacks
     usercallbacks = isnothing(callback) ? () : callback
-    callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...)
+    callbacks = Callbacks(defaultcallbacks, layercallbacks, usercallbacks)
     # build mass matrix
     mass_matrix = Numerics.build_mass_matrix(tile.state)
     # get params
     p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p
 	func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...)
-	return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs)
+    # set up saving config
+    saveat = expandtstep(saveat, tspan)
+    saveconfig = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
+    savecache = InputOutput.SaveCache(Float64[], [])
+    savefunc = saving_function(savecache, savevars...)
+	return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveconfig, savefunc, savecache, isoutofdomain, prob_kwargs)
 end
 
 function SciMLBase.remake(
@@ -114,11 +116,13 @@ function SciMLBase.remake(
     u0=nothing,
     tspan=prob.tspan,
     p=prob.p,
-    callbacks=prob.callbacks,
-    saveat=prob.saveat,
-    savefunc=prob.savefunc,
+    savevars=prob.savecfg.savevars,
+    saveat=prob.savecfg.saveat,
+    save_start=prob.savecfg.save_start,
+    save_everystep=prob.savecfg.save_everystep,
     isoutofdomain=prob.isoutofdomain,
     kwargs=prob.kwargs,
+    callbacks=nothing,
 ) where iip
     # always re-run initialcondition! with the given tspan and parameters
     _u0, du0 = initialcondition!(Tile(f), tspan, p)
@@ -129,7 +133,15 @@ function SciMLBase.remake(
     else
         u0 = _u0
     end
-    return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs)
+    # create new save cache
+    saveat = expandtstep(saveat, tspan)
+    savecfg = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
+    savecache = InputOutput.SaveCache(Float64[], [])
+    savefunc = saving_function(savecache, savevars...)
+    # rebuild Callbacks struct with new user callbacks if provided
+    callbacks = isnothing(callbacks) ? prob.callbacks : Callbacks(prob.callbacks.default, prob.callbacks.layer, callbacks)
+    # construct new CryoGridProblem
+    return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, savecfg, savefunc, savecache, isoutofdomain, kwargs)
 end
 
 CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))")
@@ -139,15 +151,30 @@ function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...)
     return solve!(integrator)
 end
 
-SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem(
-    prob.f,
-    prob.u0,
-    prob.tspan,
-    prob.p;
-    callback=prob.callbacks,
-    isoutofdomain=prob.isoutofdomain,
-    prob.kwargs...
-)
+function SciMLBase.ODEProblem(prob::CryoGridProblem)
+    savingcallback = FunctionCallingCallback(
+        prob.savefunc;
+        funcat=prob.savecfg.saveat,
+        func_start=prob.savecfg.save_start,
+        func_everystep=prob.savecfg.save_everystep
+    )
+    callbacks = CallbackSet(
+        savingcallback,
+        prob.callbacks.default...,
+        prob.callbacks.layer...,
+        prob.callbacks.user...
+    )
+    odeprob = ODEProblem(
+        prob.f,
+        prob.u0,
+        prob.tspan,
+        prob.p;
+        callback=callbacks,
+        isoutofdomain=prob.isoutofdomain,
+        prob.kwargs...
+    )
+    return odeprob
+end
 
 DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob
 
@@ -204,7 +231,14 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator)
     DiffEqBase.u_modified!(integrator, u_modified)
 end
 
-# callback building functions
+# Callback utilities
+
+struct Callbacks
+    default
+    layer
+    user
+end
+
 function _makecallbacks(tile::Tile)
     eventname(::Event{name}) where name = name
     isgridevent(::GridContinuousEvent) = true
-- 
GitLab