-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeCore
weakdep and an extension with a custom rule for the Levin transformation
#97
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #97 +/- ##
==========================================
- Coverage 96.63% 94.21% -2.42%
==========================================
Files 23 24 +1
Lines 2377 2439 +62
==========================================
+ Hits 2297 2298 +1
- Misses 80 141 +61
☔ View full report in Codecov by Sentry. |
mysterious "cannot merge projects" error.
This is a very silly issue. The problem is that you've named your module |
Hah! That is hilarious. I should have caught that, sorry for wasting your time here. I'll make that name change, thanks! |
no problem. It's a very silly error, but I can totally see myself making it and not noticing. |
@heltonmc ---this is all working now, and all tests pass. But before I write a real test file and ask for a merge, I just wanted to check in with you because it does change the Levin transformation code and I haven't touched base with you in a little while about how many other functions you've been developing to also use it. Just as a reminder of the "point", your improved version here is re-organized to make sure that it takes as inputs the sequence values and weights so that the functions that produce that ntuple will not risk a divide-by-zero problem, and then we'll use the custom rule here to check for exact convergence before passing those tuples into the Levin transform. Also, as far as tests go---I know we've discussed this in bits and pieces somewhere else, but considering that we don't really have the same level of quality reference value for the derivatives, maybe I'll just use a bigfloat+high-order If the answer isn't obvious to you guys, maybe for now I'll just put in a handful of the super expensive derivatives and check the atol for now while we discuss? Practically speaking, if it gives the right answer for any half-integer order and any full-integer order, I think we can be reasonably confident that the autodiff transformation is happening correctly, and maybe more subtle issues about accuracy to the last ulp should be a follow-up discussion. |
the general rule for finite differencing is that you lose about half of the bits. As long as you have more 100 or so bits, |
Oh, great. That's a very helpful rule of thumb. I'll crank up to 150 bits to be safe and go for the real rtol metric then. Thanks! |
Okay, so here's a funny issue: the You can check out the script I used in the README in the data file directory. But look at this: julia> FD(_v->arb_besselkx(_v, 13.1), 1.1) # using Arb for the central_fdm(10,1)
0.02902891941590647355381775055510323973333759074136335060419825427995335892649301646
julia> FD(_v->besselkx(_v, 13.1), 1.1) # using the SpecialFunctions.jl/AMOS besselkx
0.029028807353843788
julia> dbeskx_dv(1.1, 13.1)[2] # this is our Enzyme-based derivative, which agrees with the AMOS to almost every digit
0.02902880735384335
julia> arb_besselkx(1.1, 13.1) # to confirm no simple code error
0.3587283774392823423904762107496216329388054642187977726932970216277242684767524250767252332284
julia> besselkx(1.1, 13.1) # looking pretty good, agreeing to every digit
0.3587283774392823 I recognize that we shouldn't necessarily trust the AMOS one more because of what Oscar points out, but considering that our functions, which are accurate in the non-derivative case, agree so well with the AMOS finite diff but differ on like the fourth digit with the Arb finite diff something smells fishy to me here. Michael pointed out some potential issues with the Arb EDIT: julia> simple_fd(f,x,h) = (f(x+h) - f(x))/h
simple_fd (generic function with 1 method)
julia> simple_fd(_v->arb_besselkx(_v, 13.1), 1.1, ArbReal(1e-40))
0.02902880735384334062097957282707699252341403970380459
julia> dbeskx_dv(1.1, 13.1)[2] # agreement to almost every digit now
0.02902880735384335 @oscardssmith, could this conceivably be a bug in |
I believe the problem is https://github.com/JuliaDiff/FiniteDifferences.jl#dealing-with-numerical-noise. Since the ARB routines aren't as accurate to as many digits as they are supposed to be accurate to, |
(v,x) pairs. Setting to 5e-14 for now, which seems pretty acceptable considering the possibility of edge cases in the Arb function too.
Man, I need to resist the urge to put you on my speed dial. You have seen it all. Considering that I don't really know how to set that factor for Arb stuff, I just switched to a simple finite difference with Considering that the tests are in and pass (with that tiny amount of fudge), I think this is ready for a review and merge when you and @heltonmc have a chance to take a look. Thanks again for all your help! |
simplifies the manual rule.
This looks ready to merge to me. I checked out locally and ran a few benchmarks it is pretty much exactly twice as slow to generate both the scalar and derivative which makes sense to me as we have to do them separate. This seems very optimal to me. The one thing is that it requires a large number of sequence terms This right now is working great. |
gives valid derivatives. Tests TODO/to debug still.
This last commit brings in the new near-integer power series stuff. I haven't written tests yet so certainly not ready to merge, and your new form of the power series in general seems to be causing some of the |
src/BesselFunctions/besselk.jl
Outdated
v < zero(v) && return besselk_power_series_int(-v, x) | ||
flv = Int(floor(v)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's usually easiest to just indiscrimantely have v = abs(v)
at the top. I'm wondering how this line affects derivative information in anyway. As the zero order derivative is obviously zero from this line. Though we don't really need this check because we are checking at the top level but I think this is fine.
I think I generally like using the modf
function and I think this would be a good fit for these type of problems.
src/BesselFunctions/besselk.jl
Outdated
fk = f0_local_expansion_v0(v, x/2) | ||
zv = z^v | ||
znv = inv(zv) | ||
gam_1pv = GammaFunctions.gamma_near_1(1+v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure that these two gamma_near_1
calls are auto-vectorizing. They probably all but it might be worth a @inline
declaration on their function definition which should help the SIMD.
This is a good way to do this though as using SIMD will be faster than using reflection formula!
that makes the AD just a bit less accurate than it should be.
@@ -35,4 +35,17 @@ module BesselsEnzymeCoreExt | |||
Duplicated(ls, dls) | |||
end | |||
|
|||
# This is fixing a straight bug in Enzyme. | |||
function EnzymeRules.forward(func::Const{typeof(sinpi)}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any way we can get this fixed in enzyme itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this should use sincospi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I poked over at EnzymeAD/Enzyme.jl#443
I don’t think I know exactly how to solve that though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it turns out this is not the only problem. Something else in the generic power series is broken for Enzyme
but not ForwardDiff
. Leaving a summary comment below, one moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have time could post a specific example where this is broken? I will try to figure out what line is causing the issue even when separating out the sinpi.
These issues though are especially annoying......
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ironically sincospi using Enzyme should be fine. I'm adding a pr for sinpi/cospi now which hopefully will be available in a few days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, though I'm solving in a different way from this PR (internal to Enzyme proper rather than Enzyme.jl custom rule), rules like this are welcome as PR's to Enzyme.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thanks for looking at this. I'll change it over here once that is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me to mention here for people looking at this PR in the future that the problem was just the sinpi
, but that I didn't understand how to properly write EnzymeRules
and just needed to propagate the x.dval
in the derivative part of the returned Duplicated
object. I didn't include tests for power series accuracy here because that will probably be a little bit of a project to get the last few digits, but once I fixed my custom rule that worked fine.
@wsmoses, would you like me to make a PR with this rule in the meantime? If it is fixed and will be available in the next release, maybe not point unless you would make a more immediate release that has the custom rule. I'd be happy to try and make that PR if you want, but I understand if it isn't the most useful.
Sorry that this thread looks on a skim like Enzyme problems but was actually "Chris doesn't know how to write custom rules" problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See EnzymeAD/Enzyme#1216. Hopefully that fixes this issue here and we can remove that part in the future.
P.s. I have the general besselk
working now locally so hopefully we can get that merged soon and can test the general derivative cases.
of v=0. But probably ready to merge.
Okay, so this newest commit gives code that has decent timings/optimizations (modulo the The much bigger issues are:
I know @heltonmc would like to get this merged, so maybe it would be best to just merge and I will track down the power series issue in another PR, which can be smaller? The Let me know! |
sp = evalpoly(vv, (1.0, 1.6449340668482264, 1.8940656589944918, 1.9711021825948702)) | ||
g1 = evalpoly(vv, (-0.5772156649015329, 0.04200263503409518, 0.042197734555544306)) | ||
g2 = evalpoly(vv, (1.0, -0.6558780715202539, 0.16653861138229145)) | ||
sh = evalpoly(mu*mu, (1.0, 0.16666666666666666, 0.008333333333333333, 0.0001984126984126984, 2.7557319223985893e-6)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm kinda wondering how many terms we need for this expansion as mu slowly grows...
julia> v = 1e-4
0.0001
julia> v * log(2 / 24.0)
-0.0002484906649788
Five terms is probably ok ? I did some quick checks adding another term below but didn't seem to change much. Seems like a reasonable approximation that we have here. Just checking this f0
is pretty accurate? This isn't contributing to errors we are seeing?
SeriesData[x, 0, {
1.`20., 0, 0.16666666666666666666666666666666666667`20., 0,
0.00833333333333333333333333333333333333`20., 0,
0.00019841269841269841269841269841269841`20., 0,
2.75573192239858906525573192239859`20.*^-6, 0,
2.505210838544171877505210838544`20.*^-8}, 0, 11, 1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't really say where the difference is coming from here ....
julia> Bessels.BesselFunctions.besselk_power_series_int(0.000, 10.0)
1.778006219756317e-5
julia> ArbNumerics.besselk(ArbFloat(0.0000), ArbFloat(10.0))
1.7780062316167651811301192799427833154e-5
julia> ArbNumerics.besselk(ArbFloat(12.0000), ArbFloat(10.0))
0.010278998056493335846252984780767697567
julia> Bessels.BesselFunctions.besselk_power_series_int(12.000, 10.0)
0.010278998068072438
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well those are huge x
values---weren't we only going to use this for x<1.5
or so?
In general, I'm sure more terms couldn't hurt---I can tinker with that. But it isn't obvious that it will help, because if abs(v)<1e-5
and these are polynomials in v^2
, then the sixth order coefficient will be at most 1e-30
. I know that the story is more complicated for derivatives though, so I'll see about putting in an extra term to see if it helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya that's right haha. I will need to adjust cutoffs and stuff accordingly. So like what we probably should do is just have general routines for abs(v) < 1.5 and then use forward reccurence. It's much faster actually to do 2 Levin calculations and forward recurrence than it is to do the int power series. So let's adjust the whole routine. For v > 25 we will use the uniform expansion so forward recurrence will be fast here.
But ya I think as long as we verify derivatives for v near 0.0 and x < 1.5 for this power series we should be ok. The tough thing about this though is that the derivatives are zero for v so something really close to v will be tricky to get right....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dang sorry ya I need to rethink the whole routine now that the intemediate range using the Levin transform is so fast. I used to just try to avoid that completely by extending the power series range as much as possible but now it makes sense to favor that when necessary and just fall back to forward recurrence when necessary.
It actually makes checking the accuracy of the whole routine much easier because we essentially have to just check our scalar and derivative information for v < 1.0 and we know that forward recurrence is also stable and accurate for besselk. This should greatly reduce the number of points we need to explicity check. Of course we should still do some scattered spotchecks to make sure the derivatives are carried out ok with AD
Ya I think as long as the power series is accurate for scalar evaluation then I'm ok merging because I need to change the routine in general. The derivative information if the scalar is accurate should be an issue in the AD system which we can track down in a different PR. The Float32 issue though needs to be tracked down so will look at that now. |
Actually ya just change it back to the prior power series and I will track it down when I edit the whole routine. I don't want to fix it here and then have to change the routine again :) |
This PR adds a custom
EnzymeCore
rule (see #96). All issues have been resolved and all tests pass. I think it is good for a merge+review when ready.I currently get the error thatforward
is not defined, which it is asEnzymeCore.EnzymeRules.forward
. So something is happening that I'm not understanding. @oscardssmith, would you be willing to take a look at this?