Skip to content
Snippets Groups Projects
Commit efaa0493 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix regression caused by previous merge

Transform parameters were missing the component and fieldname parameters provided
by ModelParameters.

The fix also causes minor breakage in changing the format of collected parameters.
Transforms now must declare an identifier which is used to dereference them from
ParameterVector.
parent 594252f4
No related branches found
No related tags found
1 merge request!71Fix regression caused by previous merge
abstract type ParamTransform end
shortname(::ParamTransform) = :transform
# mapping
struct ParamMapping{T,name,layer}
transform::T
......@@ -28,27 +29,22 @@ Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}}
Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))")
ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals)
_paramval(x) = x
_paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units
function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...)
type2nt(p::Param) = p
type2nt(obj) = (; filter(p -> isa(p[2], Param) || !isempty(p[2]), map(n -> Symbol(n) => type2nt(getfield(obj, n)), fieldnames(typeof(obj))))...)
getparam(x) = x
function getparam(p::AbstractVector)
# currently, we assume only one variable of each name in each layer;
# this could be relaxed in the future but will need to be appropriately handled
@assert length(p) == 1 "Found duplicate parameters in a layer: $p; this is not currently supported."
return p[1]
end
getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x))
paramval(x) = ustrip(x)
paramval(x::Param) = ustrip(x.val)
m = Model(model)
nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple)
mappedparams = nestedparams
mappings = ParamMapping[]
for (layer,(var,transform)) in transforms
@set! mappedparams[layer][var] = mapflat(getparam, type2nt(transform); maptype=NamedTuple)
m_transform = Model(transform)
m_transform[:transform] = repeat([shortname(transform)], length(ModelParameters.params(m_transform)))
@set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple)
push!(mappings, ParamMapping(transform, var, layer))
end
mappedarr = ComponentArray(mapflat(_paramval, mappedparams))
mappedarr = ComponentArray(mapflat(paramval, mappedparams))
return ParameterVector(mappedarr, mappedparams, mappings...)
end
@inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t)
......@@ -96,6 +92,7 @@ Applies a linear trend to a parameter `p` by reparameterizing it as: `p = p₁*t
minval::Float64 = -Inf
maxval::Float64 = Inf
end
shortname(::LinearTrend) = :trend
function transform(state, trend::LinearTrend)
let t = min(state.t - trend.tstart, trend.tstop),
β = trend.slope / trend.period,
......@@ -104,19 +101,29 @@ function transform(state, trend::LinearTrend)
end
end
"""
PiecewiseLinear{N,Tb,Tv,Tl,I} <: ParamTransform
Helper type for PiecewiseLinear.
"""
struct PiecewiseKnot{Tw,Tv}
binwidth::Tw
value::Tv
end
"""
PiecewiseLinear{N,Tw,Tv} <: ParamTransform
Reparameterizes parameter `p` as `p = p₁δ₁t + ⋯ + pₖδₖt` where δₖ are indicators
for when `tₖ₋₁ <= t <= tₖ`. To facilitate sampling and optimization, change points
tᵢ are parameterized as bin widths, which should be strictly positive. `PiecewiseLinear`
will normalize them and scale by the size of the time interval.
"""
@with_kw struct PiecewiseLinear{Nb,Tb,Nv,Tv} <: ParamTransform
bins::NTuple{Nb,Tb} = (1.0,); @assert Nb > 0; @assert all(bins .> 0.0)
values::NTuple{Nv,Tv} = (0.0,); @assert Nv == Nb+1 "need n+1 knots for n bins"
@with_kw struct PiecewiseLinear{N,Tw,Tv} <: ParamTransform
initialvalue::Tv
knots::NTuple{N,PiecewiseKnot{Tw,Tv}}
tstart::Float64 = 0.0; @assert isfinite(tstart)
tstop::Float64 = 1.0; @assert tstop > tstart; @assert isfinite(tstop)
end
PiecewiseLinear(initialvalue::T; tstart=0.0, tstop=1.0) where T = PiecewiseLinear{0,Nothing,T}(initialvalue, (), tstart, tstop)
PiecewiseLinear(initialvalue, knots...; tstart=0.0, tstop=1.0) = PiecewiseLinear(initialvalue, Tuple(map(Base.splat(PiecewiseKnot), knots)), tstart, tstop)
shortname(::PiecewiseLinear) = :linear
function transform(state, pc::PiecewiseLinear)
function binindex(values::Tuple, st, en, x)
mid = Int(floor((st + en)/2))
......@@ -132,8 +139,9 @@ function transform(state, pc::PiecewiseLinear)
end
let tspan = pc.tstop - pc.tstart,
t = min(max(state.t - pc.tstart, zero(state.t)), tspan),
ts = (0.0, cumsum((pc.bins ./ sum(pc.bins)).*tspan)...),
vals = pc.values,
bins = map(k -> k.binwidth, pc.knots),
vals = (pc.initialvalue, map(k -> k.value, pc.knots)...),
ts = (0.0, cumsum((bins ./ sum(bins)).*tspan)...),
i = binindex(ts, 1, length(ts), t);
vals[i] + (vals[i+1] - vals[i])*(t - ts[i]) / (ts[i+1] - ts[i])
end
......
......@@ -13,14 +13,14 @@ using CryoGrid
trend = LinearTrend(slope=0.1, intercept=0.5, tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=2.0,), trend) 0.6
end
@testset "PiecewiseConstant" begin
pc = PiecewiseLinear(bins=(1.0,), values=(0.0,1.0), tstart=0.0, tstop=1.0)
@testset "PiecewiseLinear" begin
pc = PiecewiseLinear(0.0, (1.0,1.0); tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=-0.1,), pc) 0.0
@test CryoGrid.Strat.transform((t=0.0,), pc) 0.0
@test CryoGrid.Strat.transform((t=0.5,), pc) 0.5
@test CryoGrid.Strat.transform((t=1.0,), pc) 1.0
@test CryoGrid.Strat.transform((t=1.1,), pc) 1.0
pc = PiecewiseLinear(bins=(0.4,0.6), values=(1.0,0.5,0.0), tstart=0.0, tstop=1.0)
pc = PiecewiseLinear(1.0, (0.4,0.5), (0.6,0.0); tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=-0.1,), pc) 1.0
@test CryoGrid.Strat.transform((t=0.0,), pc) 1.0
@test CryoGrid.Strat.transform((t=0.7,), pc) 0.25
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment