diff --git a/examples/heat_simple_autodiff_grad.jl b/examples/heat_simple_autodiff_grad.jl
index a7318439138f257471a99916cdacbfe76bea79af..3c5d86814234960b7c450cfc1c1d627e9f09058c 100644
--- a/examples/heat_simple_autodiff_grad.jl
+++ b/examples/heat_simple_autodiff_grad.jl
@@ -4,7 +4,6 @@
 #
 # TODO: add more detail/background
 using CryoGrid
-CryoGrid.debug(true)
 
 # Set up forcings and boundary conditions similarly to other examples:
 forcings = loadforcings(CryoGrid.Forcings.Samoylov_ERA_obs_fitted_1979_2014_spinup_extended_2044);
@@ -14,7 +13,7 @@ soilprofile = SoilProfile(0.0u"m" => SimpleSoil(; freezecurve))
 grid = CryoGrid.DefaultGrid_5cm
 initT = initializer(:T, tempprofile)
 tile = CryoGrid.SoilHeatTile(
-    :H,
+    :T,
     TemperatureBC(Input(:Tair), NFactor(nf=Param(0.5), nt=Param(0.9))),
     GeothermalHeatFlux(0.053u"W/m^2"),
     soilprofile,
@@ -22,30 +21,17 @@ tile = CryoGrid.SoilHeatTile(
     initT;
     grid=grid
 )
-tspan = (DateTime(2010,9,1),DateTime(2010,10,1))
+tspan = (DateTime(2010,9,1),DateTime(2011,10,1))
 u0, du0 = @time initialcondition!(tile, tspan);
 
 # We can retrieve the parameters of the system from `tile`:
 para = CryoGrid.parameters(tile)
 
 # Create the `CryoGridProblem`.
-prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0, savevars=(:T,))
-
-function testfunc(prob)
-    function(p)
-        newprob = remake(prob, p=p)
-        sol = solve(newprob, Euler(), dt=300.0)
-        out = CryoGridOutput(sol)
-        y = mean(ustrip.(Array(out.T)))
-        return y
-    end
-end
-
-f = testfunc(prob)
-grad = @time ForwardDiff.gradient(f, vec(prob.p))
+prob = CryoGridProblem(tile, u0, tspan, saveat=3600.0)
 
 # Solve the forward problem with default parameter settings:
-sol = @time solve(prob);
+sol = @time solve(prob)
 out = CryoGridOutput(sol)
 
 # Import relevant packages for automatic differentiation.
diff --git a/src/IO/output.jl b/src/IO/output.jl
index 630f91a17b0ffcc75750ae1f03d00b5618c55132..82d7e32fe8058aafde2720382810cfdaff2fa817 100644
--- a/src/IO/output.jl
+++ b/src/IO/output.jl
@@ -99,10 +99,3 @@ function reset!(cache::SaveCache)
     resize!(cache.t, 0)
     resize!(cache.vals, 0)
 end
-
-struct SaveConfig
-    savevars::Tuple
-    saveat::Vector{Float64}
-    save_start::Bool
-    save_everystep::Bool
-end
diff --git a/src/Solvers/DiffEq/ode_solvers.jl b/src/Solvers/DiffEq/ode_solvers.jl
index f86006dbd99c0143c711737a4e29b62478319465..cd7c4e9d8819a7f284072ea4df01e5bd9bdb24c8 100644
--- a/src/Solvers/DiffEq/ode_solvers.jl
+++ b/src/Solvers/DiffEq/ode_solvers.jl
@@ -81,14 +81,12 @@ function CommonSolve.init(
     kwargs...
 )
     ode_prob = ODEProblem(prob)
-    ode_integrator = init(ode_prob, alg, args...; kwargs...)
-    integrator = CryoGridDiffEqIntegrator(prob, ode_integrator)
-    return integrator
+    integrator = init(ode_prob, alg, args...; kwargs...)
+    return CryoGridDiffEqIntegrator(prob, integrator)
 end
 
 function CommonSolve.step!(integrator::CryoGridDiffEqIntegrator, args...; kwargs...)
-    rv = step!(integrator.integrator, args...; kwargs...)
-    return rv
+    return step!(integrator.integrator, args...; kwargs...)
 end
 
 function CommonSolve.solve!(integrator::CryoGridDiffEqIntegrator)
diff --git a/src/Solvers/basic_solvers.jl b/src/Solvers/basic_solvers.jl
index b7c1e0364c9d6d870597af5aa1685d903706b11e..f93da9b37e754e9b76dc57e5013c30fccce259d2 100644
--- a/src/Solvers/basic_solvers.jl
+++ b/src/Solvers/basic_solvers.jl
@@ -37,7 +37,7 @@ function CommonSolve.init(prob::CryoGridProblem, alg::CGEuler, args...; dt=60.0,
     )
     p = isnothing(prob.p) ? prob.p : collect(prob.p)
     if isnothing(saveat)
-        saveat = prob.savecfg.saveat
+        saveat = prob.saveat
     elseif isa(saveat, Number)
         saveat = collect(prob.tspan[1]:saveat:prob.tspan[2])
     end
diff --git a/src/Solvers/integrator.jl b/src/Solvers/integrator.jl
index de23bec5974c2610d6381b62159d384832fb7f0e..ba7b2bfed2a231ba1431e1767b1e24420d0c011e 100755
--- a/src/Solvers/integrator.jl
+++ b/src/Solvers/integrator.jl
@@ -97,8 +97,7 @@ function CommonSolve.solve!(integrator::CryoGridIntegrator)
         integrator.sol.retcode = ReturnCode.Success
     end
     # if no save points are specified, save final state
-    prob = integrator.sol.prob
-    if isempty(prob.savecfg.saveat)
+    if isempty(integrator.sol.prob.saveat)
         prob.savefunc(integrator.u, integrator.t, integrator)
         push!(integrator.sol.u, integrator.u)
         push!(integrator.sol.t, integrator.t)
@@ -162,7 +161,7 @@ function InputOutput.CryoGridOutput(sol::AbstractCryoGridSolution, tspan::NTuple
     save_interval = ClosedInterval(tspan...)
     tile = Tile(sol.prob.f) # Tile
     grid = Grid(tile.grid.*u"m")
-    savecache = sol.prob.savecache
+    savecache = sol.prob.savefunc.cache
     ts = savecache.t # use save cache time points
     # check if last value is duplicated
     ts = ts[end] == ts[end-1] ? ts[1:end-1] : ts
diff --git a/src/problem.jl b/src/problem.jl
index 8c90f8e357f241c7ee49c23a7917e3fb92a443e9..a144ef8aee3e09514021b0f7ace8b37316928b96 100644
--- a/src/problem.jl
+++ b/src/problem.jl
@@ -3,15 +3,14 @@
 
 Represents a CryoGrid discretized PDE forward model configuration using the `SciMLBase`/`DiffEqBase` problem interface.
 """
-struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
+struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tcb,Tdf,Tkw} <: SciMLBase.AbstractODEProblem{Tu,Tt,iip}
     f::TT
     u0::Tu
     tspan::NTuple{2,Tt}
     p::Tp
     callbacks::Tcb
-    savecfg::Tsv
+    saveat::Tsv
     savefunc::Tsf
-    savecache::Tsc
     isoutofdomain::Tdf
     kwargs::Tkw
     CryoGridProblem{iip}(
@@ -20,13 +19,12 @@ struct CryoGridProblem{iip,Tu,Tt,Tp,TT,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} <: SciMLBase.Abs
         tspan::NTuple{2,Tt},
         p::Tp,
         cbs::Tcb,
-        savecfg::Tsv,
+        saveat::Tsv,
         savefunc::Tsf,
-        savecache::Tsc,
         iood::Tdf,
         kwargs::Tkw
-    ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw} =
-        new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tsc,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, savecfg, savefunc, savecache, iood, kwargs)
+    ) where {iip,TF,Tu,Tt,Tp,Tsv,Tsf,Tcb,Tdf,Tkw} =
+        new{iip,Tu,Tt,Tp,TF,Tsv,Tsf,Tcb,Tdf,Tkw}(f, u0, tspan, p, cbs, saveat, savefunc, iood, kwargs)
 end
 
 """
@@ -83,8 +81,13 @@ function CryoGridProblem(
     du0 = zero(u0)
     # remove units
     tile = stripunits(tile)
+    # set up saving callback
+    saveat = expandtstep(saveat, tspan)
+    savecache = InputOutput.SaveCache(Float64[], [])
+    savefunc = saving_function(savecache, savevars...)
+    savingcallback = FunctionCallingCallback(savefunc; funcat=saveat, func_start=save_start, func_everystep=save_everystep)
     diagnostic_step_callback = PresetTimeCallback(tspan[1]:diagnostic_stepsize:tspan[end], diagnosticstep!)
-    defaultcallbacks = (diagnostic_step_callback,)
+    defaultcallbacks = (savingcallback, diagnostic_step_callback)
     # add step limiter to default callbacks, if defined
     if !isnothing(step_limiter)
         defaultcallbacks = (
@@ -96,18 +99,13 @@ function CryoGridProblem(
     layercallbacks = _makecallbacks(tile)
     # add user callbacks
     usercallbacks = isnothing(callback) ? () : callback
-    callbacks = Callbacks(defaultcallbacks, layercallbacks, usercallbacks)
+    callbacks = CallbackSet(defaultcallbacks..., layercallbacks..., usercallbacks...)
     # build mass matrix
     mass_matrix = Numerics.build_mass_matrix(tile.state)
     # get params
     p = isnothing(p) && !isnothing(tilepara) ? ustrip.(vec(tilepara)) : p
 	func = odefunction(tile, u0, p, tspan; mass_matrix, specialization, function_kwargs...)
-    # set up saving config
-    saveat = expandtstep(saveat, tspan)
-    saveconfig = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
-    savecache = InputOutput.SaveCache(Float64[], [])
-    savefunc = saving_function(savecache, savevars...)
-	return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveconfig, savefunc, savecache, isoutofdomain, prob_kwargs)
+	return CryoGridProblem{true}(func, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, prob_kwargs)
 end
 
 function SciMLBase.remake(
@@ -116,13 +114,11 @@ function SciMLBase.remake(
     u0=nothing,
     tspan=prob.tspan,
     p=prob.p,
-    savevars=prob.savecfg.savevars,
-    saveat=prob.savecfg.saveat,
-    save_start=prob.savecfg.save_start,
-    save_everystep=prob.savecfg.save_everystep,
+    callbacks=prob.callbacks,
+    saveat=prob.saveat,
+    savefunc=prob.savefunc,
     isoutofdomain=prob.isoutofdomain,
     kwargs=prob.kwargs,
-    callbacks=nothing,
 ) where iip
     # always re-run initialcondition! with the given tspan and parameters
     _u0, du0 = initialcondition!(Tile(f), tspan, p)
@@ -133,15 +129,7 @@ function SciMLBase.remake(
     else
         u0 = _u0
     end
-    # create new save cache
-    saveat = expandtstep(saveat, tspan)
-    savecfg = InputOutput.SaveConfig(savevars, saveat, save_start, save_everystep)
-    savecache = InputOutput.SaveCache(Float64[], [])
-    savefunc = saving_function(savecache, savevars...)
-    # rebuild Callbacks struct with new user callbacks if provided
-    callbacks = isnothing(callbacks) ? prob.callbacks : Callbacks(prob.callbacks.default, prob.callbacks.layer, callbacks)
-    # construct new CryoGridProblem
-    return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, savecfg, savefunc, savecache, isoutofdomain, kwargs)
+    return CryoGridProblem{iip}(f, u0, tspan, p, callbacks, saveat, savefunc, isoutofdomain, kwargs)
 end
 
 CommonSolve.init(prob::CryoGridProblem, alg, args...; kwargs...) = error("init not defined for CryoGridProblem with solver $(typeof(alg))")
@@ -151,30 +139,15 @@ function CommonSolve.solve(prob::CryoGridProblem, alg, args...; kwargs...)
     return solve!(integrator)
 end
 
-function SciMLBase.ODEProblem(prob::CryoGridProblem)
-    savingcallback = FunctionCallingCallback(
-        prob.savefunc;
-        funcat=prob.savecfg.saveat,
-        func_start=prob.savecfg.save_start,
-        func_everystep=prob.savecfg.save_everystep
-    )
-    callbacks = CallbackSet(
-        savingcallback,
-        prob.callbacks.default...,
-        prob.callbacks.layer...,
-        prob.callbacks.user...
-    )
-    odeprob = ODEProblem(
-        prob.f,
-        prob.u0,
-        prob.tspan,
-        prob.p;
-        callback=callbacks,
-        isoutofdomain=prob.isoutofdomain,
-        prob.kwargs...
-    )
-    return odeprob
-end
+SciMLBase.ODEProblem(prob::CryoGridProblem) = ODEProblem(
+    prob.f,
+    prob.u0,
+    prob.tspan,
+    prob.p;
+    callback=prob.callbacks,
+    isoutofdomain=prob.isoutofdomain,
+    prob.kwargs...
+)
 
 DiffEqBase.get_concrete_problem(prob::CryoGridProblem, isadapt; kwargs...) = prob
 
@@ -231,14 +204,7 @@ function diagnosticstep!(integrator::SciMLBase.DEIntegrator)
     DiffEqBase.u_modified!(integrator, u_modified)
 end
 
-# Callback utilities
-
-struct Callbacks
-    default
-    layer
-    user
-end
-
+# callback building functions
 function _makecallbacks(tile::Tile)
     eventname(::Event{name}) where name = name
     isgridevent(::GridContinuousEvent) = true