-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[SFTTrainer]: Check for assistant mask up to max_length #3930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SFTTrainer]: Check for assistant mask up to max_length #3930
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
In my understanding, if there is one sample in the dataset for which the prompt is too long, resulting in the prompt+completion part being truncated, this would make the training fail? EDIT for clarification: |
True! Should we check for some threshold ratio of rows and then just log a warning if that threshold is reached instead of raising an exception to handle the case of truncation? |
I actually like the idea of ignoring a row! I think we should drop/ignore the rows in the dataset that have no assistant tokens post truncation and log to the user the % of rows ignored. I think I can accomplish this with some dataset ops. |
yep agree. Maybe log the number of active tokens (not masked) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses an issue where all assistant mask tokens could become 0 after truncation when max_length
is set, making training ineffective. The fix adds validation and filtering to ensure trainable assistant tokens remain after truncation.
- Adds validation to check for remaining assistant tokens after truncation
- Filters out examples with no assistant tokens and provides detailed logging
- Raises an error if no trainable tokens remain after truncation
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Maybe we can even check directly the labels, it would allow to account for all unmasked tokens (just check label!=-100) |
Hmmm my understanding is that the I also think that just accounting for the assistant tokens count would be the same as the total number of trainable tokens, everything else would either be a pad token or non-assistant token both of which would need to be masked in assistant only training mode. |
Ah yes you're right!
not exactly, because when you train on completion-only, you've a |
Ahh yes forgot about that, made some changes to reflect all the three types of |
trl/trainer/sft_trainer.py
Outdated
if args.assistant_only_loss: | ||
total_trainable_tokens_before_truncation = get_trainable_tokens(dataset, "assistant_masks") | ||
# Prompt Completions/Instruction Tuning Dataset | ||
elif "completion_mask" in first_row: | ||
total_trainable_tokens_before_truncation = get_trainable_tokens(dataset, "completion_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it can be both conversational and prompt-completion. In such case, the loss is only computed for the tokens 1. in the completion and 2 from role "user".
@albertvillanova can you review this one?
We want to log the number of tokens contributing to the loss and filter the examples where none of the tokens contribute to the loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the approach of looping row-by-row is going to be problematic for scale:
- Performance / memory: Iterating over the dataset in pure Python materializes every row and defeats the purpose of 🤗 Datasets' efficient Arrow backend. On anything larger than a toy dataset, this will be orders of magnitude slower than a vectorized map/filter.
- No batching / parallelism: You're not leveraging the dataset pipeline’s ability to process in batches and in parallel
I'd recommend re-implementing with dataset.map(..., batched=True, num_proc=...)
+ dataset.filter(...)
, which keeps the whole pipeline efficient, parallel, and robust.
@albertvillanova let me know if I'm wrong but what we're trying is more of a I can run a I came across an example of converting a dataset to a polars dataframe and running reduce operation on the polars dataframe here. Would that help in this scenario? My main concern with this approach is if we'll be effectively duplicating the dataset in memory with this approach. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that is a good approach:
- Do as many intermediate steps with
map(batched=True
and/orfilter
.- These are efficient and fast: 🤗 Datasets caches all the transformations in Arrow
- At this stage, data is not in-memory, but memory-mapped on disk
- For the final aggregation, convert to Polars and run the reduce.
- This operation works with zero-copy on the underlying Arrow buffers
…github.com/pramodith/trl into pramodith/update_assistant_token_exception
…github.com/pramodith/trl into pramodith/update_assistant_token_exception
@albertvillanova re-opened this one per your request 😄 . Adding my comment in the new PR that I had opened (and subsequently closed after re-opening this) here so that it doesn't get lost.
|
can we try to do it without polars? I'd like to avoid adding this dependency just for this feature |
Feels like there isn't a nice and efficient way of counting the number of trainable tokens ahead of time without multiple passes through the entire dataset and I don't feel like this is a valuable enough addition to the library to pursue it any further. Unless anyone has a simple solution in mind, I'm considering closing out this PR. |
What does this PR do?
Addresses #3927 where it's possible for all the
assistant_mask
tokens are 0 when the inputs are truncated ifmax_length
is set.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.