Skip to content
Snippets Groups Projects
Commit 1bbf0d10 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix type bounds breaking forward diff

parent b3d4fc80
No related branches found
No related tags found
1 merge request!74Fix unnecessary type bounds breaking autodiff
......@@ -12,8 +12,8 @@ struct VarStates{names,griddvars,TU,TD,TV,DF,DG}
griddiag::NamedTuple{griddvars,DG} # on-grid non-prognostic variables
gridcache::Dict{ClosedInterval{Int},TD} # grid cache; indices -> subgrid
end
@generated function getvar(::Val{name}, vs::VarStates{layers,griddvars,TU,TD,TV}, u::TU, du::Union{Nothing,TU}=nothing) where
{name,layers,griddvars,T,A,pax,TU<:ComponentVector{T,A,Tuple{Axis{pax}}},TD,TV}
@generated function getvar(::Val{name}, vs::VarStates{layers,griddvars}, u, du=nothing) where {name,layers,griddvars}
pax = ComponentArrays.indexmap(first(ComponentArrays.getaxes(u)))
dnames = map(n -> Symbol(:d,n), keys(pax))
if name griddvars
quote
......
......@@ -165,15 +165,17 @@ function initialcondition!(soil::Soil, heat::Heat, sfcc::SFCC{F,∇F,<:SFCCPreSo
Hmax = enthalpy(Tmax, C(Tmax), L, θ(Tmax)),
dH = sfcc.solver.dH,
Hs = Hmin:dH:Hmax;
θs = [θres]
Ts = [Tmin]
for _ in Hs[2:end]
θᵢ = θs[end]
Tᵢ = Ts[end]
θs = Vector{eltype(state.θl)}(undef, length(Hs))
θs[1] = θres
Ts = Vector{eltype(state.T)}(undef, length(Hs))
Ts[1] = Tmin
for i in 2:length(Hs)
θᵢ = θs[i-1]
Tᵢ = Ts[i-1]
dTdH = 1.0 / (C(θᵢ) + dθdT(Tᵢ)*(Tᵢ*cw + L))
dθdH = 1.0 / (1.0/dθdT(Tᵢ)*C(θᵢ)+Tᵢ*cw + L)
push!(θs, θᵢ + dH*dθdH)
push!(Ts, Tᵢ + dH*dTdH)
θs[i] = θᵢ + dH*dθdH
Ts[i] = Tᵢ + dH*dTdH
end
sfcc.solver.cache.f = Interpolations.extrapolate(
Interpolations.interpolate((Vector(Hs),), θs, Interpolations.Gridded(Interpolations.Linear())),
......
......@@ -24,7 +24,7 @@ function Base.getproperty(state::LayerState, sym::Symbol)
getproperty(getfield(state, :states), sym)
end
end
@inline function LayerState(vs::VarStates, zs::NTuple{2,Tz}, u, du, t, ::Val{layername}, ::Val{iip}=Val{inplace}()) where {Tz,layername,iip}
@inline function LayerState(vs::VarStates, zs::NTuple{2}, u, du, t, ::Val{layername}, ::Val{iip}=Val{inplace}()) where {layername,iip}
z_inds = subgridinds(edges(vs.grid), zs[1]..zs[2])
return LayerState(
_makegrids(Val{layername}(), getproperty(vs.vars, layername), vs, z_inds),
......@@ -58,9 +58,15 @@ function Base.getproperty(state::TileState, sym::Symbol)
end
end
@inline @generated function TileState(vs::VarStates{names}, zs::NTuple, u=copy(vs.uproto), du=similar(vs.uproto), t=0.0, ::Val{iip}=Val{inplace}()) where {names,iip}
layerstates = (:(LayerState(vs, (ustrip(bounds[$i][1]), ustrip(bounds[$i][2])), u, du, t, Val{$(QuoteNode(names[i]))}(), Val{iip}())) for i in 1:length(names))
layerstates = (
quote
bounds_i = (ustrip(bounds[$i][1]), ustrip(bounds[$i][2]))
LayerState(vs, bounds_i, u, du, t, Val{$(QuoteNode(names[i]))}(), Val{iip}())
end
for i in 1:length(names)
)
quote
bounds = boundarypairs(zs, vs.grid[end])
bounds = boundarypairs(zs, convert(eltype(zs), vs.grid[end]))
return TileState(
vs.grid,
NamedTuple{tuple($(map(QuoteNode,names)...))}(tuple($(layerstates...))),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment