Skip to content

Commit

Permalink
Merge pull request #97 from cgeoga/enzyme
Browse files Browse the repository at this point in the history
Add `EnzymeCore` weakdep and an extension with a custom rule for the Levin transformation
  • Loading branch information
heltonmc authored May 18, 2023
2 parents e48c6b4 + 27d257f commit caf2c2c
Show file tree
Hide file tree
Showing 15 changed files with 1,127 additions and 117 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
matrix:
version:
- '1'
- '1.8'
- '1.9'
- 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021-2022 Michael Helton, Oscar Smith, and contributors
Copyright (c) 2021-2023 Michael Helton, Oscar Smith, Chris Geoga, and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ version = "0.3.0-DEV"
SIMDMath = "5443be0b-e40a-4f70-a07e-dcd652efc383"

[compat]
julia = "1.8"
SIMDMath = "0.2.5"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
BesselsEnzymeCoreExt = "EnzymeCore"

[targets]
test = ["Test"]
34 changes: 34 additions & 0 deletions ext/BesselsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module BesselsEnzymeCoreExt

using Bessels, EnzymeCore
using EnzymeCore.EnzymeRules
using Bessels.Math

# A manual method that separately transforms the `val` and `dval`, because
# sometimes the `val` can converge while the `dval` hasn't, so just using an
# early return or something can give incorrect derivatives in edge cases.
function EnzymeRules.forward(func::Const{typeof(levin_transform)},
::Type{<:Duplicated},
s::Duplicated,
w::Duplicated)
(sv, dv, N) = (s.val, s.dval, length(s.val))
ls = levin_transform(sv, w.val)
dls = levin_transform(dv, w.dval)
Duplicated(ls, dls)
end

# This is fixing a straight bug in Enzyme.
function EnzymeRules.forward(func::Const{typeof(sinpi)},
::Type{<:Duplicated},
x::Duplicated)
(sp, cp) = sincospi(x.val)
Duplicated(sp, pi*cp*x.dval)
end

function EnzymeRules.forward(func::Const{typeof(sinpi)},
::Type{<:Const},
x::Const)
sinpi(x.val)
end

end
82 changes: 74 additions & 8 deletions src/BesselFunctions/besselk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ besselk_power_series(v, x::Float32) = Float32(besselk_power_series(v, Float64(x)
besselk_power_series(v, x::ComplexF32) = ComplexF32(besselk_power_series(v, ComplexF64(x)))

function besselk_power_series(v, x::ComplexOrReal{T}) where T
Math.isnearint(v) && return besselk_power_series_int(v, x)
MaxIter = 1000
S = eltype(x)
v, x = S(v), S(x)
Expand All @@ -512,7 +513,8 @@ function besselk_power_series(v, x::ComplexOrReal{T}) where T
# use the reflection identify to calculate gamma(-v)
# use relation gamma(v)*v = gamma(v+1) to avoid two gamma calls
gam_v = gamma(v)
gam_nv = π / (sinpi(-abs(v)) * gam_v * v)
#gam_nv = π / (sin(-pi*abs(v)) * gam_v * v) # not using sinpi here to avoid Enzyme bug
gam_nv = π / (sinpi(-abs(v)) * gam_v * v) # not using sinpi here to avoid Enzyme bug
gam_1mv = -gam_nv * v
gam_1mnv = gam_v * v

Expand Down Expand Up @@ -578,15 +580,16 @@ end
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N}
:(
begin
s_0 = zero(T)
s = zero(T)
t = one(T)
@nexprs $N i -> begin
s_{i} = s_{i-1} + t
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
w_{i} = 1 / t
end
sequence = @ntuple $N i -> s_{i}
weights = @ntuple $N i -> w_{i}
s += t
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
s_{i} = s
w_{i} = t
end
sequence = @ntuple $N i -> s_{i}
weights = @ntuple $N i -> w_{i}
return levin_transform(sequence, weights) * sqrt/ 2x)
end
)
Expand Down Expand Up @@ -614,3 +617,66 @@ end
end
)
end

# This is a version of Temme's proposed f_0 (1975 JCP, see reference above) that
# swaps in a bunch of local expansions for functions that are well-behaved but
# whose standard forms can't be naively evaluated by a computer at the origin.
@inline function f0_local_expansion_v0(v, x)
l2dx = log(2/x)
mu = v*l2dx
vv = v*v
sp = evalpoly(vv, (1.0, 1.6449340668482264, 1.8940656589944918, 1.9711021825948702))
g1 = evalpoly(vv, (-0.5772156649015329, 0.04200263503409518, 0.042197734555544306))
g2 = evalpoly(vv, (1.0, -0.6558780715202539, 0.16653861138229145))
sh = evalpoly(mu*mu, (1.0, 0.16666666666666666, 0.008333333333333333, 0.0001984126984126984, 2.7557319223985893e-6))
sp*(g1*cosh(mu) + g2*sh*l2dx)
end

# This function assumes |v|<1e-5!
function besselk_power_series_temme_basal(v::V, x::X) where{V,X}
max_iter = 50
T = promote_type(V,X)
z = x/2
zz = z*z
fk = f0_local_expansion_v0(v,x)
zv = z^v
znv = inv(zv)
gam_1_c = (1.0, -0.5772156649015329, 0.9890559953279725, -0.23263776388631713)
gam_1pv = evalpoly(v, gam_1_c)
gam_1nv = evalpoly(-v, gam_1_c)
(pk, qk, _ck, factk, vv) = (znv*gam_1pv/2, zv*gam_1nv/2, one(T), one(T), v*v)
(out_v, out_vp1) = (zero(T), zero(T))
for k in 1:max_iter
# add to the series:
ck = _ck/factk
term_v = ck*fk
term_vp1 = ck*(pk - (k-1)*fk)
out_v += term_v
out_vp1 += term_vp1
# check for convergence:
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break
# otherwise, increment new quantities:
fk = (k*fk + pk + qk)/(k^2 - vv)
pk /= (k-v)
qk /= (k+v)
_ck *= zz
factk *= k
end
(out_v, out_vp1/z)
end

function besselk_power_series_int(v, x::Float64)
v = abs(v)
(_v, flv) = modf(v)
if _v > 1/2
(_v, flv) = (_v-one(_v), flv+1)
end
(kv, kvp1) = besselk_power_series_temme_basal(_v, x)
twodx = 2/x
for _ in 1:flv
_v += 1
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv))
end
kv
end

2 changes: 2 additions & 0 deletions src/GammaFunctions/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,5 @@ function gamma(n::Integer)
n > 20 && return gamma(float(n))
@inbounds return Float64(factorial(n-1))
end

gamma_near_1(x) = evalpoly(x-one(x), (1.0, -0.5772156649015329, 0.9890559953279725, -0.23263776388631713))
8 changes: 6 additions & 2 deletions src/Math/Math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ end
#@inline levin_scale(B::T, n, k) where T = -(B + n) * (B + n + k)^(k - one(T)) / (B + n + k + one(T))^k
@inline levin_scale(B::T, n, k) where T = -(B + n + k) * (B + n + k - 1) / ((B + n + 2k) * (B + n + 2k - 1))

@inline @generated function levin_transform(s::NTuple{N, T}, w::NTuple{N, T}) where {N, T <: FloatTypes}
@inline @generated function levin_transform(s::NTuple{N, T},
w::NTuple{N, T}) where {N, T <: FloatTypes}
len = N - 1
:(
begin
@nexprs $N i -> a_{i} = Vec{2, T}((s[i] * w[i], w[i]))
@nexprs $N i -> a_{i} = iszero(w[i]) ? (return s[i]) : Vec{2, T}((s[i] / w[i], 1 / w[i]))
@nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
return (a_1[1] / a_1[2])
end
Expand All @@ -153,4 +154,7 @@ end
)
end

# TODO (cg 2023/05/16 18:09): dispute this cutoff.
isnearint(x) = abs(x-round(x)) < 1e-5

end
Loading

0 comments on commit caf2c2c

Please sign in to comment.