Skip to content
Snippets Groups Projects
Commit 9cde74bf authored by Brian Groenke's avatar Brian Groenke
Browse files

Revert to forward-diff compatible var cache with modifications

parent 0816dcca
No related branches found
No related tags found
1 merge request!46Refactor parameter and state handling system
......@@ -34,7 +34,7 @@ end
Constructs a `CryoGridSetup` from the given stratigraphy and grid. `arrayproto` keyword arg should be an array instance
(of any arbitrary length, including zero, contents are ignored) that will determine the array type used for all state vectors.
"""
function CryoGridSetup(strat::Stratigraphy, grid::Grid{Edges,<:Numerics.Geometry,<:DistQuantity}; arrayproto::AbstractArray=zeros(), observed::Vector{Symbol}=Symbol[])
function CryoGridSetup(strat::Stratigraphy, grid::Grid{Edges,<:Numerics.Geometry,<:DistQuantity}; arrayproto::AbstractArray=zeros(), observed::Vector{Symbol}=Symbol[], chunksize=nothing)
pvar_arrays = OrderedDict()
param_arrays = OrderedDict()
layer_metas = OrderedDict()
......@@ -60,7 +60,8 @@ function CryoGridSetup(strat::Stratigraphy, grid::Grid{Edges,<:Numerics.Geometry
nparams = (length(meta.paramvars) for meta in nt_meta) |> sum
@assert (npvars + ndvars) > 0 "No variable definitions found. Did you add a method definition for CryoGrid.variables(::L,::P) where {L<:Layer,P<:Process}?"
@assert npvars > 0 "At least one prognostic variable must be specified."
nt_cache = NamedTuple{Tuple(nodenames)}(Tuple(_buildcaches(strat, nt_meta, arrayproto)))
chunksize = isnothing(chunksize) ? nparams : chunksize
nt_cache = NamedTuple{Tuple(nodenames)}(Tuple(_buildcaches(strat, nt_meta, arrayproto, chunksize)))
# construct prototype of u (prognostic state) array (note that this currently performs a copy)
uproto = ComponentArray(nt_prog)
# ditto for parameter array (need a hack here to get an empty ComponentArray...)
......@@ -506,7 +507,7 @@ end
"""
Constructs per-layer variable caches given the Stratigraphy and layer-metadata named tuple.
"""
function _buildcaches(strat, metadata, arrayproto)
function _buildcaches(strat, metadata, arrayproto, chunksize)
map(strat) do node
name = nodename(node)
dvars = metadata[name].diagvars
......@@ -514,32 +515,41 @@ function _buildcaches(strat, metadata, arrayproto)
caches = map(dvars) do dvar
dvarname = varname(dvar)
grid = metadata[name].grids[dvarname]
VarCache(dvarname, grid, arrayproto)
VarCache(dvarname, grid, arrayproto, chunksize)
end
NamedTuple{Tuple(varnames)}(Tuple(caches))
end
end
struct VarCache{name,A}
x::A
function VarCache(name::Symbol, grid::AbstractArray, arrayproto::AbstractArray)
"""
VarCache{name,N,TCache}
Wrapper for `DiffEqBase.DiffCache` that stores state variables in forward-diff compatible cache arrays.
"""
struct VarCache{name,N,TCache}
cache::TCache
function VarCache(name::Symbol, grid::AbstractArray, arrayproto::AbstractArray, chunksize::Int)
# use dual cache for automatic compatibility with ForwardDiff
A = similar(arrayproto, length(grid))
A .= zero(eltype(A))
new{name,typeof(A)}(A)
cache = DiffEqBase.dualcache(A, Val{chunksize})
new{name,chunksize,typeof(cache)}(cache)
end
end
_retrieve(c::VarCache, proto) = 0*proto .+ c.x
retrieve(c::VarCache) = c.x
# dispatches for autodiff types; create a
retrieve(c::VarCache, u::AbstractArray{T}) where {T<:Union{<:ForwardDiff.Dual,<:ReverseDiff.TrackedReal}} = _retrieve(c, similar(u,length(c.x)))
retrieve(c::VarCache, u::ReverseDiff.TrackedArray) = _retrieve(c, similar(u,length(c.x)))
retrieve(c::VarCache, u::AbstractArray{T}) where {T} = retrieve(c)
retrieve(c::VarCache, u::AbstractArray, t) = retrieve(c, u)
Base.show(io::IO, cache::VarCache{name}) where name = print(io, "VarCache{$name} of length $(length(cache.cache.du)) with eltype $(eltype(cache.cache.du))")
Base.show(io::IO, mime::MIME{Symbol("text/plain")}, cache::VarCache{name}) where name = show(io, cache)
# type piracy to reduce clutter in compiled type names
Base.show(io::IO, ::Type{<:VarCache{name}}) where name = print(io, "VarCache{$name}")
# use pre-cached array if chunk size matches
retrieve(varcache::VarCache{name,N}, u::AbstractArray{T}) where {name,tag,U,N,T<:ForwardDiff.Dual{tag,U,N}} = DiffEqBase.get_tmp(varcache.cache, u)
# otherwise just make a new copy with compatible type
retrieve(varcache::VarCache, u::AbstractArray{T}) where {T<:Union{<:ForwardDiff.Dual,<:ReverseDiff.TrackedReal}} = copyto!(similar(u, length(varcache.cache.du)), varcache.cache.du)
retrieve(varcache::VarCache, u::ReverseDiff.TrackedArray) = copyto!(similar(identity.(u), length(varcache.cache.du)), varcache.cache.du)
retrieve(varcache::VarCache, u::AbstractArray{T}) where {T} = reinterpret(T, varcache.cache.du)
# this covers the case for Rosenbrock solvers where only t has differentiable type
function retrieve(c::VarCache, u::AbstractArray, t::T) where {T<:ForwardDiff.Dual}
proto = similar(u, T, length(c.x))
return _retrieve(c, proto)
end
retrieve(varcache::VarCache, u::AbstractArray, t::T) where {T<:ForwardDiff.Dual} = retrieve(varcache, similar(u, T))
retrieve(varcache::VarCache, u::AbstractArray, t) = retrieve(varcache, u)
retrieve(varcache::VarCache) = diffcache.du
# default to doing nothing on non-autodiff writeback
writeback!(c::VarCache, x::AbstractArray) = nothing
writeback!(c::VarCache, x::AbstractArray{T}) where {T<:Union{ForwardDiff.Dual,ReverseDiff.TrackedReal}} = c.x .= Utils.adstrip.(x)
writeback!(varcache::VarCache, x::AbstractArray) = nothing
writeback!(varcache::VarCache, x::AbstractArray{T}) where {T<:Union{<:ForwardDiff.Dual,<:ReverseDiff.TrackedReal}} = varcache.cache.du .= Utils.adstrip.(x)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment