Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,21 +1380,21 @@ def _generate_and_score_completions(

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
tokenizer = self.processing_class.tokenizer if has_images else self.processing_class
prompt_inputs = tokenizer(
prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(self, prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids, prompt_mask = truncate_with_protected_tokens(
Expand Down Expand Up @@ -1433,6 +1433,18 @@ def _generate_and_score_completions(
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]


prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

# Generate completions using either vLLM or regular generation
if self.use_vllm:
# First, update the vLLM weights if needed
Expand Down