Skip to content

Support importing safetensors format #2721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
15bac95
add support for safetensors in pytorch reader
wandbrandon Jan 20, 2025
6a0330e
update book
wandbrandon Jan 20, 2025
dbd40ce
support safetensors format, for it's own crate
wandbrandon Jan 28, 2025
0b87d62
Merge branch 'main' of https://github.com/tracel-ai/burn into bmw-bur…
wandbrandon Apr 26, 2025
64f3ea6
remove some duplicate code and depend on pytorch adapter
wandbrandon Apr 26, 2025
47a4a0a
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora Apr 28, 2025
10ecc1e
Fix clippy error
antimora Apr 28, 2025
e8663ba
Fix clippy errors
antimora Apr 28, 2025
5fe486d
Update doc strings
antimora Apr 28, 2025
197dfa7
Move common types/functions into common module
antimora Apr 28, 2025
b360fea
Refactor and add adapter option
antimora Apr 28, 2025
914ce5c
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora Apr 28, 2025
fa8a3e9
Update doc strings
antimora Apr 28, 2025
2b04545
Updated example for importing pt and safetensors weights
antimora Apr 29, 2025
d1151bf
Add book section on safetensors
antimora Apr 29, 2025
09538db
Mention Safetensors in the main README file
antimora Apr 29, 2025
34c7df8
Remove dead code
antimora Apr 29, 2025
9144c1f
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora May 2, 2025
66f5a7c
SafeTensors -> Safetensors type
antimora May 2, 2025
fa71374
SafeTensors -> Safetensors in docs/messages
antimora May 2, 2025
b63d562
SafeTensors -> Safetensors in book
antimora May 2, 2025
33ab1bd
SafeTensors -> Safetensors in example
antimora May 2, 2025
5bdc846
SafeTensors -> Safetensors in readme
antimora May 2, 2025
b685ef8
Removed TensorFlow
antimora May 2, 2025
30f2600
Removed excess documentation
antimora May 2, 2025
420e9c1
Update pytorch-model to look the formatting similar to safetensors md
antimora May 2, 2025
8666c4a
Update pytorch-model.md
antimora May 2, 2025
7dda904
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora May 5, 2025
883b773
Removed pytorch copied tests and replaced with one multilayer test
antimora May 5, 2025
e8c75a2
Update README.md
antimora May 5, 2025
1fe96dd
Update burn-book/src/import/README.md
antimora May 5, 2025
1e0ffb0
Update burn-book/src/import/README.md
antimora May 5, 2025
1c7a170
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora May 5, 2025
6423eb0
Merge branch 'bmw-burn-safetensors' of https://github.com/wandbrandon…
antimora May 5, 2025
316133d
Shorter example
antimora May 5, 2025
526969f
Remove redundant comment
antimora May 5, 2025
88d5677
Removed extra wording
antimora May 5, 2025
7f5b32b
Replace unwrap_or_else with expect
antimora May 5, 2025
91b1a97
Removed imports for pytorch doc consistency
antimora May 6, 2025
cc75399
Merge remote-tracking branch 'upstream/main' into pr/2721
antimora May 6, 2025
f3bc4f4
Removed imports
antimora May 6, 2025
a22aa59
Fix formating
antimora May 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ members = [
"crates/*",
"crates/burn-import/pytorch-tests",
"crates/burn-import/onnx-tests",
"crates/burn-import/safetensors-tests",
"examples/*",
"examples/pytorch-import/model",
"xtask",
]

Expand Down
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,18 @@ for more details.
<div align="left">
<img align="right" src="https://raw.github.com/tracel-ai/burn/main/assets/backend-chip.png" height="96px"/>


Burn strives to be as fast as possible on as many hardwares as possible, with robust implementations.
We believe this flexibility is crucial for modern needs where you may train your models in the cloud,
then deploy on customer hardwares, which vary from user to user.

</div>

<br />

**Supported Backends**

| Backend | Devices | Class |
| ------- | ---------------------------- | ----------- |
| -------- | ---------------------------- | ----------- |
| CUDA | NVIDIA GPUs | First-Party |
| ROCm | AMD GPUs | First-Party |
| Metal | Apple GPUs | First-Party |
Expand Down Expand Up @@ -426,13 +426,16 @@ Our ONNX support is further described in

<details>
<summary>
Importing PyTorch Models 🚚
Importing PyTorch or Safetensors Models 🚚
</summary>
<br />

Support for loading of PyTorch model weights into Burn’s native model architecture, ensuring
seamless integration. See
[Burn Book 🔥 section on importing PyTorch](https://burn.dev/burn-book/import/pytorch-model.html)
You can load weights from PyTorch or Safetensors formats directly into your Burn-defined models. This makes it easy to reuse existing models while benefiting from Burn's performance and deployment features.

Learn more:

- [Import pre-trained PyTorch models into Burn](https://burn.dev/burn-book/import/pytorch-model.html)
- [Load models from Safetensors format](https://burn.dev/burn-book/import/safetensors-model.html)

</details>

Expand Down Expand Up @@ -468,24 +471,23 @@ means it can run in bare metal environment such as embedded devices without an o

<br />


### Benchmarks

To evaluate performance across different backends and track improvements over time, we provide a
dedicated benchmarking suite.

Run and compare benchmarks using [burn-bench](https://github.com/tracel-ai/burn-bench).


> ⚠️ **Warning**
> When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain.
> ⚠️ **Warning**
> When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain.
> To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` file:
>
> ```rust
> #![recursion_limit = "256"]
> ```
>
> The default recursion limit (128) is often just below the required depth (typically 130-150) due to deeply nested associated types and trait bounds.


## Getting Started

<div align="left">
Expand Down
1 change: 1 addition & 0 deletions burn-book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- [Import Models](./import/README.md)
- [ONNX Model](./import/onnx-model.md)
- [PyTorch Model](./import/pytorch-model.md)
- [Safetensors Model](./import/safetensors-model.md)
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
- [Quantization (Beta)](./quantization.md)
- [Advanced](./advanced/README.md)
Expand Down
15 changes: 9 additions & 6 deletions burn-book/src/import/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Importing Models

The Burn project supports the import of models from various frameworks, emphasizing efficiency and
compatibility. Currently, it handles two primary model formats:
Burn supports importing models from other frameworks and file formats, enabling you to use pre-trained weights in your Burn applications.

1. [ONNX](./onnx-model.md): Facilitates direct import, ensuring the model's performance and structure
are maintained.
## Supported Formats

2. [PyTorch](./pytorch-model.md): Enables the loading of PyTorch model weights into Burn’s native model
architecture, ensuring seamless integration.
Burn currently supports three primary model import formats:

| Format | Description | Use Case |
|--------|-------------|----------|
| [**ONNX**](./onnx-model.md) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
| [**PyTorch**](./pytorch-model.md) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
| [**Safetensors**](./safetensors-model.md) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |
Loading
Loading