Skip to content
Snippets Groups Projects
Commit 792402f8 authored by Brian Groenke's avatar Brian Groenke
Browse files

Improve SFCC tabulation

parent 7ef7f6fb
No related branches found
No related tags found
1 merge request!68Minor bug fixes and refactoring
......@@ -192,6 +192,10 @@ of input values (knots) at which to evaluate the function. `A` may also be a
function tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...)
initknots(a::AbstractArray) = Interpolations.deduplicate_knots!(a)
initknots(x::Number) = initknots([x,x])
interp(::AbstractArray) = Gridded(Linear())
interp(::Number) = Gridded(Constant())
extrap(::AbstractArray) = Flat()
extrap(::Number) = Throw()
names = map(first, argknots)
# get knots for each argument, duplicating if only one value is provided
knots = map(initknots, map(last, argknots))
......@@ -199,8 +203,8 @@ function tabulate(f, argknots::Pair{Symbol,<:Union{Number,AbstractArray}}...)
@assert all(map(name -> name names, f_argnames)) "Missing one or more arguments $f_argnames in $f"
arggrid = Iterators.product(knots...)
# evaluate function construct interpolant
interp = interpolate(Tuple(knots), map(Base.splat(f), arggrid), Gridded(Linear()))
return interp
f = extrapolate(interpolate(Tuple(knots), map(Base.splat(f), arggrid), map(interp last, argknots)), map(extrap last, argknots))
return f
end
function(f::AbstractInterpolation)
gradient(args...) = Interpolations.gradient(f, args...)
......
......@@ -176,7 +176,7 @@ end
Produces an `SFCCTable` function which is a tabulation of `f`.
"""
Numerics.Tabulated(f::SFCCFunction, args...) = SFCCTable(f, Numerics.tabulate(f, args...))
Numerics.Tabulated(f::SFCCFunction, args...; kwargs...) = SFCCTable(f, Numerics.tabulate(f, args...; kwargs...))
"""
SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver())
......@@ -186,7 +186,7 @@ Constructs a SFCC from the precomputed `SFCCTable`. The derivative is generated
function SFCC(f::SFCCTable, s::SFCCSolver=SFCCNewtonSolver())
# we wrap ∇f with Base.splat here to avoid a weird issue with in-place splatting causing allocations
# when applied to runtime generated functions.
SFCC(f, Base.splat((f.f_tab)), s)
SFCC(f, Base.splat(first (f.f_tab)), s)
end
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment