Skip to content

Conversation

pramodith
Copy link
Collaborator

What does this PR do?

#3933 identified that training a model with top_entropy_quntile < 1 causes the training to hang indefinitely in a multi-gpu setting because the gather operation to gather the non_padded entropy tensor can have different shapes on different devices.

This PR fixes the said issue by:

  1. Gathering the number of non padded entropy tokens.
  2. Find the max number of non padded entropy tokens across all gpus.
  3. Pad the non padded entropies to the max length
  4. Gather all the entropies post padding.
  5. Compute the threshold

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pramodith
Copy link
Collaborator Author

Tested that training works correctly on a machine with 2 A40s. using this script:

Click to expand code
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer

dataset = load_dataset("trl-lib/tldr", split="train[:50]")

def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(
    max_steps=10,
    per_device_train_batch_size=8,
    num_generations=8,
    logging_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs = {"use_reentrant": False}, #must be false for DDP
    top_entropy_quantile=0.2,
    max_completion_length=1024
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()
accelerate launch --num_processes 2 --multi_gpu --mixed_precision=bf16 examples/scripts/test.py

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@pramodith pramodith requested a review from Copilot August 27, 2025 14:50
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a multi-GPU training hang issue when using entropy-based token masking with top_entropy_quantile < 1 in GRPO training. The problem occurred because the gather operation attempted to collect entropy tensors of different shapes across devices, causing the training to hang indefinitely.

  • Implements safe gathering of variable-length entropy tensors across multiple GPUs
  • Adds padding strategy to ensure consistent tensor shapes before gathering
  • Maintains single-GPU compatibility with conditional branching

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

all_padded_entropies = self.accelerator.gather(padded_entropies)
all_padded_entropies_mask = self.accelerator.gather(padded_entropies_mask)
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten()
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this if/else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be removed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ummm there are a few ops that aren't useful to have in the single gpu-case right? Like all the first four lines in the if block aren't relevant/needed for a single gpu.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they'll probably be no-op, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two won't be no-op

 non_pad_entropies_seq_length = torch.tensor([non_pad_entropies.numel()], device=entropies.device)
 max_non_pad_entropies_seq_length = self.accelerator.gather(non_pad_entropies_seq_length).max().item()

torch.zeros(
                        max_non_pad_entropies_seq_length - non_pad_entropies.numel(),
                        device=non_pad_entropies.device,
                    ),

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed but that's ok I think

@qgallouedec
Copy link
Member

just merge main to your branch to fix the CI

)
all_padded_entropies = self.accelerator.gather(padded_entropies)
all_padded_entropies_mask = self.accelerator.gather(padded_entropies_mask)
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()].flatten()
all_non_padded_entropies = all_padded_entropies[all_padded_entropies_mask.bool()]

already flat

@qgallouedec
Copy link
Member

can you also remove accelerator=None from the function signature?

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

@pramodith pramodith merged commit 3bfa981 into huggingface:main Sep 3, 2025
10 checks passed
SamY724 pushed a commit to SamY724/trl that referenced this pull request Sep 6, 2025
…uggingface#3964)

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
@pramodith pramodith deleted the pramodith/entropy_mask_gather_bug branch September 8, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants