From e70dc115e1a135f6eb08d5160912aab72a6820f0 Mon Sep 17 00:00:00 2001 From: Brian Groenke <brian.groenke@awi.de> Date: Sat, 22 Jan 2022 21:16:58 +0100 Subject: [PATCH] Fix ParameterVector breaking ModelParameters collection --- src/Strat/params.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Strat/params.jl b/src/Strat/params.jl index b45321cb..2e9db5a7 100644 --- a/src/Strat/params.jl +++ b/src/Strat/params.jl @@ -31,9 +31,10 @@ ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals) function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...) getparam(x) = x - getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x)) + getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : Tuple(getparam.(x)) paramval(x) = ustrip(x) paramval(x::Param) = ustrip(x.val) + paramval(x::Tuple) = collect(x) m = Model(model) nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple) mappedparams = nestedparams @@ -44,7 +45,7 @@ function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTr @set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple) push!(mappings, ParamMapping(transform, var, layer)) end - mappedarr = ComponentArray(mapflat(paramval, mappedparams)) + mappedarr = ComponentArray(mapflat(paramval, mappedparams; maptype=NamedTuple)) return ParameterVector(mappedarr, mappedparams, mappings...) end @inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t) -- GitLab