Skip to content
Snippets Groups Projects
Commit e70dc115 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix ParameterVector breaking ModelParameters collection

parent dd66f48b
No related branches found
No related tags found
1 merge request!72Fix ParameterVector breaking ModelParameters collection
......@@ -31,9 +31,10 @@ ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals)
function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTransform}}...)
getparam(x) = x
getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : collect(getparam.(x))
getparam(x::Union{<:AbstractVector,<:Tuple}) = length(x) == 1 ? getparam(x[1]) : Tuple(getparam.(x))
paramval(x) = ustrip(x)
paramval(x::Param) = ustrip(x.val)
paramval(x::Tuple) = collect(x)
m = Model(model)
nestedparams = mapflat(getparam, groupparams(m, :layer, :fieldname); maptype=NamedTuple)
mappedparams = nestedparams
......@@ -44,7 +45,7 @@ function parameters(model::Tile, transforms::Pair{Symbol,<:Pair{Symbol,<:ParamTr
@set! mappedparams[layer][var] = mapflat(getparam, groupparams(m_transform, :transform, :fieldname); maptype=NamedTuple)
push!(mappings, ParamMapping(transform, var, layer))
end
mappedarr = ComponentArray(mapflat(paramval, mappedparams))
mappedarr = ComponentArray(mapflat(paramval, mappedparams; maptype=NamedTuple))
return ParameterVector(mappedarr, mappedparams, mappings...)
end
@inline @generated function updateparams!(v::AbstractVector, model::Tile, u, du, t)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment