Skip to content
Snippets Groups Projects
Commit 1bbf0d10 authored by Brian Groenke's avatar Brian Groenke
Browse files

Fix type bounds breaking forward diff

parent b3d4fc80
No related branches found
No related tags found
1 merge request!74Fix unnecessary type bounds breaking autodiff
...@@ -12,8 +12,8 @@ struct VarStates{names,griddvars,TU,TD,TV,DF,DG} ...@@ -12,8 +12,8 @@ struct VarStates{names,griddvars,TU,TD,TV,DF,DG}
griddiag::NamedTuple{griddvars,DG} # on-grid non-prognostic variables griddiag::NamedTuple{griddvars,DG} # on-grid non-prognostic variables
gridcache::Dict{ClosedInterval{Int},TD} # grid cache; indices -> subgrid gridcache::Dict{ClosedInterval{Int},TD} # grid cache; indices -> subgrid
end end
@generated function getvar(::Val{name}, vs::VarStates{layers,griddvars,TU,TD,TV}, u::TU, du::Union{Nothing,TU}=nothing) where @generated function getvar(::Val{name}, vs::VarStates{layers,griddvars}, u, du=nothing) where {name,layers,griddvars}
{name,layers,griddvars,T,A,pax,TU<:ComponentVector{T,A,Tuple{Axis{pax}}},TD,TV} pax = ComponentArrays.indexmap(first(ComponentArrays.getaxes(u)))
dnames = map(n -> Symbol(:d,n), keys(pax)) dnames = map(n -> Symbol(:d,n), keys(pax))
if name griddvars if name griddvars
quote quote
......
...@@ -165,15 +165,17 @@ function initialcondition!(soil::Soil, heat::Heat, sfcc::SFCC{F,∇F,<:SFCCPreSo ...@@ -165,15 +165,17 @@ function initialcondition!(soil::Soil, heat::Heat, sfcc::SFCC{F,∇F,<:SFCCPreSo
Hmax = enthalpy(Tmax, C(Tmax), L, θ(Tmax)), Hmax = enthalpy(Tmax, C(Tmax), L, θ(Tmax)),
dH = sfcc.solver.dH, dH = sfcc.solver.dH,
Hs = Hmin:dH:Hmax; Hs = Hmin:dH:Hmax;
θs = [θres] θs = Vector{eltype(state.θl)}(undef, length(Hs))
Ts = [Tmin] θs[1] = θres
for _ in Hs[2:end] Ts = Vector{eltype(state.T)}(undef, length(Hs))
θᵢ = θs[end] Ts[1] = Tmin
Tᵢ = Ts[end] for i in 2:length(Hs)
θᵢ = θs[i-1]
Tᵢ = Ts[i-1]
dTdH = 1.0 / (C(θᵢ) + dθdT(Tᵢ)*(Tᵢ*cw + L)) dTdH = 1.0 / (C(θᵢ) + dθdT(Tᵢ)*(Tᵢ*cw + L))
dθdH = 1.0 / (1.0/dθdT(Tᵢ)*C(θᵢ)+Tᵢ*cw + L) dθdH = 1.0 / (1.0/dθdT(Tᵢ)*C(θᵢ)+Tᵢ*cw + L)
push!(θs, θᵢ + dH*dθdH) θs[i] = θᵢ + dH*dθdH
push!(Ts, Tᵢ + dH*dTdH) Ts[i] = Tᵢ + dH*dTdH
end end
sfcc.solver.cache.f = Interpolations.extrapolate( sfcc.solver.cache.f = Interpolations.extrapolate(
Interpolations.interpolate((Vector(Hs),), θs, Interpolations.Gridded(Interpolations.Linear())), Interpolations.interpolate((Vector(Hs),), θs, Interpolations.Gridded(Interpolations.Linear())),
......
...@@ -24,7 +24,7 @@ function Base.getproperty(state::LayerState, sym::Symbol) ...@@ -24,7 +24,7 @@ function Base.getproperty(state::LayerState, sym::Symbol)
getproperty(getfield(state, :states), sym) getproperty(getfield(state, :states), sym)
end end
end end
@inline function LayerState(vs::VarStates, zs::NTuple{2,Tz}, u, du, t, ::Val{layername}, ::Val{iip}=Val{inplace}()) where {Tz,layername,iip} @inline function LayerState(vs::VarStates, zs::NTuple{2}, u, du, t, ::Val{layername}, ::Val{iip}=Val{inplace}()) where {layername,iip}
z_inds = subgridinds(edges(vs.grid), zs[1]..zs[2]) z_inds = subgridinds(edges(vs.grid), zs[1]..zs[2])
return LayerState( return LayerState(
_makegrids(Val{layername}(), getproperty(vs.vars, layername), vs, z_inds), _makegrids(Val{layername}(), getproperty(vs.vars, layername), vs, z_inds),
...@@ -58,9 +58,15 @@ function Base.getproperty(state::TileState, sym::Symbol) ...@@ -58,9 +58,15 @@ function Base.getproperty(state::TileState, sym::Symbol)
end end
end end
@inline @generated function TileState(vs::VarStates{names}, zs::NTuple, u=copy(vs.uproto), du=similar(vs.uproto), t=0.0, ::Val{iip}=Val{inplace}()) where {names,iip} @inline @generated function TileState(vs::VarStates{names}, zs::NTuple, u=copy(vs.uproto), du=similar(vs.uproto), t=0.0, ::Val{iip}=Val{inplace}()) where {names,iip}
layerstates = (:(LayerState(vs, (ustrip(bounds[$i][1]), ustrip(bounds[$i][2])), u, du, t, Val{$(QuoteNode(names[i]))}(), Val{iip}())) for i in 1:length(names)) layerstates = (
quote
bounds_i = (ustrip(bounds[$i][1]), ustrip(bounds[$i][2]))
LayerState(vs, bounds_i, u, du, t, Val{$(QuoteNode(names[i]))}(), Val{iip}())
end
for i in 1:length(names)
)
quote quote
bounds = boundarypairs(zs, vs.grid[end]) bounds = boundarypairs(zs, convert(eltype(zs), vs.grid[end]))
return TileState( return TileState(
vs.grid, vs.grid,
NamedTuple{tuple($(map(QuoteNode,names)...))}(tuple($(layerstates...))), NamedTuple{tuple($(map(QuoteNode,names)...))}(tuple($(layerstates...))),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment