-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeCore
weakdep and an extension with a custom rule for the Levin transformation
#97
Changes from 13 commits
444db1a
17b90ca
766cff5
6db401e
dbb3821
e2d1f41
fc1e4a7
71334c4
afcb934
9c6d063
6f3445c
3c30394
0e687a1
232f566
e317d3a
1578b6f
27d257f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
module BesselsEnzymeCoreExt | ||
|
||
# TODO (cg 2023/05/08 10:02): Compat of any kind. | ||
|
||
using Bessels, EnzymeCore | ||
using EnzymeCore.EnzymeRules | ||
using Bessels.Math | ||
|
||
# A manual method that separately transforms the `val` and `dval`, because | ||
# sometimes the `val` can converge while the `dval` hasn't, so just using an | ||
# early return or something can give incorrect derivatives in edge cases. | ||
# | ||
# https://github.com/JuliaMath/Bessels.jl/issues/96 | ||
# | ||
# and links with for discussion. | ||
# | ||
# TODO (cg 2023/05/08 10:00): I'm not entirely sure how best to "generalize" | ||
# this to cases like a return type of DuplicatedNoNeed, or something being a | ||
# `Enzyme.Const`. These shouldn't in principle affect the "point" of this | ||
# function (which is just to check for convergence before applying a | ||
# function), but on its face this approach would mean I need a lot of | ||
# hand-written extra methods. I have an open issue on the Enzyme.jl repo at | ||
# | ||
# https://github.com/EnzymeAD/Enzyme.jl/issues/786 | ||
# | ||
# that gets at this problem a bit. But it's a weird request and I'm sure Billy | ||
# has a lot of asks on his time. | ||
function EnzymeRules.forward(func::Const{typeof(levin_transform)}, | ||
::Type{<:Duplicated}, | ||
s::Duplicated, | ||
w::Duplicated) | ||
(sv, dv, N) = (s.val, s.dval, length(s.val)) | ||
ls = levin_transform(sv, w.val) | ||
dls = levin_transform(dv, w.dval) | ||
Duplicated(ls, dls) | ||
end | ||
|
||
# This is fixing a straight bug in Enzyme. | ||
function EnzymeRules.forward(func::Const{typeof(sinpi)}, | ||
::Type{<:Duplicated}, | ||
x::Duplicated) | ||
Duplicated(sinpi(x.val), pi*cospi(x.val)) | ||
end | ||
|
||
function EnzymeRules.forward(func::Const{typeof(sinpi)}, | ||
::Type{<:Const}, | ||
x::Const) | ||
sinpi(x.val) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,38 +499,25 @@ besselk_power_series(v, x::Float32) = Float32(besselk_power_series(v, Float64(x) | |
besselk_power_series(v, x::ComplexF32) = ComplexF32(besselk_power_series(v, ComplexF64(x))) | ||
|
||
function besselk_power_series(v, x::ComplexOrReal{T}) where T | ||
MaxIter = 1000 | ||
S = eltype(x) | ||
v, x = S(v), S(x) | ||
|
||
z = x / 2 | ||
zz = z * z | ||
logz = log(z) | ||
xd2_v = exp(v*logz) | ||
xd2_nv = inv(xd2_v) | ||
|
||
# use the reflection identify to calculate gamma(-v) | ||
# use relation gamma(v)*v = gamma(v+1) to avoid two gamma calls | ||
gam_v = gamma(v) | ||
gam_nv = π / (sinpi(-abs(v)) * gam_v * v) | ||
gam_1mv = -gam_nv * v | ||
gam_1mnv = gam_v * v | ||
|
||
_t1 = gam_v * xd2_nv * gam_1mv | ||
_t2 = gam_nv * xd2_v * gam_1mnv | ||
(xd2_pow, fact_k, out) = (one(S), one(S), zero(S)) | ||
for k in 0:MaxIter | ||
t1 = xd2_pow * T(0.5) | ||
tmp = muladd(_t1, gam_1mnv, _t2 * gam_1mv) | ||
tmp *= inv(gam_1mv * gam_1mnv * fact_k) | ||
term = t1 * tmp | ||
out += term | ||
abs(term / out) < eps(T) && break | ||
(gam_1mnv, gam_1mv) = (gam_1mnv*(one(S) + v + k), gam_1mv*(one(S) - v + k)) | ||
xd2_pow *= zz | ||
fact_k *= k + one(S) | ||
Math.isnearint(v) && return besselk_power_series_int(v, x) | ||
MaxIter = 5000 | ||
gam = gamma(v) | ||
ngam = π / (sinpi(-abs(v)) * gam * v) | ||
|
||
s1, s2 = zero(T), zero(T) | ||
t1, t2 = one(T), one(T) | ||
|
||
for k in 1:MaxIter | ||
s1 += t1 | ||
s2 += t2 | ||
t1 *= x^2 / (4k * (k - v)) | ||
t2 *= x^2 / (4k * (k + v)) | ||
abs(t1) < eps(T) && break | ||
end | ||
return out | ||
|
||
xpv = (x/2)^v | ||
s = gam * s1 + xpv^2 * ngam * s2 | ||
return s / (2*xpv) | ||
end | ||
besselk_power_series_cutoff(nu, x::Float64) = x < 2.0 || nu > 1.6x - 1.0 | ||
besselk_power_series_cutoff(nu, x::Float32) = x < 10.0f0 || nu > 1.65f0*x - 8.0f0 | ||
|
@@ -578,15 +565,16 @@ end | |
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N} | ||
:( | ||
begin | ||
s_0 = zero(T) | ||
s = zero(T) | ||
t = one(T) | ||
@nexprs $N i -> begin | ||
s_{i} = s_{i-1} + t | ||
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i) | ||
w_{i} = 1 / t | ||
end | ||
sequence = @ntuple $N i -> s_{i} | ||
weights = @ntuple $N i -> w_{i} | ||
s += t | ||
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i) | ||
s_{i} = s | ||
w_{i} = t | ||
heltonmc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
sequence = @ntuple $N i -> s_{i} | ||
weights = @ntuple $N i -> w_{i} | ||
heltonmc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return levin_transform(sequence, weights) * sqrt(π / 2x) | ||
end | ||
) | ||
|
@@ -614,3 +602,69 @@ end | |
end | ||
) | ||
end | ||
|
||
# This is an expansion of the function | ||
# | ||
# f_0(v, x) = (x^v)*gamma(-v) + (x^(-v))*gamma(v) | ||
# = (x^v)*(gamma(-v) + (x^(-2*v))*gamma(v)) | ||
# | ||
# around v ∼ 0. As you can see by plugging that second form into Wolfram Alpha | ||
# and getting an expansion back, this is actually a bivariate polynomial in | ||
# (v^2, log(x)). So that's how this is structured. | ||
@inline function f0_local_expansion_v0(v, x) | ||
lx = log(x) | ||
c0 = evalpoly(lx, (-1.1544313298030657, -2.0)) | ||
c2 = evalpoly(lx, ( 1.4336878944573288, -1.978111990655945, -0.5772156649015329, -0.3333333333333333)) | ||
c4 = evalpoly(lx, (-0.6290784463642211, -1.4584260788225176, -0.23263776388631713, -0.32968533177599085, -0.048101305408461074, -0.016666666666666666)) | ||
evalpoly(v*v, (c0,c2,c4))/2 | ||
end | ||
|
||
# This function assumes |v| < 1e-6 or 1e-7! | ||
# | ||
# TODO (cg 2023/05/16 18:07): lots of micro-optimizations. | ||
function besselk_power_series_temme_basal(v::V, x::Float64) where{V} | ||
max_iter = 50 | ||
T = promote_type(V,Float64) | ||
z = x/2 | ||
zz = z*z | ||
fk = f0_local_expansion_v0(v, x/2) | ||
zv = z^v | ||
znv = inv(zv) | ||
gam_1pv = GammaFunctions.gamma_near_1(1+v) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure that these two This is a good way to do this though as using SIMD will be faster than using reflection formula! |
||
gam_1nv = GammaFunctions.gamma_near_1(1-v) | ||
(pk, qk, _ck, factk, vv) = (znv*gam_1pv/2, zv*gam_1nv/2, one(T), one(T), v*v) | ||
(out_v, out_vp1) = (zero(T), zero(T)) | ||
for k in 1:max_iter | ||
# add to the series: | ||
ck = _ck/factk | ||
term_v = ck*fk | ||
term_vp1 = ck*(pk - (k-1)*fk) | ||
out_v += term_v | ||
out_vp1 += term_vp1 | ||
# check for convergence: | ||
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break | ||
# otherwise, increment new quantities: | ||
fk = (k*fk + pk + qk)/(k^2 - vv) | ||
pk /= (k-v) | ||
qk /= (k+v) | ||
_ck *= zz | ||
factk *= k | ||
end | ||
(out_v, out_vp1/z) | ||
end | ||
|
||
function besselk_power_series_int(v, x::Float64) | ||
v = abs(v) | ||
(_v, flv) = modf(v) | ||
if _v > 1/2 | ||
(_v, flv) = (_v-one(_v), flv+1) | ||
end | ||
(kv, kvp1) = besselk_power_series_temme_basal(_v, x) | ||
twodx = 2/x | ||
for _ in 1:flv | ||
_v += 1 | ||
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv)) | ||
end | ||
kv | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any way we can get this fixed in enzyme itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this should use
sincospi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I poked over at EnzymeAD/Enzyme.jl#443
I don’t think I know exactly how to solve that though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it turns out this is not the only problem. Something else in the generic power series is broken for
Enzyme
but notForwardDiff
. Leaving a summary comment below, one moment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have time could post a specific example where this is broken? I will try to figure out what line is causing the issue even when separating out the sinpi.
These issues though are especially annoying......
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ironically sincospi using Enzyme should be fine. I'm adding a pr for sinpi/cospi now which hopefully will be available in a few days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, though I'm solving in a different way from this PR (internal to Enzyme proper rather than Enzyme.jl custom rule), rules like this are welcome as PR's to Enzyme.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thanks for looking at this. I'll change it over here once that is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me to mention here for people looking at this PR in the future that the problem was just the
sinpi
, but that I didn't understand how to properly writeEnzymeRules
and just needed to propagate thex.dval
in the derivative part of the returnedDuplicated
object. I didn't include tests for power series accuracy here because that will probably be a little bit of a project to get the last few digits, but once I fixed my custom rule that worked fine.@wsmoses, would you like me to make a PR with this rule in the meantime? If it is fixed and will be available in the next release, maybe not point unless you would make a more immediate release that has the custom rule. I'd be happy to try and make that PR if you want, but I understand if it isn't the most useful.
Sorry that this thread looks on a skim like Enzyme problems but was actually "Chris doesn't know how to write custom rules" problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See EnzymeAD/Enzyme#1216. Hopefully that fixes this issue here and we can remove that part in the future.
P.s. I have the general
besselk
working now locally so hopefully we can get that merged soon and can test the general derivative cases.