-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[GRPO]: Fix Multi-GPU training for Entropy based masking of tokens. #3964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GRPO]: Fix Multi-GPU training for Entropy based masking of tokens. #3964
Conversation
Tested that training works correctly on a machine with 2 A40s. using this script: Click to expand code
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a multi-GPU training hang issue when using entropy-based token masking with top_entropy_quantile < 1
in GRPO training. The problem occurred because the gather
operation attempted to collect entropy tensors of different shapes across devices, causing the training to hang indefinitely.
- Implements safe gathering of variable-length entropy tensors across multiple GPUs
- Adds padding strategy to ensure consistent tensor shapes before gathering
- Maintains single-GPU compatibility with conditional branching
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
trl/trainer/grpo_trainer.py
Outdated
all_padded_entropies = self.accelerator.gather(padded_entropies) | ||
all_padded_entropies_mask = self.accelerator.gather(padded_entropies_mask) | ||
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this if/else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ummm there are a few ops that aren't useful to have in the single gpu-case right? Like all the first four lines in the if block aren't relevant/needed for a single gpu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they'll probably be no-op, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two won't be no-op
non_pad_entropies_seq_length = torch.tensor([non_pad_entropies.numel()], device=entropies.device)
max_non_pad_entropies_seq_length = self.accelerator.gather(non_pad_entropies_seq_length).max().item()
torch.zeros(
max_non_pad_entropies_seq_length - non_pad_entropies.numel(),
device=non_pad_entropies.device,
),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed but that's ok I think
…m/pramodith/trl into pramodith/entropy_mask_gather_bug
just merge main to your branch to fix the CI |
trl/trainer/grpo_trainer.py
Outdated
) | ||
all_padded_entropies = self.accelerator.gather(padded_entropies) | ||
all_padded_entropies_mask = self.accelerator.gather(padded_entropies_mask) | ||
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten() | |
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()] |
already flat
can you also remove |
…m/pramodith/trl into pramodith/entropy_mask_gather_bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks
…uggingface#3964) Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
What does this PR do?
#3933 identified that training a model with
top_entropy_quntile < 1
causes the training to hang indefinitely in a multi-gpu setting because thegather
operation to gather the non_padded entropy tensor can have different shapes on different devices.This PR fixes the said issue by:
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.