-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeCore
weakdep and an extension with a custom rule for the Levin transformation
#97
Changes from 14 commits
444db1a
17b90ca
766cff5
6db401e
dbb3821
e2d1f41
fc1e4a7
71334c4
afcb934
9c6d063
6f3445c
3c30394
0e687a1
232f566
e317d3a
1578b6f
27d257f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
module BesselsEnzymeCoreExt | ||
|
||
# TODO (cg 2023/05/08 10:02): Compat of any kind. | ||
|
||
using Bessels, EnzymeCore | ||
using EnzymeCore.EnzymeRules | ||
using Bessels.Math | ||
|
||
# A manual method that separately transforms the `val` and `dval`, because | ||
# sometimes the `val` can converge while the `dval` hasn't, so just using an | ||
# early return or something can give incorrect derivatives in edge cases. | ||
# | ||
# https://github.com/JuliaMath/Bessels.jl/issues/96 | ||
# | ||
# and links with for discussion. | ||
# | ||
# TODO (cg 2023/05/08 10:00): I'm not entirely sure how best to "generalize" | ||
# this to cases like a return type of DuplicatedNoNeed, or something being a | ||
# `Enzyme.Const`. These shouldn't in principle affect the "point" of this | ||
# function (which is just to check for convergence before applying a | ||
# function), but on its face this approach would mean I need a lot of | ||
# hand-written extra methods. I have an open issue on the Enzyme.jl repo at | ||
# | ||
# https://github.com/EnzymeAD/Enzyme.jl/issues/786 | ||
# | ||
# that gets at this problem a bit. But it's a weird request and I'm sure Billy | ||
# has a lot of asks on his time. | ||
function EnzymeRules.forward(func::Const{typeof(levin_transform)}, | ||
::Type{<:Duplicated}, | ||
s::Duplicated, | ||
w::Duplicated) | ||
(sv, dv, N) = (s.val, s.dval, length(s.val)) | ||
ls = levin_transform(sv, w.val) | ||
dls = levin_transform(dv, w.dval) | ||
Duplicated(ls, dls) | ||
end | ||
|
||
# This is fixing a straight bug in Enzyme. | ||
function EnzymeRules.forward(func::Const{typeof(sinpi)}, | ||
::Type{<:Duplicated}, | ||
x::Duplicated) | ||
Duplicated(sinpi(x.val), pi*cospi(x.val)) | ||
end | ||
|
||
function EnzymeRules.forward(func::Const{typeof(sinpi)}, | ||
::Type{<:Const}, | ||
x::Const) | ||
sinpi(x.val) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,38 +499,25 @@ besselk_power_series(v, x::Float32) = Float32(besselk_power_series(v, Float64(x) | |
besselk_power_series(v, x::ComplexF32) = ComplexF32(besselk_power_series(v, ComplexF64(x))) | ||
|
||
function besselk_power_series(v, x::ComplexOrReal{T}) where T | ||
MaxIter = 1000 | ||
S = eltype(x) | ||
v, x = S(v), S(x) | ||
|
||
z = x / 2 | ||
zz = z * z | ||
logz = log(z) | ||
xd2_v = exp(v*logz) | ||
xd2_nv = inv(xd2_v) | ||
|
||
# use the reflection identify to calculate gamma(-v) | ||
# use relation gamma(v)*v = gamma(v+1) to avoid two gamma calls | ||
gam_v = gamma(v) | ||
gam_nv = π / (sinpi(-abs(v)) * gam_v * v) | ||
gam_1mv = -gam_nv * v | ||
gam_1mnv = gam_v * v | ||
|
||
_t1 = gam_v * xd2_nv * gam_1mv | ||
_t2 = gam_nv * xd2_v * gam_1mnv | ||
(xd2_pow, fact_k, out) = (one(S), one(S), zero(S)) | ||
for k in 0:MaxIter | ||
t1 = xd2_pow * T(0.5) | ||
tmp = muladd(_t1, gam_1mnv, _t2 * gam_1mv) | ||
tmp *= inv(gam_1mv * gam_1mnv * fact_k) | ||
term = t1 * tmp | ||
out += term | ||
abs(term / out) < eps(T) && break | ||
(gam_1mnv, gam_1mv) = (gam_1mnv*(one(S) + v + k), gam_1mv*(one(S) - v + k)) | ||
xd2_pow *= zz | ||
fact_k *= k + one(S) | ||
Math.isnearint(v) && return besselk_power_series_int(v, x) | ||
MaxIter = 5000 | ||
gam = gamma(v) | ||
ngam = π / (sinpi(-abs(v)) * gam * v) | ||
|
||
s1, s2 = zero(T), zero(T) | ||
t1, t2 = one(T), one(T) | ||
|
||
for k in 1:MaxIter | ||
s1 += t1 | ||
s2 += t2 | ||
t1 *= x^2 / (4k * (k - v)) | ||
t2 *= x^2 / (4k * (k + v)) | ||
abs(t1) < eps(T) && break | ||
end | ||
return out | ||
|
||
xpv = (x/2)^v | ||
s = gam * s1 + xpv^2 * ngam * s2 | ||
return s / (2*xpv) | ||
end | ||
besselk_power_series_cutoff(nu, x::Float64) = x < 2.0 || nu > 1.6x - 1.0 | ||
besselk_power_series_cutoff(nu, x::Float32) = x < 10.0f0 || nu > 1.65f0*x - 8.0f0 | ||
|
@@ -578,15 +565,16 @@ end | |
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N} | ||
:( | ||
begin | ||
s_0 = zero(T) | ||
s = zero(T) | ||
t = one(T) | ||
@nexprs $N i -> begin | ||
s_{i} = s_{i-1} + t | ||
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i) | ||
w_{i} = 1 / t | ||
end | ||
sequence = @ntuple $N i -> s_{i} | ||
weights = @ntuple $N i -> w_{i} | ||
s += t | ||
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i) | ||
s_{i} = s | ||
w_{i} = t | ||
heltonmc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
sequence = @ntuple $N i -> s_{i} | ||
weights = @ntuple $N i -> w_{i} | ||
heltonmc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return levin_transform(sequence, weights) * sqrt(π / 2x) | ||
end | ||
) | ||
|
@@ -614,3 +602,66 @@ end | |
end | ||
) | ||
end | ||
|
||
# This is a version of Temme's proposed f_0 (1975 JCP, see reference above) that | ||
# swaps in a bunch of local expansions for functions that are well-behaved but | ||
# whose standard forms can't be naively evaluated by a computer at the origin. | ||
@inline function f0_local_expansion_v0(v, x) | ||
l2dx = log(2/x) | ||
mu = v*l2dx | ||
vv = v*v | ||
sp = evalpoly(vv, (1.0, 1.6449340668482264, 1.8940656589944918, 1.9711021825948702)) | ||
g1 = evalpoly(vv, (-0.5772156649015329, 0.04200263503409518, 0.042197734555544306)) | ||
g2 = evalpoly(vv, (1.0, -0.6558780715202539, 0.16653861138229145)) | ||
sh = evalpoly(mu*mu, (1.0, 0.16666666666666666, 0.008333333333333333, 0.0001984126984126984, 2.7557319223985893e-6)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm kinda wondering how many terms we need for this expansion as mu slowly grows... julia> v = 1e-4
0.0001
julia> v * log(2 / 24.0)
-0.0002484906649788 Five terms is probably ok ? I did some quick checks adding another term below but didn't seem to change much. Seems like a reasonable approximation that we have here. Just checking this
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't really say where the difference is coming from here .... julia> Bessels.BesselFunctions.besselk_power_series_int(0.000, 10.0)
1.778006219756317e-5
julia> ArbNumerics.besselk(ArbFloat(0.0000), ArbFloat(10.0))
1.7780062316167651811301192799427833154e-5
julia> ArbNumerics.besselk(ArbFloat(12.0000), ArbFloat(10.0))
0.010278998056493335846252984780767697567
julia> Bessels.BesselFunctions.besselk_power_series_int(12.000, 10.0)
0.010278998068072438 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well those are huge In general, I'm sure more terms couldn't hurt---I can tinker with that. But it isn't obvious that it will help, because if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya that's right haha. I will need to adjust cutoffs and stuff accordingly. So like what we probably should do is just have general routines for abs(v) < 1.5 and then use forward reccurence. It's much faster actually to do 2 Levin calculations and forward recurrence than it is to do the int power series. So let's adjust the whole routine. For v > 25 we will use the uniform expansion so forward recurrence will be fast here. But ya I think as long as we verify derivatives for v near 0.0 and x < 1.5 for this power series we should be ok. The tough thing about this though is that the derivatives are zero for v so something really close to v will be tricky to get right.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dang sorry ya I need to rethink the whole routine now that the intemediate range using the Levin transform is so fast. I used to just try to avoid that completely by extending the power series range as much as possible but now it makes sense to favor that when necessary and just fall back to forward recurrence when necessary. It actually makes checking the accuracy of the whole routine much easier because we essentially have to just check our scalar and derivative information for v < 1.0 and we know that forward recurrence is also stable and accurate for besselk. This should greatly reduce the number of points we need to explicity check. Of course we should still do some scattered spotchecks to make sure the derivatives are carried out ok with AD |
||
sp*(g1*cosh(mu) + g2*sh*l2dx) | ||
end | ||
|
||
# This function assumes |v|<1e-5! | ||
function besselk_power_series_temme_basal(v::V, x::X) where{V,X} | ||
max_iter = 50 | ||
T = promote_type(V,X) | ||
z = x/2 | ||
zz = z*z | ||
fk = f0_local_expansion_v0(v,x) | ||
zv = z^v | ||
znv = inv(zv) | ||
gam_1_c = (1.0, -0.5772156649015329, 0.9890559953279725, -0.23263776388631713) | ||
gam_1pv = evalpoly(v, gam_1_c) | ||
gam_1nv = evalpoly(-v, gam_1_c) | ||
(pk, qk, _ck, factk, vv) = (znv*gam_1pv/2, zv*gam_1nv/2, one(T), one(T), v*v) | ||
(out_v, out_vp1) = (zero(T), zero(T)) | ||
for k in 1:max_iter | ||
# add to the series: | ||
ck = _ck/factk | ||
term_v = ck*fk | ||
term_vp1 = ck*(pk - (k-1)*fk) | ||
out_v += term_v | ||
out_vp1 += term_vp1 | ||
# check for convergence: | ||
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break | ||
# otherwise, increment new quantities: | ||
fk = (k*fk + pk + qk)/(k^2 - vv) | ||
pk /= (k-v) | ||
qk /= (k+v) | ||
_ck *= zz | ||
factk *= k | ||
end | ||
(out_v, out_vp1/z) | ||
end | ||
|
||
function besselk_power_series_int(v, x::Float64) | ||
v = abs(v) | ||
(_v, flv) = modf(v) | ||
if _v > 1/2 | ||
(_v, flv) = (_v-one(_v), flv+1) | ||
end | ||
(kv, kvp1) = besselk_power_series_temme_basal(_v, x) | ||
twodx = 2/x | ||
for _ in 1:flv | ||
_v += 1 | ||
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv)) | ||
end | ||
kv | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any way we can get this fixed in enzyme itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this should use
sincospi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I poked over at EnzymeAD/Enzyme.jl#443
I don’t think I know exactly how to solve that though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it turns out this is not the only problem. Something else in the generic power series is broken for
Enzyme
but notForwardDiff
. Leaving a summary comment below, one moment.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have time could post a specific example where this is broken? I will try to figure out what line is causing the issue even when separating out the sinpi.
These issues though are especially annoying......
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ironically sincospi using Enzyme should be fine. I'm adding a pr for sinpi/cospi now which hopefully will be available in a few days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, though I'm solving in a different way from this PR (internal to Enzyme proper rather than Enzyme.jl custom rule), rules like this are welcome as PR's to Enzyme.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thanks for looking at this. I'll change it over here once that is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me to mention here for people looking at this PR in the future that the problem was just the
sinpi
, but that I didn't understand how to properly writeEnzymeRules
and just needed to propagate thex.dval
in the derivative part of the returnedDuplicated
object. I didn't include tests for power series accuracy here because that will probably be a little bit of a project to get the last few digits, but once I fixed my custom rule that worked fine.@wsmoses, would you like me to make a PR with this rule in the meantime? If it is fixed and will be available in the next release, maybe not point unless you would make a more immediate release that has the custom rule. I'd be happy to try and make that PR if you want, but I understand if it isn't the most useful.
Sorry that this thread looks on a skim like Enzyme problems but was actually "Chris doesn't know how to write custom rules" problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See EnzymeAD/Enzyme#1216. Hopefully that fixes this issue here and we can remove that part in the future.
P.s. I have the general
besselk
working now locally so hopefully we can get that merged soon and can test the general derivative cases.