Skip to content
Snippets Groups Projects
Commit bc037435 authored by Brian Groenke's avatar Brian Groenke
Browse files

Disable reverse-mode in autodiff example

parent 0ecc92f3
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,8 @@ using CryoGrid ...@@ -8,7 +8,8 @@ using CryoGrid
# Set up forcings and boundary conditions similarly to other examples: # Set up forcings and boundary conditions similarly to other examples:
forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044); forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
soilprofile, tempprofile = CryoGrid.SamoylovDefault soilprofile, tempprofile = CryoGrid.SamoylovDefault
soilprofile = SoilProfile(0.0u"m" => SimpleSoil()) freezecurve = PainterKarra(swrc=VanGenuchten())
soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve))
grid = CryoGrid.DefaultGrid_5cm grid = CryoGrid.DefaultGrid_5cm
initT = initializer(:T, tempprofile) initT = initializer(:T, tempprofile)
tile = CryoGrid.SoilHeatTile( tile = CryoGrid.SoilHeatTile(
...@@ -43,10 +44,10 @@ using OrdinaryDiffEq ...@@ -43,10 +44,10 @@ using OrdinaryDiffEq
using Statistics using Statistics
function loss(prob::CryoGridProblem, p) function loss(prob::CryoGridProblem, p)
newprob = remake(prob, p=p) newprob = remake(prob, p=p)
# Here we specify the sensitivity algorithm. Note that this is only ## Here we specify the sensitivity algorithm. Note that this is only
# necessary for reverse-mode autodiff with Zygote. ## necessary for reverse-mode autodiff with Zygote.
# autojacvec = true uses ForwardDiff to calculate the jacobian; ## autojacvec = true uses ForwardDiff to calculate the jacobian;
# enabling checkpointing (theoretically) reduces the memory cost of the backwards pass. ## enabling checkpointing (theoretically) reduces the memory cost of the backwards pass.
sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true) sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true)
newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg); newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg);
newout = CryoGridOutput(newsol) newout = CryoGridOutput(newsol)
...@@ -57,8 +58,8 @@ end ...@@ -57,8 +58,8 @@ end
pvec = vec(prob.p) pvec = vec(prob.p)
fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec) fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec)
# We can also try with reverse-mode autodiff. This is generally slower for smaller numbers ## We can also try with reverse-mode autodiff. This is generally slower for smaller numbers
# of parmaeters (<100) but could be worthwhile for model configurations with high-dimensional ## of parmaeters (<100) but could be worthwhile for model configurations with high-dimensional
# parameterizations. ## parameterizations.
zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec) ## zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec)
@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-6 "Forward and reverse gradients don't match!" ## @assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-6 "Forward and reverse gradients don't match!"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment