Skip to content
Snippets Groups Projects
Commit cfdbe015 authored by Brian Groenke's avatar Brian Groenke
Browse files

Merge branch 'bugfix/initial-params' into 'master'

Fix param handling bug in initialization

See merge request sparcs/cryogrid/cryogridjulia!59
parents 9ea13637 43763dad
No related branches found
No related tags found
1 merge request!59Fix param handling bug in initialization
name = "CryoGrid"
uuid = "a535b82e-5f3d-4d97-8b0b-d6483f5bebd5"
authors = ["Brian Groenke <brian.groenke@awi.de>", "Moritz Langer <moritz.langer@awi.de>"]
version = "0.5.5"
version = "0.5.6"
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
......
......@@ -12,24 +12,24 @@ out-of-place (copying arrays).
"""
abstract type AbstractTile{iip} end
"""
(model::AbstractTile{inplace})(du,u,p,t)
(model::AbstractTile{ooplace})(u,p,t)
(tile::AbstractTile{inplace})(du,u,p,t)
(tile::AbstractTile{ooplace})(u,p,t)
Invokes the corresponding `step` function to compute the time derivative du/dt.
"""
(model::AbstractTile{inplace})(du,u,p,t) = step!(model,du,u,p,t)
(model::AbstractTile{ooplace})(u,p,t) = step(model,u,p,t)
(tile::AbstractTile{inplace})(du,u,p,t) = step!(tile,du,u,p,t)
(tile::AbstractTile{ooplace})(u,p,t) = step(tile,u,p,t)
"""
step!(::T,du,u,p,t) where {T<:AbstractTile}
In-place step function for model `T`. Computes du/dt and stores the result in `du`.
In-place step function for tile `T`. Computes du/dt and stores the result in `du`.
"""
step!(::T,du,u,p,t) where {T<:AbstractTile} = error("no implementation of in-place step! for $T")
"""
step(::T,u,p,t) where {T<:AbstractTile}
Out-of-place step function for model `T`. Computes and returns du/dt as vector with same size as `u`.
Out-of-place step function for tile `T`. Computes and returns du/dt as vector with same size as `u`.
"""
step(::T,u,p,t) where {T<:AbstractTile} = error("no implementation of out-of-place step for $T")
......@@ -57,7 +57,7 @@ end
ConstructionBase.constructorof(::Type{Tile{TStrat,TGrid,TStates,iip,obsv}}) where {TStrat,TGrid,TStates,iip,obsv} =
(strat, grid, state, hist) -> Tile(strat,grid,state,hist,iip,length(obsv) > 0 ? collect(obsv) : Symbol[])
Base.show(io::IO, ::MIME"text/plain", model::Tile{TStrat,TGrid,TStates,iip,obsv}) where {TStrat,TGrid,TStates,iip,obsv} = print(io, "Tile ($iip) with layers $(map(componentname, components(model.strat))), observables=$obsv, $TGrid, $TStrat")
Base.show(io::IO, ::MIME"text/plain", tile::Tile{TStrat,TGrid,TStates,iip,obsv}) where {TStrat,TGrid,TStates,iip,obsv} = print(io, "Tile ($iip) with layers $(map(componentname, components(tile.strat))), observables=$obsv, $TGrid, $TStrat")
"""
Constructs a `Tile` from the given stratigraphy and grid. `arrayproto` keyword arg should be an array instance
......@@ -119,18 +119,18 @@ prognosticstep!(layer i, ...)
Note for developers: All sections of code wrapped in quote..end blocks are generated. Code outside of quote blocks
is only executed during compilation and will not appear in the compiled version.
"""
@generated function step!(model::Tile{TStrat,TGrid,TStates,inplace,obsv}, _du,_u,_p,t) where {TStrat,TGrid,TStates,obsv}
@generated function step!(tile::Tile{TStrat,TGrid,TStates,inplace,obsv}, _du,_u,_p,t) where {TStrat,TGrid,TStates,obsv}
nodetyps = componenttypes(TStrat)
N = length(nodetyps)
expr = Expr(:block)
# Declare variables
@>> quote
p = updateparams!(_p, model, _du, _u, t)
strat = Flatten.reconstruct(model.strat, p, ModelParameters.SELECT, ModelParameters.IGNORE)
p = updateparams!(_p, tile, _du, _u, t)
strat = Flatten.reconstruct(tile.strat, p, ModelParameters.SELECT, ModelParameters.IGNORE)
_du .= zero(eltype(_du))
du = ComponentArray(_du, getaxes(model.state.uproto))
u = ComponentArray(_u, getaxes(model.state.uproto))
state = TileState(model.state, boundaries(strat), u, du, t, Val{inplace}())
du = ComponentArray(_du, getaxes(tile.state.uproto))
u = ComponentArray(_u, getaxes(tile.state.uproto))
state = TileState(tile.state, boundaries(strat), u, du, t, Val{inplace}())
end push!(expr.args)
# Initialize variables for all layers
for i in 1:N
......@@ -193,22 +193,23 @@ is only executed during compilation and will not appear in the compiled version.
return expr
end
"""
initialcondition!(model::Tile, tspan::NTuple{2,Float64}, p::AbstractVector, initializers::VarInit...)
initialcondition!(model::Tile, tspan::NTuple{2,DateTime}, p::AbstractVector, initializers::VarInit...)
initialcondition!(tile::Tile, tspan::NTuple{2,Float64}, p::AbstractVector, initializers::VarInit...)
initialcondition!(tile::Tile, tspan::NTuple{2,DateTime}, p::AbstractVector, initializers::VarInit...)
Calls `initialcondition!` on all layers/processes and returns the fully constructed u0 and du0 states.
"""
initialcondition!(model::Tile, tspan::NTuple{2,DateTime}, p::AbstractVector, args...) = initialcondition!(model, convert_tspan(tspan), p, args...)
@generated function initialcondition!(model::Tile{TStrat,TGrid,TStates,iip,obsv}, tspan::NTuple{2,Float64}, p::AbstractVector, initializers::Numerics.VarInit...) where {TStrat,TGrid,TStates,iip,obsv}
initialcondition!(tile::Tile, tspan::NTuple{2,DateTime}, _p::AbstractVector, args...) = initialcondition!(tile, convert_tspan(tspan), _p, args...)
@generated function initialcondition!(tile::Tile{TStrat,TGrid,TStates,iip,obsv}, tspan::NTuple{2,Float64}, _p::AbstractVector, initializers::Numerics.VarInit...) where {TStrat,TGrid,TStates,iip,obsv}
nodetyps = componenttypes(TStrat)
N = length(nodetyps)
expr = Expr(:block)
# Declare variables
@>> quote
du = zero(similar(model.state.uproto, eltype(p)))
u = zero(similar(model.state.uproto, eltype(p)))
strat = Flatten.reconstruct(model.strat, p, ModelParameters.SELECT, ModelParameters.IGNORE)
state = TileState(model.state, boundaries(strat), u, du, tspan[1], Val{iip}())
du = zero(similar(tile.state.uproto, eltype(_p)))
u = zero(similar(tile.state.uproto, eltype(_p)))
p = updateparams!(_p, tile, du, u, tspan[1])
strat = Flatten.reconstruct(tile.strat, p, ModelParameters.SELECT, ModelParameters.IGNORE)
state = TileState(tile.state, boundaries(strat), u, du, tspan[1], Val{iip}())
end push!(expr.args)
# Call initializers
for i in 1:N
......@@ -266,56 +267,56 @@ Calls the initializer for state variable `varname`.
initvar!(state::LayerState, ::Stratigraphy, init::Numerics.VarInit{varname}) where {varname} = init!(state[varname], init)
initvar!(state::LayerState, ::Stratigraphy, init::Numerics.InterpInit{varname}) where {varname} = init!(state[varname], init, state.grids[varname])
"""
getvar(name::Symbol, model::Tile, u)
getvar(::Val{name}, model::Tile, u)
getvar(name::Symbol, tile::Tile, u)
getvar(::Val{name}, tile::Tile, u)
Retrieves the (diagnostic or prognostic) grid variable from `model` given prognostic state `u`.
If `name` is not a variable in the model, or if it is not a grid variable, `nothing` is returned.
Retrieves the (diagnostic or prognostic) grid variable from `tile` given prognostic state `u`.
If `name` is not a variable in the tile, or if it is not a grid variable, `nothing` is returned.
"""
Numerics.getvar(name::Symbol, model::Tile, u) = getvar(Val{name}(), model, u)
Numerics.getvar(::Val{name}, model::Tile, u) where name = getvar(Val{name}(), model.state, withaxes(u, model))
Numerics.getvar(name::Symbol, tile::Tile, u) = getvar(Val{name}(), tile, u)
Numerics.getvar(::Val{name}, tile::Tile, u) where name = getvar(Val{name}(), tile.state, withaxes(u, tile))
"""
getstate(layername::Symbol, model::Tile, u, du, t)
getstate(::Val{layername}, model::Tile{TStrat,TGrid,<:VarStates{layernames},iip}, _u, _du, t)
getstate(layername::Symbol, tile::Tile, u, du, t)
getstate(::Val{layername}, tile::Tile{TStrat,TGrid,<:VarStates{layernames},iip}, _u, _du, t)
Constructs a `LayerState` representing the full state of `layername` given `model`, state vectors `u` and `du`, and the
Constructs a `LayerState` representing the full state of `layername` given `tile`, state vectors `u` and `du`, and the
time step `t`.
"""
getstate(layername::Symbol, model::Tile, u, du, t) = getstate(Val{layername}(), model, u, du, t)
function getstate(::Val{layername}, model::Tile{TStrat,TGrid,<:VarStates{layernames},iip}, _u, _du, t) where {layername,TStrat,TGrid,iip,layernames}
du = ComponentArray(_du, getaxes(model.state.uproto))
u = ComponentArray(_u, getaxes(model.state.uproto))
getstate(layername::Symbol, tile::Tile, u, du, t) = getstate(Val{layername}(), tile, u, du, t)
function getstate(::Val{layername}, tile::Tile{TStrat,TGrid,<:VarStates{layernames},iip}, _u, _du, t) where {layername,TStrat,TGrid,iip,layernames}
du = ComponentArray(_du, getaxes(tile.state.uproto))
u = ComponentArray(_u, getaxes(tile.state.uproto))
i = 1
for j in 1:length(model.strat)
for j in 1:length(tile.strat)
if layernames[j] == layername
i = j
break
end
end
z = boundarypairs(map(ustrip, stripparams(boundaries(model.strat))), ustrip(model.grid[end]))[i]
return LayerState(model.state, z, u, du, t, Val{layername}(), Val{iip}())
z = boundarypairs(map(ustrip, stripparams(boundaries(tile.strat))), ustrip(tile.grid[end]))[i]
return LayerState(tile.state, z, u, du, t, Val{layername}(), Val{iip}())
end
"""
variables(model::Tile)
variables(tile::Tile)
Returns a tuple of all variables defined in the model.
Returns a tuple of all variables defined in the tile.
"""
variables(model::Tile) = Tuple(unique(Flatten.flatten(model.state.vars, Flatten.flattenable, Var)))
variables(tile::Tile) = Tuple(unique(Flatten.flatten(tile.state.vars, Flatten.flattenable, Var)))
"""
withaxes(u::AbstractArray, ::Tile)
Constructs a `ComponentArray` with labeled axes from the given state vector `u`. Assumes `u` to be of the same type/shape
as `setup.uproto`.
"""
withaxes(u::AbstractArray, model::Tile) = ComponentArray(u, getaxes(model.state.uproto))
withaxes(u::AbstractArray, tile::Tile) = ComponentArray(u, getaxes(tile.state.uproto))
withaxes(u::ComponentArray, ::Tile) = u
"""
Gets the
"""
function getstate(model::Tile{TStrat,TGrid,TStates,iip}, _u, _du, t) where {TStrat,TGrid,TStates,iip}
du = ComponentArray(_du, getaxes(model.state.uproto))
u = ComponentArray(_u, getaxes(model.state.uproto))
return TileState(model.strat, model.state, u, du, t, Val{iip}())
function getstate(tile::Tile{TStrat,TGrid,TStates,iip}, _u, _du, t) where {TStrat,TGrid,TStates,iip}
du = ComponentArray(_du, getaxes(tile.state.uproto))
u = ComponentArray(_u, getaxes(tile.state.uproto))
return TileState(tile.strat, tile.state, u, du, t, Val{iip}())
end
"""
Collects and validates all declared variables (`Var`s) for the given strat component.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment