diff --git a/Project.toml b/Project.toml index e4f68010c6ac3c36770846b0d7127a469ba988b4..544ab68a8539b99efe518c98abd007d24db50342 100755 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CryoGrid" uuid = "a535b82e-5f3d-4d97-8b0b-d6483f5bebd5" authors = ["Brian Groenke <brian.groenke@awi.de>", "Moritz Langer <moritz.langer@awi.de>"] -version = "0.5.6" +version = "0.5.7" [deps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" diff --git a/src/Drivers/diffeq.jl b/src/Drivers/diffeq.jl index b9a5908e1a13f6ae4f2728995828da059f615882..fa2905e405d6e21a9da360e1d9c6320aace26634 100644 --- a/src/Drivers/diffeq.jl +++ b/src/Drivers/diffeq.jl @@ -25,6 +25,8 @@ function CryoGridProblem( saveat=3600.0, savevars=(), save_everystep=false, + save_start=true, + save_end=true, callback=nothing, kwargs... ) @@ -40,7 +42,7 @@ function CryoGridProblem( # set up saving callback stateproto = getsavestate(tile, u0, du0) savevals = SavedValues(Float64, typeof(stateproto)) - savingcallback = SavingCallback(savefunc, savevals; saveat=expandtstep(saveat), save_everystep=save_everystep) + savingcallback = SavingCallback(savefunc, savevals; saveat=expandtstep(saveat), save_start=save_start, save_end=save_end, save_everystep=save_everystep) layercallbacks = tuplejoin((_getcallbacks(comp) for comp in tile.strat)...) usercallbacks = isnothing(callback) ? () : callback callbacks = CallbackSet(savingcallback, layercallbacks..., usercallbacks...) @@ -109,25 +111,28 @@ end Builds the state named tuple for `layername` given an initialized integrator. """ Strat.getstate(layername::Symbol, integrator::SciMLBase.DEIntegrator) = getstate(Val{layername}(), integrator) -Strat.getstate(::Val{layername}, integrator::SciMLBase.DEIntegrator) where {layername} = Strat.getstate(integrator.f.f, integrator.u, get_du(integrator), integrator.t) +Strat.getstate(::Val{layername}, integrator::SciMLBase.DEIntegrator) where {layername} = Strat.getstate(Tile(integrator), integrator.u, get_du(integrator), integrator.t) """ getvar(var::Symbol, integrator::SciMLBase.DEIntegrator) """ -Strat.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator) = Strat.getvar(Val{var}(), integrator.f.f, integrator.u) +Strat.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator) = Strat.getvar(Val{var}(), Tile(integrator), integrator.u) """ -Constructs a `CryoGridOutput` from the given `ODESolution`. +Constructs a `CryoGridOutput` from the given `ODESolution`. Optional `tspan` """ -function InputOutput.CryoGridOutput(sol::TSol) where {TSol <: SciMLBase.AbstractODESolution} +function InputOutput.CryoGridOutput(sol::TSol; tspan=nothing) where {TSol <: SciMLBase.AbstractODESolution} # Helper functions for mapping variables to appropriate DimArrays by grid/shape. withdims(::Var{name,T,<:OnGrid{Cells}}, arr, grid, ts) where {name,T} = DimArray(arr*oneunit(T), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts))) withdims(::Var{name,T,<:OnGrid{Edges}}, arr, grid, ts) where {name,T} = DimArray(arr*oneunit(T), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts))) withdims(::Var{name,T}, arr, zs, ts) where {name,T} = DimArray(arr*oneunit(T), (Ti(ts),)) + save_interval = isnothing(tspan) ? -Inf..Inf : ClosedInterval(convert_tspan(tspan)...) model = sol.prob.f.f # Tile ts = model.hist.vals.t # use save callback time points - ts_datetime = Dates.epochms2datetime.(round.(ts*1000.0)) - u_all = reduce(hcat, sol.(ts)) + t_mask = ts .∈ save_interval # indices within t interval + u_all = reduce(hcat, sol.(ts)) # build prognostic state from continuous solution pax = ComponentArrays.indexmap(getaxes(model.state.uproto)[1]) - savedstates = model.hist.vals.saveval + # get saved diagnostic states and timestamps only in given interval + savedstates = model.hist.vals.saveval[t_mask] + ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0)) allvars = variables(model) progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars)) diagvars = filter(isdiagnostic, allvars) @@ -140,7 +145,7 @@ function InputOutput.CryoGridOutput(sol::TSol) where {TSol <: SciMLBase.Abstract for var in filter(isongrid, tuplejoin(diagvars, fluxvars)) name = varname(var) states = collect(skipmissing([name ∈ keys(state) ? state[name] : missing for state in savedstates])) - if length(states) == length(ts) + if length(states) == length(ts_datetime) arr = reduce(hcat, states) outputs[name] = withdims(var, arr, model.grid, ts_datetime) end @@ -173,7 +178,7 @@ function _criterionfunc(::Val{name}, cb::Callback, layer, process) where name (u,t,integrator) -> let layer=layer, process=process, cb=cb, - tile=integrator.f.f, + tile=Tile(integrator), u = Strat.withaxes(u, tile), du = Strat.withaxes(get_du(integrator), tile), t = t; @@ -184,7 +189,7 @@ function _affectfunc(::Val{name}, cb::Callback, layer, process) where name integrator -> let layer=layer, process=process, cb=cb, - tile=integrator.f.f, + tile=Tile(integrator), u = Strat.withaxes(integrator.u, tile), du = Strat.withaxes(get_du(integrator), tile), t = integrator.t; diff --git a/src/Numerics/Numerics.jl b/src/Numerics/Numerics.jl index 89b4290ef328fb2bb233a3ae8ad4000655c25c1c..8c65d2a399a8d4d785da18f8c72601f1f836a10f 100644 --- a/src/Numerics/Numerics.jl +++ b/src/Numerics/Numerics.jl @@ -1,7 +1,6 @@ module Numerics import Base.== -import ExprTools import ForwardDiff import PreallocationTools as Prealloc @@ -13,7 +12,7 @@ using ComponentArrays using DimensionalData: AbstractDimArray, DimArray, Dim, At, dims, Z using Flatten using IfElse -using Interpolations: Interpolations, Gridded, Linear, Flat, Line, interpolate, extrapolate +using Interpolations using IntervalSets using LinearAlgebra using LoopVectorization @@ -36,7 +35,7 @@ struct Cells <: GridSpec end abstract type Geometry end struct UnitVolume <: Geometry end -export ∇ +export ∇, Tabulated include("math.jl") export Grid, cells, edges, indexmap, subgridinds, Δ, volume, area diff --git a/src/Numerics/math.jl b/src/Numerics/math.jl index b9acda09f444d9c51d0e2955c2a8b87d3e15d5a3..8b0d7d262db735e2e32f90d01b46548d87b42a93 100644 --- a/src/Numerics/math.jl +++ b/src/Numerics/math.jl @@ -132,6 +132,7 @@ softplusinv(x) = let x = clamp(x, eps(), Inf); IfElse.ifelse(x > 34, x, log(exp( minusone(x) = x .- one.(x) plusone(x) = x .+ one.(x) +# Symbolic differentiation """ ∇(f, dvar::Symbol) @@ -158,11 +159,7 @@ f(x,y) = 2*x + x*y ``` """ function ∇(f, dvar::Symbol; choosefn=first, context_module=Numerics) - # Parse function parameter names using ExprTools - fms = ExprTools.methods(f) - symbol(arg::Symbol) = arg - symbol(expr::Expr) = expr.args[1] - argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args]) + argnames = Utils.argnames(f, choosefn) @assert dvar in argnames "function must have $dvar as an argument" dind = findfirst(s -> s == dvar, argnames) # Convert to symbols @@ -174,3 +171,38 @@ function ∇(f, dvar::Symbol; choosefn=first, context_module=Numerics) ∇f = @RuntimeGeneratedFunction(context_module, ∇f_expr) return ∇f end + +# Function tabulation +""" + Tabulated(f, argknots...) + +Alias for `tabulate` intended for function types. +""" +Tabulated(f, argknots...) = tabulate(f, argknots...) +""" + tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...) + +Tabulates the given function `f` using a linear, multi-dimensional interpolant. +Knots should be given as pairs `:arg => A` where `A` is a `StepRange` or `Vector` +of input values (knots) at which to evaluate the function. `A` may also be a +`Number`, in which case a pseudo-point interpolant will be used (i.e valid on +`[A,A+ϵ]`). No extrapolation is provided by default but can be configured via +`Interpolations.extrapolate`. +""" +function tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...) + initknots(a::AbstractArray) = Interpolations.deduplicate_knots!(a) + initknots(x::Number) = initknots([x,x]) + names = map(first, argknots) + # get knots for each argument, duplicating if only one value is provided + knots = map(initknots, map(last, argknots)) + f_argnames = Utils.argnames(f) + @assert all(map(name -> name ∈ names, f_argnames)) "Missing one or more arguments $f_argnames in $f" + arggrid = Iterators.product(knots...) + # evaluate function construct interpolant + interp = interpolate(Tuple(knots), map(Base.splat(f), arggrid), Gridded(Linear())) + return interp +end +function ∇(f::AbstractInterpolation) + gradient(args...) = Interpolations.gradient(f, args...) + return gradient +end diff --git a/src/Physics/HeatConduction/soil/sfcc.jl b/src/Physics/HeatConduction/soil/sfcc.jl index df04723e6b92e4f8383c4469d0751275f3c6aa42..23d4e01f2d89435baf41bb6fe0e9b3e92614f652 100644 --- a/src/Physics/HeatConduction/soil/sfcc.jl +++ b/src/Physics/HeatConduction/soil/sfcc.jl @@ -163,6 +163,28 @@ function (f::Westermann)(T,Tₘ,θres,θsat,θtot,δ) IfElse.ifelse(T<=Tₘ, θres - (θsat-θres)*(δ/(T-δ)), θtot) end end +struct SFCCTable{F,I} <: SFCCFunction + f::F + f_tab::I +end +(f::SFCCTable)(args...) = f.f_tab(args...) +""" + Tabulated(f::SFCCFunction, args...) + +Produces an `SFCCTable` function which is a tabulation of `f`. +""" +Numerics.Tabulated(f::SFCCFunction, args...) = SFCCTable(f, Numerics.tabulate(f, args...)) +""" + SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver()) + +Constructs a SFCC from the precomputed `SFCCTable`. The derivative is generated using the +`gradient` function provided by `Interpolations`. +""" +function SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver()) + # we wrap ∇f with Base.splat here to avoid a weird issue with in-place splatting causing allocations + # when applied to runtime generated functions. + SFCC(f, Base.splat(∇(f.f_tab)), s) +end """ Specialized implementation of Newton's method with backtracking line search for resolving diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 2592e668a956f977c943ca3c70342b8ca0f99357..49fcf40974a262def55cf9925fa8fd69f0c4cfaf 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -11,6 +11,7 @@ using StructTypes using Unitful import CryoGrid +import ExprTools import ForwardDiff export @xu_str, @Float_str, @Real_str, @Number_str, @UFloat_str, @UT_str, @setscalar, @threaded @@ -93,6 +94,22 @@ Convenience method for converting between `Dates.DateTime` and solver time. convert_tspan(tspan::NTuple{2,DateTime}) = Dates.datetime2epochms.(tspan) ./ 1000.0 convert_tspan(tspan::NTuple{2,Float64}) = Dates.epochms2datetime.(tspan.*1000.0) +""" + argnames(f, choosefn=first) + +Retrieves the argument names of function `f` via metaprogramming and `ExprTools`. +The optional argument `choosefn` allows for customization of which method instance +of `f` (if there is more than one) is chosen. +""" +function argnames(f, choosefn=first) + # Parse function parameter names using ExprTools + fms = ExprTools.methods(f) + symbol(arg::Symbol) = arg + symbol(expr::Expr) = expr.args[1] + argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args]) + return argnames +end + """ @generated selectat(i::Int, f, args::T) where {T<:Tuple}