Skip to content

Conversation

artem-spector
Copy link

What does this PR do?

This PR fixes an issue in prompt truncation when max_prompt_length is set for multimodal inputs.
Previously, the code generated prompt_inputs by calling the entire processing_class (processor) on prompts_text and images.
When the processor included image inputs, it could insert multiple image tokens into the tokenized sequence.
This sometimes caused truncate_with_protected_tokens to fail when reducing the sequence to max_prompt_length because the multiple protected tokens consumed too much of the allowed space.
Now we call only the tokenizer on the prompts_texts, thus ensuring that only the textual prompt is tokenized for truncation, while still respecting protected tokens.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ x] Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@artem-spector artem-spector changed the title Fix prompt truncation for multimodal inputs with multiple image tokens GRPOTrainer : fix prompt truncation for multimodal inputs with multiple image tokens Aug 11, 2025
Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm missing something, the inputs to the truncate_with_protected_tokens function will be the same even after the changes right? The tokenization code moved down to the if block is going to be applied in both the previous and proposed changes as long as max_prompt_length is set to True.

padding_side="left",
add_special_tokens=False
)
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
prompt_inputs = super()._prepare_inputs(self, prompt_inputs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block of code also ensures that prompts_text never contains more than one image_token so I'm wondering where you were running into failures, an example of where the current code breaks would be helpful.

if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this code:
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
may generate input_ids that include both prompt and image tokens.
For example, LlavaNextProcessor for a short prompt and an image 459x320 gives prompt_ids of length 1822. Most of those are image_tokens, and truncate_with_protected_tokens would fail trying to truncate it to 512.

In the fixed version
prompt_inputs = tokenizer(
prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False
)
the tokenizer processes only the prompt text, and the prompt_ids length is 41, nothing to truncate

Co-authored-by: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification! Looks good to me.

@qgallouedec
Copy link
Member

So if I understand correctly, instead of ensuring that the max number of tokens is on text tokens + image tokens, you apply it only on text token?

@artem-spector
Copy link
Author

So if I understand correctly, instead of ensuring that the max number of tokens is on text tokens + image tokens, you apply it only on text token?

correct

@qgallouedec
Copy link
Member

qgallouedec commented Sep 3, 2025

The issue though is that the final sequence (the one taken by the model as input) may be longer than max_length

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants