From d0c411907ed8f6d2ede9ef24ead39ac8fb35a19d Mon Sep 17 00:00:00 2001 From: Brian Groenke <brian.groenke@awi.de> Date: Fri, 21 Feb 2025 17:10:39 +0100 Subject: [PATCH] Revert "Refactor problem and built-in solvers to use CommonSolve" This reverts commit 0ecc92f3fb0669784a2e72fb69608d43abaf09bf. --- Project.toml | 3 +- src/CryoGrid.jl | 8 +- src/IO/InputOutput.jl | 2 +- src/IO/output.jl | 98 +++++++++++++++++---- src/Numerics/caches.jl | 2 - src/Solvers/DiffEq/DiffEq.jl | 2 - src/Solvers/DiffEq/ode_solvers.jl | 98 ++------------------- src/Solvers/LiteImplicit/LiteImplicit.jl | 1 - src/Solvers/LiteImplicit/cglite_types.jl | 13 +-- src/Solvers/Solvers.jl | 2 +- src/Solvers/basic_solvers.jl | 13 +-- src/Solvers/integrator.jl | 107 +++-------------------- src/Tiles/tile.jl | 33 ++++++- src/Tiles/tile_base.jl | 5 ++ src/Utils/Utils.jl | 10 +-- src/problem.jl | 63 ++++--------- 16 files changed, 172 insertions(+), 288 deletions(-) diff --git a/Project.toml b/Project.toml index 72c50630..8a45647e 100755 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,10 @@ name = "CryoGrid" uuid = "a535b82e-5f3d-4d97-8b0b-d6483f5bebd5" authors = ["Brian Groenke <brian.groenke@awi.de>", "Jan Nitzbon <jan.nitzbon@awi.de>", "Moritz Langer <moritz.langer@awi.de>"] -version = "0.23.2" +version = "0.23.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/src/CryoGrid.jl b/src/CryoGrid.jl index 39e9ff95..9c227ea0 100755 --- a/src/CryoGrid.jl +++ b/src/CryoGrid.jl @@ -42,7 +42,6 @@ using ModelParameters using Reexport using Requires -import CommonSolve import Flatten import Interpolations @@ -111,6 +110,9 @@ export ConstantBC, PeriodicBC, ConstantValue, PeriodicValue, ConstantFlux, Perio export volumetricfractions include("Physics/Physics.jl") +include("Diagnostics/Diagnostics.jl") +@reexport using .Diagnostics + # Coupling include("coupling.jl") @@ -122,10 +124,6 @@ include("problem.jl") export CGEuler, CryoGridIntegrator, CryoGridSolution include("Solvers/Solvers.jl") -# Diagnostics -include("Diagnostics/Diagnostics.jl") -@reexport using .Diagnostics - # Built-in model definitions export SoilHeatTile, SamoylovDefault include("models.jl") diff --git a/src/IO/InputOutput.jl b/src/IO/InputOutput.jl index 32dd3407..aa76df7b 100644 --- a/src/IO/InputOutput.jl +++ b/src/IO/InputOutput.jl @@ -7,7 +7,7 @@ using CryoGrid.Utils using Base: @propagate_inbounds using ComponentArrays using ConstructionBase -using DataStructures: DefaultDict +using DataStructures: DefaultDict, OrderedDict using Dates using DimensionalData using Downloads diff --git a/src/IO/output.jl b/src/IO/output.jl index 82d7e32f..b9edcee3 100644 --- a/src/IO/output.jl +++ b/src/IO/output.jl @@ -51,6 +51,85 @@ dimstr(::Ti) = "time" dimstr(::Z) = "depth" dimstr(::Dim{name}) where name = string(name) +""" + CryoGridOutput(sol::TSol, tspan::NTuple{2,Float64}=(-Inf,Inf)) where {TSol<:SciMLBase.AbstractODESolution} + +Constructs a `CryoGridOutput` from the given `ODESolution`. Optional argument `tspan` restricts the time span of the output. +""" +CryoGridOutput(sol::SciMLBase.AbstractODESolution, tspan::NTuple{2,DateTime}) = CryoGridOutput(sol, convert_tspan(tspan)) +function CryoGridOutput(sol::SciMLBase.AbstractODESolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) + # Helper functions for mapping variables to appropriate DimArrays by grid/shape. + withdims(var::Var{name,<:CryoGrid.OnGrid{Cells}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts))) + withdims(var::Var{name,<:CryoGrid.OnGrid{Edges}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts))) + withdims(var::Var{name}, arr, zs, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Dim{name}(1:size(arr,1)),Ti(ts))) + save_interval = ClosedInterval(tspan...) + tile = Tile(sol.prob.f) # Tile + grid = Grid(tile.grid.*u"m") + ts = tile.data.outputs.t # use save callback time points + # check if last value is duplicated + ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts + t_mask = map(∈(save_interval), ts) # indices within t interval + u_all = sol.(ts[t_mask]) + u_mat = reduce(hcat, u_all) # build prognostic state from continuous solution + pax = ComponentArrays.indexmap(getaxes(tile.state.uproto)[1]) + # get saved diagnostic states and timestamps only in given interval + savedstates = tile.data.outputs.saveval[1:length(ts)][t_mask] + ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0)) + allvars = variables(tile) + progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars)) + diagvars = filter(isdiagnostic, allvars) + fluxvars = filter(isflux, allvars) + outputs = OrderedDict() + # add all on-grid prognostic variables + for var in filter(isongrid, progvars) + name = varname(var) + outputs[name] = withdims(var, u_mat[pax[name],:], grid, ts_datetime) + end + # add all on-grid diagnostic variables + for var in filter(isongrid, tuplejoin(diagvars, fluxvars)) + name = varname(var) + states = collect(skipmissing([name ∈ keys(state) ? state[name] : missing for state in savedstates])) + if length(states) == length(ts_datetime) + arr = reduce(hcat, states) + outputs[name] = withdims(var, arr, grid, ts_datetime) + end + end + # handle per-layer variables + for layer in layernames(tile.strat) + # if layer name appears in saved states or prognostic state axes, then add these variables to the output. + if haskey(savedstates[1], layer) || haskey(pax, layer) + # map over all savedstates and create named tuples for each time step + layerouts = map(u_all, savedstates) do u, state + layerout = OrderedDict() + if haskey(state, layer) + layerstate = state[layer] + for var in keys(layerstate) + layerout[var] = layerstate[var] + end + else + end + # convert to named tuple + diagnostic_output = (;layerout...) + if haskey(u, layer) + u_layer = u[layer] + prognostic_output = (;map(name -> name => u_layer[name], keys(u_layer))...) + return merge(prognostic_output, diagnostic_output) + else + return diagnostic_output + end + end + layerouts_combined = reduce(layerouts[2:end]; init=layerouts[1]) do out1, out2 + map(vcat, out1, out2) + end + # for each variable in the named tuple, find the corresponding variables + layervars = (; map(name -> name => first(filter(var -> varname(var) == name, allvars)), keys(layerouts_combined))...) + # map each output to a variable and call withdims to wrap in a DimArray + outputs[layer] = map((var,out) -> withdims(var, reshape(out,1,:), nothing, ts_datetime), layervars, layerouts_combined) + end + end + return CryoGridOutput(ts_datetime, sol, (;outputs...)) +end + function write_netcdf!(filename::String, out::CryoGridOutput, filemode="c") NCD.Dataset(filename, filemode) do ds # this assumes that the primary state variable has time and depth axes @@ -80,22 +159,3 @@ function _write_ncd_var!(ds::NCD.Dataset, key::Symbol, nt::NamedTuple) _write_ncd_var!(ds, Symbol("$key.$var"), nt[var]) end end - -# Integrator saving cache - -struct SaveCache{Tt} - t::Vector{Tt} - vals::Vector{Any} -end - -function save!(cache::SaveCache, state, t) - if length(cache.t) == 0 || cache.t[end] < t - push!(cache.t, t) - push!(cache.vals, state) - end -end - -function reset!(cache::SaveCache) - resize!(cache.t, 0) - resize!(cache.vals, 0) -end diff --git a/src/Numerics/caches.jl b/src/Numerics/caches.jl index b0cfe0de..5a1c2b03 100644 --- a/src/Numerics/caches.jl +++ b/src/Numerics/caches.jl @@ -1,7 +1,5 @@ import PreallocationTools as Prealloc -# State variable caches - abstract type StateVarCache{T} end """ diff --git a/src/Solvers/DiffEq/DiffEq.jl b/src/Solvers/DiffEq/DiffEq.jl index 848d1b78..5066cc22 100644 --- a/src/Solvers/DiffEq/DiffEq.jl +++ b/src/Solvers/DiffEq/DiffEq.jl @@ -28,8 +28,6 @@ using DiffEqBase # re-export DiffEqCallbacks @reexport using DiffEqCallbacks -import CommonSolve - export TDMASolver include("linsolve.jl") diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl index cd7c4e9d..c2aeb567 100644 --- a/src/Solvers/DiffEq/ode_solvers.jl +++ b/src/Solvers/DiffEq/ode_solvers.jl @@ -1,97 +1,11 @@ -""" - CryoGridDiffEqSolution - -Wrapper types for ODE solutions from `OrdinaryDiffEq`. -""" -struct CryoGridDiffEqSolution{ - TT, - N, - Tu, - solType<:SciMLBase.AbstractODESolution{TT,N,Tu} -} <: CryoGrid.AbstractCryoGridSolution{TT,N,Tu} - prob::CryoGridProblem - sol::solType -end - -Base.propertynames(sol::CryoGridDiffEqSolution) = ( - :prob, - Base.propertynames(sol.sol)... -) - -function Base.getproperty(sol::CryoGridDiffEqSolution, name::Symbol) - if name == :prob - return getfield(sol, name) - else - return getproperty(getfield(sol, :sol), name) - end -end - -""" - CryoGridDiffEqIntegrator - -Wrapper type for ODE integrators from `OrdinaryDiffEq`. -""" -struct CryoGridDiffEqIntegrator{ - algType, - Tu, - Tt, - intType<:SciMLBase.AbstractODEIntegrator{algType,true,Tu,Tt} -} <: SciMLBase.AbstractODEIntegrator{algType,true,Tu,Tt} - prob::CryoGridProblem - integrator::intType -end - -SciMLBase.done(solver::CryoGridDiffEqIntegrator) = SciMLBase.done(solver.integrator) -SciMLBase.get_du(integrator::CryoGridDiffEqIntegrator) = get_du(integrator.integrator) -SciMLBase.add_tstop!(integrator::CryoGridDiffEqIntegrator, t) = push!(integrator.opts.tstops, t) -SciMLBase.postamble!(solver::CryoGridDiffEqIntegrator) = SciMLBase.postamble!(solver.integrator) - -function Base.show(io::IO, mime::MIME"text/plain", integrator::CryoGridDiffEqIntegrator) - println(io, "CryoGrid ODE integrator:") - show(io, mime, integrator.integrator) -end - -Base.propertynames(integrator::CryoGridDiffEqIntegrator) = ( - :prob, - :integrator, - Base.propertynames(integrator.integrator)... -) - -function Base.getproperty(integrator::CryoGridDiffEqIntegrator, name::Symbol) - inner = getfield(integrator, :integrator) - if name ∈ (:prob, :integrator) - return getfield(integrator, name) - elseif name == :sol - return CryoGridDiffEqSolution(integrator.prob, inner.sol) - else - return getproperty(inner, name) - end -end - -function Base.setproperty!(integrator::CryoGridDiffEqIntegrator, name::Symbol, value) - return setproperty!(getfield(integrator, :integrator), name, value) -end - -# CommonSolve solve/init interface - -function CommonSolve.init( - prob::CryoGridProblem, - alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, - args...; - kwargs... -) +# solve/init interface +function DiffEqBase.__solve(prob::CryoGridProblem, alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; saveat=prob.saveat, kwargs...) ode_prob = ODEProblem(prob) - integrator = init(ode_prob, alg, args...; kwargs...) - return CryoGridDiffEqIntegrator(prob, integrator) -end - -function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...) - return step!(integrator.integrator, args...; kwargs...) + return DiffEqBase.solve(ode_prob, alg, args...; saveat, kwargs...) end - -function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator) - ode_sol = solve!(integrator.integrator) - return CryoGridDiffEqSolution(integrator.prob, ode_sol) +function DiffEqBase.__init(prob::CryoGridProblem, alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, args...; saveat=prob.saveat, kwargs...) + ode_prob = ODEProblem(prob) + return DiffEqBase.init(ode_prob, alg, args...; saveat, kwargs...) end # custom nonlinear solvers diff --git a/src/Solvers/LiteImplicit/LiteImplicit.jl b/src/Solvers/LiteImplicit/LiteImplicit.jl index a55356b5..4f58ab67 100644 --- a/src/Solvers/LiteImplicit/LiteImplicit.jl +++ b/src/Solvers/LiteImplicit/LiteImplicit.jl @@ -8,7 +8,6 @@ using DataStructures using DiffEqBase, DiffEqCallbacks using Interpolations -import CommonSolve import SciMLBase export LiteImplicitEuler diff --git a/src/Solvers/LiteImplicit/cglite_types.jl b/src/Solvers/LiteImplicit/cglite_types.jl index 82eace98..493584bc 100644 --- a/src/Solvers/LiteImplicit/cglite_types.jl +++ b/src/Solvers/LiteImplicit/cglite_types.jl @@ -22,7 +22,7 @@ struct LiteImplicitEulerCache{Tu,TA} <: SciMLBase.DECache D::TA end -function CommonSolve.init( +function DiffEqBase.__init( prob::CryoGridProblem, alg::LiteImplicitEuler, args...; @@ -43,6 +43,12 @@ function CommonSolve.init( # evaluate tile at initial condition tile = Tiles.materialize(Tile(prob.f), prob.p, t0) tile(zero(u0), u0, prob.p, t0, dt) + # reset SavedValues on tile.data + initialsave = prob.savefunc(tile, u0, similar(u0)) + savevals = SavedValues(Float64, typeof(initialsave)) + push!(savevals.saveval, initialsave) + push!(savevals.t, t0) + tile.data.outputs = savevals sol = CryoGridSolution(prob, u_storage, t_storage, alg, ReturnCode.Default) cache = LiteImplicitEulerCache( similar(prob.u0), # should have ComponentArray type @@ -58,8 +64,5 @@ function CommonSolve.init( ) p = isnothing(prob.p) ? prob.p : collect(prob.p) opts = CryoGridIntegratorOptions(; saveat=CryoGrid.expandtstep(saveat, prob.tspan), dtmax, dtmin, kwargs...) - integrator = CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0, convert(eltype(prob.tspan), dt), 1, 1) - # save initial state - prob.savefunc(u0, t0, integrator) - return integrator + return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0, convert(eltype(prob.tspan), dt), 1, 1) end diff --git a/src/Solvers/Solvers.jl b/src/Solvers/Solvers.jl index 2ad726fd..bc11f318 100644 --- a/src/Solvers/Solvers.jl +++ b/src/Solvers/Solvers.jl @@ -1,7 +1,7 @@ using CryoGrid using CryoGrid.Utils -using DataStructures: OrderedDict using ForwardDiff + using Reexport export CryoGridIntegrator, CryoGridIntegratorOptions, CryoGridSolution diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl index f93da9b3..c03a4b99 100644 --- a/src/Solvers/basic_solvers.jl +++ b/src/Solvers/basic_solvers.jl @@ -19,7 +19,7 @@ struct CGEulerCache{Tu} <: SciMLBase.DECache du::Tu end -function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...) +function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...) tile = Tile(prob.f) u0 = copy(prob.u0) du0 = zero(u0) @@ -30,6 +30,12 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, # evaluate tile at initial condition tile = Tiles.materialize(Tile(prob.f), prob.p, t0) tile(du0, u0, prob.p, t0, dt) + # reset SavedValues on tile.data + initialsave = prob.savefunc(tile, u0, similar(u0)) + savevals = SavedValues(Float64, typeof(initialsave)) + push!(savevals.saveval, initialsave) + push!(savevals.t, t0) + tile.data.outputs = savevals sol = CryoGridSolution(prob, u_storage, t_storage, alg, ReturnCode.Default) cache = CGEulerCache( similar(prob.u0), @@ -42,10 +48,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat = collect(prob.tspan[1]:saveat:prob.tspan[2]) end opts = CryoGridIntegratorOptions(; saveat=expandtstep(saveat, prob.tspan), kwargs...) - integrator = CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1) - # save initial state - prob.savefunc(u0, t0, integrator) - return integrator + return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1) end function perform_step!(integrator::CryoGridIntegrator{CGEuler}) diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl index ba7b2bfe..195bab2b 100755 --- a/src/Solvers/integrator.jl +++ b/src/Solvers/integrator.jl @@ -2,9 +2,7 @@ using DataStructures: SortedSet abstract type CryoGridODEAlgorithm <: SciMLBase.AbstractODEAlgorithm end -abstract type AbstractCryoGridSolution{TT,N,Tu} <: SciMLBase.AbstractODESolution{TT,N,Tu} end - -mutable struct CryoGridSolution{TT,Tu<:AbstractVector{TT},Tt,Talg,Tprob} <: AbstractCryoGridSolution{TT,1,Tu} +mutable struct CryoGridSolution{TT,Tu<:AbstractVector{TT},Tt,Talg,Tprob} <: SciMLBase.AbstractODESolution{TT,1,Tu} prob::Tprob u::Vector{Tu} t::Vector{Tt} @@ -83,7 +81,7 @@ SciMLBase.postamble!(integrator::CryoGridIntegrator) = nothing # add tstop by default because we don't support fancy interpolation DiffEqBase.step!(integrator::CryoGridIntegrator, dt) = step!(integrator, dt, true) -function CommonSolve.step!(integrator::CryoGridIntegrator) +function DiffEqBase.step!(integrator::CryoGridIntegrator) handle_tstops!(integrator) perform_step!(integrator) saveat!(integrator) @@ -91,34 +89,37 @@ function CommonSolve.step!(integrator::CryoGridIntegrator) integrator.dt = min(integrator.opts.dtmax, integrator.dt) end -function CommonSolve.solve!(integrator::CryoGridIntegrator) +function DiffEqBase.__solve(prob::CryoGridProblem, alg::CryoGridODEAlgorithm, args...; kwargs...) + integrator = DiffEqBase.__init(prob, alg, args...; kwargs...) for i in integrator end if integrator.sol.retcode == ReturnCode.Default integrator.sol.retcode = ReturnCode.Success end # if no save points are specified, save final state - if isempty(integrator.sol.prob.saveat) - prob.savefunc(integrator.u, integrator.t, integrator) + if isempty(prob.saveat) + tile = Tile(integrator) + push!(tile.data.outputs.saveval, prob.savefunc(tile, integrator.u, get_du(integrator))) + push!(tile.data.outputs.t, ForwardDiff.value(integrator.t)) push!(integrator.sol.u, integrator.u) push!(integrator.sol.t, integrator.t) end return integrator.sol end -# CryoGridIntegrator interface - perform_step!(integrator::CryoGridIntegrator) = error("perform_step! not implemented for algorithm $(integrator.alg)") function saveat!(integrator::CryoGridIntegrator) - prob = integrator.sol.prob + tile = Tile(integrator) + du = get_du(integrator) saveat = integrator.opts.saveat t_saves = integrator.sol.t u_saves = integrator.sol.u res = searchsorted(saveat, integrator.t) i_next = first(res) i_prev = last(res) - dtsave = if i_next == i_prev - prob.savefunc(integrator.u, integrator.t, integrator) + dtsave = if i_next == i_prev + push!(tile.data.outputs.saveval, integrator.sol.prob.savefunc(tile, integrator.u, du)) + push!(tile.data.outputs.t, ForwardDiff.value(integrator.t)) push!(u_saves, copy(integrator.u)) push!(t_saves, integrator.t) Inf @@ -144,85 +145,3 @@ end expandtstep(tstep::Number, tspan) = tspan[1]:tstep:tspan[end] expandtstep(tstep::AbstractVector, tspan) = tstep - -# Output - -""" - CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) - -Constructs a `CryoGridOutput` from the given `ODESolution`. Optional argument `tspan` restricts the time span of the output. -""" -InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,DateTime}) = CryoGridOutput(sol, convert_tspan(tspan)) -function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple{2,Float64}=(-Inf,Inf)) - # Helper functions for mapping variables to appropriate DimArrays by grid/shape. - withdims(var::Var{name,<:CryoGrid.OnGrid{Cells}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts))) - withdims(var::Var{name,<:CryoGrid.OnGrid{Edges}}, arr, grid, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts))) - withdims(var::Var{name}, arr, zs, ts) where {name} = DimArray(arr*one(vartype(var))*varunits(var), (Dim{name}(1:size(arr,1)),Ti(ts))) - save_interval = ClosedInterval(tspan...) - tile = Tile(sol.prob.f) # Tile - grid = Grid(tile.grid.*u"m") - savecache = sol.prob.savefunc.cache - ts = savecache.t # use save cache time points - # check if last value is duplicated - ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts - t_mask = map(∈(save_interval), ts) # indices within t interval - u_all = sol.(ts[t_mask]) - u_mat = reduce(hcat, u_all) # build prognostic state from continuous solution - pax = ComponentArrays.indexmap(getaxes(tile.state.uproto)[1]) - # get saved diagnostic states and timestamps only in given interval - savedstates = savecache.vals[1:length(ts)][t_mask] - ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0)) - allvars = variables(tile) - progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars)) - diagvars = filter(isdiagnostic, allvars) - fluxvars = filter(isflux, allvars) - outputs = OrderedDict() - # add all on-grid prognostic variables - for var in filter(isongrid, progvars) - name = varname(var) - outputs[name] = withdims(var, u_mat[pax[name],:], grid, ts_datetime) - end - # add all on-grid diagnostic variables - for var in filter(isongrid, tuplejoin(diagvars, fluxvars)) - name = varname(var) - states = collect(skipmissing([name ∈ keys(state) ? state[name] : missing for state in savedstates])) - if length(states) == length(ts_datetime) - arr = reduce(hcat, states) - outputs[name] = withdims(var, arr, grid, ts_datetime) - end - end - # handle per-layer variables - for layer in layernames(tile.strat) - # if layer name appears in saved states or prognostic state axes, then add these variables to the output. - if haskey(savedstates[1], layer) || haskey(pax, layer) - # map over all savedstates and create named tuples for each time step - layerouts = map(u_all, savedstates) do u, state - layerout = OrderedDict() - if haskey(state, layer) - layerstate = state[layer] - for var in keys(layerstate) - layerout[var] = layerstate[var] - end - else - end - # convert to named tuple - diagnostic_output = (;layerout...) - if haskey(u, layer) - u_layer = u[layer] - prognostic_output = (;map(name -> name => u_layer[name], keys(u_layer))...) - return merge(prognostic_output, diagnostic_output) - else - return diagnostic_output - end - end - layerouts_combined = reduce(layerouts[2:end]; init=layerouts[1]) do out1, out2 - map(vcat, out1, out2) - end - # for each variable in the named tuple, find the corresponding variables - layervars = (; map(name -> name => first(filter(var -> varname(var) == name, allvars)), keys(layerouts_combined))...) - # map each output to a variable and call withdims to wrap in a DimArray - outputs[layer] = map((var,out) -> withdims(var, reshape(out,1,:), nothing, ts_datetime), layervars, layerouts_combined) - end - end - return CryoGridOutput(ts_datetime, sol, (;outputs...)) -end diff --git a/src/Tiles/tile.jl b/src/Tiles/tile.jl index a327913a..d68bdc7f 100644 --- a/src/Tiles/tile.jl +++ b/src/Tiles/tile.jl @@ -10,6 +10,7 @@ struct Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} <: AbstractTile{iip inits::TInits # initializers events::TEvents # events inputs::TInputs # inputs + data::TileData # output data metadata::Dict # metadata function Tile( strat::TStrat, @@ -18,14 +19,15 @@ struct Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} <: AbstractTile{iip inits::TInits, events::TEvents, inputs::TInputs, + data::TileData=TileData(), metadata::Dict=Dict(), iip::Bool=true) where {TStrat<:Stratigraphy,TGrid<:Grid{Edges},TStates<:StateVars,TInits<:Tuple,TEvents<:NamedTuple,TInputs<:InputProvider} - new{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}(strat, grid, state, inits, events, inputs, metadata) + new{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}(strat, grid, state, inits, events, inputs, data, metadata) end end ConstructionBase.constructorof(::Type{Tile{TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip}}) where {TStrat,TGrid,TStates,TInits,TEvents,TInputs,iip} = - (strat, grid, state, inits, events, inputs, metadata) -> Tile(strat, grid, state, inits, events, inputs, metadata, iip) + (strat, grid, state, inits, events, inputs, data, metadata) -> Tile(strat, grid, state, inits, events, inputs, data, metadata, iip) # mark only stratigraphy and initializers fields as flattenable Flatten.flattenable(::Type{<:Tile}, ::Type{Val{:strat}}) = true Flatten.flattenable(::Type{<:Tile}, ::Type{Val{:inits}}) = true @@ -90,7 +92,7 @@ function Tile( _addlayerfield(init, Symbol(:init)) end end - tile = Tile(strat, grid, states, inits, (;events...), inputs, metadata, iip) + tile = Tile(strat, grid, states, inits, (;events...), inputs, TileData(), metadata, iip) _validate_inputs(tile, inputs) return tile end @@ -280,6 +282,29 @@ getstate(integrator::SciMLBase.DEIntegrator) = Tiles.getstate(Tile(integrator), """ Numerics.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator; interp=true) = Numerics.getvar(Val{var}(), Tile(integrator), integrator.u; interp) +# """ +# parameterize(tile::Tile) + +# Adds parameter information to all nested types in `tile` by recursively calling `parameterize`. +# """ +# function CryoGrid.parameterize(tile::Tile) +# ctor = ConstructionBase.constructorof(typeof(tile)) +# new_layers = map(namedlayers(tile.strat)) do named_layer +# name = nameof(named_layer) +# layer = CryoGrid.parameterize(named_layer.val) +# name => _addlayerfield(layer, name) +# end +# new_inits = map(tile.inits) do init +# _addlayerfield(CryoGrid.parameterize(init), :init) +# end +# new_events = map(keys(tile.events)) do name +# evs = map(CryoGrid.parameterize, getproperty(tile.events, name)) +# name => _addlayerfield(evs, name) +# end +# new_strat = Stratigraphy(boundaries(tile.strat), (;new_layers...)) +# return ctor(new_strat, tile.grid, tile.state, new_inits, (;new_events...), tile.data, tile.metadata) +# end + """ variables(tile::Tile) @@ -388,6 +413,6 @@ function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvide end # helper method that returns a Union type of all types that should be ignored by Flatten.flatten -@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,Unitful.Quantity,Numerics.ForwardDiff.Dual} +@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual} # ===================================================================== # diff --git a/src/Tiles/tile_base.jl b/src/Tiles/tile_base.jl index d29ed699..be02dd91 100644 --- a/src/Tiles/tile_base.jl +++ b/src/Tiles/tile_base.jl @@ -1,3 +1,8 @@ +mutable struct TileData + outputs::Any + TileData() = new(missing) +end + """ AbstractTile{iip} diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 2ed52d6f..4c20a1e7 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -28,7 +28,7 @@ export @UFloat_str, @UT_str, @setscalar, @threaded, @sym_str, @pstrip include("macros.jl") export StrictlyPositive, StrictlyNegative, Nonnegative, Nonpositive -export applyunits, normalize_units, normalize_temperature, pstrip, adstrip +export applyunits, normalize_units, normalize_temperature, pstrip export fastmap, fastiterate, structiterate, getscalar, tuplejoin, convert_t, convert_tspan, haskeys # Variable/parameter domains @@ -161,14 +161,6 @@ function ffill!(x::AbstractVector{T}) where {E,T<:Union{Missing,E}} return x end -""" - adstrip(x) - -Strips autodiff type wrappers from `x`. -""" -adstrip(x) = x -adstrip(x::ForwardDiff.Dual) = ForwardDiff.value(x) - """ pstrip(obj; keep_units=false) diff --git a/src/problem.jl b/src/problem.jl index a144ef8a..db3ebf19 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -13,17 +13,7 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.Abstrac savefunc::Tsf isoutofdomain::Tdf kwargs::Tkw - CryoGridProblem{iip}( - f::TF, - u0::Tu, - tspan::NTuple{2,Tt}, - p::Tp, - cbs::Tcb, - saveat::Tsv, - savefunc::Tsf, - iood::Tdf, - kwargs::Tkw - ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = + CryoGridProblem{iip}(f::TF, u0::Tu, tspan::NTuple{2,Tt}, p::Tp, cbs::Tcb, saveat::Tsv, savefunc::Tsf, iood::Tdf, kwargs::Tkw) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} = new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs) end @@ -31,7 +21,6 @@ end CryoGridProblem(tile::Tile, u0::ComponentVector, tspan::NTuple{2,DateTime}, args...;kwargs...) """ CryoGridProblem(tile::Tile, u0::ComponentVector, tspan::NTuple{2,DateTime}, args...;kwargs...) = CryoGridProblem(tile, u0, convert_tspan(tspan), args...;kwargs...) - """ CryoGridProblem( tile::Tile, @@ -62,9 +51,10 @@ function CryoGridProblem( p::Union{Nothing,AbstractVector}=nothing; diagnostic_stepsize=3600.0, saveat=3600.0, - save_start=true, - save_everystep=false, savevars=(), + save_everystep=false, + save_start=true, + save_end=true, step_limiter=timestep, safety_factor=1, max_step=true, @@ -74,6 +64,8 @@ function CryoGridProblem( function_kwargs=(), prob_kwargs... ) + getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...)) + savefunc(u, t, integrator) = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator)) # strip all "fixed" parameters tile = stripparams(FixedParam, tile) # retrieve variable parameters @@ -82,10 +74,10 @@ function CryoGridProblem( # remove units tile = stripunits(tile) # set up saving callback + stateproto = getsavestate(tile, u0, du0) + savevals = SavedValues(Float64, typeof(stateproto)) saveat = expandtstep(saveat, tspan) - savecache = InputOutput.SaveCache(Float64[], []) - savefunc = saving_function(savecache, savevars...) - savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep) + savingcallback = SavingCallback(savefunc, savevals; saveat=saveat, save_start=save_start, save_end=save_end, save_everystep=save_everystep) diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!) defaultcallbacks = (savingcallback, diagnostic_step_callback) # add step limiter to default callbacks, if defined @@ -100,12 +92,14 @@ function CryoGridProblem( # add user callbacks usercallbacks = isnothing(callback) ? () : callback callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...) + # note that this implicitly discards any existing saved values in the model setup's state history + tile.data.outputs = savevals # build mass matrix mass_matrix = Numerics.build_mass_matrix(tile.state) # get params p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...) - return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs) + return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, getsavestate, isoutofdomain, prob_kwargs) end function SciMLBase.remake( @@ -132,34 +126,6 @@ function SciMLBase.remake( return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs) end -CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))") - -function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...) - integrator = init(prob, alg, args...; kwargs...) - return solve!(integrator) -end - -SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem( - prob.f, - prob.u0, - prob.tspan, - prob.p; - callback=prob.callbacks, - isoutofdomain=prob.isoutofdomain, - prob.kwargs... -) - -DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob - -function saving_function(cache::InputOutput.SaveCache, savevars...) - getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...)) - return function(u, t, integrator) - state = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator)) - InputOutput.save!(cache, state, adstrip(t)) - return state - end -end - """ odefunction(setup::Tile, u0, p, tspan; kwargs...) @@ -204,6 +170,11 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator) DiffEqBase.u_modified!(integrator, u_modified) end +# overrides to make SciML problem interface work +SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem(prob.f, prob.u0, prob.tspan, prob.p; callback=prob.callbacks, isoutofdomain=prob.isoutofdomain, prob.kwargs...) + +DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob + # callback building functions function _makecallbacks(tile::Tile) eventname(::Event{name}) where name = name -- GitLab