-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeCore
weakdep and an extension with a custom rule for the Levin transformation
#97
Changes from 1 commit
444db1a
17b90ca
766cff5
6db401e
dbb3821
e2d1f41
fc1e4a7
71334c4
afcb934
9c6d063
6f3445c
3c30394
0e687a1
232f566
e317d3a
1578b6f
27d257f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,38 +499,25 @@ besselk_power_series(v, x::Float32) = Float32(besselk_power_series(v, Float64(x) | |
besselk_power_series(v, x::ComplexF32) = ComplexF32(besselk_power_series(v, ComplexF64(x))) | ||
|
||
function besselk_power_series(v, x::ComplexOrReal{T}) where T | ||
MaxIter = 1000 | ||
S = eltype(x) | ||
v, x = S(v), S(x) | ||
|
||
z = x / 2 | ||
zz = z * z | ||
logz = log(z) | ||
xd2_v = exp(v*logz) | ||
xd2_nv = inv(xd2_v) | ||
|
||
# use the reflection identify to calculate gamma(-v) | ||
# use relation gamma(v)*v = gamma(v+1) to avoid two gamma calls | ||
gam_v = gamma(v) | ||
gam_nv = π / (sinpi(-abs(v)) * gam_v * v) | ||
gam_1mv = -gam_nv * v | ||
gam_1mnv = gam_v * v | ||
|
||
_t1 = gam_v * xd2_nv * gam_1mv | ||
_t2 = gam_nv * xd2_v * gam_1mnv | ||
(xd2_pow, fact_k, out) = (one(S), one(S), zero(S)) | ||
for k in 0:MaxIter | ||
t1 = xd2_pow * T(0.5) | ||
tmp = muladd(_t1, gam_1mnv, _t2 * gam_1mv) | ||
tmp *= inv(gam_1mv * gam_1mnv * fact_k) | ||
term = t1 * tmp | ||
out += term | ||
abs(term / out) < eps(T) && break | ||
(gam_1mnv, gam_1mv) = (gam_1mnv*(one(S) + v + k), gam_1mv*(one(S) - v + k)) | ||
xd2_pow *= zz | ||
fact_k *= k + one(S) | ||
Math.isnearint(v) && return besselk_power_series_int(v, x) | ||
MaxIter = 5000 | ||
gam = gamma(v) | ||
ngam = π / (sinpi(-abs(v)) * gam * v) | ||
|
||
s1, s2 = zero(T), zero(T) | ||
t1, t2 = one(T), one(T) | ||
|
||
for k in 1:MaxIter | ||
s1 += t1 | ||
s2 += t2 | ||
t1 *= x^2 / (4k * (k - v)) | ||
t2 *= x^2 / (4k * (k + v)) | ||
abs(t1) < eps(T) && break | ||
end | ||
return out | ||
|
||
xpv = (x/2)^v | ||
s = gam * s1 + xpv^2 * ngam * s2 | ||
return s / (2*xpv) | ||
end | ||
besselk_power_series_cutoff(nu, x::Float64) = x < 2.0 || nu > 1.6x - 1.0 | ||
besselk_power_series_cutoff(nu, x::Float32) = x < 10.0f0 || nu > 1.65f0*x - 8.0f0 | ||
|
@@ -615,3 +602,68 @@ end | |
end | ||
) | ||
end | ||
|
||
# This is an expansion of the function | ||
# | ||
# f_0(v, x) = (x^v)*gamma(-v) + (x^(-v))*gamma(v) | ||
# = (x^v)*(gamma(-v) + (x^(-2*v))*gamma(v)) | ||
# | ||
# around v ∼ 0. As you can see by plugging that second form into Wolfram Alpha | ||
# and getting an expansion back, this is actually a bivariate polynomial in | ||
# (v^2, log(x)). So that's how this is structured. | ||
@inline function f0_local_expansion_v0(v, x) | ||
lx = log(x) | ||
c0 = evalpoly(lx, (-1.1544313298030657, -2.0)) | ||
c2 = evalpoly(lx, ( 1.4336878944573288, -1.978111990655945, -0.5772156649015329, -0.3333333333333333)) | ||
c4 = evalpoly(lx, (-0.6290784463642211, -1.4584260788225176, -0.23263776388631713, -0.32968533177599085, -0.048101305408461074, -0.016666666666666666)) | ||
evalpoly(v*v, (c0,c2,c4))/2 | ||
end | ||
|
||
# This function assumes |v| < 1e-6 or 1e-7! | ||
# | ||
# TODO (cg 2023/05/16 18:07): lots of micro-optimizations. | ||
function besselk_power_series_temme_basal(v::V, x::Float64) where{V} | ||
max_iter = 50 | ||
T = promote_type(V,Float64) | ||
z = x/2 | ||
zz = z*z | ||
fk = f0_local_expansion_v0(v, x/2) | ||
zv = z^v | ||
znv = inv(zv) | ||
gam_1pv = GammaFunctions.gamma_near_1(1+v) | ||
gam_1nv = GammaFunctions.gamma_near_1(1-v) | ||
(pk, qk, _ck, factk, vv) = (znv*gam_1pv/2, zv*gam_1nv/2, one(T), one(T), v*v) | ||
(out_v, out_vp1) = (zero(T), zero(T)) | ||
for k in 1:max_iter | ||
# add to the series: | ||
ck = _ck/factk | ||
term_v = ck*fk | ||
term_vp1 = ck*(pk - (k-1)*fk) | ||
out_v += term_v | ||
out_vp1 += term_vp1 | ||
# check for convergence: | ||
((abs(term_v) < eps(T)) && (abs(term_vp1) < eps(T))) && break | ||
# otherwise, increment new quantities: | ||
fk = (k*fk + pk + qk)/(k^2 - vv) | ||
pk /= (k-v) | ||
qk /= (k+v) | ||
_ck *= zz | ||
factk *= k | ||
end | ||
(out_v, out_vp1/z) | ||
end | ||
|
||
function besselk_power_series_int(v, x::Float64) | ||
v < zero(v) && return besselk_power_series_int(-v, x) | ||
flv = Int(floor(v)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's usually easiest to just indiscrimantely have I think I generally like using the |
||
_v = v - flv | ||
(kv, kvp1) = besselk_power_series_temme_basal(_v, x) | ||
abs(v) < 1/2 && return kv | ||
twodx = 2/x | ||
for _ in 1:(flv-1) | ||
_v += 1 | ||
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv)) | ||
end | ||
kvp1 | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure that these two
gamma_near_1
calls are auto-vectorizing. They probably all but it might be worth a@inline
declaration on their function definition which should help the SIMD.This is a good way to do this though as using SIMD will be faster than using reflection formula!