From 2f54ba20c6768112976a2f0b055fae22427c8f6a Mon Sep 17 00:00:00 2001
From: Brian Groenke <brian.groenke@awi.de>
Date: Fri, 6 Dec 2024 17:50:10 +0100
Subject: [PATCH] Minor update to autodiff example

---
 examples/heat_simple_autodiff_grad.jl | 23 ++++++++++++++++-------
 1 file changed, 16 insertions(+), 7 deletions(-)

diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl
index f644a103..9bc97ad5 100644
--- a/examples/heat_simple_autodiff_grad.jl
+++ b/examples/heat_simple_autodiff_grad.jl
@@ -19,11 +19,14 @@ tile = CryoGrid.SoilHeatTile(
     initT;
     grid=grid
 )
-tspan = (DateTime(2010,10,1),DateTime(2010,10,2))
+tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
 u0, du0 = @time initialcondition!(tile, tspan);
 
+# We can retrieve the parameters of the system from `tile`:
+para = CryoGrid.parameters(tile)
+
 # Create the `CryoGridProblem`.
-prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0);
+prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
 
 # Solve the forward problem with default parameter settings:
 sol = @time solve(prob)
@@ -35,20 +38,26 @@ using SciMLSensitivity
 using Zygote
 
 # Define a "loss" function; here we'll just take the mean over the final temperature field.
+using OrdinaryDiffEq
 using Statistics
 function loss(prob::CryoGridProblem, p)
     newprob = remake(prob, p=p)
+    # Here we specify the sensitivity algorithm. Note that this is only
+    # necessary for reverse-mode autodiff with Zygote.
     # autojacvec = true uses ForwardDiff to calculate the jacobian;
-    # enabling checkpointing (theroetically) reduces the memory cost of the backwards pass.
+    # enabling checkpointing (theoretically) reduces the memory cost of the backwards pass.
     sensealg = InterpolatingAdjoint(autojacvec=true, checkpointing=true)
     newsol = solve(newprob, Euler(), dt=300.0, sensealg=sensealg);
     newout = CryoGridOutput(newsol)
     return mean(ustrip.(newout.T[:,end]))
 end
 
-# Compute gradient with forward diff:
-pvec = prob.p
+# Compute gradient with forward-mode autodiff:
+pvec = vec(prob.p)
 fd_grad = @time ForwardDiff.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
+
+# We can also try with reverse-mode autodiff. This is generally slower for smaller numbers
+# of parmaeters (<100) but could be worthwhile for model configurations with high-dimensional
+# parameterizations.
 zy_grad = @time Zygote.gradient(páµ¢ -> loss(prob, páµ¢), pvec)
-@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-4 "Forward and reverse gradients don't match!"
-@show fd_grad
+@assert maximum(abs.(fd_grad .- zy_grad)) .< 1e-6 "Forward and reverse gradients don't match!"
-- 
GitLab