Skip to content
Snippets Groups Projects
Commit 2f54ba20 authored by Brian Groenke's avatar Brian Groenke
Browse files

Minor update to autodiff example

parent 2df13aa7
No related branches found
No related tags found
No related merge requests found
...@@ -19,11 +19,14 @@ tile = CryoGrid.SoilHeatTile( ...@@ -19,11 +19,14 @@ tile = CryoGrid.SoilHeatTile(
initT; initT;
grid=grid grid=grid
) )
tspan = (DateTime(2010,10,1),DateTime(2010,10,2)) tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
u0, du0 = @time initialcondition!(tile, tspan); u0, du0 = @time initialcondition!(tile, tspan);
# We can retrieve the parameters of the system from `tile`:
para = CryoGrid.parameters(tile)
# Create the `CryoGridProblem`. # Create the `CryoGridProblem`.
prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0); prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
# Solve the forward problem with default parameter settings: # Solve the forward problem with default parameter settings:
sol = @time solve(prob) sol = @time solve(prob)
...@@ -35,20 +38,26 @@ using SciMLSensitivity ...@@ -35,20 +38,26 @@ using SciMLSensitivity
using Zygote using Zygote
# Define a "loss" function; here we'll just take the mean over the final temperature field. # Define a "loss" function; here we'll just take the mean over the final temperature field.
using OrdinaryDiffEq
using Statistics using Statistics
function loss(prob::CryoGridProblem, p) function loss(prob::CryoGridProblem, p)
newprob = remake(prob, p=p) newprob = remake(prob, p=p)
# Here we specify the sensitivity algorithm. Note that this is only
# necessary for reverse-mode autodiff with Zygote.
# autojacvec = true uses ForwardDiff to calculate the jacobian; # autojacvec = true uses ForwardDiff to calculate the jacobian;
# enabling checkpointing (theroetically) reduces the memory cost of the backwards pass. # enabling checkpointing (theoretically) reduces the memory cost of the backwards pass.
sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true) sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true)
newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg); newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg);
newout = CryoGridOutput(newsol) newout = CryoGridOutput(newsol)
return mean(ustrip.(newout.T[:,end])) return mean(ustrip.(newout.T[:,end]))
end end
# Compute gradient with forward diff: # Compute gradient with forward-mode autodiff:
pvec = prob.p pvec = vec(prob.p)
fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec) fd_grad = @time ForwardDiff.gradient(pᵢ -> loss(prob, pᵢ), pvec)
# We can also try with reverse-mode autodiff. This is generally slower for smaller numbers
# of parmaeters (<100) but could be worthwhile for model configurations with high-dimensional
# parameterizations.
zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec) zy_grad = @time Zygote.gradient(pᵢ -> loss(prob, pᵢ), pvec)
@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-4 "Forward and reverse gradients don't match!" @assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-6 "Forward and reverse gradients don't match!"
@show fd_grad
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment