From 2df13aa7facf6c8be285fb5646a45b47acb8319a Mon Sep 17 00:00:00 2001 From: Brian Groenke <brian.groenke@awi.de> Date: Fri, 6 Dec 2024 17:24:08 +0100 Subject: [PATCH] Add param method with global auto-param switch Adds a new method `param` and type `FixedParam` which represents a parameter of the system that is assumed constant at simulation time. This is useful for tracking all parameters of the system but still permitting the user to individually specify which values to vary as parameters. The global switch `AUTOPARAM` can be enabled to mark all parameters as variable by default. --- examples/heat_simple_autodiff_grad.jl | 21 ++++++++--------- src/CryoGrid.jl | 34 ++++++++++++++++++++++----- src/Diagnostics/spinup.jl | 6 ++--- src/IO/InputOutput.jl | 3 --- src/IO/params/param_types.jl | 25 ++++++++++++++++++++ src/IO/params/parameterizations.jl | 5 ---- src/IO/params/params.jl | 13 ++++++++-- src/Physics/Heat/heat_bc.jl | 4 ++-- src/Physics/Hydrology/water_ET.jl | 6 ++--- src/Physics/Hydrology/water_types.jl | 2 +- src/Physics/Salt/salt_types.jl | 4 ++-- src/Physics/Soils/para/simple.jl | 14 +++++------ src/Solvers/basic_solvers.jl | 8 +++++-- src/Tiles/stratigraphy.jl | 2 +- src/Tiles/tile.jl | 10 ++++---- src/Utils/Utils.jl | 20 +++++++++++++--- src/methods.jl | 15 ++++++++++++ src/problem.jl | 19 +++++---------- 18 files changed, 142 insertions(+), 69 deletions(-) create mode 100644 src/IO/params/param_types.jl diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl index 99f74eed..f644a103 100644 --- a/examples/heat_simple_autodiff_grad.jl +++ b/examples/heat_simple_autodiff_grad.jl @@ -3,9 +3,9 @@ # two parameters (summer and winter n-factors) using forward-mode automatic simulation. # # TODO: add more detail/background +using CryoGrid # Set up forcings and boundary conditions similarly to other examples: -using CryoGrid forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044); soilprofile, tempprofile = CryoGrid.SamoylovDefault grid = CryoGrid.DefaultGrid_5cm @@ -22,17 +22,14 @@ tile = CryoGrid.SoilHeatTile( tspan = (DateTime(2010,10,1),DateTime(2010,10,2)) u0, du0 = @time initialcondition!(tile, tspan); -# Collect model parameters -p = CryoGrid.parameters(tile) - # Create the `CryoGridProblem`. -prob = CryoGridProblem(tile, u0, tspan, p, saveat=3600.0); +prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0); # Solve the forward problem with default parameter settings: sol = @time solve(prob) out = CryoGridOutput(sol) -# ForwardDiff provides tools for forward-mode automatic differentiation. +# Import relevant packages for automatic differentiation. using ForwardDiff using SciMLSensitivity using Zygote @@ -40,9 +37,9 @@ using Zygote # Define a "loss" function; here we'll just take the mean over the final temperature field. using Statistics function loss(prob::CryoGridProblem, p) - # local u0, _ = initialcondition!(tile, tspan, p) - # local prob = CryoGridProblem(tile, u0, tspan, p, saveat=24*3600.0) newprob = remake(prob, p=p) + # autojacvec = true uses ForwardDiff to calculate the jacobian; + # enabling checkpointing (theroetically) reduces the memory cost of the backwards pass. sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true) newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg); newout = CryoGridOutput(newsol) @@ -50,6 +47,8 @@ function loss(prob::CryoGridProblem, p) end # Compute gradient with forward diff: -pvec = vec(p) -grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec) -grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec) +pvec = prob.p +fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec) +zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec) +@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-4 "Forward and reverse gradients don't match!" +@show fd_grad diff --git a/src/CryoGrid.jl b/src/CryoGrid.jl index 1f13ec07..9c227ea0 100755 --- a/src/CryoGrid.jl +++ b/src/CryoGrid.jl @@ -1,12 +1,34 @@ module CryoGrid -global CRYOGRID_DEBUG = haskey(ENV,"CG_DEBUG") && ENV["CG_DEBUG"] == "true" -function debug(debug::Bool) - global CRYOGRID_DEBUG = debug +global DEBUG = haskey(ENV,"CG_DEBUG") && ENV["CG_DEBUG"] == "true" +@deprecate debug(x) debug!(x) + +""" + debug!(debug::Bool) + +Enable or disable global debug mode for CryoGrid. Debug mode disables certain optimizations (e.g. loop vectorizations) +which sometimes ineterfere with the debugger. It also enables additional numerical stability checks. +""" +function debug!(debug::Bool) + global DEBUG = debug # disable loop vectorization in debug mode Numerics.turbo(!debug) - CRYOGRID_DEBUG && @warn "Debug mode enabled! Some performance features such as loop vectorization are now turned off by default." - return CRYOGRID_DEBUG + DEBUG && @warn "Debug mode enabled! Some performance features such as loop vectorization are now turned off by default." + return DEBUG +end + +global AUTOPARA = haskey(ENV,"CG_AUTOPARA") && ENV["CG_AUTOPARA"] == "true" + +""" + autoparam!(parameterize::Bool) + +Enable or disable automatic parameterization mode. When enabled, model parameters in all layer/process types will be initialized +as free `Param`s rather than `FixedParam`s. +""" +function autoparam!(parameterize::Bool) + global AUTOPARA = parameterize + AUTOPARA && @warn "Automatic parameterization enabled!" || @warn "Automatic parameterization disabled!" + return AUTOPARA end using Adapt @@ -57,7 +79,7 @@ export BCKind include("traits.jl") export initialcondition!, computediagnostic!, interact!, interactmaybe!, computeprognostic!, resetfluxes!, diagnosticstep! -export variables, processes, initializers, timestep, isactive, caninteract +export variables, processes, initializers, timestep, isactive, caninteract, param export boundaryflux, boundaryvalue, criterion, criterion!, trigger! include("methods.jl") diff --git a/src/Diagnostics/spinup.jl b/src/Diagnostics/spinup.jl index 3049132f..2da0fbb5 100644 --- a/src/Diagnostics/spinup.jl +++ b/src/Diagnostics/spinup.jl @@ -5,9 +5,9 @@ Implements a simple, iterative spin-up procedure. Runs the model specified by `setup` over `tspan` until the profile mean up to `maxdepth` over the whole time span changes only within the given tolerance `tol`. Returns the `ODESolution` generated by the final iteration. """ -function spinup(setup::Tile, tspan::NTuple{2,DateTime}, p, tol, layername; maxdepth=100u"m", maxiter=1000, saveat=3*3600.0, solver=CGEuler(), dt=60.0, solve_args...) - u0, du0 = initialcondition!(tile, tspan, p) - prob = CryoGridProblem(setup, u0, tspan, p) +function spinup(setup::Tile, tspan::NTuple{2,DateTime}, tol, layername; maxdepth=100u"m", maxiter=1000, saveat=3*3600.0, solver=CGEuler(), dt=60.0, solve_args...) + u0, du0 = initialcondition!(tile, tspan) + prob = CryoGridProblem(setup, u0, tspan) @info "Running initial solve ..." sol = solve(prob, solver, dt=dt, saveat=saveat, solve_args...) out = CryoGridOutput(sol) diff --git a/src/IO/InputOutput.jl b/src/IO/InputOutput.jl index aeb237e9..aa76df7b 100644 --- a/src/IO/InputOutput.jl +++ b/src/IO/InputOutput.jl @@ -60,9 +60,6 @@ include("ioutils.jl") export CryoGridParams include("params/params.jl") -export ParamsJSON, ParamsYAML -include("params/params_loaders.jl") - export Input, InputProvider, InputFunctionProvider export inputs include("input.jl") diff --git a/src/IO/params/param_types.jl b/src/IO/params/param_types.jl new file mode 100644 index 00000000..f1fb4e6f --- /dev/null +++ b/src/IO/params/param_types.jl @@ -0,0 +1,25 @@ +""" + FixedParam(p::NamedTuple) + FixedParam(; kw...) + FixedParam(val) + +Subtype of `AbstractParam` that rerpesents a "fixed" parameter which +should not be included in the set of free parameters for the model. +""" +struct FixedParam{T<:Number} <: AbstractParam{T} + parent::NamedTuple + function FixedParam{T}(nt::NamedTuple) where {T<:Number} + @assert :val ∈ keys(nt) + new{T}(nt) + end +end + +function FixedParam(nt::NT) where {NT<:NamedTuple} + FixedParam{typeof(nt.val)}(nt) +end + +FixedParam(val; kwargs...) = FixedParam((; val, kwargs...)) + +Base.parent(param::FixedParam) = getfield(param, :parent) + +ModelParameters.rebuild(param::FixedParam, nt::NamedTuple) = FixedParam(nt) diff --git a/src/IO/params/parameterizations.jl b/src/IO/params/parameterizations.jl index 83878c6c..8a0ee5c5 100644 --- a/src/IO/params/parameterizations.jl +++ b/src/IO/params/parameterizations.jl @@ -127,11 +127,6 @@ function binindex(values::Tuple, st, en, x) end end -CryoGrid.parameters(pw::PiecewiseLinear) = (; - initialvalue = CryoGrid.parameters(pw.initialvalue), - knots = CryoGrid.parameters(pw.knots), -) - # Transformed parameterization """ diff --git a/src/IO/params/params.jl b/src/IO/params/params.jl index cd2e93f9..227c3885 100644 --- a/src/IO/params/params.jl +++ b/src/IO/params/params.jl @@ -1,3 +1,6 @@ +export FixedParam +include("param_types.jl") + """ CryoGridParams{T,TM} <: DenseArray{T,1} @@ -6,7 +9,7 @@ type directly in math or linear algebra operations but rather to use `Base.value of parameter values. """ struct CryoGridParams{T,TM} <: DenseArray{T,1} - obj::TM # param obj + obj::TM # parameter "model" CryoGridParams(m::AbstractModel) = new{eltype(m[:val]),typeof(m)}(m) end @@ -66,7 +69,7 @@ function Base.show(io::IO, ::MIME"text/plain", ps::CryoGridParams{T}) where T ModelParameters.printparams(io, ps.obj) end -paramname(p::Param, component::Type{T}, fieldname) where {T} = fieldname +paramname(p::AbstractParam, component::Type{T}, fieldname) where {T} = fieldname Tables.columns(ps::CryoGridParams) = Tables.columns(ps.obj) Tables.rows(ps::CryoGridParams) = Tables.rows(ps.obj) @@ -91,3 +94,9 @@ function _setparafields(m::Model) newparent = Flatten.reconstruct(parent(m), updated_parameterizations, CryoGrid.Parameterization) return Model(newparent) end + +export ParamsJSON, ParamsYAML +include("params_loaders.jl") + +export PiecewiseConstant, PiecewiseLinear, LinearTrend, Transformed +include("parameterizations.jl") diff --git a/src/Physics/Heat/heat_bc.jl b/src/Physics/Heat/heat_bc.jl index 0a358042..0f05f1d5 100644 --- a/src/Physics/Heat/heat_bc.jl +++ b/src/Physics/Heat/heat_bc.jl @@ -66,8 +66,8 @@ function CryoGrid.computediagnostic!(::Bottom, bc::TemperatureBC, state) end Base.@kwdef struct NFactor{W,S} <: CryoGrid.BoundaryEffect - nf::W = Param(1.0, domain=0..1) # applied when Tair <= 0 - nt::S = Param(1.0, domain=0..1) # applied when Tair > 0 + nf::W = param(1.0, domain=0..1) # applied when Tair <= 0 + nt::S = param(1.0, domain=0..1) # applied when Tair > 0 end CryoGrid.variables(::Top, bc::TemperatureBC{<:NFactor}) = ( diff --git a/src/Physics/Hydrology/water_ET.jl b/src/Physics/Hydrology/water_ET.jl index fe6c3e7e..40c99376 100644 --- a/src/Physics/Hydrology/water_ET.jl +++ b/src/Physics/Hydrology/water_ET.jl @@ -11,9 +11,9 @@ struct EvapTop <: Evapotranspiration end Corresponds to evapotranspiration scheme 2 described in section 2.2.4 of Westermann et al. (2022). """ Base.@kwdef struct DampedET{Tftr,Tdtr,Tdev} <: Evapotranspiration - f_tr::Tftr = Param(0.5, domain=0..1, desc="Factor between 0 and 1 weighting transpirative vs. evaporative fluxes.") - d_tr::Tdtr = Param(0.5, units=u"m", domain=0..Inf, desc="Damping depth for transpiration.") - d_ev::Tdev = Param(0.1, units=u"m", domain=0..Inf, desc="Damping depth for evaporation.") + f_tr::Tftr = param(0.5, domain=0..1, desc="Factor between 0 and 1 weighting transpirative vs. evaporative fluxes.") + d_tr::Tdtr = param(0.5, units=u"m", domain=0..Inf, desc="Damping depth for transpiration.") + d_ev::Tdev = param(0.1, units=u"m", domain=0..Inf, desc="Damping depth for evaporation.") end """ diff --git a/src/Physics/Hydrology/water_types.jl b/src/Physics/Hydrology/water_types.jl index f5b2c72d..04c718bd 100644 --- a/src/Physics/Hydrology/water_types.jl +++ b/src/Physics/Hydrology/water_types.jl @@ -15,7 +15,7 @@ end Default material hydraulic properties. """ Utils.@properties HydraulicProperties( - kw_sat = Param(1e-5, units=u"m/s"), + kw_sat = param(1e-5, units=u"m/s"), ) """ diff --git a/src/Physics/Salt/salt_types.jl b/src/Physics/Salt/salt_types.jl index 88176b45..c7df4b8a 100644 --- a/src/Physics/Salt/salt_types.jl +++ b/src/Physics/Salt/salt_types.jl @@ -1,6 +1,6 @@ SaltProperties( - τ = Param(1.5), # Turtuosity - dₛ₀ = Param(8.0e-10, units=u"m^2/s"), # salt diffusion coefficient + τ = param(1.5), # Turtuosity + dₛ₀ = param(8.0e-10, units=u"m^2/s"), # salt diffusion coefficient ) = (; τ, dₛ₀) abstract type SaltOperator end diff --git a/src/Physics/Soils/para/simple.jl b/src/Physics/Soils/para/simple.jl index 13c7b825..aaa13fdf 100644 --- a/src/Physics/Soils/para/simple.jl +++ b/src/Physics/Soils/para/simple.jl @@ -6,9 +6,9 @@ i.e. natural porosity, saturation, and organic solid fraction. This is the stand of a discrete soil volume. """ Base.@kwdef struct SimpleSoil{Tfc,Tpor,Tsat,Torg,Thp,Twp} <: SoilParameterization - por::Tpor = Param(0.5, domain=0..1, desc="Natural porosity of the soil volume.") - sat::Tsat = Param(1.0, domain=0..1, desc="Initial water+ice saturation level of the soil volume.") - org::Torg = Param(0.0, domain=0..1, desc="Organic solid fraction of the soil volume.") + por::Tpor = param(0.5, domain=0..1, desc="Natural porosity of the soil volume.") + sat::Tsat = param(1.0, domain=0..1, desc="Initial water+ice saturation level of the soil volume.") + org::Torg = param(0.0, domain=0..1, desc="Organic solid fraction of the soil volume.") freezecurve::Tfc = FreeWater() heat::Thp = SoilThermalProperties(SimpleSoil) water::Twp = SoilHydraulicProperties(SimpleSoil, fieldcapacity=0.20) @@ -34,13 +34,13 @@ SoilThermalProperties( kh_w = ThermalProperties().kh_w, kh_i = ThermalProperties().kh_i, kh_a = ThermalProperties().kh_a, - kh_o=Param(0.25, units=u"W/m/K", domain=0..Inf), # organic [Hillel (1982)] - kh_m=Param(3.8, units=u"W/m/K", domain=0..Inf), # mineral [Hillel (1982)] + kh_o=param(0.25, units=u"W/m/K", domain=0..Inf), # organic [Hillel (1982)] + kh_m=param(3.8, units=u"W/m/K", domain=0..Inf), # mineral [Hillel (1982)] ch_w = ThermalProperties().ch_w, ch_i = ThermalProperties().ch_i, ch_a = ThermalProperties().ch_a, - ch_o=Param(2.5e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity organic - ch_m=Param(2.0e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity mineral + ch_o=param(2.5e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity organic + ch_m=param(2.0e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity mineral kwargs..., ) = ThermalProperties(; kh_w, kh_i, kh_a, kh_m, kh_o, ch_w, ch_i, ch_a, ch_m, ch_o, kwargs...) diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl index 0a6c3a53..5e9afe2d 100644 --- a/src/Solvers/basic_solvers.jl +++ b/src/Solvers/basic_solvers.jl @@ -19,7 +19,7 @@ struct CGEulerCache{Tu} <: SciMLBase.DECache du::Tu end -function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=3*3600.0, kwargs...) +function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...) tile = Tile(prob.f) u0 = copy(prob.u0) du0 = zero(u0) @@ -42,7 +42,11 @@ function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0 similar(prob.u0), ) p = isnothing(prob.p) ? prob.p : collect(prob.p) - saveat = isa(saveat, Number) ? collect(prob.tspan[1]:saveat:prob.tspan[2]) : saveat + if isnothing(saveat) + saveat = prob.saveat + elseif isa(saveat, Number) + saveat = collect(prob.tspan[1]:saveat:prob.tspan[2]) + end opts = CryoGridIntegratorOptions(; saveat=expandtstep(saveat, prob.tspan), kwargs...) return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1) end diff --git a/src/Tiles/stratigraphy.jl b/src/Tiles/stratigraphy.jl index 2d2d65e3..cf1861c3 100644 --- a/src/Tiles/stratigraphy.jl +++ b/src/Tiles/stratigraphy.jl @@ -33,7 +33,7 @@ struct Stratigraphy{N,TLayers<:NamedTuple,TBoundaries} @nospecialize(sub::Tuple{Vararg{Pair{<:Number}}}), @nospecialize(bot::Pair{<:Number,<:Bottom}) ) - updateparam(p::Param) = Param(merge(parent(p), (layer=:strat,))) + updateparam(p::paraType) where {paraType<:AbstractParam} = paraType(merge(parent(p), (layer=:strat,))) updateparam(x) = x # check subsurface layers @assert length(sub) > 0 "At least one subsurface layer must be specified" diff --git a/src/Tiles/tile.jl b/src/Tiles/tile.jl index 96f2cd4f..797db01d 100644 --- a/src/Tiles/tile.jl +++ b/src/Tiles/tile.jl @@ -339,7 +339,7 @@ Materializes the given `tile` by: Returns the reconstructed `Tile` instance. """ function materialize(tile::Tile, p::AbstractVector, t::Number) - IgnoreTypes = _ignored_types(tile) + IgnoreTypes = Utils.ignored_types(tile) # ==== Update parameter values ==== # # unfortunately, reconstruct causes allocations due to a mysterious dynamic dispatch when returning the result of _reconstruct; # I really don't know why, could be a compiler bug, but it doesn't happen if we call the internal _reconstruct method directly... @@ -349,7 +349,7 @@ function materialize(tile::Tile, p::AbstractVector, t::Number) return materialize(parameterized_tile, nothing, t) end function materialize(tile::Tile, ::Nothing, t::Number) - IgnoreTypes = _ignored_types(tile) + IgnoreTypes = Utils.ignored_types(tile) # ==== Compute dynamic parameter values ==== # # TODO: perhaps should allow dependence on local layer state; # this would likely require deconstruction/reconstruction of layers in order to @@ -365,7 +365,7 @@ function materialize(tile::Tile, ::Nothing, t::Number) end function checkstate!(tile::Tile, state::TileState, u, du, label::Symbol) - if CryoGrid.CRYOGRID_DEBUG + if CryoGrid.DEBUG @inbounds for i in eachindex(u) if !isfinite(u[i]) debughook!(tile, state, AssertionError("[$label] Found NaN/Inf value in current state vector at index $i")) @@ -403,7 +403,7 @@ function _initstatevars(@nospecialize(strat::Stratigraphy), @nospecialize(grid:: end function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvider) - IgnoreTypes = _ignored_types(tile) + IgnoreTypes = Utils.ignored_types(tile) inputs = Flatten.flatten(tile, Flatten.flattenable, Input, IgnoreTypes) names = keys(inputprovider) for input in inputs @@ -413,6 +413,6 @@ function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvide end # helper method that returns a Union type of all types that should be ignored by Flatten.flatten -@inline _ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual} +@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual} # ===================================================================== # diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 185c5d7d..4c20a1e7 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -172,22 +172,36 @@ function pstrip(obj; keep_units=false) end # TODO: this should be in ModelParameters.jl, not here. -function Unitful.uconvert(u::Unitful.Units, p::Param) +function Unitful.uconvert(u::Unitful.Units, p::paraType) where {paraType<:AbstractParam} nt = parent(p) @set! nt.val = ustrip(u, stripparams(p)) @set! nt.units = u - return Param(nt) + return paraType(nt) end """ ModelParameters.stripunits(obj) -Additional override for `stripunits` which reconstructs `obj` with all fields that have unitful quantity +Additional dispatch for `stripunits` which reconstructs `obj` with all fields that have unitful quantity types converted to base SI units and then stripped to be unit free. """ function ModelParameters.stripunits(obj) values = Flatten.flatten(obj, Flatten.flattenable, Unitful.AbstractQuantity, Flatten.IGNORE) return Flatten.reconstruct(obj, map(ustrip ∘ normalize_units, values), Unitful.AbstractQuantity, Flatten.IGNORE) end + +""" + ModelParameters.stripparams(::Type{paraType}, obj) where {paraType<:AbstractParam} + +Additional dispatch for `stripparams` which allows specification of a specific type. +""" +function ModelParameters.stripparams(::Type{paraType}, obj) where {paraType<:AbstractParam} + IgnoreTypes = ignored_types(obj) + selected_params = Flatten.flatten(obj, Flatten.flattenable, paraType, IgnoreTypes) + return Flatten._reconstruct(obj, map(stripparams, selected_params), Flatten.flattenable, paraType, IgnoreTypes, 1)[1] +end + +ignored_types(obj) = ModelParameters.IGNORE + # pretty print Param types Base.show(io::IO, mime::MIME"text/plain", p::Param) = print(io, "Param($(p.val))") diff --git a/src/methods.jl b/src/methods.jl index 0017f588..0900440f 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -233,3 +233,18 @@ thickness(::Layer, state, i) = Δ(state.grid)[i] thickness(l::Layer, state, ::typeof(first)) = thickness(l, state, 1) thickness(l::Layer, state, ::typeof(last)) = thickness(l, state, lastindex(state.grid)-1) thickness(::Union{Top,Bottom}, state) = Inf + + +""" + param([::Type{paraType}], defval; kwargs...) + +Creates a new parameter type from the given default value and keyword properties. +""" +param(::Type{paraType}, defval; kwargs...) where {paraType<:AbstractParam} = paraType(defval; kwargs...) +function param(defval; kwargs...) + if AUTOPARA + return param(Param, defval; kwargs...) + else + return param(FixedParam, defval; kwargs...) + end +end diff --git a/src/problem.jl b/src/problem.jl index daf876a3..6c05881f 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -47,8 +47,7 @@ Constructor for `CryoGridProblem` that automatically generates all necessary cal function CryoGridProblem( tile::Tile, u0::ComponentVector, - tspan::NTuple{2,Float64}, - p=nothing; + tspan::NTuple{2,Float64}; diagnostic_stepsize=3600.0, saveat=3600.0, savevars=(), @@ -66,16 +65,10 @@ function CryoGridProblem( ) getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...)) savefunc(u, t, integrator) = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator)) - tile, p = if isnothing(p) && isempty(ModelParameters.params(tile)) - # case 1: no parameters provided - tile, nothing - else - # case 2: parameters are provided; use Model interface to reconstruct Tile with new parameter values - model_tile = Model(tile) - p = isnothing(p) ? collect(model_tile[:val]) : p - model_tile[:val] = p - parent(model_tile), p - end + # strip all "fixed" parameters + tile = stripparams(FixedParam, tile) + # retrieve variable parameters + p = length(ModelParameters.params(tile)) > 0 ? parameters(tile) : nothing du0 = zero(u0) # remove units tile = stripunits(tile) @@ -142,7 +135,7 @@ function CryoGrid.odefunction(::DefaultJac, setup::typeof(tile), u0, p, tspan) # make sure to return an instance of ODEFunction end ... -prob = CryoGridProblem(tile, tspan, p) +prob = CryoGridProblem(tile, u0, tspan) ``` `JacobianStyle` can also be extended to create custom traits which can then be applied to compatible `Tile`s. -- GitLab