Skip to content
Snippets Groups Projects
Commit a09dd1db authored by Brian Groenke's avatar Brian Groenke
Browse files

Merge branch 'feature/cached-freezecurve' into 'master'

Add simple tabulation scheme for freeze curve

See merge request sparcs/cryogrid/cryogridjulia!61
parents b66e8c45 c0af009f
No related branches found
No related tags found
1 merge request!61Add simple tabulation scheme for freeze curve
name = "CryoGrid"
uuid = "a535b82e-5f3d-4d97-8b0b-d6483f5bebd5"
authors = ["Brian Groenke <brian.groenke@awi.de>", "Moritz Langer <moritz.langer@awi.de>"]
version = "0.5.6"
version = "0.5.7"
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
......
......@@ -25,6 +25,8 @@ function CryoGridProblem(
saveat=3600.0,
savevars=(),
save_everystep=false,
save_start=true,
save_end=true,
callback=nothing,
kwargs...
)
......@@ -40,7 +42,7 @@ function CryoGridProblem(
# set up saving callback
stateproto = getsavestate(tile, u0, du0)
savevals = SavedValues(Float64, typeof(stateproto))
savingcallback = SavingCallback(savefunc, savevals; saveat=expandtstep(saveat), save_everystep=save_everystep)
savingcallback = SavingCallback(savefunc, savevals; saveat=expandtstep(saveat), save_start=save_start, save_end=save_end, save_everystep=save_everystep)
layercallbacks = tuplejoin((_getcallbacks(comp) for comp in tile.strat)...)
usercallbacks = isnothing(callback) ? () : callback
callbacks = CallbackSet(savingcallback, layercallbacks..., usercallbacks...)
......@@ -109,25 +111,28 @@ end
Builds the state named tuple for `layername` given an initialized integrator.
"""
Strat.getstate(layername::Symbol, integrator::SciMLBase.DEIntegrator) = getstate(Val{layername}(), integrator)
Strat.getstate(::Val{layername}, integrator::SciMLBase.DEIntegrator) where {layername} = Strat.getstate(integrator.f.f, integrator.u, get_du(integrator), integrator.t)
Strat.getstate(::Val{layername}, integrator::SciMLBase.DEIntegrator) where {layername} = Strat.getstate(Tile(integrator), integrator.u, get_du(integrator), integrator.t)
"""
getvar(var::Symbol, integrator::SciMLBase.DEIntegrator)
"""
Strat.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator) = Strat.getvar(Val{var}(), integrator.f.f, integrator.u)
Strat.getvar(var::Symbol, integrator::SciMLBase.DEIntegrator) = Strat.getvar(Val{var}(), Tile(integrator), integrator.u)
"""
Constructs a `CryoGridOutput` from the given `ODESolution`.
Constructs a `CryoGridOutput` from the given `ODESolution`. Optional `tspan`
"""
function InputOutput.CryoGridOutput(sol::TSol) where {TSol <: SciMLBase.AbstractODESolution}
function InputOutput.CryoGridOutput(sol::TSol; tspan=nothing) where {TSol <: SciMLBase.AbstractODESolution}
# Helper functions for mapping variables to appropriate DimArrays by grid/shape.
withdims(::Var{name,T,<:OnGrid{Cells}}, arr, grid, ts) where {name,T} = DimArray(arr*oneunit(T), (Z(round.(typeof(1.0u"m"), cells(grid), digits=5)),Ti(ts)))
withdims(::Var{name,T,<:OnGrid{Edges}}, arr, grid, ts) where {name,T} = DimArray(arr*oneunit(T), (Z(round.(typeof(1.0u"m"), edges(grid), digits=5)),Ti(ts)))
withdims(::Var{name,T}, arr, zs, ts) where {name,T} = DimArray(arr*oneunit(T), (Ti(ts),))
save_interval = isnothing(tspan) ? -Inf..Inf : ClosedInterval(convert_tspan(tspan)...)
model = sol.prob.f.f # Tile
ts = model.hist.vals.t # use save callback time points
ts_datetime = Dates.epochms2datetime.(round.(ts*1000.0))
u_all = reduce(hcat, sol.(ts))
t_mask = ts . save_interval # indices within t interval
u_all = reduce(hcat, sol.(ts)) # build prognostic state from continuous solution
pax = ComponentArrays.indexmap(getaxes(model.state.uproto)[1])
savedstates = model.hist.vals.saveval
# get saved diagnostic states and timestamps only in given interval
savedstates = model.hist.vals.saveval[t_mask]
ts_datetime = Dates.epochms2datetime.(round.(ts[t_mask]*1000.0))
allvars = variables(model)
progvars = tuplejoin(filter(isprognostic, allvars), filter(isalgebraic, allvars))
diagvars = filter(isdiagnostic, allvars)
......@@ -140,7 +145,7 @@ function InputOutput.CryoGridOutput(sol::TSol) where {TSol <: SciMLBase.Abstract
for var in filter(isongrid, tuplejoin(diagvars, fluxvars))
name = varname(var)
states = collect(skipmissing([name keys(state) ? state[name] : missing for state in savedstates]))
if length(states) == length(ts)
if length(states) == length(ts_datetime)
arr = reduce(hcat, states)
outputs[name] = withdims(var, arr, model.grid, ts_datetime)
end
......@@ -173,7 +178,7 @@ function _criterionfunc(::Val{name}, cb::Callback, layer, process) where name
(u,t,integrator) -> let layer=layer,
process=process,
cb=cb,
tile=integrator.f.f,
tile=Tile(integrator),
u = Strat.withaxes(u, tile),
du = Strat.withaxes(get_du(integrator), tile),
t = t;
......@@ -184,7 +189,7 @@ function _affectfunc(::Val{name}, cb::Callback, layer, process) where name
integrator -> let layer=layer,
process=process,
cb=cb,
tile=integrator.f.f,
tile=Tile(integrator),
u = Strat.withaxes(integrator.u, tile),
du = Strat.withaxes(get_du(integrator), tile),
t = integrator.t;
......
module Numerics
import Base.==
import ExprTools
import ForwardDiff
import PreallocationTools as Prealloc
......@@ -13,7 +12,7 @@ using ComponentArrays
using DimensionalData: AbstractDimArray, DimArray, Dim, At, dims, Z
using Flatten
using IfElse
using Interpolations: Interpolations, Gridded, Linear, Flat, Line, interpolate, extrapolate
using Interpolations
using IntervalSets
using LinearAlgebra
using LoopVectorization
......@@ -36,7 +35,7 @@ struct Cells <: GridSpec end
abstract type Geometry end
struct UnitVolume <: Geometry end
export
export , Tabulated
include("math.jl")
export Grid, cells, edges, indexmap, subgridinds, Δ, volume, area
......
......@@ -132,6 +132,7 @@ softplusinv(x) = let x = clamp(x, eps(), Inf); IfElse.ifelse(x > 34, x, log(exp(
minusone(x) = x .- one.(x)
plusone(x) = x .+ one.(x)
# Symbolic differentiation
"""
∇(f, dvar::Symbol)
......@@ -158,11 +159,7 @@ f(x,y) = 2*x + x*y
```
"""
function(f, dvar::Symbol; choosefn=first, context_module=Numerics)
# Parse function parameter names using ExprTools
fms = ExprTools.methods(f)
symbol(arg::Symbol) = arg
symbol(expr::Expr) = expr.args[1]
argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args])
argnames = Utils.argnames(f, choosefn)
@assert dvar in argnames "function must have $dvar as an argument"
dind = findfirst(s -> s == dvar, argnames)
# Convert to symbols
......@@ -174,3 +171,38 @@ function ∇(f, dvar::Symbol; choosefn=first, context_module=Numerics)
∇f = @RuntimeGeneratedFunction(context_module, ∇f_expr)
return ∇f
end
# Function tabulation
"""
Tabulated(f, argknots...)
Alias for `tabulate` intended for function types.
"""
Tabulated(f, argknots...) = tabulate(f, argknots...)
"""
tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...)
Tabulates the given function `f` using a linear, multi-dimensional interpolant.
Knots should be given as pairs `:arg => A` where `A` is a `StepRange` or `Vector`
of input values (knots) at which to evaluate the function. `A` may also be a
`Number`, in which case a pseudo-point interpolant will be used (i.e valid on
`[A,A+ϵ]`). No extrapolation is provided by default but can be configured via
`Interpolations.extrapolate`.
"""
function tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...)
initknots(a::AbstractArray) = Interpolations.deduplicate_knots!(a)
initknots(x::Number) = initknots([x,x])
names = map(first, argknots)
# get knots for each argument, duplicating if only one value is provided
knots = map(initknots, map(last, argknots))
f_argnames = Utils.argnames(f)
@assert all(map(name -> name names, f_argnames)) "Missing one or more arguments $f_argnames in $f"
arggrid = Iterators.product(knots...)
# evaluate function construct interpolant
interp = interpolate(Tuple(knots), map(Base.splat(f), arggrid), Gridded(Linear()))
return interp
end
function(f::AbstractInterpolation)
gradient(args...) = Interpolations.gradient(f, args...)
return gradient
end
......@@ -163,6 +163,28 @@ function (f::Westermann)(T,Tₘ,θres,θsat,θtot,δ)
IfElse.ifelse(T<=Tₘ, θres - (θsat-θres)*(δ/(T-δ)), θtot)
end
end
struct SFCCTable{F,I} <: SFCCFunction
f::F
f_tab::I
end
(f::SFCCTable)(args...) = f.f_tab(args...)
"""
Tabulated(f::SFCCFunction, args...)
Produces an `SFCCTable` function which is a tabulation of `f`.
"""
Numerics.Tabulated(f::SFCCFunction, args...) = SFCCTable(f, Numerics.tabulate(f, args...))
"""
SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver())
Constructs a SFCC from the precomputed `SFCCTable`. The derivative is generated using the
`gradient` function provided by `Interpolations`.
"""
function SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver())
# we wrap ∇f with Base.splat here to avoid a weird issue with in-place splatting causing allocations
# when applied to runtime generated functions.
SFCC(f, Base.splat((f.f_tab)), s)
end
"""
Specialized implementation of Newton's method with backtracking line search for resolving
......
......@@ -11,6 +11,7 @@ using StructTypes
using Unitful
import CryoGrid
import ExprTools
import ForwardDiff
export @xu_str, @Float_str, @Real_str, @Number_str, @UFloat_str, @UT_str, @setscalar, @threaded
......@@ -93,6 +94,22 @@ Convenience method for converting between `Dates.DateTime` and solver time.
convert_tspan(tspan::NTuple{2,DateTime}) = Dates.datetime2epochms.(tspan) ./ 1000.0
convert_tspan(tspan::NTuple{2,Float64}) = Dates.epochms2datetime.(tspan.*1000.0)
"""
argnames(f, choosefn=first)
Retrieves the argument names of function `f` via metaprogramming and `ExprTools`.
The optional argument `choosefn` allows for customization of which method instance
of `f` (if there is more than one) is chosen.
"""
function argnames(f, choosefn=first)
# Parse function parameter names using ExprTools
fms = ExprTools.methods(f)
symbol(arg::Symbol) = arg
symbol(expr::Expr) = expr.args[1]
argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args])
return argnames
end
"""
@generated selectat(i::Int, f, args::T) where {T<:Tuple}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment