diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl index a7318439138f257471a99916cdacbfe76bea79af..3c5d86814234960b7c450cfc1c1d627e9f09058c 100644 --- a/examples/heat_simple_autodiff_grad.jl +++ b/examples/heat_simple_autodiff_grad.jl @@ -4,7 +4,6 @@ # # TODO: add more detail/background using CryoGrid -CryoGrid.debug(true) # Set up forcings and boundary conditions similarly to other examples: forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044); @@ -14,7 +13,7 @@ soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve)) grid = CryoGrid.DefaultGrid_5cm initT = initializer(:T, tempprofile) tile = CryoGrid.SoilHeatTile( - :H, + :T, TemperatureBC(Input(:Tair), NFactor(nf=Param(0.5), nt=Param(0.9))), GeothermalHeatFlux(0.053u"W/m^2"), soilprofile, @@ -22,30 +21,17 @@ tile = CryoGrid.SoilHeatTile( initT; grid=grid ) -tspan = (DateTime(2010,9,1),DateTime(2010,10,1)) +tspan = (DateTime(2010,9,1),DateTime(2011,10,1)) u0, du0 = @time initialcondition!(tile, tspan); # We can retrieve the parameters of the system from `tile`: para = CryoGrid.parameters(tile) # Create the `CryoGridProblem`. -prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0, savevars=(:T,)) - -function testfunc(prob) - function(p) - newprob = remake(prob, p=p) - sol = solve(newprob, Euler(), dt=300.0) - out = CryoGridOutput(sol) - y = mean(ustrip.(Array(out.T))) - return y - end -end - -f = testfunc(prob) -grad = @time ForwardDiff.gradient(f, vec(prob.p)) +prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0) # Solve the forward problem with default parameter settings: -sol = @time solve(prob); +sol = @time solve(prob) out = CryoGridOutput(sol) # Import relevant packages for automatic differentiation. diff --git a/src/IO/output.jl b/src/IO/output.jl index 630f91a17b0ffcc75750ae1f03d00b5618c55132..82d7e32fe8058aafde2720382810cfdaff2fa817 100644 --- a/src/IO/output.jl +++ b/src/IO/output.jl @@ -99,10 +99,3 @@ function reset!(cache::SaveCache) resize!(cache.t, 0) resize!(cache.vals, 0) end - -struct SaveConfig - savevars::Tuple - saveat::Vector{Float64} - save_start::Bool - save_everystep::Bool -end diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl index f86006dbd99c0143c711737a4e29b62478319465..cd7c4e9d8819a7f284072ea4df01e5bd9bdb24c8 100644 --- a/src/Solvers/DiffEq/ode_solvers.jl +++ b/src/Solvers/DiffEq/ode_solvers.jl @@ -81,14 +81,12 @@ function CommonSolve.init( kwargs... ) ode_prob = ODEProblem(prob) - ode_integrator = init(ode_prob, alg, args...; kwargs...) - integrator = CryoGridDiffEqIntegrator(prob, ode_integrator) - return integrator + integrator = init(ode_prob, alg, args...; kwargs...) + return CryoGridDiffEqIntegrator(prob, integrator) end function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...) - rv = step!(integrator.integrator, args...; kwargs...) - return rv + return step!(integrator.integrator, args...; kwargs...) end function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator) diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl index b7c1e0364c9d6d870597af5aa1685d903706b11e..f93da9b37e754e9b76dc57e5013c30fccce259d2 100644 --- a/src/Solvers/basic_solvers.jl +++ b/src/Solvers/basic_solvers.jl @@ -37,7 +37,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, ) p = isnothing(prob.p) ? prob.p : collect(prob.p) if isnothing(saveat) - saveat = prob.savecfg.saveat + saveat = prob.saveat elseif isa(saveat, Number) saveat = collect(prob.tspan[1]:saveat:prob.tspan[2]) end diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl index de23bec5974c2610d6381b62159d384832fb7f0e..ba7b2bfed2a231ba1431e1767b1e24420d0c011e 100755 --- a/src/Solvers/integrator.jl +++ b/src/Solvers/integrator.jl @@ -97,8 +97,7 @@ function CommonSolve.solve!(integrator::CryoGridIntegrator) integrator.sol.retcode = ReturnCode.Success end # if no save points are specified, save final state - prob = integrator.sol.prob - if isempty(prob.savecfg.saveat) + if isempty(integrator.sol.prob.saveat) prob.savefunc(integrator.u, integrator.t, integrator) push!(integrator.sol.u, integrator.u) push!(integrator.sol.t, integrator.t) @@ -162,7 +161,7 @@ function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple save_interval = ClosedInterval(tspan...) tile = Tile(sol.prob.f) # Tile grid = Grid(tile.grid.*u"m") - savecache = sol.prob.savecache + savecache = sol.prob.savefunc.cache ts = savecache.t # use save cache time points # check if last value is duplicated ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts diff --git a/src/problem.jl b/src/problem.jl index 8c90f8e357f241c7ee49c23a7917e3fb92a443e9..a144ef8aee3e09514021b0f7ace8b37316928b96 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -3,15 +3,14 @@ Represents a CryoGrid discretized PDE forward model configuration using the `SciMLBase`/`DiffEqBase` problem interface. """ -struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip} +struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip} f::TT u0::Tu tspan::NTuple{2,Tt} p::Tp callbacks::Tcb - savecfg::Tsv + saveat::Tsv savefunc::Tsf - savecache::Tsc isoutofdomain::Tdf kwargs::Tkw CryoGridProblem{iip}( @@ -20,13 +19,12 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.Abs tspan::NTuple{2,Tt}, p::Tp, cbs::Tcb, - savecfg::Tsv, + saveat::Tsv, savefunc::Tsf, - savecache::Tsc, iood::Tdf, kwargs::Tkw - ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} = - new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, savecfg, savefunc, savecache, iood, kwargs) + ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = + new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs) end """ @@ -83,8 +81,13 @@ function CryoGridProblem( du0 = zero(u0) # remove units tile = stripunits(tile) + # set up saving callback + saveat = expandtstep(saveat, tspan) + savecache = InputOutput.SaveCache(Float64[], []) + savefunc = saving_function(savecache, savevars...) + savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep) diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!) - defaultcallbacks = (diagnostic_step_callback,) + defaultcallbacks = (savingcallback, diagnostic_step_callback) # add step limiter to default callbacks, if defined if !isnothing(step_limiter) defaultcallbacks = ( @@ -96,18 +99,13 @@ function CryoGridProblem( layercallbacks = _makecallbacks(tile) # add user callbacks usercallbacks = isnothing(callback) ? () : callback - callbacks = Callbacks(defaultcallbacks, layercallbacks, usercallbacks) + callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...) # build mass matrix mass_matrix = Numerics.build_mass_matrix(tile.state) # get params p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...) - # set up saving config - saveat = expandtstep(saveat, tspan) - saveconfig = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep) - savecache = InputOutput.SaveCache(Float64[], []) - savefunc = saving_function(savecache, savevars...) - return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveconfig, savefunc, savecache, isoutofdomain, prob_kwargs) + return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs) end function SciMLBase.remake( @@ -116,13 +114,11 @@ function SciMLBase.remake( u0=nothing, tspan=prob.tspan, p=prob.p, - savevars=prob.savecfg.savevars, - saveat=prob.savecfg.saveat, - save_start=prob.savecfg.save_start, - save_everystep=prob.savecfg.save_everystep, + callbacks=prob.callbacks, + saveat=prob.saveat, + savefunc=prob.savefunc, isoutofdomain=prob.isoutofdomain, kwargs=prob.kwargs, - callbacks=nothing, ) where iip # always re-run initialcondition! with the given tspan and parameters _u0, du0 = initialcondition!(Tile(f), tspan, p) @@ -133,15 +129,7 @@ function SciMLBase.remake( else u0 = _u0 end - # create new save cache - saveat = expandtstep(saveat, tspan) - savecfg = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep) - savecache = InputOutput.SaveCache(Float64[], []) - savefunc = saving_function(savecache, savevars...) - # rebuild Callbacks struct with new user callbacks if provided - callbacks = isnothing(callbacks) ? prob.callbacks : Callbacks(prob.callbacks.default, prob.callbacks.layer, callbacks) - # construct new CryoGridProblem - return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, savecfg, savefunc, savecache, isoutofdomain, kwargs) + return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs) end CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))") @@ -151,30 +139,15 @@ function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...) return solve!(integrator) end -function SciMLBase.ODEProblem(prob::CryoGridProblem) - savingcallback = FunctionCallingCallback( - prob.savefunc; - funcat=prob.savecfg.saveat, - func_start=prob.savecfg.save_start, - func_everystep=prob.savecfg.save_everystep - ) - callbacks = CallbackSet( - savingcallback, - prob.callbacks.default..., - prob.callbacks.layer..., - prob.callbacks.user... - ) - odeprob = ODEProblem( - prob.f, - prob.u0, - prob.tspan, - prob.p; - callback=callbacks, - isoutofdomain=prob.isoutofdomain, - prob.kwargs... - ) - return odeprob -end +SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem( + prob.f, + prob.u0, + prob.tspan, + prob.p; + callback=prob.callbacks, + isoutofdomain=prob.isoutofdomain, + prob.kwargs... +) DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob @@ -231,14 +204,7 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator) DiffEqBase.u_modified!(integrator, u_modified) end -# Callback utilities - -struct Callbacks - default - layer - user -end - +# callback building functions function _makecallbacks(tile::Tile) eventname(::Event{name}) where name = name isgridevent(::GridContinuousEvent) = true