diff --git a/src/Strat/params.jl b/src/Strat/params.jl index 1281473a424ddc7a3f95553f3787fa1e98e706ae..8c998bd19ba0f8f85ae9cb86bf73040684f6c018 100644 --- a/src/Strat/params.jl +++ b/src/Strat/params.jl @@ -4,7 +4,6 @@ struct ParamMapping{T,name,layer} transform::T ParamMapping(transform::T, name::Symbol, layer::Symbol) where {T<:ParamTransform} = new{T,name,layer}(transform) end - struct ParameterVector{T,TV,P,M} <: DenseArray{T,1} vals::TV # input/reparameterized param vector params::P # parameters grouped by layer and name @@ -24,7 +23,8 @@ Base.getproperty(rv::ParameterVector, sym::Symbol) = getproperty(getfield(rv, :v Base.getindex(rv::ParameterVector, i) = getfield(rv, :vals)[i] Base.setproperty!(rv::ParameterVector, val, i) = setproperty!(getfield(rv, :vals), val, sym) Base.setindex!(rv::ParameterVector, val, i) = setindex!(getfield(rv, :vals), val, i) -Base.show(io, ::MIME"text/plain", rv::ParameterVector) = println(io, getfield(rv, :vals)) +Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}}) where {T} = println(io, "$(length(rv))-element ParameterVector{T}:\n$(getfield(rv, :vals))") +Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))") ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals) _paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units