Skip to content

Update pytorch-model.md with a new troubleshooting help #3081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ tutorial on importing a more complex model.
## How to export a PyTorch model

If you have a PyTorch model that you want to import into Burn, you will need to export it first,
unless you are using a pre-trained published model. To export a PyTorch model, you can use the
`torch.save` function.
unless you are using a pre-trained published model. To export a PyTorch model correctly, you need to
save only the model weights (state_dict) using the `torch.save` function, not the entire model.

The following is an example of how to export a PyTorch model:
The following is an example of how to properly export a PyTorch model:

```python
import torch
Expand All @@ -38,12 +38,24 @@ class Net(nn.Module):
if __name__ == "__main__":
torch.manual_seed(42) # To make it reproducible
model = Net().to(torch.device("cpu"))
model_weights = model.state_dict()
torch.save(model_weights, "conv2d.pt")
model_weights = model.state_dict() # This extracts just the weights
torch.save(model_weights, "conv2d.pt") # Save only the weights, not the entire model
```

Use [Netron](https://github.com/lutzroeder/netron) to view the exported model. You should see
something like this:
If you accidentally save the entire model instead of just the weights, you may encounter errors
during import like:

```
Failed to decode foobar: DeserializeError("Serde error: other error:
Missing source values for the 'foo1' field of type 'BarRecordItem'.
Please verify the source data and ensure the field name is correct")
```

You can verify if your model is exported correctly by opening the `.pt` file in
[Netron](https://github.com/lutzroeder/netron). A properly exported weights file will show a flat
structure of tensors, while an incorrectly exported file will display nested blocks representing the
entire model architecture. When viewing the exported model in Netron, you should see something like
this:

![image alt>](./conv2d.svg)

Expand Down