From 4c1ef9a0a877459ba8dfc4987ff34722517e04c3 Mon Sep 17 00:00:00 2001
From: Brian Groenke <brian.groenke@awi.de>
Date: Mon, 9 Dec 2024 17:00:29 +0100
Subject: [PATCH] Update parameter ensemble script

---
 examples/cglite_parameter_ensembles.jl | 25 +++++++++++++------------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/examples/cglite_parameter_ensembles.jl b/examples/cglite_parameter_ensembles.jl
index 4c6c40e5..bac784e6 100644
--- a/examples/cglite_parameter_ensembles.jl
+++ b/examples/cglite_parameter_ensembles.jl
@@ -14,7 +14,9 @@ if Threads.nthreads() == 1
     @warn "Only one thread is available. Ensemble execution will run sequentially. Did you start julia with `--threads=auto` ?"
 end
 
-# Load forcings and build stratigraphy like before.
+# Load forcings and build stratigraphy like before, except this time we assign
+# `Param` values to the quantiies which we want to vary in the ensemble. Here
+# we vary the porosity in each layer as well as the n-factors.
 forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_MkL3_CCSM4_long_term);
 soilprofile = SoilProfile(
     0.0u"m" => SimpleSoil(por=Param(0.80, prior=Uniform(0.65,0.95)),sat=1.0,org=0.75),
@@ -27,7 +29,7 @@ soilprofile = SoilProfile(
 z_top = -2.0u"m"
 z_bot = 1000.0u"m"
 upperbc = TemperatureBC(
-    forcings.Tair,
+    Input(:Tair),
     NFactor(
         nf=Param(0.5, prior=Beta(1,1)),
         nt=Param(0.9, prior=Beta(1,1)),
@@ -43,27 +45,25 @@ strat = Stratigraphy(
     z_bot => Bottom(GeothermalHeatFlux(0.053u"W/m^2"))
 );
 modelgrid = CryoGrid.DefaultGrid_2cm
-tile = Tile(strat, modelgrid, ssinit);
+tile = Tile(strat, modelgrid, forcings, ssinit);
 # Since the solver can take daily timesteps, we can easily specify longer simulation time spans at minimal cost.
 # Here we specify a time span of 10 years.
 tspan = (DateTime(2000,1,1), DateTime(2010,12,31))
 u0, du0 = initialcondition!(tile, tspan);
 prob = CryoGridProblem(tile, u0, tspan, saveat=24*3600.0, savevars=(:T,))
 
+# Here we retrieve the `CryoGridParams` from the `CryoGridProblem` constructed above.
 # The CryoGridParams type behaves like a table and can be easily converted
 # to a DataFrame with DataFrame(params) when DataFrames.jl is loaded.
-params = CryoGrid.parameters(tile)
+params = prob.p
 
-# you can use Julia's `vec` method to convert `CryoGridParams` into a `ComponentVector`
+# Note that you can also use Julia's `vec` method to convert `CryoGridParams` into a `ComponentVector` with labels.
 p0 = vec(params)
 
-# extract prior distributions and collect them into a multivariate Product distribution;
+# Here we extract prior distributions and collect them into a multivariate Product distribution;
 # note that this assumes each parameter to be independent from the others
 prior = Product(collect(params[:prior]))
 
-# declare 
-const rng = Random.MersenneTwister(1234)
-
 # Method 1: SciML EnsembleProblem
 
 function make_prob_func(ensmeble::AbstractMatrix)
@@ -77,11 +77,12 @@ function output_func(sol, i)
     return CryoGridOutput(sol), false
 end
 
-# sample parameter values from prior with fixed RNG;
+# Now we sample parameter values from prior with fixed RNG;
 # the number of samples determines the size of the ensemble
+const rng = Random.MersenneTwister(1234)
 prior_ensemble = rand(rng, prior, 64)
 
-# create EnsembleProblem from CryoGridProblem and prob/output functions;
+# We create an `EnsembleProblem` from `CryoGridProblem` and prob/output functions;
 # note that we use safetycopy=true because we're using multithreading;
 # this prevents the different threads from using the same state caches
 prob_func = make_prob_func(prior_ensemble)
@@ -90,7 +91,7 @@ ensprob = EnsembleProblem(prob; prob_func, output_func, safetycopy=true)
 # alternatively, one can specify EnsembleDistributed() for process or slurm parallelization or EnsembleSerial() for sequential execution.
 enssol = @time solve(ensprob, LiteImplicitEuler(), EnsembleThreads(), trajectories=size(prior_ensemble,2))
 
-# extract permafrost temperatures at 20m depth and plot the ensemble
+# Now we will extract permafrost temperatures at 20m depth and plot the ensemble.
 T20m_ens = reduce(hcat, map(out -> out.T[Z(Near(20.0u"m"))], enssol))
 Plots.plot(T20m_ens, leg=nothing, c=:black, alpha=0.5, ylabel="Permafrost temperature")
 
-- 
GitLab