Skip to content

Add special tokens in text-generation pipeline if tokenizer requires #1370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,11 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
let isBatched = false;
let isChatInput = false;

// By default, do not add special tokens, unless the tokenizer specifies otherwise
let add_special_tokens = generate_kwargs.add_special_tokens
?? (this.tokenizer.add_bos_token || this.tokenizer.add_eos_token)
?? false;

// Normalize inputs
/** @type {string[]} */
let inputs;
Expand All @@ -1021,11 +1026,9 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
add_generation_prompt: true,
})
));
add_special_tokens = false; // Chat template handles this already
}

// By default, do not add special tokens
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;

// By default, return full text
const return_full_text = isChatInput
? false
Expand Down
3 changes: 3 additions & 0 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2659,6 +2659,9 @@ export class PreTrainedTokenizer extends Callable {
this.padding_side = tokenizerConfig.padding_side;
}

this.add_bos_token = tokenizerConfig.add_bos_token;
this.add_eos_token = tokenizerConfig.add_eos_token;

this.legacy = false;

this.chat_template = tokenizerConfig.chat_template ?? null;
Expand Down
8 changes: 4 additions & 4 deletions tests/bundles.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ const result = await generator("hello", { max_new_tokens: 3, return_full_text: f
process.stdout.write(result[0].generated_text);
`;

const TARGET_OUTPUT = "erdingsAndroid Load";
const TARGET_OUTPUT = "erdingsdelete mely";

const wrap_async_iife = (code) => `(async function() { ${code} })();`;

const check = (code, module = false) => {
const args = ["-e", code];
if (module) args.push("--input-type=module");
const { status, stdout, stderr } = spawnSync("node", args);
expect(stderr.toString()).toBe(""); // No warnings or errors are printed
expect(stdout.toString()).toBe(TARGET_OUTPUT); // The output should match
expect(status).toBe(0); // The process should exit cleanly
expect(stderr.toString()).toEqual(""); // No warnings or errors are printed
expect(stdout.toString()).toEqual(TARGET_OUTPUT); // The output should match
expect(status).toEqual(0); // The process should exit cleanly
};

describe("Testing the bundle", () => {
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipelines_text_generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export default () => {

describe("batch_size=1", () => {
const text_input = "hello";
const generated_text_target = "erdingsAndroid Load";
const generated_text_target = "erdingsdelete mely";
const text_target = [{ generated_text: text_input + generated_text_target }];
const new_text_target = [{ generated_text: generated_text_target }];

Expand Down
79 changes: 56 additions & 23 deletions tests/tokenizers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ describe("Tokenizer padding/truncation", () => {
}, MAX_TOKENIZER_LOAD_TIME);

describe("return_tensor=false (jagged array)", () => {

test("jagged array output when return_tensor is false", () => {
const output = tokenizer(inputs, {
return_tensor: false,
Expand Down Expand Up @@ -105,7 +104,6 @@ describe("Tokenizer padding/truncation", () => {
compare(output, expected);
});


test("No padding, max_length=3 (implicit truncation strategy)", () => {
const output = tokenizer(inputs_2, {
padding: false,
Expand All @@ -129,9 +127,18 @@ describe("Tokenizer padding/truncation", () => {
return_tensor: false,
});
const expected = {
input_ids: [[1037, 0, 0, 0, 0], [1038, 1039, 1040, 1041, 1042]],
token_type_ids: [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
attention_mask: [[1, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
input_ids: [
[1037, 0, 0, 0, 0],
[1038, 1039, 1040, 1041, 1042],
],
token_type_ids: [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
attention_mask: [
[1, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
],
};
compare(output, expected);
});
Expand Down Expand Up @@ -161,48 +168,75 @@ describe("Tokenizer padding/truncation", () => {
return_tensor: false,
});
const expected = {
input_ids: [[1037, 0, 0], [1038, 1039, 1040]],
token_type_ids: [[0, 0, 0], [0, 0, 0]],
attention_mask: [[1, 0, 0], [1, 1, 1]],
input_ids: [
[1037, 0, 0],
[1038, 1039, 1040],
],
token_type_ids: [
[0, 0, 0],
[0, 0, 0],
],
attention_mask: [
[1, 0, 0],
[1, 1, 1],
],
};
compare(output, expected);
});

test("Padding 'max_length' without truncation, max_length=3", () => {
const output = tokenizer(inputs_2, {
padding: 'max_length',
padding: "max_length",
truncation: false,
max_length: 3,
add_special_tokens: false,
return_tensor: false,
});
const expected = {
input_ids: [[1037, 0, 0], [1038, 1039, 1040, 1041, 1042]],
token_type_ids: [[0, 0, 0], [0, 0, 0, 0, 0]],
attention_mask: [[1, 0, 0], [1, 1, 1, 1, 1]],
input_ids: [
[1037, 0, 0],
[1038, 1039, 1040, 1041, 1042],
],
token_type_ids: [
[0, 0, 0],
[0, 0, 0, 0, 0],
],
attention_mask: [
[1, 0, 0],
[1, 1, 1, 1, 1],
],
};
compare(output, expected);
});

test("Padding 'max_length' with truncation, max_length=3", () => {
const output = tokenizer(inputs_2, {
padding: 'max_length',
padding: "max_length",
truncation: true,
max_length: 3,
add_special_tokens: false,
return_tensor: false,
});
const expected = {
input_ids: [[1037, 0, 0], [1038, 1039, 1040]],
token_type_ids: [[0, 0, 0], [0, 0, 0]],
attention_mask: [[1, 0, 0], [1, 1, 1]],
input_ids: [
[1037, 0, 0],
[1038, 1039, 1040],
],
token_type_ids: [
[0, 0, 0],
[0, 0, 0],
],
attention_mask: [
[1, 0, 0],
[1, 1, 1],
],
};
compare(output, expected);
});

test("Padding 'max_length' without truncation and max_length=null", () => {
const output = tokenizer(inputs_2, {
padding: 'max_length',
padding: "max_length",
truncation: false,
max_length: null,
add_special_tokens: false,
Expand All @@ -211,23 +245,22 @@ describe("Tokenizer padding/truncation", () => {
const expected = {
input_ids: [
[1037, ...Array(511).fill(0)],
[1038, 1039, 1040, 1041, 1042, ...Array(507).fill(0)]
[1038, 1039, 1040, 1041, 1042, ...Array(507).fill(0)],
],
token_type_ids: [
[0, ...Array(511).fill(0)],
[0, 0, 0, 0, 0, ...Array(507).fill(0)]
[0, 0, 0, 0, 0, ...Array(507).fill(0)],
],
attention_mask: [
[1, ...Array(511).fill(0)],
[1, 1, 1, 1, 1, ...Array(507).fill(0)]
[1, 1, 1, 1, 1, ...Array(507).fill(0)],
],
};
compare(output, expected);
});
});

describe("return_tensor=true", () => {

test("throws error when tensor output is requested for a jagged array", () => {
expect(() => tokenizer(inputs)).toThrow("Unable to create tensor");
});
Expand Down Expand Up @@ -329,7 +362,7 @@ describe("Tokenizer padding/truncation", () => {

test("padding:'max_length' pads to the specified max_length", () => {
const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, {
padding: 'max_length',
padding: "max_length",
truncation: true,
add_special_tokens: false,
max_length: 3,
Expand All @@ -347,7 +380,7 @@ describe("Tokenizer padding/truncation", () => {
[0n, 0n, 0n],
]);
});
})
});
});

describe("Token type ids", () => {
Expand Down
26 changes: 13 additions & 13 deletions tests/utils/logits_process.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ describe("Logits Processors", () => {
async () => {
const text_input = "hello";

const generated_text_target = " Bert explicit wed digasset";
const generated_text_target = "\uff0d Giuseppeitte natoud";
const text_target = [{ generated_text: text_input + generated_text_target }];

const output = await pipe(text_input, {
max_new_tokens: 5,
bad_words_ids: [
// default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n]
// default: [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n]
[18547],

// block #1: [22172n, 16662n, 6261n, 18916n, 29109n, 799n]
[6261, 18916],
// block #1: [1n, 22172n, 31583n, 18824n, 16621n, 8136n, 16012n]
[18824, 16621],
],
});
compare(output, text_target);
Expand All @@ -58,22 +58,22 @@ describe("Logits Processors", () => {
async () => {
const text_input = "hello";

const generated_text_target = "erdingsdeletearus)?nor";
const generated_text_target = "erdingsdelete войsequ族";
const text_target = [{ generated_text: text_input + generated_text_target }];

// Construct long list of bad words
const bad_words_ids = [];
// default: [22172n, 18547n, 8136n, 16012n, 28064n, 11361n]
// default: [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n]
for (let i = 0; i < 100000; ++i) {
bad_words_ids.push([i * 2]); // block all even numbers
}
// block #1: [22172n, 18547n, 8143n, 30327n, 20061n, 18193n]
// block #1: [1n, 22172n, 18547n, 8143n, 30327n, 624n, 2806n, 2004n]
bad_words_ids.push([8143, 30327]);

// block #2: [22172n, 18547n, 8143n, 29485n, 3799n, 29331n]
// block #2: [1n, 22172n, 18547n, 8143n, 29485n, 3799n, 29331n]
bad_words_ids.push([18547, 8143, 29485]);

// block #3: [22172n, 18547n, 8143n, 26465n, 6877n, 15459n]
// block #3: [1n, 22172n, 18547n, 8143n, 7587n, 6831n, 30999n]
const output = await pipe(text_input, { max_new_tokens: 5, bad_words_ids });
compare(output, text_target);
},
Expand All @@ -85,19 +85,19 @@ describe("Logits Processors", () => {
async () => {
const text_input = "this is a test";

const generated_text_target = "кт México constructed lake user";
const generated_text_target = "кт México constructed lake års";
const text_target = [{ generated_text: text_input + generated_text_target }];

const output = await pipe(text_input, {
max_new_tokens: 5,
bad_words_ids: [
// default: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
// default: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 31252n]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3931], // should never trigger (longer than input sequence)

// block #1: [445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 16426n]
// block #1: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 7811n, 21645n, 31252n]
[3931, 14756, 7811],

// result: [445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 1404n]
// result: [1n, 445n, 338n, 263n, 1243n, 3931n, 14756n, 13319n, 19437n, 21948n]
],
});
compare(output, text_target);
Expand Down