-
Notifications
You must be signed in to change notification settings - Fork 2.2k
GRPOTrainer : fix prompt truncation for multimodal inputs with multiple image tokens #3879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…tokenizer, and not the whole processor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm missing something, the inputs to the truncate_with_protected_tokens
function will be the same even after the changes right? The tokenization code moved down to the if block is going to be applied in both the previous and proposed changes as long as max_prompt_length
is set to True.
trl/trainer/grpo_trainer.py
Outdated
padding_side="left", | ||
add_special_tokens=False | ||
) | ||
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) | |
prompt_inputs = super()._prepare_inputs(self, prompt_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code also ensures that prompts_text never contains more than one image_token
so I'm wondering where you were running into failures, an example of where the current code breaks would be helpful.
trl/trl/trainer/grpo_trainer.py
Lines 1416 to 1434 in de27d61
if self.image_token is not None: | |
escaped_img_token = re.escape(self.image_token) | |
# Search for the image token in the chat template | |
if re.search(escaped_img_token, self.processing_class.chat_template): | |
prompts_text = [ | |
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text | |
] | |
else: | |
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id | |
if self.vision_end_token_id is not None: | |
escaped_eoi_token = re.escape( | |
self.processing_class.tokenizer.decode([self.vision_end_token_id]) | |
) | |
prompts_text = [ | |
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text | |
] | |
else: | |
# If vision_end_token_id is None, just remove the image tokens | |
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that this code:
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
may generate input_ids that include both prompt and image tokens.
For example, LlavaNextProcessor for a short prompt and an image 459x320 gives prompt_ids of length 1822. Most of those are image_tokens, and truncate_with_protected_tokens would fail trying to truncate it to 512.
In the fixed version
prompt_inputs = tokenizer(
prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False
)
the tokenizer processes only the prompt text, and the prompt_ids length is 41, nothing to truncate
Co-authored-by: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification! Looks good to me.
So if I understand correctly, instead of ensuring that the max number of tokens is on text tokens + image tokens, you apply it only on text token? |
correct |
The issue though is that the final sequence (the one taken by the model as input) may be longer than |
What does this PR do?
This PR fixes an issue in prompt truncation when max_prompt_length is set for multimodal inputs.
Previously, the code generated prompt_inputs by calling the entire processing_class (processor) on prompts_text and images.
When the processor included image inputs, it could insert multiple image tokens into the tokenized sequence.
This sometimes caused truncate_with_protected_tokens to fail when reducing the sequence to max_prompt_length because the multiple protected tokens consumed too much of the allowed space.
Now we call only the tokenizer on the prompts_texts, thus ensuring that only the textual prompt is tokenized for truncation, while still respecting protected tokens.
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.