diff --git a/src/tokenizers.js b/src/tokenizers.js index cc61f17a4..8e485b087 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2787,22 +2787,29 @@ export class PreTrainedTokenizer extends Callable { // For single input, we just wrap in an array, and then unwrap later. encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })]; } - // At this point, tokens is batched: [batch_size, tokens] - // However, array may be jagged. So, we pad to max_length - + // At this point, `encodedTokens` is batched, of shape [batch_size, tokens]. + // However, array may be jagged. So, we may need pad to max_length. if (max_length === null) { - if (padding === 'max_length') { + max_length = this.model_max_length; + } else if (truncation === null) { + if (padding === true) { + console.warn( + "`max_length` is ignored when `padding: true` and there is no truncation strategy. " + + "To pad to max length, use `padding: 'max_length'`." + ) max_length = this.model_max_length; - } else { - // Calculate max length from sequences - max_length = max(encodedTokens.map(x => x.input_ids.length))[0]; - } - } else { - if (!truncation) { - console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`) + } else if (padding === false) { + console.warn("Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length."); + truncation = true; } } + // padding: 'max_length' doesn't require any additional calculation + // but padding: true has to calculate max_length from the sequences + if (padding === true) { + max_length = Math.min(max(encodedTokens.map(x => x.input_ids.length))[0], max_length ?? Infinity); + } + // Ensure it is less than model max length max_length = Math.min(max_length, this.model_max_length ?? Infinity); diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 2742513ee..dbb7f99d1 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -46,11 +46,16 @@ describe("Tokenizer padding/truncation", () => { const inputs = ["a", "b c"]; const text_pair = ["d e", "f g h"]; - it("should create a jagged array", async () => { - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + const inputs_2 = ["a", "b c d e f"]; - { - // support jagged array if `return_tensor=false` + let tokenizer; + beforeAll(async () => { + tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + }, MAX_TOKENIZER_LOAD_TIME); + + describe("return_tensor=false (jagged array)", () => { + + test("jagged array output when return_tensor is false", () => { const output = tokenizer(inputs, { return_tensor: false, }); @@ -69,9 +74,9 @@ describe("Tokenizer padding/truncation", () => { ], }; compare(output, expected); - } + }); - { + test("truncation output without special tokens when return_tensor is false", () => { const output = tokenizer(inputs, { return_tensor: false, truncation: true, @@ -83,106 +88,266 @@ describe("Tokenizer padding/truncation", () => { token_type_ids: [[0], [0, 0]], }; compare(output, expected); - } - }); + }); - it( - "should create a tensor", - async () => { - const tokenizer = await AutoTokenizer.from_pretrained("Xenova/bert-base-uncased"); + test("no padding with max_length defined and truncation unset", () => { + const output = tokenizer(inputs, { + return_tensor: false, + padding: false, + max_length: 1, + add_special_tokens: false, + }); + const expected = { + input_ids: [[1037], [1038]], + attention_mask: [[1], [1]], + token_type_ids: [[0], [0]], + }; + compare(output, expected); + }); - { - // Expected to throw error if jagged array - expect(() => tokenizer(inputs)).toThrow("Unable to create tensor"); - } - { - // Truncation - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - truncation: true, - max_length: 1, - add_special_tokens: false, - }); + test("No padding, max_length=3 (implicit truncation strategy)", () => { + const output = tokenizer(inputs_2, { + padding: false, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037], [1038, 1039, 1040]], + token_type_ids: [[0], [0, 0, 0]], + attention_mask: [[1], [1, 1, 1]], + }; + compare(output, expected); + }); - expect(input_ids.tolist()).toEqual([[1037n], [1038n]]); - expect(attention_mask.tolist()).toEqual([[1n], [1n]]); - expect(token_type_ids.tolist()).toEqual([[0n], [0n]]); - } - { - // Truncation w/ text pair - // TODO - } + test("Padding true, max_length=3 (implicit truncation strategy)", () => { + const output = tokenizer(inputs_2, { + padding: true, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037, 0, 0, 0, 0], [1038, 1039, 1040, 1041, 1042]], + token_type_ids: [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + attention_mask: [[1, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + }; + compare(output, expected); + }); - { - // Padding - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - padding: true, - add_special_tokens: false, - }); + test("No padding with explicit truncation, max_length=3", () => { + const output = tokenizer(inputs_2, { + padding: false, + truncation: true, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037], [1038, 1039, 1040]], + token_type_ids: [[0], [0, 0, 0]], + attention_mask: [[1], [1, 1, 1]], + }; + compare(output, expected); + }); - expect(input_ids.tolist()).toEqual([ - [1037n, 0n], - [1038n, 1039n], - ]); - expect(attention_mask.tolist()).toEqual([ - [1n, 0n], - [1n, 1n], - ]); - expect(token_type_ids.tolist()).toEqual([ - [0n, 0n], - [0n, 0n], - ]); - } - { - // Padding w/ text pair - const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { - text_pair, - padding: true, - add_special_tokens: false, - }); + test("Padding true with explicit truncation, max_length=3", () => { + const output = tokenizer(inputs_2, { + padding: true, + truncation: true, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037, 0, 0], [1038, 1039, 1040]], + token_type_ids: [[0, 0, 0], [0, 0, 0]], + attention_mask: [[1, 0, 0], [1, 1, 1]], + }; + compare(output, expected); + }); - expect(input_ids.tolist()).toEqual([ - [1037n, 1040n, 1041n, 0n, 0n], - [1038n, 1039n, 1042n, 1043n, 1044n], - ]); - expect(attention_mask.tolist()).toEqual([ - [1n, 1n, 1n, 0n, 0n], - [1n, 1n, 1n, 1n, 1n], - ]); - expect(token_type_ids.tolist()).toEqual([ - [0n, 1n, 1n, 0n, 0n], - [0n, 0n, 1n, 1n, 1n], - ]); - } + test("Padding 'max_length' without truncation, max_length=3", () => { + const output = tokenizer(inputs_2, { + padding: 'max_length', + truncation: false, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037, 0, 0], [1038, 1039, 1040, 1041, 1042]], + token_type_ids: [[0, 0, 0], [0, 0, 0, 0, 0]], + attention_mask: [[1, 0, 0], [1, 1, 1, 1, 1]], + }; + compare(output, expected); + }); - { - // Truncation + padding - const { input_ids, attention_mask, token_type_ids } = tokenizer(["a", "b c", "d e f"], { - padding: true, - truncation: true, - add_special_tokens: false, - max_length: 2, - }); + test("Padding 'max_length' with truncation, max_length=3", () => { + const output = tokenizer(inputs_2, { + padding: 'max_length', + truncation: true, + max_length: 3, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [[1037, 0, 0], [1038, 1039, 1040]], + token_type_ids: [[0, 0, 0], [0, 0, 0]], + attention_mask: [[1, 0, 0], [1, 1, 1]], + }; + compare(output, expected); + }); - expect(input_ids.tolist()).toEqual([ - [1037n, 0n], - [1038n, 1039n], - [1040n, 1041n], - ]); - expect(attention_mask.tolist()).toEqual([ - [1n, 0n], - [1n, 1n], - [1n, 1n], - ]); - expect(token_type_ids.tolist()).toEqual([ - [0n, 0n], - [0n, 0n], - [0n, 0n], - ]); - } - }, - MAX_TEST_EXECUTION_TIME, - ); + test("Padding 'max_length' without truncation and max_length=null", () => { + const output = tokenizer(inputs_2, { + padding: 'max_length', + truncation: false, + max_length: null, + add_special_tokens: false, + return_tensor: false, + }); + const expected = { + input_ids: [ + [1037, ...Array(511).fill(0)], + [1038, 1039, 1040, 1041, 1042, ...Array(507).fill(0)] + ], + token_type_ids: [ + [0, ...Array(511).fill(0)], + [0, 0, 0, 0, 0, ...Array(507).fill(0)] + ], + attention_mask: [ + [1, ...Array(511).fill(0)], + [1, 1, 1, 1, 1, ...Array(507).fill(0)] + ], + }; + compare(output, expected); + }); + }); + + describe("return_tensor=true", () => { + + test("throws error when tensor output is requested for a jagged array", () => { + expect(() => tokenizer(inputs)).toThrow("Unable to create tensor"); + }); + + test("truncation output for tensor inputs", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + truncation: true, + max_length: 1, + add_special_tokens: false, + }); + expect(input_ids.tolist()).toEqual([[1037n], [1038n]]); + expect(attention_mask.tolist()).toEqual([[1n], [1n]]); + expect(token_type_ids.tolist()).toEqual([[0n], [0n]]); + }); + + test("padding output for tensor inputs without text pair", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: true, + add_special_tokens: false, + }); + expect(input_ids.tolist()).toEqual([ + [1037n, 0n], + [1038n, 1039n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n], + [1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n], + [0n, 0n], + ]); + }); + + test("padding output for tensor inputs with text pair", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + text_pair, + padding: true, + add_special_tokens: false, + }); + expect(input_ids.tolist()).toEqual([ + [1037n, 1040n, 1041n, 0n, 0n], + [1038n, 1039n, 1042n, 1043n, 1044n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 1n, 1n, 0n, 0n], + [1n, 1n, 1n, 1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 1n, 1n, 0n, 0n], + [0n, 0n, 1n, 1n, 1n], + ]); + }); + + test("truncation and padding output for tensor inputs", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(["a", "b c", "d e f"], { + padding: true, + truncation: true, + add_special_tokens: false, + max_length: 2, + }); + expect(input_ids.tolist()).toEqual([ + [1037n, 0n], + [1038n, 1039n], + [1040n, 1041n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n], + [1n, 1n], + [1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n], + [0n, 0n], + [0n, 0n], + ]); + }); + + test("padding:true pads to the longest encoding in the batch regardless of max_length", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: true, + truncation: true, + add_special_tokens: false, + max_length: 3, + }); + expect(input_ids.tolist()).toEqual([ + [1037n, 0n], + [1038n, 1039n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n], + [1n, 1n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n], + [0n, 0n], + ]); + }); + + test("padding:'max_length' pads to the specified max_length", () => { + const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, { + padding: 'max_length', + truncation: true, + add_special_tokens: false, + max_length: 3, + }); + expect(input_ids.tolist()).toEqual([ + [1037n, 0n, 0n], + [1038n, 1039n, 0n], + ]); + expect(attention_mask.tolist()).toEqual([ + [1n, 0n, 0n], + [1n, 1n, 0n], + ]); + expect(token_type_ids.tolist()).toEqual([ + [0n, 0n, 0n], + [0n, 0n, 0n], + ]); + }); + }) }); describe("Token type ids", () => {