From 2df13aa7facf6c8be285fb5646a45b47acb8319a Mon Sep 17 00:00:00 2001
From: Brian Groenke <brian.groenke@awi.de>
Date: Fri, 6 Dec 2024 17:24:08 +0100
Subject: [PATCH] Add param method with global auto-param switch

Adds a new method `param` and type `FixedParam` which represents a parameter
of the system that is assumed constant at simulation time. This is useful
for tracking all parameters of the system but still permitting the user
to individually specify which values to vary as parameters.

The global switch `AUTOPARAM` can be enabled to mark all parameters
as variable by default.
---
 examples/heat_simple_autodiff_grad.jl | 21 ++++++++---------
 src/CryoGrid.jl                       | 34 ++++++++++++++++++++++-----
 src/Diagnostics/spinup.jl             |  6 ++---
 src/IO/InputOutput.jl                 |  3 ---
 src/IO/params/param_types.jl          | 25 ++++++++++++++++++++
 src/IO/params/parameterizations.jl    |  5 ----
 src/IO/params/params.jl               | 13 ++++++++--
 src/Physics/Heat/heat_bc.jl           |  4 ++--
 src/Physics/Hydrology/water_ET.jl     |  6 ++---
 src/Physics/Hydrology/water_types.jl  |  2 +-
 src/Physics/Salt/salt_types.jl        |  4 ++--
 src/Physics/Soils/para/simple.jl      | 14 +++++------
 src/Solvers/basic_solvers.jl          |  8 +++++--
 src/Tiles/stratigraphy.jl             |  2 +-
 src/Tiles/tile.jl                     | 10 ++++----
 src/Utils/Utils.jl                    | 20 +++++++++++++---
 src/methods.jl                        | 15 ++++++++++++
 src/problem.jl                        | 19 +++++----------
 18 files changed, 142 insertions(+), 69 deletions(-)
 create mode 100644 src/IO/params/param_types.jl

diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl
index 99f74eed..f644a103 100644
--- a/examples/heat_simple_autodiff_grad.jl
+++ b/examples/heat_simple_autodiff_grad.jl
@@ -3,9 +3,9 @@
 # two parameters (summer and winter n-factors) using forward-mode automatic simulation.
 #
 # TODO: add more detail/background
+using CryoGrid
 
 # Set up forcings and boundary conditions similarly to other examples:
-using CryoGrid
 forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
 soilprofile, tempprofile = CryoGrid.SamoylovDefault
 grid = CryoGrid.DefaultGrid_5cm
@@ -22,17 +22,14 @@ tile = CryoGrid.SoilHeatTile(
 tspan = (DateTime(2010,10,1),DateTime(2010,10,2))
 u0, du0 = @time initialcondition!(tile, tspan);
 
-# Collect model parameters
-p = CryoGrid.parameters(tile)
-
 # Create the `CryoGridProblem`.
-prob = CryoGridProblem(tile, u0, tspan, p, saveat=3600.0);
+prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0);
 
 # Solve the forward problem with default parameter settings:
 sol = @time solve(prob)
 out = CryoGridOutput(sol)
 
-# ForwardDiff provides tools for forward-mode automatic differentiation.
+# Import relevant packages for automatic differentiation.
 using ForwardDiff
 using SciMLSensitivity
 using Zygote
@@ -40,9 +37,9 @@ using Zygote
 # Define a "loss" function; here we'll just take the mean over the final temperature field.
 using Statistics
 function loss(prob::CryoGridProblem, p)
-    # local u0, _ = initialcondition!(tile, tspan, p)
-    # local prob = CryoGridProblem(tile, u0, tspan, p, saveat=24*3600.0)
     newprob = remake(prob, p=p)
+    # autojacvec = true uses ForwardDiff to calculate the jacobian;
+    # enabling checkpointing (theroetically) reduces the memory cost of the backwards pass.
     sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true)
     newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg);
     newout = CryoGridOutput(newsol)
@@ -50,6 +47,8 @@ function loss(prob::CryoGridProblem, p)
 end
 
 # Compute gradient with forward diff:
-pvec = vec(p)
-grad = @time ForwardDiff.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
-grad = @time Zygote.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
+pvec = prob.p
+fd_grad = @time ForwardDiff.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
+zy_grad = @time Zygote.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
+@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-4 "Forward and reverse gradients don't match!"
+@show fd_grad
diff --git a/src/CryoGrid.jl b/src/CryoGrid.jl
index 1f13ec07..9c227ea0 100755
--- a/src/CryoGrid.jl
+++ b/src/CryoGrid.jl
@@ -1,12 +1,34 @@
 module CryoGrid
 
-global CRYOGRID_DEBUG = haskey(ENV,"CG_DEBUG") && ENV["CG_DEBUG"] == "true"
-function debug(debug::Bool)
-    global CRYOGRID_DEBUG = debug
+global DEBUG = haskey(ENV,"CG_DEBUG") && ENV["CG_DEBUG"] == "true"
+@deprecate debug(x) debug!(x)
+
+"""
+    debug!(debug::Bool)
+
+Enable or disable global debug mode for CryoGrid. Debug mode disables certain optimizations (e.g. loop vectorizations)
+which sometimes ineterfere with the debugger. It also enables additional numerical stability checks.
+"""
+function debug!(debug::Bool)
+    global DEBUG = debug
     # disable loop vectorization in debug mode
     Numerics.turbo(!debug)
-    CRYOGRID_DEBUG && @warn "Debug mode enabled! Some performance features such as loop vectorization are now turned off by default."
-    return CRYOGRID_DEBUG
+    DEBUG && @warn "Debug mode enabled! Some performance features such as loop vectorization are now turned off by default."
+    return DEBUG
+end
+
+global AUTOPARA = haskey(ENV,"CG_AUTOPARA") && ENV["CG_AUTOPARA"] == "true"
+
+"""
+    autoparam!(parameterize::Bool)
+
+Enable or disable automatic parameterization mode. When enabled, model parameters in all layer/process types will be initialized
+as free `Param`s rather than `FixedParam`s.
+"""
+function autoparam!(parameterize::Bool)
+    global AUTOPARA = parameterize
+    AUTOPARA && @warn "Automatic parameterization enabled!" || @warn "Automatic parameterization disabled!"
+    return AUTOPARA
 end
 
 using Adapt
@@ -57,7 +79,7 @@ export BCKind
 include("traits.jl")
 
 export initialcondition!, computediagnostic!, interact!, interactmaybe!, computeprognostic!, resetfluxes!, diagnosticstep!
-export variables, processes, initializers, timestep, isactive, caninteract
+export variables, processes, initializers, timestep, isactive, caninteract, param
 export boundaryflux, boundaryvalue, criterion, criterion!, trigger!
 include("methods.jl")
 
diff --git a/src/Diagnostics/spinup.jl b/src/Diagnostics/spinup.jl
index 3049132f..2da0fbb5 100644
--- a/src/Diagnostics/spinup.jl
+++ b/src/Diagnostics/spinup.jl
@@ -5,9 +5,9 @@ Implements a simple, iterative spin-up procedure.
 Runs the model specified by `setup` over `tspan` until the profile mean up to `maxdepth` over the whole time span changes only within the given tolerance `tol`.
 Returns the `ODESolution` generated by the final iteration.
 """
-function spinup(setup::Tile, tspan::NTuple{2,DateTime}, p, tol, layername; maxdepth=100u"m", maxiter=1000, saveat=3*3600.0, solver=CGEuler(), dt=60.0, solve_args...)
-    u0, du0 = initialcondition!(tile, tspan, p)
-    prob = CryoGridProblem(setup, u0, tspan, p)
+function spinup(setup::Tile, tspan::NTuple{2,DateTime}, tol, layername; maxdepth=100u"m", maxiter=1000, saveat=3*3600.0, solver=CGEuler(), dt=60.0, solve_args...)
+    u0, du0 = initialcondition!(tile, tspan)
+    prob = CryoGridProblem(setup, u0, tspan)
     @info "Running initial solve ..."
     sol = solve(prob, solver, dt=dt, saveat=saveat, solve_args...)
     out = CryoGridOutput(sol)
diff --git a/src/IO/InputOutput.jl b/src/IO/InputOutput.jl
index aeb237e9..aa76df7b 100644
--- a/src/IO/InputOutput.jl
+++ b/src/IO/InputOutput.jl
@@ -60,9 +60,6 @@ include("ioutils.jl")
 export CryoGridParams
 include("params/params.jl")
 
-export ParamsJSON, ParamsYAML
-include("params/params_loaders.jl")
-
 export Input, InputProvider, InputFunctionProvider
 export inputs
 include("input.jl")
diff --git a/src/IO/params/param_types.jl b/src/IO/params/param_types.jl
new file mode 100644
index 00000000..f1fb4e6f
--- /dev/null
+++ b/src/IO/params/param_types.jl
@@ -0,0 +1,25 @@
+"""
+    FixedParam(p::NamedTuple)
+    FixedParam(; kw...)
+    FixedParam(val)
+
+Subtype of `AbstractParam` that rerpesents a "fixed" parameter which
+should not be included in the set of free parameters for the model.
+"""
+struct FixedParam{T<:Number} <: AbstractParam{T}
+    parent::NamedTuple
+    function FixedParam{T}(nt::NamedTuple) where {T<:Number}
+        @assert :val ∈ keys(nt)
+        new{T}(nt)
+    end
+end
+
+function FixedParam(nt::NT) where {NT<:NamedTuple}
+    FixedParam{typeof(nt.val)}(nt)
+end
+
+FixedParam(val; kwargs...) = FixedParam((; val, kwargs...))
+
+Base.parent(param::FixedParam) = getfield(param, :parent)
+
+ModelParameters.rebuild(param::FixedParam, nt::NamedTuple) = FixedParam(nt)
diff --git a/src/IO/params/parameterizations.jl b/src/IO/params/parameterizations.jl
index 83878c6c..8a0ee5c5 100644
--- a/src/IO/params/parameterizations.jl
+++ b/src/IO/params/parameterizations.jl
@@ -127,11 +127,6 @@ function binindex(values::Tuple, st, en, x)
     end
 end
 
-CryoGrid.parameters(pw::PiecewiseLinear) = (;
-    initialvalue = CryoGrid.parameters(pw.initialvalue),
-    knots = CryoGrid.parameters(pw.knots),
-)
-
 # Transformed parameterization
 
 """
diff --git a/src/IO/params/params.jl b/src/IO/params/params.jl
index cd2e93f9..227c3885 100644
--- a/src/IO/params/params.jl
+++ b/src/IO/params/params.jl
@@ -1,3 +1,6 @@
+export FixedParam
+include("param_types.jl")
+
 """
     CryoGridParams{T,TM} <: DenseArray{T,1}
 
@@ -6,7 +9,7 @@ type directly in math or linear algebra operations but rather to use `Base.value
 of parameter values.
 """
 struct CryoGridParams{T,TM} <: DenseArray{T,1}
-    obj::TM # param obj
+    obj::TM # parameter "model"
     CryoGridParams(m::AbstractModel) = new{eltype(m[:val]),typeof(m)}(m)
 end
 
@@ -66,7 +69,7 @@ function Base.show(io::IO, ::MIME"text/plain", ps::CryoGridParams{T}) where T
     ModelParameters.printparams(io, ps.obj)
 end
 
-paramname(p::Param, component::Type{T}, fieldname) where {T} = fieldname
+paramname(p::AbstractParam, component::Type{T}, fieldname) where {T} = fieldname
 
 Tables.columns(ps::CryoGridParams) = Tables.columns(ps.obj)
 Tables.rows(ps::CryoGridParams) = Tables.rows(ps.obj)
@@ -91,3 +94,9 @@ function _setparafields(m::Model)
     newparent = Flatten.reconstruct(parent(m), updated_parameterizations, CryoGrid.Parameterization)
     return Model(newparent)
 end
+
+export ParamsJSON, ParamsYAML
+include("params_loaders.jl")
+
+export PiecewiseConstant, PiecewiseLinear, LinearTrend, Transformed
+include("parameterizations.jl")
diff --git a/src/Physics/Heat/heat_bc.jl b/src/Physics/Heat/heat_bc.jl
index 0a358042..0f05f1d5 100644
--- a/src/Physics/Heat/heat_bc.jl
+++ b/src/Physics/Heat/heat_bc.jl
@@ -66,8 +66,8 @@ function CryoGrid.computediagnostic!(::Bottom, bc::TemperatureBC, state)
 end
 
 Base.@kwdef struct NFactor{W,S} <: CryoGrid.BoundaryEffect
-    nf::W = Param(1.0, domain=0..1) # applied when Tair <= 0
-    nt::S = Param(1.0, domain=0..1) # applied when Tair > 0
+    nf::W = param(1.0, domain=0..1) # applied when Tair <= 0
+    nt::S = param(1.0, domain=0..1) # applied when Tair > 0
 end
 
 CryoGrid.variables(::Top, bc::TemperatureBC{<:NFactor}) = (
diff --git a/src/Physics/Hydrology/water_ET.jl b/src/Physics/Hydrology/water_ET.jl
index fe6c3e7e..40c99376 100644
--- a/src/Physics/Hydrology/water_ET.jl
+++ b/src/Physics/Hydrology/water_ET.jl
@@ -11,9 +11,9 @@ struct EvapTop <: Evapotranspiration end
 Corresponds to evapotranspiration scheme 2 described in section 2.2.4 of Westermann et al. (2022).
 """
 Base.@kwdef struct DampedET{Tftr,Tdtr,Tdev} <: Evapotranspiration
-    f_tr::Tftr = Param(0.5, domain=0..1, desc="Factor between 0 and 1 weighting transpirative vs. evaporative fluxes.")
-    d_tr::Tdtr = Param(0.5, units=u"m", domain=0..Inf, desc="Damping depth for transpiration.")
-    d_ev::Tdev = Param(0.1, units=u"m", domain=0..Inf, desc="Damping depth for evaporation.")
+    f_tr::Tftr = param(0.5, domain=0..1, desc="Factor between 0 and 1 weighting transpirative vs. evaporative fluxes.")
+    d_tr::Tdtr = param(0.5, units=u"m", domain=0..Inf, desc="Damping depth for transpiration.")
+    d_ev::Tdev = param(0.1, units=u"m", domain=0..Inf, desc="Damping depth for evaporation.")
 end
 
 """
diff --git a/src/Physics/Hydrology/water_types.jl b/src/Physics/Hydrology/water_types.jl
index f5b2c72d..04c718bd 100644
--- a/src/Physics/Hydrology/water_types.jl
+++ b/src/Physics/Hydrology/water_types.jl
@@ -15,7 +15,7 @@ end
 Default material hydraulic properties.
 """
 Utils.@properties HydraulicProperties(
-    kw_sat = Param(1e-5, units=u"m/s"),
+    kw_sat = param(1e-5, units=u"m/s"),
 )
 
 """
diff --git a/src/Physics/Salt/salt_types.jl b/src/Physics/Salt/salt_types.jl
index 88176b45..c7df4b8a 100644
--- a/src/Physics/Salt/salt_types.jl
+++ b/src/Physics/Salt/salt_types.jl
@@ -1,6 +1,6 @@
 SaltProperties(
-    Ï„ = Param(1.5), # Turtuosity
-    dₛ₀ = Param(8.0e-10, units=u"m^2/s"), # salt diffusion coefficient    
+    Ï„ = param(1.5), # Turtuosity
+    dₛ₀ = param(8.0e-10, units=u"m^2/s"), # salt diffusion coefficient    
 ) = (; τ, dₛ₀)
 
 abstract type SaltOperator end
diff --git a/src/Physics/Soils/para/simple.jl b/src/Physics/Soils/para/simple.jl
index 13c7b825..aaa13fdf 100644
--- a/src/Physics/Soils/para/simple.jl
+++ b/src/Physics/Soils/para/simple.jl
@@ -6,9 +6,9 @@ i.e. natural porosity, saturation, and organic solid fraction. This is the stand
 of a discrete soil volume.
 """
 Base.@kwdef struct SimpleSoil{Tfc,Tpor,Tsat,Torg,Thp,Twp} <: SoilParameterization
-    por::Tpor = Param(0.5, domain=0..1, desc="Natural porosity of the soil volume.")
-    sat::Tsat = Param(1.0, domain=0..1, desc="Initial water+ice saturation level of the soil volume.")
-    org::Torg = Param(0.0, domain=0..1, desc="Organic solid fraction of the soil volume.")
+    por::Tpor = param(0.5, domain=0..1, desc="Natural porosity of the soil volume.")
+    sat::Tsat = param(1.0, domain=0..1, desc="Initial water+ice saturation level of the soil volume.")
+    org::Torg = param(0.0, domain=0..1, desc="Organic solid fraction of the soil volume.")
     freezecurve::Tfc = FreeWater()
     heat::Thp = SoilThermalProperties(SimpleSoil)
     water::Twp = SoilHydraulicProperties(SimpleSoil, fieldcapacity=0.20)
@@ -34,13 +34,13 @@ SoilThermalProperties(
     kh_w = ThermalProperties().kh_w,
     kh_i = ThermalProperties().kh_i,
     kh_a = ThermalProperties().kh_a,
-    kh_o=Param(0.25, units=u"W/m/K", domain=0..Inf), # organic [Hillel (1982)]
-    kh_m=Param(3.8, units=u"W/m/K", domain=0..Inf), # mineral [Hillel (1982)]
+    kh_o=param(0.25, units=u"W/m/K", domain=0..Inf), # organic [Hillel (1982)]
+    kh_m=param(3.8, units=u"W/m/K", domain=0..Inf), # mineral [Hillel (1982)]
     ch_w = ThermalProperties().ch_w,
     ch_i = ThermalProperties().ch_i,
     ch_a = ThermalProperties().ch_a,
-    ch_o=Param(2.5e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity organic
-    ch_m=Param(2.0e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity mineral
+    ch_o=param(2.5e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity organic
+    ch_m=param(2.0e6, units=u"J/K/m^3", domain=0..Inf), # heat capacity mineral
     kwargs...,
 ) = ThermalProperties(; kh_w, kh_i, kh_a, kh_m, kh_o, ch_w, ch_i, ch_a, ch_m, ch_o, kwargs...)
 
diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl
index 0a6c3a53..5e9afe2d 100644
--- a/src/Solvers/basic_solvers.jl
+++ b/src/Solvers/basic_solvers.jl
@@ -19,7 +19,7 @@ struct CGEulerCache{Tu} <: SciMLBase.DECache
     du::Tu
 end
 
-function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=3*3600.0, kwargs...)
+function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0, saveat=nothing, kwargs...)
     tile = Tile(prob.f)
     u0 = copy(prob.u0)
     du0 = zero(u0)
@@ -42,7 +42,11 @@ function DiffEqBase.__init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0
         similar(prob.u0),
     )
     p = isnothing(prob.p) ? prob.p : collect(prob.p)
-    saveat = isa(saveat, Number) ? collect(prob.tspan[1]:saveat:prob.tspan[2]) : saveat
+    if isnothing(saveat)
+        saveat = prob.saveat
+    elseif isa(saveat, Number)
+        saveat = collect(prob.tspan[1]:saveat:prob.tspan[2])
+    end
     opts = CryoGridIntegratorOptions(; saveat=expandtstep(saveat, prob.tspan), kwargs...)
     return CryoGridIntegrator(alg, cache, opts, sol, copy(u0), p, t0*one(eltype(u0)), dt*one(eltype(u0)), 1, 1)
 end
diff --git a/src/Tiles/stratigraphy.jl b/src/Tiles/stratigraphy.jl
index 2d2d65e3..cf1861c3 100644
--- a/src/Tiles/stratigraphy.jl
+++ b/src/Tiles/stratigraphy.jl
@@ -33,7 +33,7 @@ struct Stratigraphy{N,TLayers<:NamedTuple,TBoundaries}
         @nospecialize(sub::Tuple{Vararg{Pair{<:Number}}}),
         @nospecialize(bot::Pair{<:Number,<:Bottom})
     )
-        updateparam(p::Param) = Param(merge(parent(p), (layer=:strat,)))
+        updateparam(p::paraType) where {paraType<:AbstractParam} = paraType(merge(parent(p), (layer=:strat,)))
         updateparam(x) = x
         # check subsurface layers
         @assert length(sub) > 0 "At least one subsurface layer must be specified"
diff --git a/src/Tiles/tile.jl b/src/Tiles/tile.jl
index 96f2cd4f..797db01d 100644
--- a/src/Tiles/tile.jl
+++ b/src/Tiles/tile.jl
@@ -339,7 +339,7 @@ Materializes the given `tile` by:
 Returns the reconstructed `Tile` instance.
 """
 function materialize(tile::Tile, p::AbstractVector, t::Number)
-    IgnoreTypes = _ignored_types(tile)
+    IgnoreTypes = Utils.ignored_types(tile)
     # ==== Update parameter values ==== #
     # unfortunately, reconstruct causes allocations due to a mysterious dynamic dispatch when returning the result of _reconstruct;
     # I really don't know why, could be a compiler bug, but it doesn't happen if we call the internal _reconstruct method directly...
@@ -349,7 +349,7 @@ function materialize(tile::Tile, p::AbstractVector, t::Number)
     return materialize(parameterized_tile, nothing, t)
 end
 function materialize(tile::Tile, ::Nothing, t::Number)
-    IgnoreTypes = _ignored_types(tile)
+    IgnoreTypes = Utils.ignored_types(tile)
     # ==== Compute dynamic parameter values ==== #
     # TODO: perhaps should allow dependence on local layer state;
     # this would likely require deconstruction/reconstruction of layers in order to
@@ -365,7 +365,7 @@ function materialize(tile::Tile, ::Nothing, t::Number)
 end
 
 function checkstate!(tile::Tile, state::TileState, u, du, label::Symbol)
-    if CryoGrid.CRYOGRID_DEBUG
+    if CryoGrid.DEBUG
         @inbounds for i in eachindex(u)
             if !isfinite(u[i])
                 debughook!(tile, state, AssertionError("[$label] Found NaN/Inf value in current state vector at index $i"))
@@ -403,7 +403,7 @@ function _initstatevars(@nospecialize(strat::Stratigraphy), @nospecialize(grid::
 end
 
 function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvider)
-    IgnoreTypes = _ignored_types(tile)
+    IgnoreTypes = Utils.ignored_types(tile)
     inputs = Flatten.flatten(tile, Flatten.flattenable, Input, IgnoreTypes)
     names = keys(inputprovider)
     for input in inputs
@@ -413,6 +413,6 @@ function _validate_inputs(@nospecialize(tile::Tile), inputprovider::InputProvide
 end
 
 # helper method that returns a Union type of all types that should be ignored by Flatten.flatten
-@inline _ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual}
+@inline Utils.ignored_types(::Tile{TStrat,TGrid,TStates}) where {TStrat,TGrid,TStates} = Union{TGrid,TStates,TileData,Unitful.Quantity,Numerics.ForwardDiff.Dual}
 
 # ===================================================================== #
diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl
index 185c5d7d..4c20a1e7 100644
--- a/src/Utils/Utils.jl
+++ b/src/Utils/Utils.jl
@@ -172,22 +172,36 @@ function pstrip(obj; keep_units=false)
 end
 
 # TODO: this should be in ModelParameters.jl, not here.
-function Unitful.uconvert(u::Unitful.Units, p::Param)
+function Unitful.uconvert(u::Unitful.Units, p::paraType) where {paraType<:AbstractParam}
     nt = parent(p)
     @set! nt.val = ustrip(u, stripparams(p))
     @set! nt.units = u
-    return Param(nt)
+    return paraType(nt)
 end
 """
     ModelParameters.stripunits(obj)
 
-Additional override for `stripunits` which reconstructs `obj` with all fields that have unitful quantity
+Additional dispatch for `stripunits` which reconstructs `obj` with all fields that have unitful quantity
 types converted to base SI units and then stripped to be unit free.
 """
 function ModelParameters.stripunits(obj)
     values = Flatten.flatten(obj, Flatten.flattenable, Unitful.AbstractQuantity, Flatten.IGNORE)
     return Flatten.reconstruct(obj, map(ustrip ∘ normalize_units, values), Unitful.AbstractQuantity, Flatten.IGNORE)
 end
+
+"""
+    ModelParameters.stripparams(::Type{paraType}, obj) where {paraType<:AbstractParam}
+
+Additional dispatch for `stripparams` which allows specification of a specific type.
+"""
+function ModelParameters.stripparams(::Type{paraType}, obj) where {paraType<:AbstractParam}
+    IgnoreTypes = ignored_types(obj)
+    selected_params = Flatten.flatten(obj, Flatten.flattenable, paraType, IgnoreTypes)
+    return Flatten._reconstruct(obj, map(stripparams, selected_params), Flatten.flattenable, paraType, IgnoreTypes, 1)[1]
+end
+
+ignored_types(obj) = ModelParameters.IGNORE
+
 # pretty print Param types
 Base.show(io::IO, mime::MIME"text/plain", p::Param) = print(io, "Param($(p.val))")
 
diff --git a/src/methods.jl b/src/methods.jl
index 0017f588..0900440f 100644
--- a/src/methods.jl
+++ b/src/methods.jl
@@ -233,3 +233,18 @@ thickness(::Layer, state, i) = Δ(state.grid)[i]
 thickness(l::Layer, state, ::typeof(first)) = thickness(l, state, 1)
 thickness(l::Layer, state, ::typeof(last)) = thickness(l, state, lastindex(state.grid)-1)
 thickness(::Union{Top,Bottom}, state) = Inf
+
+
+"""
+    param([::Type{paraType}], defval; kwargs...)
+
+Creates a new parameter type from the given default value and keyword properties.
+"""
+param(::Type{paraType}, defval; kwargs...) where {paraType<:AbstractParam} = paraType(defval; kwargs...)
+function param(defval; kwargs...)
+    if AUTOPARA
+        return param(Param, defval; kwargs...)
+    else
+        return param(FixedParam, defval; kwargs...)
+    end
+end
diff --git a/src/problem.jl b/src/problem.jl
index daf876a3..6c05881f 100644
--- a/src/problem.jl
+++ b/src/problem.jl
@@ -47,8 +47,7 @@ Constructor for `CryoGridProblem` that automatically generates all necessary cal
 function CryoGridProblem(
     tile::Tile,
     u0::ComponentVector,
-    tspan::NTuple{2,Float64},
-    p=nothing;
+    tspan::NTuple{2,Float64};
     diagnostic_stepsize=3600.0,
     saveat=3600.0,
     savevars=(),
@@ -66,16 +65,10 @@ function CryoGridProblem(
 )
     getsavestate(tile::Tile, u, du) = deepcopy(Tiles.getvars(tile.state, Tiles.withaxes(u, tile), Tiles.withaxes(du, tile), savevars...))
     savefunc(u, t, integrator) = getsavestate(Tile(integrator), Tiles.withaxes(u, Tile(integrator)), get_du(integrator))
-    tile, p = if isnothing(p) && isempty(ModelParameters.params(tile))
-        # case 1: no parameters provided
-        tile, nothing
-    else
-        # case 2: parameters are provided; use Model interface to reconstruct Tile with new parameter values
-        model_tile = Model(tile)
-        p = isnothing(p) ? collect(model_tile[:val]) : p
-        model_tile[:val] = p
-        parent(model_tile), p
-    end
+    # strip all "fixed" parameters
+    tile = stripparams(FixedParam, tile)
+    # retrieve variable parameters
+    p = length(ModelParameters.params(tile)) > 0 ? parameters(tile) : nothing
     du0 = zero(u0)
     # remove units
     tile = stripunits(tile)
@@ -142,7 +135,7 @@ function CryoGrid.odefunction(::DefaultJac, setup::typeof(tile), u0, p, tspan)
     # make sure to return an instance of ODEFunction
 end
 ...
-prob = CryoGridProblem(tile, tspan, p)
+prob = CryoGridProblem(tile, u0, tspan)
 ```
 
 `JacobianStyle` can also be extended to create custom traits which can then be applied to compatible `Tile`s.
-- 
GitLab