From 0ecc92f3fb0669784a2e72fb69608d43abaf09bf Mon Sep 17 00:00:00 2001 From: Brian Groenke <brian.groenke@awi.de> Date: Wed, 1 Jan 2025 17:36:19 +0100 Subject: [PATCH] Refactor problem and built-in solvers to use CommonSolve --- Project.toml | 3 +- src/CryoGrid.jl | 8 +- src/IO/InputOutput.jl | 2 +- src/IO/output.jl | 98 ++++----------------- src/Numerics/caches.jl | 2 + src/Solvers/DiffEq/DiffEq.jl | 2 + src/Solvers/DiffEq/ode_solvers.jl | 98 +++++++++++++++++++-- src/Solvers/LiteImplicit/LiteImplicit.jl | 1 + src/Solvers/LiteImplicit/cglite_types.jl | 13 ++- src/Solvers/Solvers.jl | 2 +- src/Solvers/basic_solvers.jl | 13 ++- src/Solvers/integrator.jl | 107 ++++++++++++++++++++--- src/Tiles/tile.jl | 33 +------ src/Tiles/tile_base.jl | 5 -- src/Utils/Utils.jl | 10 ++- src/problem.jl | 63 +++++++++---- 16 files changed, 288 insertions(+), 172 deletions(-) diff --git a/Project.toml b/Project.toml index 8a45647e..72c50630 100755 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "CryoGrid" uuid = "a535b82e-5f3d-4d97-8b0b-d6483f5bebd5" authors = ["Brian Groenke <brian.groenke@awi.de>", "Jan Nitzbon <jan.nitzbon@awi.de>", "Moritz Langer <moritz.langer@awi.de>"] -version = "0.23.1" +version = "0.23.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/src/CryoGrid.jl b/src/CryoGrid.jl index 9c227ea0..39e9ff95 100755 --- a/src/CryoGrid.jl +++ b/src/CryoGrid.jl @@ -42,6 +42,7 @@ using ModelParameters using Reexport using Requires +import CommonSolve import Flatten import Interpolations @@ -110,9 +111,6 @@ export ConstantBC, PeriodicBC, ConstantValue, PeriodicValue, ConstantFlux, Perio export volumetricfractions include("Physics/Physics.jl") -include("Diagnostics/Diagnostics.jl") -@reexport using .Diagnostics - # Coupling include("coupling.jl") @@ -124,6 +122,10 @@ include("problem.jl") export CGEuler, CryoGridIntegrator, CryoGridSolution include("Solvers/Solvers.jl") +# Diagnostics +include("Diagnostics/Diagnostics.jl") +@reexport using .Diagnostics + # Built-in model definitions export SoilHeatTile, SamoylovDefault include("models.jl") diff --git a/src/IO/InputOutput.jl b/src/IO/InputOutput.jl index aa76df7b..32dd3407 100644 --- a/src/IO/InputOutput.jl +++ b/src/IO/InputOutput.jl @@ -7,7 +7,7 @@ using CryoGrid.Utils using Base: @propagate_inbounds using ComponentArrays using ConstructionBase -using DataStructures: DefaultDict, OrderedDict +using DataStructures: DefaultDict using Dates using DimensionalData using Downloads diff --git a/src/IO/output.jl b/src/IO/output.jl index b9edcee3..82d7e32f 100644 --- a/src/IO/output.jl +++ b/src/IO/output.jl @@ -51,85 +51,6 @@ dimstr(::Ti) = "time" dimstr(::Z) = "depth" dimstr(::Dim{name}) where name = string(name) -""" - CryoGridOutput(sol::TSol, tspan::NTuple{2,Float64}=(-Inf,Inf)) where {TSol<:SciMLBase.AbstractODESolution} - -Constructs a `CryoGridOutput` from the given `ODESolution`. Optional argument `tspan` restricts the time span of the output. -""" -CryoGridOutput(sol::SciMLBase.AbstractODESolution, tspan::NTuple{2,DateTime}) = CryoGridOutput(sol, convert_tspan(tspan)) -function CryoGridOutput(sol::SciMLBase.AbstractODESolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) - # Helper functions for mapping variables to appropriate DimArrays by grid/shape. - withdims(var::Var{name,<:CryoGrid.OnGrid{Cells}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts))) - withdims(var::Var{name,<:CryoGrid.OnGrid{Edges}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts))) - withdims(var::Var{name}, arr, zs, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Dim{name}(1:size(arr,1)),Ti(ts))) - save_interval = ClosedInterval(tspan...) - tile = Tile(sol.prob.f) # Tile - grid = Grid(tile.grid.*u"m") - ts = tile.data.outputs.t # use save callback time points - # check if last value is duplicated - ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts - t_mask = map(∈(save_interval), ts) # indices within t interval - u_all = sol.(ts[t_mask]) - u_mat = reduce(hcat, u_all) # build prognostic state from continuous solution - pax = ComponentArrays.indexmap(getaxes(tile.state.uproto)[1]) - # get saved diagnostic states and timestamps only in given interval - savedstates = tile.data.outputs.saveval[1:length(ts)][t_mask] - ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0)) - allvars = variables(tile) - progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars)) - diagvars = filter(isdiagnostic, allvars) - fluxvars = filter(isflux, allvars) - outputs = OrderedDict() - # add all on-grid prognostic variables - for var in filter(isongrid, progvars) - name = varname(var) - outputs[name] = withdims(var, u_mat[pax[name],:], grid, ts_datetime) - end - # add all on-grid diagnostic variables - for var in filter(isongrid, tuplejoin(diagvars, fluxvars)) - name = varname(var) - states = collect(skipmissing([name ∈ keys(state) ? state[name] : missing for state in savedstates])) - if length(states) == length(ts_datetime) - arr = reduce(hcat, states) - outputs[name] = withdims(var, arr, grid, ts_datetime) - end - end - # handle per-layer variables - for layer in layernames(tile.strat) - # if layer name appears in saved states or prognostic state axes, then add these variables to the output. - if haskey(savedstates[1], layer) || haskey(pax, layer) - # map over all savedstates and create named tuples for each time step - layerouts = map(u_all, savedstates) do u, state - layerout = OrderedDict() - if haskey(state, layer) - layerstate = state[layer] - for var in keys(layerstate) - layerout[var] = layerstate[var] - end - else - end - # convert to named tuple - diagnostic_output = (;layerout...) - if haskey(u, layer) - u_layer = u[layer] - prognostic_output = (;map(name -> name => u_layer[name], keys(u_layer))...) - return merge(prognostic_output, diagnostic_output) - else - return diagnostic_output - end - end - layerouts_combined = reduce(layerouts[2:end]; init=layerouts[1]) do out1, out2 - map(vcat, out1, out2) - end - # for each variable in the named tuple, find the corresponding variables - layervars = (; map(name -> name => first(filter(var -> varname(var) == name, allvars)), keys(layerouts_combined))...) - # map each output to a variable and call withdims to wrap in a DimArray - outputs[layer] = map((var,out) -> withdims(var, reshape(out,1,:), nothing, ts_datetime), layervars, layerouts_combined) - end - end - return CryoGridOutput(ts_datetime, sol, (;outputs...)) -end - function write_netcdf!(filename::String, out::CryoGridOutput, filemode="c") NCD.Dataset(filename, filemode) do ds # this assumes that the primary state variable has time and depth axes @@ -159,3 +80,22 @@ function _write_ncd_var!(ds::NCD.Dataset, key::Symbol, nt::NamedTuple) _write_ncd_var!(ds, Symbol("$key.$var"), nt[var]) end end + +# Integrator saving cache + +struct SaveCache{Tt} + t::Vector{Tt} + vals::Vector{Any} +end + +function save!(cache::SaveCache, state, t) + if length(cache.t) == 0 || cache.t[end] < t + push!(cache.t, t) + push!(cache.vals, state) + end +end + +function reset!(cache::SaveCache) + resize!(cache.t, 0) + resize!(cache.vals, 0) +end diff --git a/src/Numerics/caches.jl b/src/Numerics/caches.jl index e848e005..39c72fb3 100644 --- a/src/Numerics/caches.jl +++ b/src/Numerics/caches.jl @@ -1,5 +1,7 @@ import PreallocationTools as Prealloc +# State variable caches + abstract type StateVarCache{T} end """ diff --git a/src/Solvers/DiffEq/DiffEq.jl b/src/Solvers/DiffEq/DiffEq.jl index 5066cc22..848d1b78 100644 --- a/src/Solvers/DiffEq/DiffEq.jl +++ b/src/Solvers/DiffEq/DiffEq.jl @@ -28,6 +28,8 @@ using DiffEqBase # re-export DiffEqCallbacks @reexport using DiffEqCallbacks +import CommonSolve + export TDMASolver include("linsolve.jl") diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl index c2aeb567..cd7c4e9d 100644 --- a/src/Solvers/DiffEq/ode_solvers.jl +++ b/src/Solvers/DiffEq/ode_solvers.jl @@ -1,11 +1,97 @@ -# solve/init interface -function DiffEqBase.__solve(prob::CryoGridProblem, alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; saveat=prob.saveat, kwargs...) - ode_prob = ODEProblem(prob) - return DiffEqBase.solve(ode_prob, alg, args...; saveat, kwargs...) +""" + CryoGridDiffEqSolution + +Wrapper types for ODE solutions from `OrdinaryDiffEq`. +""" +struct CryoGridDiffEqSolution{ + TT, + N, + Tu, + solType<:SciMLBase.AbstractODESolution{TT,N,Tu} +} <: CryoGrid.AbstractCryoGridSolution{TT,N,Tu} + prob::CryoGridProblem + sol::solType +end + +Base.propertynames(sol::CryoGridDiffEqSolution) = ( + :prob, + Base.propertynames(sol.sol)... +) + +function Base.getproperty(sol::CryoGridDiffEqSolution, name::Symbol) + if name == :prob + return getfield(sol, name) + else + return getproperty(getfield(sol, :sol), name) + end +end + +""" + CryoGridDiffEqIntegrator + +Wrapper type for ODE integrators from `OrdinaryDiffEq`. +""" +struct CryoGridDiffEqIntegrator{ + algType, + Tu, + Tt, + intType<:SciMLBase.AbstractODEIntegrator{algType,true,Tu,Tt} +} <: SciMLBase.AbstractODEIntegrator{algType,true,Tu,Tt} + prob::CryoGridProblem + integrator::intType +end + +SciMLBase.done(solver::CryoGridDiffEqIntegrator) = SciMLBase.done(solver.integrator) +SciMLBase.get_du(integrator::CryoGridDiffEqIntegrator) = get_du(integrator.integrator) +SciMLBase.add_tstop!(integrator::CryoGridDiffEqIntegrator, t) = push!(integrator.opts.tstops, t) +SciMLBase.postamble!(solver::CryoGridDiffEqIntegrator) = SciMLBase.postamble!(solver.integrator) + +function Base.show(io::IO, mime::MIME"text/plain", integrator::CryoGridDiffEqIntegrator) + println(io, "CryoGrid ODE integrator:") + show(io, mime, integrator.integrator) end -function DiffEqBase.__init(prob::CryoGridProblem, alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; saveat=prob.saveat, kwargs...) + +Base.propertynames(integrator::CryoGridDiffEqIntegrator) = ( + :prob, + :integrator, + Base.propertynames(integrator.integrator)... +) + +function Base.getproperty(integrator::CryoGridDiffEqIntegrator, name::Symbol) + inner = getfield(integrator, :integrator) + if name ∈ (:prob, :integrator) + return getfield(integrator, name) + elseif name == :sol + return CryoGridDiffEqSolution(integrator.prob, inner.sol) + else + return getproperty(inner, name) + end +end + +function Base.setproperty!(integrator::CryoGridDiffEqIntegrator, name::Symbol, value) + return setproperty!(getfield(integrator, :integrator), name, value) +end + +# CommonSolve solve/init interface + +function CommonSolve.init( + prob::CryoGridProblem, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, + args...; + kwargs... +) ode_prob = ODEProblem(prob) - return DiffEqBase.init(ode_prob, alg, args...; saveat, kwargs...) + integrator = init(ode_prob, alg, args...; kwargs...) + return CryoGridDiffEqIntegrator(prob, integrator) +end + +function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...) + return step!(integrator.integrator, args...; kwargs...) +end + +function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator) + ode_sol = solve!(integrator.integrator) + return CryoGridDiffEqSolution(integrator.prob, ode_sol) end # custom nonlinear solvers diff --git a/src/Solvers/LiteImplicit/LiteImplicit.jl b/src/Solvers/LiteImplicit/LiteImplicit.jl index 4f58ab67..a55356b5 100644 --- a/src/Solvers/LiteImplicit/LiteImplicit.jl +++ b/src/Solvers/LiteImplicit/LiteImplicit.jl @@ -8,6 +8,7 @@ using DataStructures using DiffEqBase, DiffEqCallbacks using Interpolations +import CommonSolve import SciMLBase export LiteImplicitEuler diff --git a/src/Solvers/LiteImplicit/cglite_types.jl b/src/Solvers/LiteImplicit/cglite_types.jl index 493584bc..82eace98 100644 --- a/src/Solvers/LiteImplicit/cglite_types.jl +++ b/src/Solvers/LiteImplicit/cglite_types.jl @@ -22,7 +22,7 @@ struct LiteImplicitEulerCache{Tu,TA} <: SciMLBase.DECache D::TA end -function DiffEqBase.__init( +function CommonSolve.init( prob::CryoGridProblem, alg::LiteImplicitEuler, args...; @@ -43,12 +43,6 @@ function DiffEqBase.__init( # evaluate tile at initial condition tile = Tiles.materialize(Tile(prob.f), prob.p, t0) tile(zero(u0), u0, prob.p, t0, dt) - # reset SavedValues on tile.data - initialsave = prob.savefunc(tile, u0, similar(u0)) - savevals = SavedValues(Float64, typeof(initialsave)) - push!(savevals.saveval, initialsave) - push!(savevals.t, t0) - tile.data.outputs = savevals sol = CryoGridSolution(prob, u_storage, t_storage, alg, ReturnCode.Default) cache = LiteImplicitEulerCache( similar(prob.u0), # should have ComponentArray type @@ -64,5 +58,8 @@ function DiffEqBase.__init( ) p = isnothing(prob.p) ? prob.p : collect(prob.p) opts = CryoGridIntegratorOptions(; saveat=CryoGrid.expandtstep(saveat, prob.tspan), dtmax, dtmin, kwargs...) - return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0, convert(eltype(prob.tspan), dt), 1, 1) + integrator = CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0, convert(eltype(prob.tspan), dt), 1, 1) + # save initial state + prob.savefunc(u0, t0, integrator) + return integrator end diff --git a/src/Solvers/Solvers.jl b/src/Solvers/Solvers.jl index bc11f318..2ad726fd 100644 --- a/src/Solvers/Solvers.jl +++ b/src/Solvers/Solvers.jl @@ -1,7 +1,7 @@ using CryoGrid using CryoGrid.Utils +using DataStructures: OrderedDict using ForwardDiff - using Reexport export CryoGridIntegrator, CryoGridIntegratorOptions, CryoGridSolution diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl index c03a4b99..f93da9b3 100644 --- a/src/Solvers/basic_solvers.jl +++ b/src/Solvers/basic_solvers.jl @@ -19,7 +19,7 @@ struct CGEulerCache{Tu} <: SciMLBase.DECache du::Tu end -function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...) +function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...) tile = Tile(prob.f) u0 = copy(prob.u0) du0 = zero(u0) @@ -30,12 +30,6 @@ function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0 # evaluate tile at initial condition tile = Tiles.materialize(Tile(prob.f), prob.p, t0) tile(du0, u0, prob.p, t0, dt) - # reset SavedValues on tile.data - initialsave = prob.savefunc(tile, u0, similar(u0)) - savevals = SavedValues(Float64, typeof(initialsave)) - push!(savevals.saveval, initialsave) - push!(savevals.t, t0) - tile.data.outputs = savevals sol = CryoGridSolution(prob, u_storage, t_storage, alg, ReturnCode.Default) cache = CGEulerCache( similar(prob.u0), @@ -48,7 +42,10 @@ function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0 saveat = collect(prob.tspan[1]:saveat:prob.tspan[2]) end opts = CryoGridIntegratorOptions(; saveat=expandtstep(saveat, prob.tspan), kwargs...) - return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1) + integrator = CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1) + # save initial state + prob.savefunc(u0, t0, integrator) + return integrator end function perform_step!(integrator::CryoGridIntegrator{CGEuler}) diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl index 195bab2b..ba7b2bfe 100755 --- a/src/Solvers/integrator.jl +++ b/src/Solvers/integrator.jl @@ -2,7 +2,9 @@ using DataStructures: SortedSet abstract type CryoGridODEAlgorithm <: SciMLBase.AbstractODEAlgorithm end -mutable struct CryoGridSolution{TT,Tu<:AbstractVector{TT},Tt,Talg,Tprob} <: SciMLBase.AbstractODESolution{TT,1,Tu} +abstract type AbstractCryoGridSolution{TT,N,Tu} <: SciMLBase.AbstractODESolution{TT,N,Tu} end + +mutable struct CryoGridSolution{TT,Tu<:AbstractVector{TT},Tt,Talg,Tprob} <: AbstractCryoGridSolution{TT,1,Tu} prob::Tprob u::Vector{Tu} t::Vector{Tt} @@ -81,7 +83,7 @@ SciMLBase.postamble!(integrator::CryoGridIntegrator) = nothing # add tstop by default because we don't support fancy interpolation DiffEqBase.step!(integrator::CryoGridIntegrator, dt) = step!(integrator, dt, true) -function DiffEqBase.step!(integrator::CryoGridIntegrator) +function CommonSolve.step!(integrator::CryoGridIntegrator) handle_tstops!(integrator) perform_step!(integrator) saveat!(integrator) @@ -89,37 +91,34 @@ function DiffEqBase.step!(integrator::CryoGridIntegrator) integrator.dt = min(integrator.opts.dtmax, integrator.dt) end -function DiffEqBase.__solve(prob::CryoGridProblem, alg::CryoGridODEAlgorithm, args...; kwargs...) - integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) +function CommonSolve.solve!(integrator::CryoGridIntegrator) for i in integrator end if integrator.sol.retcode == ReturnCode.Default integrator.sol.retcode = ReturnCode.Success end # if no save points are specified, save final state - if isempty(prob.saveat) - tile = Tile(integrator) - push!(tile.data.outputs.saveval, prob.savefunc(tile, integrator.u, get_du(integrator))) - push!(tile.data.outputs.t, ForwardDiff.value(integrator.t)) + if isempty(integrator.sol.prob.saveat) + prob.savefunc(integrator.u, integrator.t, integrator) push!(integrator.sol.u, integrator.u) push!(integrator.sol.t, integrator.t) end return integrator.sol end +# CryoGridIntegrator interface + perform_step!(integrator::CryoGridIntegrator) = error("perform_step! not implemented for algorithm $(integrator.alg)") function saveat!(integrator::CryoGridIntegrator) - tile = Tile(integrator) - du = get_du(integrator) + prob = integrator.sol.prob saveat = integrator.opts.saveat t_saves = integrator.sol.t u_saves = integrator.sol.u res = searchsorted(saveat, integrator.t) i_next = first(res) i_prev = last(res) - dtsave = if i_next == i_prev - push!(tile.data.outputs.saveval, integrator.sol.prob.savefunc(tile, integrator.u, du)) - push!(tile.data.outputs.t, ForwardDiff.value(integrator.t)) + dtsave = if i_next == i_prev + prob.savefunc(integrator.u, integrator.t, integrator) push!(u_saves, copy(integrator.u)) push!(t_saves, integrator.t) Inf @@ -145,3 +144,85 @@ end expandtstep(tstep::Number, tspan) = tspan[1]:tstep:tspan[end] expandtstep(tstep::AbstractVector, tspan) = tstep + +# Output + +""" + CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) + +Constructs a `CryoGridOutput` from the given `ODESolution`. Optional argument `tspan` restricts the time span of the output. +""" +InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,DateTime}) = CryoGridOutput(sol, convert_tspan(tspan)) +function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) + # Helper functions for mapping variables to appropriate DimArrays by grid/shape. + withdims(var::Var{name,<:CryoGrid.OnGrid{Cells}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts))) + withdims(var::Var{name,<:CryoGrid.OnGrid{Edges}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts))) + withdims(var::Var{name}, arr, zs, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Dim{name}(1:size(arr,1)),Ti(ts))) + save_interval = ClosedInterval(tspan...) + tile = Tile(sol.prob.f) # Tile + grid = Grid(tile.grid.*u"m") + savecache = sol.prob.savefunc.cache + ts = savecache.t # use save cache time points + # check if last value is duplicated + ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts + t_mask = map(∈(save_interval), ts) # indices within t interval + u_all = sol.(ts[t_mask]) + u_mat = reduce(hcat, u_all) # build prognostic state from continuous solution + pax = ComponentArrays.indexmap(getaxes(tile.state.uproto)[1]) + # get saved diagnostic states and timestamps only in given interval + savedstates = savecache.vals[1:length(ts)][t_mask] + ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0)) + allvars = variables(tile) + progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars)) + diagvars = filter(isdiagnostic, allvars) + fluxvars = filter(isflux, allvars) + outputs = OrderedDict() + # add all on-grid prognostic variables + for var in filter(isongrid, progvars) + name = varname(var) + outputs[name] = withdims(var, u_mat[pax[name],:], grid, ts_datetime) + end + # add all on-grid diagnostic variables + for var in filter(isongrid, tuplejoin(diagvars, fluxvars)) + name = varname(var) + states = collect(skipmissing([name ∈ keys(state) ? state[name] : missing for state in savedstates])) + if length(states) == length(ts_datetime) + arr = reduce(hcat, states) + outputs[name] = withdims(var, arr, grid, ts_datetime) + end + end + # handle per-layer variables + for layer in layernames(tile.strat) + # if layer name appears in saved states or prognostic state axes, then add these variables to the output. + if haskey(savedstates[1], layer) || haskey(pax, layer) + # map over all savedstates and create named tuples for each time step + layerouts = map(u_all, savedstates) do u, state + layerout = OrderedDict() + if haskey(state, layer) + layerstate = state[layer] + for var in keys(layerstate) + layerout[var] = layerstate[var] + end + else + end + # convert to named tuple + diagnostic_output = (;layerout...) + if haskey(u, layer) + u_layer = u[layer] + prognostic_output = (;map(name -> name => u_layer[name], keys(u_layer))...) + return merge(prognostic_output, diagnostic_output) + else + return diagnostic_output + end + end + layerouts_combined = reduce(layerouts[2:end]; init=layerouts[1]) do out1, out2 + map(vcat, out1, out2) + end + # for each variable in the named tuple, find the corresponding variables + layervars = (; map(name -> name => first(filter(var -> varname(var) == name, allvars)), keys(layerouts_combined))...) + # map each output to a variable and call withdims to wrap in a DimArray + outputs[layer] = map((var,out) -> withdims(var, reshape(out,1,:), nothing, ts_datetime), layervars, layerouts_combined) + end + end + return CryoGridOutput(ts_datetime, sol, (;outputs...)) +end diff --git a/src/Tiles/tile.jl b/src/Tiles/tile.jl index d68bdc7f..a327913a 100644 --- a/src/Tiles/tile.jl +++ b/src/Tiles/tile.jl @@ -10,7 +10,6 @@ struct Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} <: AbstractTile{iip inits::TInits # initializers events::TEvents # events inputs::TInputs # inputs - data::TileData # output data metadata::Dict # metadata function Tile( strat::TStrat, @@ -19,15 +18,14 @@ struct Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} <: AbstractTile{iip inits::TInits, events::TEvents, inputs::TInputs, - data::TileData=TileData(), metadata::Dict=Dict(), iip::Bool=true) where {TStrat<:Stratigraphy,TGrid<:Grid{Edges},TStates<:StateVars,TInits<:Tuple,TEvents<:NamedTuple,TInputs<:InputProvider} - new{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}(strat, grid, state, inits, events, inputs, data, metadata) + new{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}(strat, grid, state, inits, events, inputs, metadata) end end ConstructionBase.constructorof(::Type{Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}}) where {TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} = - (strat, grid, state, inits, events, inputs, data, metadata) -> Tile(strat, grid, state, inits, events, inputs, data, metadata, iip) + (strat, grid, state, inits, events, inputs, metadata) -> Tile(strat, grid, state, inits, events, inputs, metadata, iip) # mark only stratigraphy and initializers fields as flattenable Flatten.flattenable(::Type{<:Tile}, ::Type{Val{:strat}}) = true Flatten.flattenable(::Type{<:Tile}, ::Type{Val{:inits}}) = true @@ -92,7 +90,7 @@ function Tile( _addlayerfield(init, Symbol(:init)) end end - tile = Tile(strat, grid, states, inits, (;events...), inputs, TileData(), metadata, iip) + tile = Tile(strat, grid, states, inits, (;events...), inputs, metadata, iip) _validate_inputs(tile, inputs) return tile end @@ -282,29 +280,6 @@ getstate(integrator::SciMLBase.DEIntegrator) = Tiles.getstate(Tile(integrator), """ Numerics.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator; interp=true) = Numerics.getvar(Val{var}(), Tile(integrator), integrator.u; interp) -# """ -# parameterize(tile::Tile) - -# Adds parameter information to all nested types in `tile` by recursively calling `parameterize`. -# """ -# function CryoGrid.parameterize(tile::Tile) -# ctor = ConstructionBase.constructorof(typeof(tile)) -# new_layers = map(namedlayers(tile.strat)) do named_layer -# name = nameof(named_layer) -# layer = CryoGrid.parameterize(named_layer.val) -# name => _addlayerfield(layer, name) -# end -# new_inits = map(tile.inits) do init -# _addlayerfield(CryoGrid.parameterize(init), :init) -# end -# new_events = map(keys(tile.events)) do name -# evs = map(CryoGrid.parameterize, getproperty(tile.events, name)) -# name => _addlayerfield(evs, name) -# end -# new_strat = Stratigraphy(boundaries(tile.strat), (;new_layers...)) -# return ctor(new_strat, tile.grid, tile.state, new_inits, (;new_events...), tile.data, tile.metadata) -# end - """ variables(tile::Tile) @@ -413,6 +388,6 @@ function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvide end # helper method that returns a Union type of all types that should be ignored by Flatten.flatten -@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual} +@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,Unitful.Quantity,Numerics.ForwardDiff.Dual} # ===================================================================== # diff --git a/src/Tiles/tile_base.jl b/src/Tiles/tile_base.jl index be02dd91..d29ed699 100644 --- a/src/Tiles/tile_base.jl +++ b/src/Tiles/tile_base.jl @@ -1,8 +1,3 @@ -mutable struct TileData - outputs::Any - TileData() = new(missing) -end - """ AbstractTile{iip} diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 4c20a1e7..2ed52d6f 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -28,7 +28,7 @@ export @UFloat_str, @UT_str, @setscalar, @threaded, @sym_str, @pstrip include("macros.jl") export StrictlyPositive, StrictlyNegative, Nonnegative, Nonpositive -export applyunits, normalize_units, normalize_temperature, pstrip +export applyunits, normalize_units, normalize_temperature, pstrip, adstrip export fastmap, fastiterate, structiterate, getscalar, tuplejoin, convert_t, convert_tspan, haskeys # Variable/parameter domains @@ -161,6 +161,14 @@ function ffill!(x::AbstractVector{T}) where {E,T<:Union{Missing,E}} return x end +""" + adstrip(x) + +Strips autodiff type wrappers from `x`. +""" +adstrip(x) = x +adstrip(x::ForwardDiff.Dual) = ForwardDiff.value(x) + """ pstrip(obj; keep_units=false) diff --git a/src/problem.jl b/src/problem.jl index db3ebf19..a144ef8a 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -13,7 +13,17 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.Abstrac savefunc::Tsf isoutofdomain::Tdf kwargs::Tkw - CryoGridProblem{iip}(f::TF, u0::Tu, tspan::NTuple{2,Tt}, p::Tp, cbs::Tcb, saveat::Tsv, savefunc::Tsf, iood::Tdf, kwargs::Tkw) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = + CryoGridProblem{iip}( + f::TF, + u0::Tu, + tspan::NTuple{2,Tt}, + p::Tp, + cbs::Tcb, + saveat::Tsv, + savefunc::Tsf, + iood::Tdf, + kwargs::Tkw + ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs) end @@ -21,6 +31,7 @@ end CryoGridProblem(tile::Tile, u0::ComponentVector, tspan::NTuple{2,DateTime}, args...;kwargs...) """ CryoGridProblem(tile::Tile, u0::ComponentVector, tspan::NTuple{2,DateTime}, args...;kwargs...) = CryoGridProblem(tile, u0, convert_tspan(tspan), args...;kwargs...) + """ CryoGridProblem( tile::Tile, @@ -51,10 +62,9 @@ function CryoGridProblem( p::Union{Nothing,AbstractVector}=nothing; diagnostic_stepsize=3600.0, saveat=3600.0, - savevars=(), - save_everystep=false, save_start=true, - save_end=true, + save_everystep=false, + savevars=(), step_limiter=timestep, safety_factor=1, max_step=true, @@ -64,8 +74,6 @@ function CryoGridProblem( function_kwargs=(), prob_kwargs... ) - getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...)) - savefunc(u, t, integrator) = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator)) # strip all "fixed" parameters tile = stripparams(FixedParam, tile) # retrieve variable parameters @@ -74,10 +82,10 @@ function CryoGridProblem( # remove units tile = stripunits(tile) # set up saving callback - stateproto = getsavestate(tile, u0, du0) - savevals = SavedValues(Float64, typeof(stateproto)) saveat = expandtstep(saveat, tspan) - savingcallback = SavingCallback(savefunc, savevals; saveat=saveat, save_start=save_start, save_end=save_end, save_everystep=save_everystep) + savecache = InputOutput.SaveCache(Float64[], []) + savefunc = saving_function(savecache, savevars...) + savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep) diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!) defaultcallbacks = (savingcallback, diagnostic_step_callback) # add step limiter to default callbacks, if defined @@ -92,14 +100,12 @@ function CryoGridProblem( # add user callbacks usercallbacks = isnothing(callback) ? () : callback callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...) - # note that this implicitly discards any existing saved values in the model setup's state history - tile.data.outputs = savevals # build mass matrix mass_matrix = Numerics.build_mass_matrix(tile.state) # get params p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...) - return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, getsavestate, isoutofdomain, prob_kwargs) + return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs) end function SciMLBase.remake( @@ -126,6 +132,34 @@ function SciMLBase.remake( return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs) end +CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))") + +function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...) + integrator = init(prob, alg, args...; kwargs...) + return solve!(integrator) +end + +SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem( + prob.f, + prob.u0, + prob.tspan, + prob.p; + callback=prob.callbacks, + isoutofdomain=prob.isoutofdomain, + prob.kwargs... +) + +DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob + +function saving_function(cache::InputOutput.SaveCache, savevars...) + getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...)) + return function(u, t, integrator) + state = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator)) + InputOutput.save!(cache, state, adstrip(t)) + return state + end +end + """ odefunction(setup::Tile, u0, p, tspan; kwargs...) @@ -170,11 +204,6 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator) DiffEqBase.u_modified!(integrator, u_modified) end -# overrides to make SciML problem interface work -SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem(prob.f, prob.u0, prob.tspan, prob.p; callback=prob.callbacks, isoutofdomain=prob.isoutofdomain, prob.kwargs...) - -DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob - # callback building functions function _makecallbacks(tile::Tile) eventname(::Event{name}) where name = name -- GitLab