Fix GRU to match pytorch (#2701). #2704
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
This addresses issue #2701
Changes
Update GRU implementation of "new" gate to match pytorch implementation. This can change numerical output in some cases.
Add GRU unit test with sequence length > 1.
Fix GRU input state dimensions and hidden state handling. This is an API change since the dimensions of the optional hidden state input are being corrected to the right sizes.
These changes do affect numerical results and change the API slightly. I think just updating to the correct API dimensions seems like the best thing since the previous implementation was incorrect, not just different than pytorch.
Testing
These changes were tested with a small unit test. For this test the correct values were computed manually using the equations for GRU.
I tested these changes against PyTorch. The weights and biases from PyTorch were saved then split into sections using a custom script (to split apart the weights for each gate). Input and output tensors were separately saved and then loaded into a test rust program. Everything was randomly initialized. With this PR the results from burn and torch were almost identical (within 6 decimal digits). I tried input sizes of 1, 2, and 8. I tried hidden sizes of 1, 2, and 8. I tried sequence lengths of 1, 2, and 3.