Skip to content

Support importing safetensors format #2721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
May 6, 2025

Conversation

wandbrandon
Copy link
Contributor

@wandbrandon wandbrandon commented Jan 20, 2025

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #626

Changes

This pull request introduces support for the SafeTensors format in the Burn library, enabling secure and efficient model weight imports. It includes updates to documentation, new examples, and test cases to demonstrate the usage of SafeTensors. The changes also enhance the project's modularity by adding SafeTensors-specific features and dependencies.

Support for SafeTensors Format:

  • Added SafeTensors support to burn-import with a new feature flag and dependencies in Cargo.toml. This includes updates to the default features to enable SafeTensors by default.

Documentation Updates:

  • Updated the Burn Book to include a new section on SafeTensors in the summary and import documentation. This includes a detailed guide for exporting and importing SafeTensors weights, troubleshooting, and advanced features like key remapping and framework-specific adapters.
  • Enhanced the README with information about SafeTensors support, including updated descriptions for importing models and links to relevant guides.

Codebase Adjustments:

  • Removed outdated examples and added a new example directory for importing model weights, aligning with the new SafeTensors functionality.

These changes collectively improve the Burn library's flexibility and usability, particularly for users looking to leverage the SafeTensors format for secure and efficient model weight handling.

Testing

Added end-to-end safetensors-tests.

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 80.56680% with 48 lines in your changes missing coverage. Please review.

Project coverage is 81.36%. Comparing base (1f92ec1) to head (a22aa59).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/safetensors/recorder.rs 43.63% 31 Missing ⚠️
crates/burn-import/src/pytorch/config.rs 55.55% 8 Missing ⚠️
crates/burn-import/src/common/candle.rs 93.50% 5 Missing ⚠️
crates/burn-import/src/pytorch/recorder.rs 33.33% 2 Missing ⚠️
crates/burn-import/src/safetensors/reader.rs 93.75% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2721      +/-   ##
==========================================
- Coverage   81.37%   81.36%   -0.02%     
==========================================
  Files         818      821       +3     
  Lines      117643   117791     +148     
==========================================
+ Hits        95736    95835      +99     
- Misses      21907    21956      +49     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition 🙏

Looks pretty good overall, just some minor comments.

@Nikaidou-Shinku
Copy link
Contributor

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

@wandbrandon
Copy link
Contributor Author

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency.

@laggui
Copy link
Member

laggui commented Jan 21, 2025

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅

Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition.

@antimora antimora self-requested a review January 27, 2025 17:48
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

Yes, I strongly agree with this opinion. My intention was that we have a separate support for different formats. We can reuse common code in burn-import but entry point should be different. We need a separate module and feature for GGUF files, see #1187, and it would follow the same pattern.

PyTorch's pt and Safetensors are unrelated. We should not mix up in code or documention as such. It just happens we're using Candle's reader but it could be different.

P.S. Thanks for taking up this problem! It's been often asked feature.

@antimora
Copy link
Collaborator

I suggest creating a dedicated SafeTensorFileRecorder to handle SafeTensor files independently from PyTorch's .pt files. This approach ensures a clear separation between different file formats and supports framework-specific transformations during the import process.

Additionally, I propose providing configurable options (via LoadArgs, similar to PyTorch's recorder) within the recorder to specify the appropriate transformation adapter. By default, this could use the PyTorchAdapter but allow customization for other frameworks, such as TensorFlow. This design enhances flexibility and decouples the handling of different tensor file formats. Moreover, it might be beneficial to support passing a user-defined implementation of BurnModuleAdapter when needed.

Lastly, we should replicate PyTorch import tests to ensure comprehensive coverage. Over time, we can expand these tests to include SafeTensor files exported from TensorFlow.

One more thing: we should introduce a new feature flag, safetensors.

@antimora antimora added the feature The feature request label Jan 27, 2025
@wandbrandon
Copy link
Contributor Author

wandbrandon commented Jan 28, 2025

Hi all, I went through and essentially copied over the implementation for pytorch recorder, and created the safetensors recorder. It's a lot of new files that are essentially copied code but with little adjustments. I think this gives a good base for the future when we'd like to remove the candle dependency, and to add further support for safetsensors in the future.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes.

That is a lot of code duplication 😅 I don't think "replicate" was meant it this way hahah

We should re-use the existing PyTorchAdapter when both pytorch and safetensors features are enabled. Definitely don't want to copy its implementation over. And by default we can use the existing DefaultAdapter in burn::record::serde::adapter to simply load the safetensor file as is.

For the tests, we also don't need to copy all the python scripts. We can just have the existing scripts under pytorch-tests save both in pickle and safetensor formats. And since the current tests added come from pytorch, we can add the additional safetensor tests to the existing pytorch tests under the safetensors feature flag guard. We would need the safetensor recorder to use the pytorch adapter anyway, so these tests can live under the pytorch-tests (with the addition of the safetensors feature flag). If we want to add standalone tests for models saved in safetensors that don't require any other transformations we could expand that to have safetensors-tests.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking this up. This feature will help many that requested this.

This is in the right direction but we should try reducing code duplications. I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time). It's up to @laggui to allow it.

I also suggest creating a new section under the book specifically for SafeTensors. You can talk about the transformation for model modules.

The last thing, don't forget to update LoadArgs for SafeTensorsFileRecorder per my previous comment.

@laggui
Copy link
Member

laggui commented Jan 29, 2025

I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time).

Agreed for the bold part.

Copy link
Contributor

github-actions bot commented Mar 1, 2025

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Mar 1, 2025
@ivila
Copy link
Contributor

ivila commented Mar 12, 2025

Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports no_std, could we also have a SafeTensorBytesRecorder?

@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Mar 12, 2025
@laggui
Copy link
Member

laggui commented Mar 13, 2025

Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports no_std, could we also have a SafeTensorBytesRecorder?

A bit hesitant on adding new record formats tbh, not sure of the value. To import safetensor format (like the initial target of this PR), sure. But maybe you could provide a bit more info to justify? 🙂

@antimora
Copy link
Collaborator

Just a feature request, can ignore this if you think it is unnessesary😂: since safetensors now supports no_std, could we also have a SafeTensorBytesRecorder?

A bit hesitant on adding new record formats tbh, not sure of the value. To import safetensor format (like the initial target of this PR), sure. But maybe you could provide a bit more info to justify? 🙂

We don't need a new type. We can provide with an arg option.

Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added stale The issue or pr has been open for too long and removed stale The issue or pr has been open for too long labels Apr 13, 2025
@reneleonhardt
Copy link
Contributor

@laggui This has been a lot of work. What can be done to complete the review?

@antimora
Copy link
Collaborator

@laggui This has been a lot of work. What can be done to complete the review?

Still the PR needs to be refactored.

Do you use have a use case for this feature? We could complete this ourselves.

@antimora antimora requested a review from laggui May 2, 2025 18:44
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for removing the bloated tests 😅

Much better now. I have a few comments mostly regarding docs

@antimora antimora requested a review from laggui May 5, 2025 21:34
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed a small thing in the previous review! Otherwise, should be good to go after.

Thanks for taking this over btw!

@antimora antimora requested a review from laggui May 6, 2025 16:16
@antimora
Copy link
Collaborator

antimora commented May 6, 2025

@laggui Done updating. Hopefully no issues.

Thanks for reviewing the long PR. Takes up lots of context switching.

@laggui
Copy link
Member

laggui commented May 6, 2025

Thanks for reviewing the long PR. Takes up lots of context switching.

Of course! Longer PRs take more time to review carefully so time-to-merge is usually longer as well 😅

There has been an uptick in activity since the 0.17 release too so I have to balance it out. Sorry if your other PRs are not entirely reviewed yet, should come soon!

@laggui laggui merged commit eb57d7a into tracel-ai:main May 6, 2025
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature The feature request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support importing safetensors format
6 participants