Skip to content

Add utility methods to split gradients to GradientParams #2311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 4, 2024

Conversation

ArthurBrussee
Copy link
Contributor

@ArthurBrussee ArthurBrussee commented Sep 26, 2024

Changes

Currently, working with multiple different learning rates in Burn can be a bit of a hassle. The easiest way is to create a different GradientParams for each set of variables, and then step the optimizer for each of these. Eg. in my codebase I had something like:

let mut grad_means = GradientsParams::new();
grad_means.register(
    splats.means.id.clone(),
    splats.means.grad_remove(&mut grads).unwrap(),
);
let mut grad_opac = GradientsParams::new();
grad_opac.register(
    splats.raw_opacity.id.clone(),
    splats.raw_opacity.grad_remove(&mut grads).unwrap(),
);
let mut grad_rest = GradientsParams::new();
grad_rest.register(
    splats.sh_coeffs.id.clone(),
    splats.sh_coeffs.grad_remove(&mut grads).unwrap(),
);
grad_rest.register(
    splats.rotation.id.clone(),
    splats.rotation.grad_remove(&mut grads).unwrap(),
);
grad_rest.register(
    splats.log_scales.id.clone(),
    splats.log_scales.grad_remove(&mut grads).unwrap(),
);

This PR just adds some utility methods to turn this into

let grad_mean = GradientsParams::from_params(&mut grads, &splats, &[splats.means.id]);
let grad_opac = GradientsParams::from_params(&mut grads, &splats, &[splats.raw_opacity.id]);
// Since the first call modifies the grads this will just be the remaining grads.
let grad_rest = GradientsParams::from_module(&mut grads, &splats);

Not entirely sure about the names/shape of the API! It feels like this could be a method on Gradients instead but that's currently not really public.

To make the API slightly simpler I've made ParamId: Copy by making the ID a 64 bit random hash or something instead of a 6(?) byte random string encoded ID.

I've left from_grads in place as is to not break compatibility but also added a from_module, which allows people to extract a GradientParams for a specific module without consuming the gradients, so you could do something like

let linear_grads = GradientParams::from_module(&mut all_grads, &self.model.linear.a);
let rest_grads = GradientParams::from_grads(all_grads, &self.model); 

from_grads could also just be the same thing as from_module if it's ok to break compat.

Something that's a bit gnarly that if you were to extract a GradientsParam for each param individually, each optimizer step would map over all parameters and things become O(n^2). I doubt that matters much in practice.

Copy link

codecov bot commented Sep 26, 2024

Codecov Report

Attention: Patch coverage is 71.05263% with 44 lines in your changes missing coverage. Please review.

Project coverage is 85.36%. Comparing base (ce2d8e0) to head (1a23036).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-core/src/optim/grads.rs 58.06% 13 Missing ⚠️
crates/burn-core/src/optim/visitor.rs 30.76% 9 Missing ⚠️
crates/burn-core/src/module/base.rs 25.00% 6 Missing ⚠️
crates/burn-core/src/module/param/tensor.rs 33.33% 6 Missing ⚠️
crates/burn-core/src/module/param/visitor.rs 33.33% 4 Missing ⚠️
crates/burn-core/src/record/primitive.rs 50.00% 4 Missing ⚠️
crates/burn-core/src/module/param/running.rs 66.66% 1 Missing ⚠️
crates/burn-core/src/module/quantize.rs 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2311      +/-   ##
==========================================
- Coverage   85.41%   85.36%   -0.05%     
==========================================
  Files         767      768       +1     
  Lines       97960    98719     +759     
==========================================
+ Hits        83669    84274     +605     
- Misses      14291    14445     +154     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@AsherJingkongChen
Copy link
Contributor

It is a minor but practical change. I think ParamId is fine to be u64 instead of String.

@ArthurBrussee
Copy link
Contributor Author

Thanks! Yeah let me just actually include the ID -> u64 change here as well, just so this has a nicer API from the get-go.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be pretty useful! Thanks for the addition 🙏

I had a resnet record on disk saved using 0.14 and tested with your changes. The backward compatibility is preserved so looks good on that front.

Would be nice to add an example to the book to document the usage. But not strictly required for this PR if you don't have the time.

Also, before we merge we should update the binary models in some of the examples (e.g., mnist-inference-web). But I can take care of that.

@ArthurBrussee
Copy link
Contributor Author

Ah great! Thanks for checking the backwards compat. I've added a small test at least for decoding "legacy" 6 bit IDs, doesn't exactly guarantee compatibility but it's something.

I've added a small section to the custom training loop, lmk if that looks good, I'm really not familiar with the "non-custom" training loop so idk how things would work there haha.

I don't know much about how to update the old models so if you're happy to do that that'd be amazing!

laggui
laggui previously approved these changes Sep 30, 2024
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

The section in the custom training loop is perfect for now! We don't really have a section dedicated for optimizers yet, maybe with documentation changes it'll be moved but that's not for now 🙂

The models to update are strictly for the examples that use the binary record format (not as portable). Just need to save them to another format and re-save them with the latest changes. I'll take care of that in a minute.

@laggui
Copy link
Member

laggui commented Sep 30, 2024

Hmmm looks like the model used in the examples actually uses an id format that predates the currently covered backward compat (parameter ids look like uuids 30b82c23-788d-4d63-a743-ada258d5f13c).

I think I'll have to go over the saved model in a bit more details 😅

@laggui laggui self-requested a review September 30, 2024 17:48
@laggui laggui dismissed their stale review September 30, 2024 17:49

Need to review backwards compatibility

@ArthurBrussee
Copy link
Contributor Author

Oh hurr that's weird :/ 30b82c23-788d-4d63-a743-ada258d5f13c seems like way more than a 6 bytes id! I guess those might be coming from a different source?

@laggui
Copy link
Member

laggui commented Sep 30, 2024

Yeah totally forgot the param ids used to be 128-bit UUIDs before it was changed in #1912, and the mnist-inference-web model I was trying to convert was generated before the switch.

It didn't break before since these were also stored as a string, but this explains the error.

I still think the rest of the changes are good! I'll just have to check for this case since users might still have UUIDs in their records that would otherwise break when trying to load.

@laggui
Copy link
Member

laggui commented Oct 2, 2024

Just added additional support for 16-byte uuid. Turns out the binary model doesn't have to be updated, just had to add backward compatibility for the parameter ids 🙂

@laggui laggui merged commit dbd577a into tracel-ai:main Oct 4, 2024
11 checks passed
@ArthurBrussee ArthurBrussee deleted the opt-split branch October 16, 2024 19:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants