From efaa0493decabb1e3e53d2a827972c63df065732 Mon Sep 17 00:00:00 2001
From: Brian Groenke <brian.groenke@awi.de>
Date: Sat, 22 Jan 2022 16:03:58 +0100
Subject: [PATCH] Fix regression caused by previous merge

Transform parameters were missing the component and fieldname parameters provided
by ModelParameters.

The fix also causes minor breakage in changing the format of collected parameters.
Transforms now must declare an identifier which is used to dereference them from
ParameterVector.
---
 src/Strat/params.jl       | 44 +++++++++++++++++++++++----------------
 test/Strat/param_tests.jl |  6 +++---
 2 files changed, 29 insertions(+), 21 deletions(-)

diff --git a/src/Strat/params.jl b/src/Strat/params.jl
index 772bea29..b45321cb 100644
--- a/src/Strat/params.jl
+++ b/src/Strat/params.jl
@@ -1,4 +1,5 @@
 abstract type ParamTransform end
+shortname(::ParamTransform) = :transform
 # mapping
 struct ParamMapping{T,name,layer}
     transform::T
@@ -28,27 +29,22 @@ Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}}
 Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))")
 ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals)
 
-_paramval(x) = x
-_paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units
 function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...)
-    type2nt(p::Param) = p
-    type2nt(obj) = (; filter(p -> isa(p[2], Param) || !isempty(p[2]), map(n -> Symbol(n) => type2nt(getfield(obj, n)), fieldnames(typeof(obj))))...)
     getparam(x) = x
-    function getparam(p::AbstractVector)
-        # currently, we assume only one variable of each name in each layer;
-        # this could be relaxed in the future but will need to be appropriately handled
-        @assert length(p) == 1 "Found duplicate parameters in a layer: $p; this is not currently supported."
-        return p[1]
-    end
+    getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x))
+    paramval(x) = ustrip(x)
+    paramval(x::Param) = ustrip(x.val)
     m = Model(model)
     nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple)
     mappedparams = nestedparams
     mappings = ParamMapping[]
     for (layer,(var,transform)) in transforms
-        @set! mappedparams[layer][var] = mapflat(getparam, type2nt(transform); maptype=NamedTuple)
+        m_transform = Model(transform)
+        m_transform[:transform] = repeat([shortname(transform)], length(ModelParameters.params(m_transform)))
+        @set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple)
         push!(mappings, ParamMapping(transform, var, layer))
     end
-    mappedarr = ComponentArray(mapflat(_paramval, mappedparams))
+    mappedarr = ComponentArray(mapflat(paramval, mappedparams))
     return ParameterVector(mappedarr, mappedparams, mappings...)
 end
 @inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t)
@@ -96,6 +92,7 @@ Applies a linear trend to a parameter `p` by reparameterizing it as: `p = p₁*t
     minval::Float64 = -Inf
     maxval::Float64 = Inf
 end
+shortname(::LinearTrend) = :trend
 function transform(state, trend::LinearTrend)
     let t = min(state.t - trend.tstart, trend.tstop),
         β = trend.slope / trend.period,
@@ -104,19 +101,29 @@ function transform(state, trend::LinearTrend)
     end
 end
 """
-    PiecewiseLinear{N,Tb,Tv,Tl,I} <: ParamTransform
+Helper type for PiecewiseLinear.
+"""
+struct PiecewiseKnot{Tw,Tv}
+    binwidth::Tw
+    value::Tv
+end
+"""
+    PiecewiseLinear{N,Tw,Tv} <: ParamTransform
 
 Reparameterizes parameter `p` as `p = p₁δ₁t + ⋯ + pₖδₖt` where δₖ are indicators
 for when `tₖ₋₁ <= t <= tₖ`. To facilitate sampling and optimization, change points
 táµ¢ are parameterized as bin widths, which should be strictly positive. `PiecewiseLinear`
 will normalize them and scale by the size of the time interval.
 """
-@with_kw struct PiecewiseLinear{Nb,Tb,Nv,Tv} <: ParamTransform
-    bins::NTuple{Nb,Tb} = (1.0,); @assert Nb > 0; @assert all(bins .> 0.0)
-    values::NTuple{Nv,Tv} = (0.0,); @assert Nv == Nb+1 "need n+1 knots for n bins"
+@with_kw struct PiecewiseLinear{N,Tw,Tv} <: ParamTransform
+    initialvalue::Tv
+    knots::NTuple{N,PiecewiseKnot{Tw,Tv}}
     tstart::Float64 = 0.0; @assert isfinite(tstart)
     tstop::Float64 = 1.0; @assert tstop > tstart; @assert isfinite(tstop)
 end
+PiecewiseLinear(initialvalue::T; tstart=0.0, tstop=1.0) where T = PiecewiseLinear{0,Nothing,T}(initialvalue, (), tstart, tstop)
+PiecewiseLinear(initialvalue, knots...; tstart=0.0, tstop=1.0) = PiecewiseLinear(initialvalue, Tuple(map(Base.splat(PiecewiseKnot), knots)), tstart, tstop)
+shortname(::PiecewiseLinear) = :linear
 function transform(state, pc::PiecewiseLinear)
     function binindex(values::Tuple, st, en, x)
         mid = Int(floor((st + en)/2))
@@ -132,8 +139,9 @@ function transform(state, pc::PiecewiseLinear)
     end
     let tspan = pc.tstop - pc.tstart,
         t = min(max(state.t - pc.tstart, zero(state.t)), tspan),
-        ts = (0.0, cumsum((pc.bins ./ sum(pc.bins)).*tspan)...),
-        vals = pc.values,
+        bins = map(k -> k.binwidth, pc.knots),
+        vals = (pc.initialvalue, map(k -> k.value, pc.knots)...),
+        ts = (0.0, cumsum((bins ./ sum(bins)).*tspan)...),
         i = binindex(ts, 1, length(ts), t);
         vals[i] + (vals[i+1] - vals[i])*(t - ts[i]) / (ts[i+1] - ts[i])
     end
diff --git a/test/Strat/param_tests.jl b/test/Strat/param_tests.jl
index a2d58a0a..d7bd5e37 100644
--- a/test/Strat/param_tests.jl
+++ b/test/Strat/param_tests.jl
@@ -13,14 +13,14 @@ using CryoGrid
        trend = LinearTrend(slope=0.1, intercept=0.5, tstart=0.0, tstop=1.0)
        @test CryoGrid.Strat.transform((t=2.0,), trend) ≈ 0.6
     end
-    @testset "PiecewiseConstant" begin
-        pc = PiecewiseLinear(bins=(1.0,), values=(0.0,1.0), tstart=0.0, tstop=1.0)
+    @testset "PiecewiseLinear" begin
+        pc = PiecewiseLinear(0.0, (1.0,1.0); tstart=0.0, tstop=1.0)
         @test CryoGrid.Strat.transform((t=-0.1,), pc) ≈ 0.0
         @test CryoGrid.Strat.transform((t=0.0,), pc) ≈ 0.0
         @test CryoGrid.Strat.transform((t=0.5,), pc) ≈ 0.5
         @test CryoGrid.Strat.transform((t=1.0,), pc) ≈ 1.0
         @test CryoGrid.Strat.transform((t=1.1,), pc) ≈ 1.0
-        pc = PiecewiseLinear(bins=(0.4,0.6), values=(1.0,0.5,0.0), tstart=0.0, tstop=1.0)
+        pc = PiecewiseLinear(1.0, (0.4,0.5), (0.6,0.0); tstart=0.0, tstop=1.0)
         @test CryoGrid.Strat.transform((t=-0.1,), pc) ≈ 1.0
         @test CryoGrid.Strat.transform((t=0.0,), pc) ≈ 1.0
         @test CryoGrid.Strat.transform((t=0.7,), pc) ≈ 0.25
-- 
GitLab