Skip to content
Snippets Groups Projects
Commit 2f54ba20 authored by Brian Groenke's avatar Brian Groenke
Browse files

Minor update to autodiff example

parent 2df13aa7
No related branches found
No related tags found
No related merge requests found
......@@ -19,11 +19,14 @@ tile = CryoGrid.SoilHeatTile(
initT;
grid=grid
)
tspan = (DateTime(2010,10,1),DateTime(2010,10,2))
tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
u0, du0 = @time initialcondition!(tile, tspan);
# We can retrieve the parameters of the system from `tile`:
para = CryoGrid.parameters(tile)
# Create the `CryoGridProblem`.
prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0);
prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
# Solve the forward problem with default parameter settings:
sol = @time solve(prob)
......@@ -35,20 +38,26 @@ using SciMLSensitivity
using Zygote
# Define a "loss" function; here we'll just take the mean over the final temperature field.
using OrdinaryDiffEq
using Statistics
function loss(prob::CryoGridProblem, p)
newprob = remake(prob, p=p)
# Here we specify the sensitivity algorithm. Note that this is only
# necessary for reverse-mode autodiff with Zygote.
# autojacvec = true uses ForwardDiff to calculate the jacobian;
# enabling checkpointing (theroetically) reduces the memory cost of the backwards pass.
# enabling checkpointing (theoretically) reduces the memory cost of the backwards pass.
sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true)
newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg);
newout = CryoGridOutput(newsol)
return mean(ustrip.(newout.T[:,end]))
end
# Compute gradient with forward diff:
pvec = prob.p
# Compute gradient with forward-mode autodiff:
pvec = vec(prob.p)
fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec)
# We can also try with reverse-mode autodiff. This is generally slower for smaller numbers
# of parmaeters (<100) but could be worthwhile for model configurations with high-dimensional
# parameterizations.
zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec)
@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-4 "Forward and reverse gradients don't match!"
@show fd_grad
@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-6 "Forward and reverse gradients don't match!"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment