Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeCore weakdep and an extension with a custom rule for the Levin transformation #97

Merged
merged 17 commits into from
May 18, 2023

Conversation

cgeoga
Copy link
Contributor

@cgeoga cgeoga commented May 8, 2023

This PR adds a custom EnzymeCore rule (see #96). All issues have been resolved and all tests pass. I think it is good for a merge+review when ready.

I currently get the error that forward is not defined, which it is as EnzymeCore.EnzymeRules.forward. So something is happening that I'm not understanding. @oscardssmith, would you be willing to take a look at this?

@codecov
Copy link

codecov bot commented May 8, 2023

Codecov Report

Patch coverage: 15.49% and project coverage change: -2.42 ⚠️

Comparison is base (152771d) 96.63% compared to head (27d257f) 94.21%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #97      +/-   ##
==========================================
- Coverage   96.63%   94.21%   -2.42%     
==========================================
  Files          23       24       +1     
  Lines        2377     2439      +62     
==========================================
+ Hits         2297     2298       +1     
- Misses         80      141      +61     
Impacted Files Coverage Δ
ext/BesselsEnzymeCoreExt.jl 0.00% <0.00%> (ø)
src/GammaFunctions/gamma.jl 98.63% <0.00%> (-1.37%) ⬇️
src/BesselFunctions/besselk.jl 81.14% <15.78%> (-15.62%) ⬇️
src/Math/Math.jl 65.47% <66.66%> (+0.41%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@oscardssmith
Copy link
Member

This is a very silly issue. The problem is that you've named your module EnzymeRules and are trying to import EnzymeCore.EnzymeRules so julia gets confused as to which module EnzymeRules refers to. I would recommend changing the name to BesselsEnzymeCoreext

@cgeoga
Copy link
Contributor Author

cgeoga commented May 9, 2023

Hah! That is hilarious. I should have caught that, sorry for wasting your time here. I'll make that name change, thanks!

@oscardssmith
Copy link
Member

no problem. It's a very silly error, but I can totally see myself making it and not noticing.

@cgeoga
Copy link
Contributor Author

cgeoga commented May 9, 2023

@heltonmc ---this is all working now, and all tests pass. But before I write a real test file and ask for a merge, I just wanted to check in with you because it does change the Levin transformation code and I haven't touched base with you in a little while about how many other functions you've been developing to also use it.

Just as a reminder of the "point", your improved version here is re-organized to make sure that it takes as inputs the sequence values and weights so that the functions that produce that ntuple will not risk a divide-by-zero problem, and then we'll use the custom rule here to check for exact convergence before passing those tuples into the Levin transform.

Also, as far as tests go---I know we've discussed this in bits and pieces somewhere else, but considering that we don't really have the same level of quality reference value for the derivatives, maybe I'll just use a bigfloat+high-order central_fdm and check for something like isapprox(ref, candidate, atol=1e-15)? I'm a bit nervous to ask for rtols of 1e-15 using a finite difference rule even with higher precision floats and a very high-order adaptive finite difference rule. Maybe @oscardssmith has some opinions on this as well?

If the answer isn't obvious to you guys, maybe for now I'll just put in a handful of the super expensive derivatives and check the atol for now while we discuss? Practically speaking, if it gives the right answer for any half-integer order and any full-integer order, I think we can be reasonably confident that the autodiff transformation is happening correctly, and maybe more subtle issues about accuracy to the last ulp should be a follow-up discussion.

@oscardssmith
Copy link
Member

the general rule for finite differencing is that you lose about half of the bits. As long as you have more 100 or so bits, rtol=1e-15 should be fine (you may need extra bits for other reasons).

@cgeoga
Copy link
Contributor Author

cgeoga commented May 9, 2023

Oh, great. That's a very helpful rule of thumb. I'll crank up to 150 bits to be safe and go for the real rtol metric then. Thanks!

@cgeoga
Copy link
Contributor Author

cgeoga commented May 9, 2023

Okay, so here's a funny issue: the ArbNumerics/Arb besselk(x) values are pretty dubious. If you use ArbFloat and crank the precision too high, everything starts coming out as basically zeros and you get garbage. So I switched to ArbReal and set the digits to be almost as high as I could without the silent failure, but the values are still quite suspect.

You can check out the script I used in the README in the data file directory. But look at this:

julia> FD(_v->arb_besselkx(_v, 13.1), 1.1) # using Arb for the central_fdm(10,1)
0.02902891941590647355381775055510323973333759074136335060419825427995335892649301646

julia> FD(_v->besselkx(_v, 13.1), 1.1) # using the SpecialFunctions.jl/AMOS besselkx
0.029028807353843788

julia> dbeskx_dv(1.1, 13.1)[2] # this is our Enzyme-based derivative, which agrees with the AMOS to almost every digit
0.02902880735384335

julia> arb_besselkx(1.1, 13.1) # to confirm no simple code error
0.3587283774392823423904762107496216329388054642187977726932970216277242684767524250767252332284

julia> besselkx(1.1, 13.1) # looking pretty good, agreeing to every digit
0.3587283774392823

I recognize that we shouldn't necessarily trust the AMOS one more because of what Oscar points out, but considering that our functions, which are accurate in the non-derivative case, agree so well with the AMOS finite diff but differ on like the fourth digit with the Arb finite diff something smells fishy to me here.

Michael pointed out some potential issues with the Arb besselk before. But maybe this is also just a bug with FiniteDifferences. Will keep investigating, but I thought it might be useful to have a paper trail of this problem in case it turns into an upstream bug report somewhere.

EDIT:
Look at this:

julia> simple_fd(f,x,h) = (f(x+h) - f(x))/h
simple_fd (generic function with 1 method)

julia> simple_fd(_v->arb_besselkx(_v, 13.1), 1.1, ArbReal(1e-40))
0.02902880735384334062097957282707699252341403970380459

julia> dbeskx_dv(1.1, 13.1)[2] # agreement to almost every digit now
0.02902880735384335

@oscardssmith, could this conceivably be a bug in FiniteDifferences.jl? Do you think the adaptive step size or something makes a mistake when given ArbReals or ArbFloats?

@oscardssmith
Copy link
Member

I believe the problem is https://github.com/JuliaDiff/FiniteDifferences.jl#dealing-with-numerical-noise. Since the ARB routines aren't as accurate to as many digits as they are supposed to be accurate to, FiniteDiffereences is picking too small an epsilon.

(v,x) pairs. Setting to 5e-14 for now, which seems pretty acceptable
considering the possibility of edge cases in the Arb function too.
@cgeoga
Copy link
Contributor Author

cgeoga commented May 9, 2023

Man, I need to resist the urge to put you on my speed dial. You have seen it all. Considering that I don't really know how to set that factor for Arb stuff, I just switched to a simple finite difference with h=ArbReal(1e-40). For all but six trial values the rtol of 1e-15 passes, but for the six others I need to set it to about 5e-14. I know you guys won't love that, but it seems decent enough for an initial merge to me.

Considering that the tests are in and pass (with that tiny amount of fudge), I think this is ready for a review and merge when you and @heltonmc have a chance to take a look.

Thanks again for all your help!

test/Manifest.toml Outdated Show resolved Hide resolved
src/Math/Math.jl Outdated Show resolved Hide resolved
@heltonmc
Copy link
Member

This looks ready to merge to me.

I checked out locally and ran a few benchmarks it is pretty much exactly twice as slow to generate both the scalar and derivative which makes sense to me as we have to do them separate. This seems very optimal to me. The one thing is that it requires a large number of sequence terms Val(30) which is rather significant when I only found that the scalar value needed about 16 terms. That is something we can look more into next though when we piece everything together.

This right now is working great.

gives valid derivatives. Tests TODO/to debug still.
@cgeoga
Copy link
Contributor Author

cgeoga commented May 16, 2023

This last commit brings in the new near-integer power series stuff. I haven't written tests yet so certainly not ready to merge, and your new form of the power series in general seems to be causing some of the Float32 issues to fail. But it is so fast and I'm sure there is just a little thing to fix somewhere, so I'll look at that and at setting up the AD tests tomorrow.

Comment on lines 657 to 658
v < zero(v) && return besselk_power_series_int(-v, x)
flv = Int(floor(v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's usually easiest to just indiscrimantely have v = abs(v) at the top. I'm wondering how this line affects derivative information in anyway. As the zero order derivative is obviously zero from this line. Though we don't really need this check because we are checking at the top level but I think this is fine.

I think I generally like using the modf function and I think this would be a good fit for these type of problems.

fk = f0_local_expansion_v0(v, x/2)
zv = z^v
znv = inv(zv)
gam_1pv = GammaFunctions.gamma_near_1(1+v)
Copy link
Member

@heltonmc heltonmc May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that these two gamma_near_1 calls are auto-vectorizing. They probably all but it might be worth a @inline declaration on their function definition which should help the SIMD.

This is a good way to do this though as using SIMD will be faster than using reflection formula!

@heltonmc
Copy link
Member

Float32 errors are annoying. I went through and checked the versions and this new version is actually slightly more accurate. It seems like a cutoff issue.
Screen Shot 2023-05-17 at 1 52 56 PM
Screen Shot 2023-05-17 at 1 53 10 PM

Probably some weird cutoff issue that I'll track down... I mean it's a weird section because we are promoting the Float32 values to Float64 but using a different cutoff to avoid the intermediate routine. But now that Levin is so fast.. it makes it almost not worthwhile to do this.

that makes the AD just a bit less accurate than it should be.
@@ -35,4 +35,17 @@ module BesselsEnzymeCoreExt
Duplicated(ls, dls)
end

# This is fixing a straight bug in Enzyme.
function EnzymeRules.forward(func::Const{typeof(sinpi)},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we can get this fixed in enzyme itself?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this should use sincospi

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I poked over at EnzymeAD/Enzyme.jl#443
I don’t think I know exactly how to solve that though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it turns out this is not the only problem. Something else in the generic power series is broken for Enzyme but not ForwardDiff. Leaving a summary comment below, one moment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you have time could post a specific example where this is broken? I will try to figure out what line is causing the issue even when separating out the sinpi.

These issues though are especially annoying......

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ironically sincospi using Enzyme should be fine. I'm adding a pr for sinpi/cospi now which hopefully will be available in a few days.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, though I'm solving in a different way from this PR (internal to Enzyme proper rather than Enzyme.jl custom rule), rules like this are welcome as PR's to Enzyme.jl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thanks for looking at this. I'll change it over here once that is available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me to mention here for people looking at this PR in the future that the problem was just the sinpi, but that I didn't understand how to properly write EnzymeRules and just needed to propagate the x.dval in the derivative part of the returned Duplicated object. I didn't include tests for power series accuracy here because that will probably be a little bit of a project to get the last few digits, but once I fixed my custom rule that worked fine.

@wsmoses, would you like me to make a PR with this rule in the meantime? If it is fixed and will be available in the next release, maybe not point unless you would make a more immediate release that has the custom rule. I'd be happy to try and make that PR if you want, but I understand if it isn't the most useful.

Sorry that this thread looks on a skim like Enzyme problems but was actually "Chris doesn't know how to write custom rules" problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See EnzymeAD/Enzyme#1216. Hopefully that fixes this issue here and we can remove that part in the future.

P.s. I have the general besselk working now locally so hopefully we can get that merged soon and can test the general derivative cases.

@cgeoga
Copy link
Contributor Author

cgeoga commented May 18, 2023

Okay, so this newest commit gives code that has decent timings/optimizations (modulo the muladd conversions), and it passes the autodiff tests for order and arg....albeit with reduced precision for abs(v)<1e-8. But I think those derivatives are pretty hard and I'm almost suspicious of the Arb version too. So personally, I would still vote to merge for the moment and revisit.

The much bigger issues are:

  1. The generic power series code, even with the sinpi fix, is still totally broken with Enzyme (but works great with ForwardDiff if I remove some type annotations in Bessels.gamma).
  2. The new power series code has some Float32 failures.

I know @heltonmc would like to get this merged, so maybe it would be best to just merge and I will track down the power series issue in another PR, which can be smaller? The Float32 issues are less obvious. If you'd like, I can just go back to the old power series code that passes the tests and that can also be a separate PR.

Let me know!

sp = evalpoly(vv, (1.0, 1.6449340668482264, 1.8940656589944918, 1.9711021825948702))
g1 = evalpoly(vv, (-0.5772156649015329, 0.04200263503409518, 0.042197734555544306))
g2 = evalpoly(vv, (1.0, -0.6558780715202539, 0.16653861138229145))
sh = evalpoly(mu*mu, (1.0, 0.16666666666666666, 0.008333333333333333, 0.0001984126984126984, 2.7557319223985893e-6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm kinda wondering how many terms we need for this expansion as mu slowly grows...

julia> v = 1e-4
0.0001

julia> v * log(2 / 24.0)
-0.0002484906649788

Five terms is probably ok ? I did some quick checks adding another term below but didn't seem to change much. Seems like a reasonable approximation that we have here. Just checking this f0 is pretty accurate? This isn't contributing to errors we are seeing?

SeriesData[x, 0, {
 1.`20., 0, 0.16666666666666666666666666666666666667`20., 0, 
  0.00833333333333333333333333333333333333`20., 0, 
  0.00019841269841269841269841269841269841`20., 0, 
  2.75573192239858906525573192239859`20.*^-6, 0, 
  2.505210838544171877505210838544`20.*^-8}, 0, 11, 1]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't really say where the difference is coming from here ....

julia> Bessels.BesselFunctions.besselk_power_series_int(0.000, 10.0)
1.778006219756317e-5

julia> ArbNumerics.besselk(ArbFloat(0.0000), ArbFloat(10.0))
1.7780062316167651811301192799427833154e-5

julia> ArbNumerics.besselk(ArbFloat(12.0000), ArbFloat(10.0))
0.010278998056493335846252984780767697567

julia> Bessels.BesselFunctions.besselk_power_series_int(12.000, 10.0)
0.010278998068072438

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well those are huge x values---weren't we only going to use this for x<1.5 or so?

In general, I'm sure more terms couldn't hurt---I can tinker with that. But it isn't obvious that it will help, because if abs(v)<1e-5 and these are polynomials in v^2, then the sixth order coefficient will be at most 1e-30. I know that the story is more complicated for derivatives though, so I'll see about putting in an extra term to see if it helps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya that's right haha. I will need to adjust cutoffs and stuff accordingly. So like what we probably should do is just have general routines for abs(v) < 1.5 and then use forward reccurence. It's much faster actually to do 2 Levin calculations and forward recurrence than it is to do the int power series. So let's adjust the whole routine. For v > 25 we will use the uniform expansion so forward recurrence will be fast here.

But ya I think as long as we verify derivatives for v near 0.0 and x < 1.5 for this power series we should be ok. The tough thing about this though is that the derivatives are zero for v so something really close to v will be tricky to get right....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dang sorry ya I need to rethink the whole routine now that the intemediate range using the Levin transform is so fast. I used to just try to avoid that completely by extending the power series range as much as possible but now it makes sense to favor that when necessary and just fall back to forward recurrence when necessary.

It actually makes checking the accuracy of the whole routine much easier because we essentially have to just check our scalar and derivative information for v < 1.0 and we know that forward recurrence is also stable and accurate for besselk. This should greatly reduce the number of points we need to explicity check. Of course we should still do some scattered spotchecks to make sure the derivatives are carried out ok with AD

@heltonmc
Copy link
Member

I know @heltonmc would like to get this merged, so maybe it would be best to just merge and I will track down the power series issue in another PR, which can be smaller?

Ya I think as long as the power series is accurate for scalar evaluation then I'm ok merging because I need to change the routine in general. The derivative information if the scalar is accurate should be an issue in the AD system which we can track down in a different PR.

The Float32 issue though needs to be tracked down so will look at that now.

@heltonmc
Copy link
Member

Actually ya just change it back to the prior power series and I will track it down when I edit the whole routine. I don't want to fix it here and then have to change the routine again :)

@heltonmc heltonmc merged commit caf2c2c into JuliaMath:master May 18, 2023
@cgeoga cgeoga deleted the enzyme branch May 30, 2023 13:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants