-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running stats with gradient checkpointing #1035
Comments
Thanks for reporting. I slightly modified your code to demonstrate how it works:
As you can see, you need to run the backward pass to make the running_mean match. Just forward is not enough. Checkpoint_wrapper is used for training. Only doing the forward pass does not make sense IMHO. With backward pass, the stats is matching correctly. |
There is one kaggle trick - you can run multiple forward passes on test set to adapt running stats to it. This is weird case but still) |
Oh I see. That’s interesting! Do you have example code or pseudo code for it? |
Code for this trick? If yes then it is as simple as
|
I see. Then after this loop you proceed with normal training for 1 epoch or the whole training N epochs? |
No, I run this loop on the test set (on which I want to get the highest target metric in competition) after the whole training. The idea is to adapt BN statistics to the test set which can have slightly different distribution |
I think I can come up with a solution on the next week, ok? |
@vovaf709 ok |
According to patch_batchnorm source code if layer collecting running stats (e.g. BatchNorm) is checkpointed it will accumulate statistics only when grad is enabled (on backward pass). This induces inconsistency:
I think this behaviour should be modified to accumulate statistics at 1-st forward pass or at least mentioned in docs
The text was updated successfully, but these errors were encountered: