diff --git a/src/Strat/params.jl b/src/Strat/params.jl index 772bea2970e7775a8f83118a9c8c8d13e070453d..b45321cb5c5b1dafe92405cf0825fbdb19cad46e 100644 --- a/src/Strat/params.jl +++ b/src/Strat/params.jl @@ -1,4 +1,5 @@ abstract type ParamTransform end +shortname(::ParamTransform) = :transform # mapping struct ParamMapping{T,name,layer} transform::T @@ -28,27 +29,22 @@ Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}} Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))") ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals) -_paramval(x) = x -_paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...) - type2nt(p::Param) = p - type2nt(obj) = (; filter(p -> isa(p[2], Param) || !isempty(p[2]), map(n -> Symbol(n) => type2nt(getfield(obj, n)), fieldnames(typeof(obj))))...) getparam(x) = x - function getparam(p::AbstractVector) - # currently, we assume only one variable of each name in each layer; - # this could be relaxed in the future but will need to be appropriately handled - @assert length(p) == 1 "Found duplicate parameters in a layer: $p; this is not currently supported." - return p[1] - end + getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x)) + paramval(x) = ustrip(x) + paramval(x::Param) = ustrip(x.val) m = Model(model) nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple) mappedparams = nestedparams mappings = ParamMapping[] for (layer,(var,transform)) in transforms - @set! mappedparams[layer][var] = mapflat(getparam, type2nt(transform); maptype=NamedTuple) + m_transform = Model(transform) + m_transform[:transform] = repeat([shortname(transform)], length(ModelParameters.params(m_transform))) + @set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple) push!(mappings, ParamMapping(transform, var, layer)) end - mappedarr = ComponentArray(mapflat(_paramval, mappedparams)) + mappedarr = ComponentArray(mapflat(paramval, mappedparams)) return ParameterVector(mappedarr, mappedparams, mappings...) end @inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t) @@ -96,6 +92,7 @@ Applies a linear trend to a parameter `p` by reparameterizing it as: `p = pâ‚*t minval::Float64 = -Inf maxval::Float64 = Inf end +shortname(::LinearTrend) = :trend function transform(state, trend::LinearTrend) let t = min(state.t - trend.tstart, trend.tstop), β = trend.slope / trend.period, @@ -104,19 +101,29 @@ function transform(state, trend::LinearTrend) end end """ - PiecewiseLinear{N,Tb,Tv,Tl,I} <: ParamTransform +Helper type for PiecewiseLinear. +""" +struct PiecewiseKnot{Tw,Tv} + binwidth::Tw + value::Tv +end +""" + PiecewiseLinear{N,Tw,Tv} <: ParamTransform Reparameterizes parameter `p` as `p = pâ‚δâ‚t + ⋯ + pₖδₖt` where δₖ are indicators for when `tâ‚–â‚‹â‚ <= t <= tâ‚–`. To facilitate sampling and optimization, change points táµ¢ are parameterized as bin widths, which should be strictly positive. `PiecewiseLinear` will normalize them and scale by the size of the time interval. """ -@with_kw struct PiecewiseLinear{Nb,Tb,Nv,Tv} <: ParamTransform - bins::NTuple{Nb,Tb} = (1.0,); @assert Nb > 0; @assert all(bins .> 0.0) - values::NTuple{Nv,Tv} = (0.0,); @assert Nv == Nb+1 "need n+1 knots for n bins" +@with_kw struct PiecewiseLinear{N,Tw,Tv} <: ParamTransform + initialvalue::Tv + knots::NTuple{N,PiecewiseKnot{Tw,Tv}} tstart::Float64 = 0.0; @assert isfinite(tstart) tstop::Float64 = 1.0; @assert tstop > tstart; @assert isfinite(tstop) end +PiecewiseLinear(initialvalue::T; tstart=0.0, tstop=1.0) where T = PiecewiseLinear{0,Nothing,T}(initialvalue, (), tstart, tstop) +PiecewiseLinear(initialvalue, knots...; tstart=0.0, tstop=1.0) = PiecewiseLinear(initialvalue, Tuple(map(Base.splat(PiecewiseKnot), knots)), tstart, tstop) +shortname(::PiecewiseLinear) = :linear function transform(state, pc::PiecewiseLinear) function binindex(values::Tuple, st, en, x) mid = Int(floor((st + en)/2)) @@ -132,8 +139,9 @@ function transform(state, pc::PiecewiseLinear) end let tspan = pc.tstop - pc.tstart, t = min(max(state.t - pc.tstart, zero(state.t)), tspan), - ts = (0.0, cumsum((pc.bins ./ sum(pc.bins)).*tspan)...), - vals = pc.values, + bins = map(k -> k.binwidth, pc.knots), + vals = (pc.initialvalue, map(k -> k.value, pc.knots)...), + ts = (0.0, cumsum((bins ./ sum(bins)).*tspan)...), i = binindex(ts, 1, length(ts), t); vals[i] + (vals[i+1] - vals[i])*(t - ts[i]) / (ts[i+1] - ts[i]) end diff --git a/test/Strat/param_tests.jl b/test/Strat/param_tests.jl index a2d58a0a62acd1c9e85524517c45dabd469b81d3..d7bd5e37e24c582544f9c3be44678076d43d8d08 100644 --- a/test/Strat/param_tests.jl +++ b/test/Strat/param_tests.jl @@ -13,14 +13,14 @@ using CryoGrid trend = LinearTrend(slope=0.1, intercept=0.5, tstart=0.0, tstop=1.0) @test CryoGrid.Strat.transform((t=2.0,), trend) ≈ 0.6 end - @testset "PiecewiseConstant" begin - pc = PiecewiseLinear(bins=(1.0,), values=(0.0,1.0), tstart=0.0, tstop=1.0) + @testset "PiecewiseLinear" begin + pc = PiecewiseLinear(0.0, (1.0,1.0); tstart=0.0, tstop=1.0) @test CryoGrid.Strat.transform((t=-0.1,), pc) ≈ 0.0 @test CryoGrid.Strat.transform((t=0.0,), pc) ≈ 0.0 @test CryoGrid.Strat.transform((t=0.5,), pc) ≈ 0.5 @test CryoGrid.Strat.transform((t=1.0,), pc) ≈ 1.0 @test CryoGrid.Strat.transform((t=1.1,), pc) ≈ 1.0 - pc = PiecewiseLinear(bins=(0.4,0.6), values=(1.0,0.5,0.0), tstart=0.0, tstop=1.0) + pc = PiecewiseLinear(1.0, (0.4,0.5), (0.6,0.0); tstart=0.0, tstop=1.0) @test CryoGrid.Strat.transform((t=-0.1,), pc) ≈ 1.0 @test CryoGrid.Strat.transform((t=0.0,), pc) ≈ 1.0 @test CryoGrid.Strat.transform((t=0.7,), pc) ≈ 0.25