Skip to content

Commit eb57d7a

Browse files
wandbrandonantimoralaggui
authored
Support importing safetensors format (#2721)
* add support for safetensors in pytorch reader * update book * support safetensors format, for it's own crate * remove some duplicate code and depend on pytorch adapter * Fix clippy error * Fix clippy errors * Update doc strings * Move common types/functions into common module * Refactor and add adapter option * Update doc strings * Updated example for importing pt and safetensors weights * Add book section on safetensors * Mention Safetensors in the main README file * Remove dead code * SafeTensors -> Safetensors type * SafeTensors -> Safetensors in docs/messages * SafeTensors -> Safetensors in book * SafeTensors -> Safetensors in example * SafeTensors -> Safetensors in readme * Removed TensorFlow * Removed excess documentation * Update pytorch-model to look the formatting similar to safetensors md * Update pytorch-model.md * Removed pytorch copied tests and replaced with one multilayer test * Update README.md Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com> * Update burn-book/src/import/README.md Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com> * Update burn-book/src/import/README.md Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com> * Shorter example * Remove redundant comment * Removed extra wording * Replace unwrap_or_else with expect * Removed imports for pytorch doc consistency * Removed imports * Fix formating --------- Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
1 parent 1f92ec1 commit eb57d7a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1785
-814
lines changed

Cargo.lock

Lines changed: 20 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ members = [
88
"crates/*",
99
"crates/burn-import/pytorch-tests",
1010
"crates/burn-import/onnx-tests",
11+
"crates/burn-import/safetensors-tests",
1112
"examples/*",
12-
"examples/pytorch-import/model",
1313
"xtask",
1414
]
1515

README.md

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,18 @@ for more details.
188188
<div align="left">
189189
<img align="right" src="https://raw.github.com/tracel-ai/burn/main/assets/backend-chip.png" height="96px"/>
190190

191-
192191
Burn strives to be as fast as possible on as many hardwares as possible, with robust implementations.
193192
We believe this flexibility is crucial for modern needs where you may train your models in the cloud,
194193
then deploy on customer hardwares, which vary from user to user.
194+
195195
</div>
196196

197197
<br />
198198

199199
**Supported Backends**
200200

201201
| Backend | Devices | Class |
202-
| ------- | ---------------------------- | ----------- |
202+
| -------- | ---------------------------- | ----------- |
203203
| CUDA | NVIDIA GPUs | First-Party |
204204
| ROCm | AMD GPUs | First-Party |
205205
| Metal | Apple GPUs | First-Party |
@@ -426,13 +426,16 @@ Our ONNX support is further described in
426426

427427
<details>
428428
<summary>
429-
Importing PyTorch Models 🚚
429+
Importing PyTorch or Safetensors Models 🚚
430430
</summary>
431431
<br />
432432

433-
Support for loading of PyTorch model weights into Burn’s native model architecture, ensuring
434-
seamless integration. See
435-
[Burn Book 🔥 section on importing PyTorch](https://burn.dev/burn-book/import/pytorch-model.html)
433+
You can load weights from PyTorch or Safetensors formats directly into your Burn-defined models. This makes it easy to reuse existing models while benefiting from Burn's performance and deployment features.
434+
435+
Learn more:
436+
437+
- [Import pre-trained PyTorch models into Burn](https://burn.dev/burn-book/import/pytorch-model.html)
438+
- [Load models from Safetensors format](https://burn.dev/burn-book/import/safetensors-model.html)
436439

437440
</details>
438441

@@ -468,24 +471,23 @@ means it can run in bare metal environment such as embedded devices without an o
468471

469472
<br />
470473

471-
472474
### Benchmarks
473475

474476
To evaluate performance across different backends and track improvements over time, we provide a
475477
dedicated benchmarking suite.
476478

477479
Run and compare benchmarks using [burn-bench](https://github.com/tracel-ai/burn-bench).
478480

479-
480-
> ⚠️ **Warning**
481-
> When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain.
481+
> ⚠️ **Warning**
482+
> When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain.
482483
> To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` file:
484+
>
483485
> ```rust
484486
> #![recursion_limit = "256"]
485487
> ```
488+
>
486489
> The default recursion limit (128) is often just below the required depth (typically 130-150) due to deeply nested associated types and trait bounds.
487490
488-
489491
## Getting Started
490492
491493
<div align="left">

burn-book/src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- [Import Models](./import/README.md)
2525
- [ONNX Model](./import/onnx-model.md)
2626
- [PyTorch Model](./import/pytorch-model.md)
27+
- [Safetensors Model](./import/safetensors-model.md)
2728
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
2829
- [Quantization (Beta)](./quantization.md)
2930
- [Advanced](./advanced/README.md)

burn-book/src/import/README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Importing Models
22

3-
The Burn project supports the import of models from various frameworks, emphasizing efficiency and
4-
compatibility. Currently, it handles two primary model formats:
3+
Burn supports importing models from other frameworks and file formats, enabling you to use pre-trained weights in your Burn applications.
54

6-
1. [ONNX](./onnx-model.md): Facilitates direct import, ensuring the model's performance and structure
7-
are maintained.
5+
## Supported Formats
86

9-
2. [PyTorch](./pytorch-model.md): Enables the loading of PyTorch model weights into Burn’s native model
10-
architecture, ensuring seamless integration.
7+
Burn currently supports three primary model import formats:
8+
9+
| Format | Description | Use Case |
10+
|--------|-------------|----------|
11+
| [**ONNX**](./onnx-model.md) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
12+
| [**PyTorch**](./pytorch-model.md) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
13+
| [**Safetensors**](./safetensors-model.md) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |

0 commit comments

Comments
 (0)