-
Notifications
You must be signed in to change notification settings - Fork 645
Add utility methods to split gradients to GradientParams #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2311 +/- ##
==========================================
- Coverage 85.41% 85.36% -0.05%
==========================================
Files 767 768 +1
Lines 97960 98719 +759
==========================================
+ Hits 83669 84274 +605
- Misses 14291 14445 +154 ☔ View full report in Codecov by Sentry. |
It is a minor but practical change. I think |
Thanks! Yeah let me just actually include the ID -> u64 change here as well, just so this has a nicer API from the get-go. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should be pretty useful! Thanks for the addition 🙏
I had a resnet record on disk saved using 0.14 and tested with your changes. The backward compatibility is preserved so looks good on that front.
Would be nice to add an example to the book to document the usage. But not strictly required for this PR if you don't have the time.
Also, before we merge we should update the binary models in some of the examples (e.g., mnist-inference-web
). But I can take care of that.
Ah great! Thanks for checking the backwards compat. I've added a small test at least for decoding "legacy" 6 bit IDs, doesn't exactly guarantee compatibility but it's something. I've added a small section to the custom training loop, lmk if that looks good, I'm really not familiar with the "non-custom" training loop so idk how things would work there haha. I don't know much about how to update the old models so if you're happy to do that that'd be amazing! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The section in the custom training loop is perfect for now! We don't really have a section dedicated for optimizers yet, maybe with documentation changes it'll be moved but that's not for now 🙂
The models to update are strictly for the examples that use the binary record format (not as portable). Just need to save them to another format and re-save them with the latest changes. I'll take care of that in a minute.
Hmmm looks like the model used in the examples actually uses an id format that predates the currently covered backward compat (parameter ids look like uuids I think I'll have to go over the saved model in a bit more details 😅 |
Oh hurr that's weird :/ 30b82c23-788d-4d63-a743-ada258d5f13c seems like way more than a 6 bytes id! I guess those might be coming from a different source? |
Yeah totally forgot the param ids used to be 128-bit UUIDs before it was changed in #1912, and the It didn't break before since these were also stored as a string, but this explains the error. I still think the rest of the changes are good! I'll just have to check for this case since users might still have UUIDs in their records that would otherwise break when trying to load. |
Just added additional support for 16-byte uuid. Turns out the binary model doesn't have to be updated, just had to add backward compatibility for the parameter ids 🙂 |
Changes
Currently, working with multiple different learning rates in Burn can be a bit of a hassle. The easiest way is to create a different GradientParams for each set of variables, and then step the optimizer for each of these. Eg. in my codebase I had something like:
This PR just adds some utility methods to turn this into
Not entirely sure about the names/shape of the API! It feels like this could be a method on Gradients instead but that's currently not really public.
To make the API slightly simpler I've made ParamId: Copy by making the ID a 64 bit random hash or something instead of a 6(?) byte random string encoded ID.
I've left from_grads in place as is to not break compatibility but also added a from_module, which allows people to extract a GradientParams for a specific module without consuming the gradients, so you could do something like
from_grads could also just be the same thing as from_module if it's ok to break compat.
Something that's a bit gnarly that if you were to extract a GradientsParam for each param individually, each optimizer step would map over all parameters and things become O(n^2). I doubt that matters much in practice.