Skip to content

Conversation

pramodith
Copy link
Collaborator

@pramodith pramodith commented Aug 20, 2025

What does this PR do?

Addresses #3927 where it's possible for all the assistant_mask tokens are 0 when the inputs are truncated if max_length is set.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pramodith pramodith changed the title Update assistant mask exception. [SFTTrainer]: Check for assistant mask up to max_length Aug 20, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

qgallouedec commented Aug 20, 2025

In my understanding, if there is one sample in the dataset for which the prompt is too long, resulting in the prompt+completion part being truncated, this would make the training fail?

EDIT for clarification:
Example if there is one sample with prompt_len=70, completion_len=32 and max_length=64, this would make the training fail instead of ignoring this example?

@pramodith
Copy link
Collaborator Author

In my understanding, if there is one sample in the dataset for which the prompt is too long, resulting in the prompt+completion part being truncated, this would make the training fail?

True! Should we check for some threshold ratio of rows and then just log a warning if that threshold is reached instead of raising an exception to handle the case of truncation?

@pramodith
Copy link
Collaborator Author

pramodith commented Aug 20, 2025

EDIT for clarification:
Example if there is one sample with prompt_len=70, completion_len=32 and max_length=64, this would make the training fail instead of ignoring this example?

I actually like the idea of ignoring a row! I think we should drop/ignore the rows in the dataset that have no assistant tokens post truncation and log to the user the % of rows ignored. I think I can accomplish this with some dataset ops.

@qgallouedec
Copy link
Member

yep agree. Maybe log the number of active tokens (not masked)

@pramodith pramodith requested a review from Copilot August 21, 2025 10:57
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR addresses an issue where all assistant mask tokens could become 0 after truncation when max_length is set, making training ineffective. The fix adds validation and filtering to ensure trainable assistant tokens remain after truncation.

  • Adds validation to check for remaining assistant tokens after truncation
  • Filters out examples with no assistant tokens and provides detailed logging
  • Raises an error if no trainable tokens remain after truncation

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@qgallouedec
Copy link
Member

Maybe we can even check directly the labels, it would allow to account for all unmasked tokens (just check label!=-100)

@pramodith
Copy link
Collaborator Author

Maybe we can even check directly the labels, it would allow to account for all unmasked tokens (just check label!=-100)

Hmmm my understanding is that the labels column is populated in the data collator and not at dataset preparation stage. This'd mean that we need to keep track of the running sum of the number of tokens since the data collator is called per iteration of the Dataloader. Is this something we want to do?

I also think that just accounting for the assistant tokens count would be the same as the total number of trainable tokens, everything else would either be a pad token or non-assistant token both of which would need to be masked in assistant only training mode.

@qgallouedec
Copy link
Member

Hmmm my understanding is that the labels column is populated in the data collator

Ah yes you're right!

I also think that just accounting for the assistant tokens count would be the same as the total number of trainable tokens

not exactly, because when you train on completion-only, you've a completion_mask that can also be full of 0.0 after truncation

@pramodith
Copy link
Collaborator Author

not exactly, because when you train on completion-only, you've a completion_mask that can also be full of 0.0 after truncation

Ahh yes forgot about that, made some changes to reflect all the three types of datasets conversational, completions/instruction-tuning and language modeling.

Comment on lines 1021 to 1025
if args.assistant_only_loss:
total_trainable_tokens_before_truncation = get_trainable_tokens(dataset, "assistant_masks")
# Prompt Completions/Instruction Tuning Dataset
elif "completion_mask" in first_row:
total_trainable_tokens_before_truncation = get_trainable_tokens(dataset, "completion_mask")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it can be both conversational and prompt-completion. In such case, the loss is only computed for the tokens 1. in the completion and 2 from role "user".

@qgallouedec
Copy link
Member

qgallouedec commented Sep 2, 2025

@albertvillanova can you review this one?
Basically the idea is that we have a dataset containing the column input_ids (list of ints), maybe a column assistant_mask (list of bools), maybe a column completion_mask (list of bools).
A token contributes to the loss only if:

  • assistant_masks is not a column or assistant_masks=1 for this token AND
  • completion_mask is not a column or completion_mask=1 for this token

We want to log the number of tokens contributing to the loss and filter the examples where none of the tokens contribute to the loss.
I'm not sure if it's the best way to leverage datasets in this case

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the approach of looping row-by-row is going to be problematic for scale:

  • Performance / memory: Iterating over the dataset in pure Python materializes every row and defeats the purpose of 🤗 Datasets' efficient Arrow backend. On anything larger than a toy dataset, this will be orders of magnitude slower than a vectorized map/filter.
  • No batching / parallelism: You're not leveraging the dataset pipeline’s ability to process in batches and in parallel

I'd recommend re-implementing with dataset.map(..., batched=True, num_proc=...) + dataset.filter(...), which keeps the whole pipeline efficient, parallel, and robust.

@pramodith
Copy link
Collaborator Author

pramodith commented Sep 4, 2025

I'd recommend re-implementing with dataset.map(..., batched=True, num_proc=...) + dataset.filter(...), which keeps the whole pipeline efficient, parallel, and robust.

@albertvillanova let me know if I'm wrong but what we're trying is more of a reduce operation than a map operation since we're trying to compute the total number of non masked tokens.

I can run a map with batched=True to get the row wise sum but will still need to iterate over the entire dataset to reduce that to a final dataset level sum. Am I missing something here?

I came across an example of converting a dataset to a polars dataframe and running reduce operation on the polars dataframe here. Would that help in this scenario? My main concern with this approach is if we'll be effectively duplicating the dataset in memory with this approach.

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that is a good approach:

  • Do as many intermediate steps with map(batched=True and/or filter.
    • These are efficient and fast: 🤗 Datasets caches all the transformations in Arrow
    • At this stage, data is not in-memory, but memory-mapped on disk
  • For the final aggregation, convert to Polars and run the reduce.
    • This operation works with zero-copy on the underlying Arrow buffers

@pramodith pramodith closed this Sep 8, 2025
@pramodith pramodith deleted the pramodith/update_assistant_token_exception branch September 8, 2025 20:27
@pramodith pramodith reopened this Sep 10, 2025
@pramodith
Copy link
Collaborator Author

pramodith commented Sep 10, 2025

@albertvillanova re-opened this one per your request 😄 .

Adding my comment in the new PR that I had opened (and subsequently closed after re-opening this) here so that it doesn't get lost.

I don't think I can avoid the for-loops even with batched I have to iterate through all the rows in the batch to get the sum of tokens in each row, so I'm not sure how much this'll help.

@qgallouedec
Copy link
Member

can we try to do it without polars? I'd like to avoid adding this dependency just for this feature

@pramodith
Copy link
Collaborator Author

Feels like there isn't a nice and efficient way of counting the number of trainable tokens ahead of time without multiple passes through the entire dataset and I don't feel like this is a valuable enough addition to the library to pursue it any further. Unless anyone has a simple solution in mind, I'm considering closing out this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants