Skip to content
Snippets Groups Projects
Commit efaa0493 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix regression caused by previous merge

Transform parameters were missing the component and fieldname parameters provided
by ModelParameters.

The fix also causes minor breakage in changing the format of collected parameters.
Transforms now must declare an identifier which is used to dereference them from
ParameterVector.
parent 594252f4
No related branches found
No related tags found
1 merge request!71Fix regression caused by previous merge
abstract type ParamTransform end abstract type ParamTransform end
shortname(::ParamTransform) = :transform
# mapping # mapping
struct ParamMapping{T,name,layer} struct ParamMapping{T,name,layer}
transform::T transform::T
...@@ -28,27 +29,22 @@ Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}} ...@@ -28,27 +29,22 @@ Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}}
Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))") Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))")
ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals) ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals)
_paramval(x) = x
_paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units
function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...) function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...)
type2nt(p::Param) = p
type2nt(obj) = (; filter(p -> isa(p[2], Param) || !isempty(p[2]), map(n -> Symbol(n) => type2nt(getfield(obj, n)), fieldnames(typeof(obj))))...)
getparam(x) = x getparam(x) = x
function getparam(p::AbstractVector) getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x))
# currently, we assume only one variable of each name in each layer; paramval(x) = ustrip(x)
# this could be relaxed in the future but will need to be appropriately handled paramval(x::Param) = ustrip(x.val)
@assert length(p) == 1 "Found duplicate parameters in a layer: $p; this is not currently supported."
return p[1]
end
m = Model(model) m = Model(model)
nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple) nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple)
mappedparams = nestedparams mappedparams = nestedparams
mappings = ParamMapping[] mappings = ParamMapping[]
for (layer,(var,transform)) in transforms for (layer,(var,transform)) in transforms
@set! mappedparams[layer][var] = mapflat(getparam, type2nt(transform); maptype=NamedTuple) m_transform = Model(transform)
m_transform[:transform] = repeat([shortname(transform)], length(ModelParameters.params(m_transform)))
@set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple)
push!(mappings, ParamMapping(transform, var, layer)) push!(mappings, ParamMapping(transform, var, layer))
end end
mappedarr = ComponentArray(mapflat(_paramval, mappedparams)) mappedarr = ComponentArray(mapflat(paramval, mappedparams))
return ParameterVector(mappedarr, mappedparams, mappings...) return ParameterVector(mappedarr, mappedparams, mappings...)
end end
@inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t) @inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t)
...@@ -96,6 +92,7 @@ Applies a linear trend to a parameter `p` by reparameterizing it as: `p = p₁*t ...@@ -96,6 +92,7 @@ Applies a linear trend to a parameter `p` by reparameterizing it as: `p = p₁*t
minval::Float64 = -Inf minval::Float64 = -Inf
maxval::Float64 = Inf maxval::Float64 = Inf
end end
shortname(::LinearTrend) = :trend
function transform(state, trend::LinearTrend) function transform(state, trend::LinearTrend)
let t = min(state.t - trend.tstart, trend.tstop), let t = min(state.t - trend.tstart, trend.tstop),
β = trend.slope / trend.period, β = trend.slope / trend.period,
...@@ -104,19 +101,29 @@ function transform(state, trend::LinearTrend) ...@@ -104,19 +101,29 @@ function transform(state, trend::LinearTrend)
end end
end end
""" """
PiecewiseLinear{N,Tb,Tv,Tl,I} <: ParamTransform Helper type for PiecewiseLinear.
"""
struct PiecewiseKnot{Tw,Tv}
binwidth::Tw
value::Tv
end
"""
PiecewiseLinear{N,Tw,Tv} <: ParamTransform
Reparameterizes parameter `p` as `p = p₁δ₁t + ⋯ + pₖδₖt` where δₖ are indicators Reparameterizes parameter `p` as `p = p₁δ₁t + ⋯ + pₖδₖt` where δₖ are indicators
for when `tₖ₋₁ <= t <= tₖ`. To facilitate sampling and optimization, change points for when `tₖ₋₁ <= t <= tₖ`. To facilitate sampling and optimization, change points
tᵢ are parameterized as bin widths, which should be strictly positive. `PiecewiseLinear` tᵢ are parameterized as bin widths, which should be strictly positive. `PiecewiseLinear`
will normalize them and scale by the size of the time interval. will normalize them and scale by the size of the time interval.
""" """
@with_kw struct PiecewiseLinear{Nb,Tb,Nv,Tv} <: ParamTransform @with_kw struct PiecewiseLinear{N,Tw,Tv} <: ParamTransform
bins::NTuple{Nb,Tb} = (1.0,); @assert Nb > 0; @assert all(bins .> 0.0) initialvalue::Tv
values::NTuple{Nv,Tv} = (0.0,); @assert Nv == Nb+1 "need n+1 knots for n bins" knots::NTuple{N,PiecewiseKnot{Tw,Tv}}
tstart::Float64 = 0.0; @assert isfinite(tstart) tstart::Float64 = 0.0; @assert isfinite(tstart)
tstop::Float64 = 1.0; @assert tstop > tstart; @assert isfinite(tstop) tstop::Float64 = 1.0; @assert tstop > tstart; @assert isfinite(tstop)
end end
PiecewiseLinear(initialvalue::T; tstart=0.0, tstop=1.0) where T = PiecewiseLinear{0,Nothing,T}(initialvalue, (), tstart, tstop)
PiecewiseLinear(initialvalue, knots...; tstart=0.0, tstop=1.0) = PiecewiseLinear(initialvalue, Tuple(map(Base.splat(PiecewiseKnot), knots)), tstart, tstop)
shortname(::PiecewiseLinear) = :linear
function transform(state, pc::PiecewiseLinear) function transform(state, pc::PiecewiseLinear)
function binindex(values::Tuple, st, en, x) function binindex(values::Tuple, st, en, x)
mid = Int(floor((st + en)/2)) mid = Int(floor((st + en)/2))
...@@ -132,8 +139,9 @@ function transform(state, pc::PiecewiseLinear) ...@@ -132,8 +139,9 @@ function transform(state, pc::PiecewiseLinear)
end end
let tspan = pc.tstop - pc.tstart, let tspan = pc.tstop - pc.tstart,
t = min(max(state.t - pc.tstart, zero(state.t)), tspan), t = min(max(state.t - pc.tstart, zero(state.t)), tspan),
ts = (0.0, cumsum((pc.bins ./ sum(pc.bins)).*tspan)...), bins = map(k -> k.binwidth, pc.knots),
vals = pc.values, vals = (pc.initialvalue, map(k -> k.value, pc.knots)...),
ts = (0.0, cumsum((bins ./ sum(bins)).*tspan)...),
i = binindex(ts, 1, length(ts), t); i = binindex(ts, 1, length(ts), t);
vals[i] + (vals[i+1] - vals[i])*(t - ts[i]) / (ts[i+1] - ts[i]) vals[i] + (vals[i+1] - vals[i])*(t - ts[i]) / (ts[i+1] - ts[i])
end end
......
...@@ -13,14 +13,14 @@ using CryoGrid ...@@ -13,14 +13,14 @@ using CryoGrid
trend = LinearTrend(slope=0.1, intercept=0.5, tstart=0.0, tstop=1.0) trend = LinearTrend(slope=0.1, intercept=0.5, tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=2.0,), trend) 0.6 @test CryoGrid.Strat.transform((t=2.0,), trend) 0.6
end end
@testset "PiecewiseConstant" begin @testset "PiecewiseLinear" begin
pc = PiecewiseLinear(bins=(1.0,), values=(0.0,1.0), tstart=0.0, tstop=1.0) pc = PiecewiseLinear(0.0, (1.0,1.0); tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=-0.1,), pc) 0.0 @test CryoGrid.Strat.transform((t=-0.1,), pc) 0.0
@test CryoGrid.Strat.transform((t=0.0,), pc) 0.0 @test CryoGrid.Strat.transform((t=0.0,), pc) 0.0
@test CryoGrid.Strat.transform((t=0.5,), pc) 0.5 @test CryoGrid.Strat.transform((t=0.5,), pc) 0.5
@test CryoGrid.Strat.transform((t=1.0,), pc) 1.0 @test CryoGrid.Strat.transform((t=1.0,), pc) 1.0
@test CryoGrid.Strat.transform((t=1.1,), pc) 1.0 @test CryoGrid.Strat.transform((t=1.1,), pc) 1.0
pc = PiecewiseLinear(bins=(0.4,0.6), values=(1.0,0.5,0.0), tstart=0.0, tstop=1.0) pc = PiecewiseLinear(1.0, (0.4,0.5), (0.6,0.0); tstart=0.0, tstop=1.0)
@test CryoGrid.Strat.transform((t=-0.1,), pc) 1.0 @test CryoGrid.Strat.transform((t=-0.1,), pc) 1.0
@test CryoGrid.Strat.transform((t=0.0,), pc) 1.0 @test CryoGrid.Strat.transform((t=0.0,), pc) 1.0
@test CryoGrid.Strat.transform((t=0.7,), pc) 0.25 @test CryoGrid.Strat.transform((t=0.7,), pc) 0.25
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment