Skip to content
Snippets Groups Projects
Commit c1131df0 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix pretty printing of ParameterVector

parent 792402f8
No related branches found
No related tags found
1 merge request!68Minor bug fixes and refactoring
......@@ -4,7 +4,6 @@ struct ParamMapping{T,name,layer}
transform::T
ParamMapping(transform::T, name::Symbol, layer::Symbol) where {T<:ParamTransform} = new{T,name,layer}(transform)
end
struct ParameterVector{T,TV,P,M} <: DenseArray{T,1}
vals::TV # input/reparameterized param vector
params::P # parameters grouped by layer and name
......@@ -24,7 +23,8 @@ Base.getproperty(rv::ParameterVector, sym::Symbol) = getproperty(getfield(rv, :v
Base.getindex(rv::ParameterVector, i) = getfield(rv, :vals)[i]
Base.setproperty!(rv::ParameterVector, val, i) = setproperty!(getfield(rv, :vals), val, sym)
Base.setindex!(rv::ParameterVector, val, i) = setindex!(getfield(rv, :vals), val, i)
Base.show(io, ::MIME"text/plain", rv::ParameterVector) = println(io, getfield(rv, :vals))
Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any,Tuple{}}) where {T} = println(io, "$(length(rv))-element ParameterVector{T}:\n$(getfield(rv, :vals))")
Base.show(io::IO, ::MIME"text/plain", rv::ParameterVector{T,<:Any,<:Any}) where {T} = println(io, "$(length(rv))-element ParameterVector{T} with $(length(mappings(rv))) mappings\n$(mappings(rv)):\n$(getfield(rv, :vals))")
ComponentArrays.ComponentArray(rv::ParameterVector) = getfield(rv, :vals)
_paramval(p::Param) = ustrip(p.val) # extracts value from Param type and strips units
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment