From 9367fcd027e0795f5c0562d4556b0a933eabf379 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 14 Jan 2025 19:26:48 +0000 Subject: [PATCH 01/45] convert character spans to token spans --- fast_llm/data/preparator/gpt_memmap/config.py | 3 ++ .../data/preparator/gpt_memmap/prepare.py | 36 ++++++++++++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 29730448..d29c8c80 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -44,6 +44,9 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Field of the dataset to use.", hint=FieldHint.optional, ) + spans_field: None | str = Field( + default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index dd475829..2527a048 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -22,14 +22,37 @@ class GPTMemmapDatasetPreparator(DatasetPreparator): _tokenizer: Tokenizer _data_type: DataType + def _tokenize_with_spans(self, sample): + """ + Perform span-aware tokenization and return the tokenized input_ids along with token spans. + """ + char_spans = sample[self._config.dataset.spans_field] + text = sample[self._config.dataset.field] + input_ids = [] + token_spans = [] + char_pos = 0 + for start, end in char_spans: + if char_pos < start: + curr_text = text[char_pos:start] + tokenized_text = self._tokenizer.tokenize(curr_text) + input_ids.extend(tokenized_text) + curr_text = text[start : end + 1] + tokenized_text = self._tokenizer.tokenize(curr_text) + input_ids.extend(tokenized_text) + token_spans.append((len(token_spans), len(token_spans) + len(tokenized_text) - 1)) + char_pos = end + 1 + if char_pos < len(text): + curr_text = text[char_pos:] + tokenized_text = self._tokenizer.tokenize(curr_text) + input_ids.extend(tokenized_text) + return np.array(input_ids, dtype=self._data_type.numpy), token_spans + def _tokenize_batch(self, batch): - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - for text in batch[self._config.dataset.field] - ] + input_ids, token_spans = zip(*[self._tokenize_with_spans(sample) for sample in batch]) num_tokens = [len(x) for x in input_ids] return { "input_ids": input_ids, + "token_spans": token_spans, "num_tokens": num_tokens, } @@ -126,6 +149,11 @@ def run(self): ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") + if ( + self._config.dataset.spans_field is not None + and self._config.dataset.spans_field not in dataset.column_names + ): + raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.") # Tokenize the dataset in parallel tokenized_dataset = dataset.map( From 515dcb50d25a30355e4cf6854d986e2d447de784 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 14 Jan 2025 19:34:41 +0000 Subject: [PATCH 02/45] handle null spans --- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2527a048..596a2e30 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -26,7 +26,7 @@ def _tokenize_with_spans(self, sample): """ Perform span-aware tokenization and return the tokenized input_ids along with token spans. """ - char_spans = sample[self._config.dataset.spans_field] + char_spans = sample.get(self._config.dataset.spans_field, []) text = sample[self._config.dataset.field] input_ids = [] token_spans = [] From 3457ba2148552ad443f347a4694f5cde6b0b5eaf Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 15 Jan 2025 02:32:20 +0000 Subject: [PATCH 03/45] handle spans in data iterator, fix test --- fast_llm/data/dataset/gpt/memmap.py | 47 ++++++++++++++++--- .../data/preparator/gpt_memmap/prepare.py | 6 ++- tests/test_memmap_dataset.py | 20 ++++++-- 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f24536f8..a2239f16 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -48,6 +48,25 @@ def _init(self, name: str, prefix: pathlib.Path | str): offset=offset + self._document_sizes.nbytes, ) + self._num_spans = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._num_documents, + offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + ) + spans = [] + offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes + for n_spans in self._num_spans: + span = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_spans * 2, + offset=offset, + ).reshape(-1, 2) + spans.append(span) + offset += span.nbytes + self._spans = spans + self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -64,11 +83,14 @@ def __del__(self): del self._index_bin_buffer_mmap def get(self, idx, offset=0, length=None): - return np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + return ( + np.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=self._document_sizes[idx] - offset if length is None else length, + offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + ), + self._spans[idx], ) @property @@ -92,20 +114,23 @@ def get_document_sizes(self) -> "np.ndarray": return self._document_sizes @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np.ndarray]): + def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[tuple[np.ndarray, np.ndarray]]): # Initialize metadata dtype = None num_documents = 0 lengths = [] pointers = [] offset = 0 + # number of spans for each document + num_spans = [] + spans = [] prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) # Write the binary data file (.bin) lazily with prefix.with_suffix(".bin").open("wb") as bin_stream: - for document in documents: + for document, mask_spans in documents: # Infer dtype from the first document if dtype is None: dtype = document.dtype @@ -121,12 +146,16 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np doc_length = len(document) lengths.append(doc_length) pointers.append(offset) + num_spans.append(len(mask_spans)) + spans.append(mask_spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 # Finalize metadata arrays lengths = np.array(lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) + num_spans = np.array(num_spans, dtype=np.int32) + spans = np.vstack(spans, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: @@ -142,5 +171,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[np idx_stream.write(lengths.tobytes(order="C")) # Sequence (document) begin offsets in the bin file idx_stream.write(pointers.tobytes(order="C")) + # Number of spans per document + idx_stream.write(num_spans.tobytes(order="C")) + # Span indices for each document + idx_stream.write(spans.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C")) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 596a2e30..d484c4ae 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -45,7 +45,7 @@ def _tokenize_with_spans(self, sample): curr_text = text[char_pos:] tokenized_text = self._tokenizer.tokenize(curr_text) input_ids.extend(tokenized_text) - return np.array(input_ids, dtype=self._data_type.numpy), token_spans + return np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32) def _tokenize_batch(self, batch): input_ids, token_spans = zip(*[self._tokenize_with_spans(sample) for sample in batch]) @@ -63,7 +63,9 @@ def _save_shard(self, args) -> dict: def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield np.array(item["input_ids"], dtype=self._data_type.numpy) + yield np.array(item["input_ids"], dtype=self._data_type.numpy), np.array( + item["token_spans"], dtype=np.int32 + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index 261f87e4..b84eb108 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -10,12 +10,26 @@ @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_gpt_memmap_dataset(dtype): - documents = [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)] + documents = list( + zip( + [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)], + np.array([[]] * 100, dtype=np.int32), + ) + ) with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): + for i, (document, spans) in enumerate(documents): + memmap_document, memmap_spans = dataset.get(i) assert np.array_equal( - dataset.get(i), document, equal_nan=True + memmap_document, document, equal_nan=True ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + if len(spans) > 0: + assert np.array_equal( + memmap_spans, spans, equal_nan=True + ), f"Mismatch for non-empty spans {i}: {spans} != {dataset.get(i)}." + else: + assert np.array_equal( + memmap_spans.flatten(), spans.flatten(), equal_nan=True + ), f"Mismatch for empty spans {i}: {spans} != {dataset.get(i)}." From c7373b946aad8be10b7ff0e2ec0fc2fa7a9067c0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 16 Jan 2025 18:31:04 +0000 Subject: [PATCH 04/45] bump dataset version --- fast_llm/data/dataset/gpt/memmap.py | 46 ++++++++++++++++++----------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a2239f16..960a7d1a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -29,7 +29,7 @@ def _init(self, name: str, prefix: pathlib.Path | str): with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER) - Assert.eq(struct.unpack(" Date: Thu, 16 Jan 2025 18:50:00 +0000 Subject: [PATCH 05/45] create a document class --- fast_llm/data/dataset/gpt/memmap.py | 23 +++++++++++------ .../data/preparator/gpt_memmap/prepare.py | 7 +++--- tests/test_memmap_dataset.py | 25 ++++++++++--------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 960a7d1a..44a22367 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,3 +1,4 @@ +import dataclasses import pathlib import struct import typing @@ -10,6 +11,12 @@ from fast_llm.utils import Assert, div +@dataclasses.dataclass +class GPTMemmapDocument: + text: np.ndarray + spans: np.ndarray + + class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, @@ -122,7 +129,7 @@ def get_document_sizes(self) -> "np.ndarray": return self._document_sizes @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[tuple[np.ndarray, np.ndarray]]): + def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTMemmapDocument]): # Initialize metadata dtype = None num_documents = 0 @@ -138,24 +145,24 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[tu # Write the binary data file (.bin) lazily with prefix.with_suffix(".bin").open("wb") as bin_stream: - for document, mask_spans in documents: + for document in documents: # Infer dtype from the first document if dtype is None: - dtype = document.dtype + dtype = document.text.dtype assert dtype is not None, "Document dtype could not be inferred from the data." # Ensure all documents have the same dtype - assert document.dtype == dtype, f"Expected dtype {dtype}, got {document.dtype}." + assert document.text.dtype == dtype, f"Expected dtype {dtype}, got {document.text.dtype}." # Write document to binary file - bin_stream.write(document.tobytes(order="C")) + bin_stream.write(document.text.tobytes(order="C")) # Update metadata - doc_length = len(document) + doc_length = len(document.text) lengths.append(doc_length) pointers.append(offset) - num_spans.append(len(mask_spans)) - spans.append(mask_spans) + num_spans.append(len(document.spans)) + spans.append(document.spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d484c4ae..015c788c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -8,7 +8,7 @@ import tqdm import transformers -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer @@ -63,8 +63,9 @@ def _save_shard(self, args) -> dict: def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield np.array(item["input_ids"], dtype=self._data_type.numpy), np.array( - item["token_spans"], dtype=np.int32 + yield GPTMemmapDocument( + np.array(item["input_ids"], dtype=self._data_type.numpy), + np.array(item["token_spans"], dtype=np.int32), ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index b84eb108..8c2235bf 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -4,32 +4,33 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_gpt_memmap_dataset(dtype): - documents = list( - zip( + documents = [ + GPTMemmapDocument(text, spans) + for text, spans in zip( [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)], np.array([[]] * 100, dtype=np.int32), ) - ) + ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, (document, spans) in enumerate(documents): + for i, document in enumerate(documents): memmap_document, memmap_spans = dataset.get(i) assert np.array_equal( - memmap_document, document, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." - if len(spans) > 0: + memmap_document, document.text, equal_nan=True + ), f"Mismatch for document {i}: {document.text} != {dataset.get(i)}." + if len(document.spans) > 0: assert np.array_equal( - memmap_spans, spans, equal_nan=True - ), f"Mismatch for non-empty spans {i}: {spans} != {dataset.get(i)}." + memmap_spans, document.spans, equal_nan=True + ), f"Mismatch for non-empty spans {i}: {document.spans} != {dataset.get(i)}." else: assert np.array_equal( - memmap_spans.flatten(), spans.flatten(), equal_nan=True - ), f"Mismatch for empty spans {i}: {spans} != {dataset.get(i)}." + memmap_spans.flatten(), document.spans.flatten(), equal_nan=True + ), f"Mismatch for empty spans {i}: {document.spans} != {dataset.get(i)}." From 419acd70d22eb4d98fd91191e588bcd5dc9290bb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 24 Jan 2025 08:28:30 +0000 Subject: [PATCH 06/45] make loss masking work for prepare and training --- fast_llm/data/data/gpt/data.py | 16 ++++++++++ fast_llm/data/dataset/gpt/memmap.py | 22 +++++++------- fast_llm/data/dataset/gpt/sampled.py | 24 +++++++++++---- .../data/preparator/gpt_memmap/prepare.py | 29 +++++++++++++------ fast_llm/data/tokenizer.py | 4 +-- fast_llm/functional/cross_entropy.py | 24 +++++++++++---- fast_llm/models/gpt/model.py | 24 ++++++++++----- 7 files changed, 104 insertions(+), 39 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 36165da7..83a83900 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,3 +1,4 @@ +import dataclasses import json import logging import math @@ -5,6 +6,7 @@ import typing import warnings +import numpy as np import torch import torch.utils.data @@ -26,6 +28,19 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class GPTDataBatch: + ids: torch.Tensor + spans: torch.Tensor + + +def gpt_data_collate_fn(batch): + stacked_ids = np.stack([sample.ids for sample in batch]) + # stacked_spans = np.stack([sample.spans for sample in batch]) + stacked_spans = [torch.from_numpy(sample.spans) for sample in batch] + return GPTDataBatch(ids=torch.from_numpy(stacked_ids), spans=stacked_spans) + + class GPTData(Data): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. @@ -226,6 +241,7 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, + collate_fn=gpt_data_collate_fn, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 44a22367..8c805089 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,8 +58,8 @@ def _init(self, name: str, prefix: pathlib.Path | str): # Spans are introduced in version 2. Datasets tokenized with version 1 do not contain span information and # compute loss on all tokens by default if self._version == 1: - self._num_spans = 0 - self._spans = [] + self._num_spans = np.zeros(self._num_documents, dtype=np.int32) + self._spans = [np.array([], dtype=np.int32).reshape(-1, 2)] * self._num_documents elif self._version == 2: self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -98,15 +98,17 @@ def __del__(self): del self._index_bin_buffer_mmap def get(self, idx, offset=0, length=None): - return ( - np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ), - self._spans[idx], + ids = np.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=self._document_sizes[idx] - offset if length is None else length, + offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + spans = [] + for span in self._spans[idx]: + if span[0] < offset + len(ids) and span[1] >= offset: + spans.append([max(span[0], offset) - offset, min(span[1], offset + len(ids) - 1) - offset]) + return (ids, np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self): diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index beebc5ed..ad82de93 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -1,3 +1,4 @@ +import dataclasses import math import numpy as np @@ -20,6 +21,12 @@ _extension_available = False +@dataclasses.dataclass +class GPTSample: + ids: np.ndarray + spans: np.ndarray + + class GPTSampledIndexedDataset(SampledDataset): """ A GPT dataset augmented with a sampling, i.e., @@ -189,14 +196,21 @@ def __getitem__(self, idx): ) for doc in range(doc_f, doc_l + 1) ] - sample = np.concatenate( - sample_list, - dtype=np.int64, - ) + + sample_ids = [] + sample_spans = [] + span_offset = 0 + for ids, spans in sample_list: + sample_ids.extend(ids) + for span in spans: + sample_spans.append([span[0] + span_offset, span[1] + span_offset]) + span_offset += len(ids) + sample_ids = np.array(sample_ids, dtype=np.int64) + sample_spans = np.array(sample_spans, dtype=np.int32).reshape(-1, 2) if self._fim is not None: sample = self._fim(sample, np.random.RandomState(seed=(self._sampling_config.seed + idx) % MAX_SEED)) - return sample + return GPTSample(ids=sample_ids, spans=sample_spans) @property def name(self): diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 015c788c..659aca8b 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -22,33 +22,44 @@ class GPTMemmapDatasetPreparator(DatasetPreparator): _tokenizer: Tokenizer _data_type: DataType - def _tokenize_with_spans(self, sample): + def _tokenize_with_spans(self, text, char_spans): """ Perform span-aware tokenization and return the tokenized input_ids along with token spans. """ - char_spans = sample.get(self._config.dataset.spans_field, []) - text = sample[self._config.dataset.field] input_ids = [] token_spans = [] char_pos = 0 + beginning_of_text = True for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] - tokenized_text = self._tokenizer.tokenize(curr_text) + tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) + beginning_of_text = False input_ids.extend(tokenized_text) curr_text = text[start : end + 1] - tokenized_text = self._tokenizer.tokenize(curr_text) + tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) + beginning_of_text = False + token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) input_ids.extend(tokenized_text) - token_spans.append((len(token_spans), len(token_spans) + len(tokenized_text) - 1)) char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] tokenized_text = self._tokenizer.tokenize(curr_text) input_ids.extend(tokenized_text) - return np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32) + return np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2) def _tokenize_batch(self, batch): - input_ids, token_spans = zip(*[self._tokenize_with_spans(sample) for sample in batch]) + input_ids, token_spans = map( + list, + zip( + *[ + self._tokenize_with_spans(text, char_spans) + for text, char_spans in zip( + batch[self._config.dataset.field], batch[self._config.dataset.spans_field] + ) + ] + ), + ) num_tokens = [len(x) for x in input_ids] return { "input_ids": input_ids, @@ -65,7 +76,7 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTMemmapDocument( np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32), + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index f5fde98d..06aa195f 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -31,8 +31,8 @@ def vocab(self): def inv_vocab(self): return self._inv_vocab - def tokenize(self, text: str): - return self.tokenizer.encode(text) + def tokenize(self, text: str, add_special_tokens: bool = True): + return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) def detokenize(self, token_ids): return self.tokenizer.decode(token_ids) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0c2579c2..c9111902 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -8,7 +8,9 @@ from fast_llm.utils import Assert -def torch_cross_entropy_forward_backward(logits, target, grad_output, logits_scale_factor: float = 1.0): +def torch_cross_entropy_forward_backward( + logits, target, grad_output, logits_scale_factor: float = 1.0, ignore_index=-100 +): """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting @@ -19,13 +21,15 @@ def torch_cross_entropy_forward_backward(logits, target, grad_output, logits_sca logits_ = logits.float().detach().requires_grad_() if logits_scale_factor != 1.0: logits_ *= logits_scale_factor - loss = torch.nn.functional.cross_entropy(logits_, target).mean() + loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean() loss.backward(torch.full_like(loss, grad_output)) return loss.detach(), logits_.grad.detach().to(logits.dtype) @torch.compile -def fused_cross_entropy_forward_backward(logits, target, grad_output: float | None, logits_scale_factor: float = 1.0): +def fused_cross_entropy_forward_backward( + logits, target, grad_output: float | None, logits_scale_factor: float = 1.0, ignore_index=-100 +): """ A fused implementation of cross-entropy with torch compile. It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, @@ -33,6 +37,9 @@ def fused_cross_entropy_forward_backward(logits, target, grad_output: float | No """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. + mask = target != ignore_index + target = target[mask] + logits = logits[mask] target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: @@ -43,6 +50,7 @@ def fused_cross_entropy_forward_backward(logits, target, grad_output: float | No if grad_output is None: grad = None else: + grad = torch.zeros((mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) @@ -50,7 +58,7 @@ def fused_cross_entropy_forward_backward(logits, target, grad_output: float | No if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - grad = exp_logits.to(logits.dtype) + grad.index_put_((mask,), exp_logits.to(logits.dtype)) loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() @@ -59,7 +67,7 @@ def fused_cross_entropy_forward_backward(logits, target, grad_output: float | No @torch.compile def parallel_cross_entropy_forward_backward( - logits, target, grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0 + logits, target, grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, ignore_index=-100 ): """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -67,6 +75,9 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions + mask = target != ignore_index + target = target[mask] + logits = logits[mask] target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0] @@ -88,6 +99,7 @@ def parallel_cross_entropy_forward_backward( if grad_output is None: grad = None else: + grad = torch.zeros((mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits1 = exp_logits.scatter( 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) ) @@ -95,7 +107,7 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - grad = exp_logits2.to(logits.dtype) + grad.index_put_((mask,), exp_logits2.to(logits.dtype)) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 90f2f883..1179a6e4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -2,6 +2,7 @@ import torch +from fast_llm.data.data.gpt.data import GPTDataBatch from fast_llm.engine.base_model.base_model import BaseModel, LossDef from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -182,7 +183,7 @@ def preprocess_meta(self, input_: BatchConfig | torch.Tensor, phase: PhaseType) def preprocess( self, - batch: torch.Tensor, + batch: GPTDataBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, @@ -200,14 +201,14 @@ def preprocess( sequence_first = common_kwargs[TransformerKwargs.sequence_first] sequence_length = common_kwargs[TransformerKwargs.sequence_length] - batch = batch.to( + batch.ids = batch.ids.to( device=self._tensor_space.distributed.device, dtype=torch.int64, non_blocking=True, ) if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch = batch.transpose(0, 1).contiguous() + batch.ids = batch.ids.transpose(0, 1).contiguous() if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.create_tensors(sequence_length) @@ -221,10 +222,10 @@ def preprocess( for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch[sequence_k - sequence_q : sequence_k] + tokens = batch.ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = batch.ids[:, sequence_k - sequence_q : sequence_k].contiguous() # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -237,10 +238,19 @@ def preprocess( } if phase != PhaseType.inference: if sequence_first: - labels = batch[sequence_k - sequence_q + 1 : sequence_k + 1] + labels = batch.ids[sequence_k - sequence_q + 1 : sequence_k + 1] else: # TODO: Avoid multiple contiguous calls? - labels = batch[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + labels = batch.ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss + # TODO: take ignore_index from config + for i, spans_i in enumerate(batch.spans): + mask_indices = ( + torch.cat([torch.arange(s - 1, e) for s, e in spans_i]) + if len(spans_i) + else torch.tensor([], dtype=torch.int64) + ) + labels[i, mask_indices] = -100 kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) From daa2ad797610af2c76f6315a1e4adec180cd7781 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Sat, 25 Jan 2025 00:20:17 +0000 Subject: [PATCH 07/45] bos and eos options for tokenizer --- fast_llm/data/data/gpt/data.py | 3 ++- .../data/preparator/gpt_memmap/prepare.py | 4 ++-- fast_llm/data/tokenizer.py | 19 +++++++++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 44079f94..b7a71a06 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -12,6 +12,7 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.tokenizer import Tokenizer @@ -30,7 +31,7 @@ class GPTDataBatch: spans: torch.Tensor -def gpt_data_collate_fn(batch): +def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTDataBatch: stacked_ids = np.stack([sample.ids for sample in batch]) # stacked_spans = np.stack([sample.spans for sample in batch]) stacked_spans = [torch.from_numpy(sample.spans) for sample in batch] diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 3c850bed..21aeeef8 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -33,11 +33,11 @@ def _tokenize_with_spans(self, text: str, char_spans: list[tuple[int, int]]) -> for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] - tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) + tokenized_text = self._tokenizer.tokenize(curr_text, add_bos_token=beginning_of_text) beginning_of_text = False input_ids.extend(tokenized_text) curr_text = text[start : end + 1] - tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) + tokenized_text = self._tokenizer.tokenize(curr_text, add_bos_token=beginning_of_text) beginning_of_text = False token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) input_ids.extend(tokenized_text) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index fabdaf5b..1a56116b 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -33,8 +33,23 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, add_special_tokens: bool = True) -> list[int]: - return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) + def tokenize( + self, + text: str, + add_special_tokens: bool = True, + add_bos_token: bool | None = None, + add_eos_token: bool | None = None, + ) -> list[int]: + # add_special_tokens will use the default tokenizer behaviour. + # If add_bos_token or add_eos_token is set, we use them and ignore add_special_tokens. + if add_bos_token is not None or add_eos_token is not None: + return ( + ([self.tokenizer.bos_token_id] if add_bos_token and self.tokenizer.bos_token_id else []) + + self.tokenizer.encode(text, add_special_tokens=False) + + ([self.tokenizer.eos_token_id] if add_eos_token else []) + ) + else: + return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) From bb175bfe5ede982548541f11993fb7eda829ef2e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 27 Jan 2025 18:20:41 +0000 Subject: [PATCH 08/45] loss masking for triton cross entropy --- fast_llm/functional/cross_entropy.py | 34 ++++++++++++--------- fast_llm/functional/triton/cross_entropy.py | 14 +++++++-- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index ff81ea34..0e806dd7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -11,9 +11,10 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index=-100, + ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -38,6 +39,7 @@ def torch_cross_entropy_forward_backward( def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, ignore_index: int = -100, @@ -49,9 +51,6 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. - mask = target != ignore_index - target = target[mask] - logits = logits[mask] target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: @@ -62,7 +61,7 @@ def fused_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) + grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) @@ -70,7 +69,7 @@ def fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - grad.index_put_((mask,), exp_logits.to(logits.dtype)) + grad.index_put_((loss_mask,), exp_logits.to(logits.dtype)) loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() @@ -79,7 +78,13 @@ def fused_cross_entropy_forward_backward( @torch.compile def parallel_cross_entropy_forward_backward( - logits, target, grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, ignore_index=-100 + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + grad_output: float | None, + group: ProcessGroup, + logits_scale_factor: float = 1.0, + ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -87,9 +92,6 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions - mask = target != ignore_index - target = target[mask] - logits = logits[mask] target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0] @@ -111,7 +113,7 @@ def parallel_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) + grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits1 = exp_logits.scatter( 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) ) @@ -119,7 +121,7 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - grad.index_put_((mask,), exp_logits2.to(logits.dtype)) + grad.index_put_((loss_mask,), exp_logits2.to(logits.dtype)) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) @@ -142,6 +144,7 @@ def cross_entropy_forward_backward( group: ProcessGroup | None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, + ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -149,12 +152,15 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + loss_mask = target != ignore_index + target = target[loss_mask] + logits = logits[loss_mask] if group: Assert.eq(implementation, CrossEntropyImpl.fused) return parallel_cross_entropy_forward_backward( - logits, target, grad_output, group, logits_scale_factor=logits_scale_factor + logits, target, loss_mask, grad_output, group, logits_scale_factor=logits_scale_factor ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor + logits, target, loss_mask, grad_output, logits_scale_factor=logits_scale_factor ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 118577e6..f78f1cc0 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,7 @@ import torch - import triton import triton.language as tl + from fast_llm.functional.config import TritonConfig @@ -9,6 +9,7 @@ def triton_cross_entropy_forward_backward_kernel( logits_ptr, labels_ptr, + loss_mask_ptr, grad_logits_ptr, losses_ptr, grad_losses, @@ -54,7 +55,11 @@ def triton_cross_entropy_forward_backward_kernel( def triton_cross_entropy_forward_backward( - logits, target, grad_output: float | None, logits_scale_factor: float = 1.0 + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -77,6 +82,7 @@ def triton_cross_entropy_forward_backward( triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, + loss_mask, grad_logits, losses, 1 if grad_output is None else grad_output / n_rows, @@ -87,4 +93,6 @@ def triton_cross_entropy_forward_backward( block_size=block_size, num_warps=num_warps, ) - return losses.mean(), None if grad_output is None else grad_logits + full_grad_logits = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) + full_grad_logits.index_put_((loss_mask,), grad_logits) + return losses.mean(), None if grad_output is None else full_grad_logits From 0e7ad8b95d4bb00a6774cb2dc8caa5cc148acc6a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 01:05:56 +0000 Subject: [PATCH 09/45] fix random data tests --- .pre-commit-config.yaml | 4 +-- fast_llm/data/dataset/gpt/random.py | 17 ++++++++++-- tests/test_dataset.py | 40 ++++++++++++++++++++++------- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 548a4edc..4f174132 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - app/scripts/utility/shell.py - --remove-duplicate-keys - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 6.0.0 hooks: - id: isort name: isort (python) @@ -49,7 +49,7 @@ repos: - "--config" - "./pyproject.toml" - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.16.0 + rev: v0.17.2 hooks: - id: markdownlint-cli2 name: markdownlint diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 142dca71..463c3c10 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -2,6 +2,7 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.gpt.sampled import GPTSample class GPTRandomDataset(SamplableDataset): @@ -31,10 +32,22 @@ def __init__(self, config: GPTSamplingConfig, name: str): def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx) -> np.ndarray: - return np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( + def __getitem__(self, idx) -> GPTSample: + np_seed = self._seed + 48576439 + 74593 * idx + ids = np.random.RandomState(np_seed).randint( 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 ) + n_spans = np.random.RandomState(np_seed).randint(0, 3) + spans = [] + prev_end = -1 + for _ in range(n_spans): + start = np.random.RandomState(np_seed).randint(prev_end + 1, len(ids)) + end = np.random.RandomState(np_seed).randint(start, len(ids)) + spans.append([start, end]) + prev_end = end + if prev_end >= len(ids) - 1: + break + return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self) -> str: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d9e30b34..a1151a4f 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -75,8 +75,8 @@ def get_test_data_and_samples( batch_config.setup(distributed_config) batch_config.validate() samples = { - phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] - for phase, samples in samples_per_phase.items() + phase: list(data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)) + for phase, n_samples in samples_per_phase.items() } return data, samples @@ -88,12 +88,18 @@ def get_test_dataset_1(): return get_test_dataset(prefix=DATASET_PREFIX_MIX_1, seed=2345) -RANDOM_DATASET_EXPECTED_SAMPLES = [ +RANDOM_DATASET_EXPECTED_SAMPLES_IDS = [ [3954, 4105, 6766, 859, 5494, 1675, 1303, 6913], [1654, 5701, 32, 1662, 7053, 3487, 1861, 1502], [5409, 6240, 5504, 7458, 7667, 3955, 3151, 3912], [5640, 6131, 7750, 2699, 1349, 2585, 7113, 6981], ] +RANDOM_DATASET_EXPECTED_SAMPLES_SPANS = [ + [[2, 4], [7, 7]], + [[6, 6], [7, 7]], + [[1, 2]], + [], +] def test_gpt_random_dataset(): @@ -102,9 +108,17 @@ def test_gpt_random_dataset(): get_sampling_config(4, sequence_length=7) ) Assert.eq(len(sampled), 4) + sampled_ids = np.stack([sampled[i].ids for i in range(4)]) + sampled_spans = np.vstack([sampled[i].spans for i in range(4)]) + Assert.all_equal( + sampled_ids, + np.stack(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), + ) Assert.all_equal( - np.stack([sampled[i] for i in range(4)]), - np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + sampled_spans, + np.vstack( + [np.array(x, dtype=sampled_spans.dtype).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS] + ), ) @@ -121,16 +135,24 @@ def test_gpt_random_data(): sequence_length=7, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS]), ) def test_gpt_random_data_legacy(): _, samples = get_test_data_and_samples({"format": "random"}, {PhaseType.training: 4}, sequence_length=7) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(RANDOM_DATASET_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS]), ) From 989a8f89fd9108208f1cc4a58bde756df1307c95 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 01:07:11 +0000 Subject: [PATCH 10/45] revert precommit versions --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f174132..548a4edc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - app/scripts/utility/shell.py - --remove-duplicate-keys - repo: https://github.com/pycqa/isort - rev: 6.0.0 + rev: 5.13.2 hooks: - id: isort name: isort (python) @@ -49,7 +49,7 @@ repos: - "--config" - "./pyproject.toml" - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.17.2 + rev: v0.16.0 hooks: - id: markdownlint-cli2 name: markdownlint From 9633f886a0c22664ece1fdf163e9f1e6a0b43ff9 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 03:04:11 +0000 Subject: [PATCH 11/45] fix memmap dataset test --- fast_llm/data/dataset/gpt/memmap.py | 8 ++++++++ fast_llm/data/dataset/gpt/random.py | 4 ++-- tests/common.py | 15 ++++++++++++++- tests/test_dataset.py | 22 ++++++++++++++-------- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 9c556575..d0b1f621 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -136,6 +136,14 @@ def get_document_sizes(self) -> np.ndarray: """ return self._document_sizes + def get_span_sizes(self) -> np.ndarray: + """ + The number of spans in each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + return self._num_spans + @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTMemmapDocument]): # Initialize metadata diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 463c3c10..47dd1817 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -41,12 +41,12 @@ def __getitem__(self, idx) -> GPTSample: spans = [] prev_end = -1 for _ in range(n_spans): + if prev_end >= len(ids) - 1: + break start = np.random.RandomState(np_seed).randint(prev_end + 1, len(ids)) end = np.random.RandomState(np_seed).randint(start, len(ids)) spans.append([start, end]) prev_end = end - if prev_end >= len(ids) - 1: - break return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property diff --git a/tests/common.py b/tests/common.py index 9494fe14..d4a49a10 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,7 @@ import pytest import torch -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -230,6 +230,19 @@ def get_test_dataset( documents = [ np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size for document in documents ] + for idx, doc in enumerate(documents): + doc_seed = seed + idx + n_spans = random.Random(doc_seed).randint(0, 5) + spans = [] + prev_end = -1 + for _ in range(n_spans): + if prev_end >= len(doc) - 1: + break + start = random.Random(doc_seed).randint(prev_end + 1, len(doc) - 1) + end = random.Random(doc_seed).randint(start, len(doc) - 1) + spans.append([start, end]) + prev_end = end + documents[idx] = GPTMemmapDocument(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) GPTMemmapDataset.write_dataset(prefix, documents) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a1151a4f..ba513e5a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -159,11 +159,12 @@ def test_gpt_random_data_legacy(): # Most documents are too long to write here, we test a few known short ones. MEMMAP_DATASET_EXPECTED_LENGTH = 6153 MEMMAP_DATASET_EXPECTED_TOKENS = 508327 +MEMMAP_DATASET_EXPECTED_SPANS = 9138 MEMMAP_DATASET_EXPECTED_SAMPLES = { - 9: [], - 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], - 13: [78, 727, 74, 317, 1358, 89], - 15: [78], + 9: ([], []), + 10: ([80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], [[4, 6], [8, 9]]), + 13: ([78, 727, 74, 317, 1358, 89], []), + 15: ([78], [[0, 0]]), } @@ -173,11 +174,16 @@ def test_gpt_memmap(cache_directory): get_test_dataset() dataset = _get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() Assert.eq(len(dataset), MEMMAP_DATASET_EXPECTED_LENGTH) - sizes = dataset.get_document_sizes() - Assert.eq(sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) - Assert.all_equal([len(dataset.get(i)) for i in range(100)], sizes[:100]) + doc_sizes = dataset.get_document_sizes() + span_sizes = dataset.get_span_sizes() + Assert.eq(doc_sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) + Assert.eq(span_sizes.sum(), MEMMAP_DATASET_EXPECTED_SPANS) + Assert.all_equal([len(dataset.get(i).ids) for i in range(100)], doc_sizes[:100]) + Assert.all_equal([len(dataset.get(i).spans) for i in range(100)], span_sizes[:100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + ds_sample = dataset.get(i) + Assert.all_equal(ds_sample.ids, np.array(sample[0], dtype=np.uint16)) + Assert.all_equal(ds_sample.spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) GPT_SAMPLED_EXPECTED_SAMPLES = [ From 4f955ff5d7dcee621a6a4bc8ac355dde0a62290a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 07:57:55 +0000 Subject: [PATCH 12/45] fix remaining dataset tests --- fast_llm/data/dataset/gpt/fim.py | 23 +-- fast_llm/data/dataset/gpt/indexed.py | 6 + tests/test_dataset.py | 276 ++++++++++++++++++++++----- 3 files changed, 246 insertions(+), 59 deletions(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 1ffb3bfc..2e9bafe7 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -2,6 +2,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingConfig +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.engine.distributed.config import MAX_SEED @@ -42,12 +43,12 @@ def __getitem__(self, idx: int) -> np.ndarray: def name(self) -> str: return f"{self._dataset.name}_fim" - def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: + def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # FIM # TODO: permute segments in sample_list, before concatenating. - sample_len = sample.shape[0] + sample_len = sample.ids.shape[0] eod = self._tokenizer.eod - segment_breaks = np.argwhere(sample == eod) # split sample by document + segment_breaks = np.argwhere(sample.ids == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example curr_start_position = 0 @@ -57,26 +58,26 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # Only permute non-empty segments. if loc - curr_start_position > 0: # permute {prefix, suffix, middle} or {suffix, prefix, middle} - permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc], np_rng) + permuted = self._fim_split_and_permute_sequence(sample.ids[curr_start_position:loc], np_rng) new_samples += [permuted, [eod]] curr_start_position = loc + 1 # jump over the EOD token # Permute the segment after the last EOD - permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) + permuted = self._fim_split_and_permute_sequence(sample.ids[curr_start_position:], np_rng) new_samples.append(permuted) - sample = np.concatenate(new_samples) + sample.ids = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + sample.ids = self._fim_split_and_permute_sequence(sample.ids, np_rng) # Truncate or pad sequence to max-length - diff = sample.shape[0] - sample_len + diff = sample.ids.shape[0] - sample_len if diff > 0: # too long - sample = sample[:sample_len] + sample.ids = sample.ids[:sample_len] elif diff < 0: # too short - sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) + sample.ids = np.concatenate([sample.ids, np.full((-1 * diff), self._pad_tok_id)]) - assert sample.shape[0] == sample_len + assert sample.ids.shape[0] == sample_len return sample def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 2c158bff..e3ddbeda 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -36,6 +36,9 @@ def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return self._dataset.get_document_sizes()[self._begin : self._end] + def get_span_sizes(self) -> np.ndarray: + return self._dataset.get_span_sizes()[self._begin : self._end] + class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset @@ -45,3 +48,6 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_span_sizes(self) -> np.ndarray: + return np.concatenate([dataset.get_span_sizes() for dataset in self._datasets]) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ba513e5a..873893a9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -186,7 +186,7 @@ def test_gpt_memmap(cache_directory): Assert.all_equal(ds_sample.spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) -GPT_SAMPLED_EXPECTED_SAMPLES = [ +GPT_SAMPLED_EXPECTED_SAMPLES_IDS = [ [1725, 74, 207, 1635, 4440, 2774], [359, 489, 4266, 2052, 5351, 80], [374, 7534, 87, 1073, 79, 480], @@ -197,6 +197,17 @@ def test_gpt_memmap(cache_directory): [330, 155, 2449, 1136, 1106, 5370], ] +GPT_SAMPLED_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [[0, 2]], + [], + [], + [], + [[0, 5]], + [], + [], +] + def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. @@ -206,8 +217,14 @@ def test_gpt_sampled(): ) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS] + ), ) @@ -226,8 +243,12 @@ def test_gpt_sampled_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS]), ) @@ -238,12 +259,16 @@ def test_gpt_sampled_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS]), ) -GPT_CONCATENATED_EXPECTED_SAMPLES = [ +GPT_CONCATENATED_EXPECTED_SAMPLES_IDS = [ [243, 498, 7172, 777, 306, 74], [821, 6042, 89, 977, 4797, 499], [387, 74, 330, 328, 1858, 484], @@ -255,6 +280,18 @@ def test_gpt_sampled_data_legacy(): ] +GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [], + [], + [[0, 5]], + [], + [[0, 0]], + [], + [], +] + + def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. get_test_dataset() @@ -263,18 +300,28 @@ def test_gpt_concatenate(): GPTConcatenatedDatasetConfig, ).build() Assert.eq(len(dataset), 3 * MEMMAP_DATASET_EXPECTED_LENGTH) - sizes = dataset.get_document_sizes() - Assert.eq(sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) + doc_sizes = dataset.get_document_sizes() + span_sizes = dataset.get_span_sizes() + Assert.eq(doc_sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) + Assert.eq(span_sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_SPANS) for i in range(3): begin = i * MEMMAP_DATASET_EXPECTED_LENGTH - Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100]) + Assert.all_equal([len(dataset.get(begin + i).ids) for i in range(100)], doc_sizes[begin : begin + 100]) + Assert.all_equal([len(dataset.get(begin + i).spans) for i in range(100)], span_sizes[begin : begin + 100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(begin + i).ids, np.array(sample[0], dtype=np.uint16)) + Assert.all_equal(dataset.get(begin + i).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS] + ), ) @@ -292,19 +339,30 @@ def test_gpt_concatenate_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS]), ) -GPT_SLICE_EXPECTED_TRAINING_SAMPLES = [ +GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS = [ [2625, 76, 2625, 2639, 74, 243], [207, 481, 5546, 74, 414, 498], [74, 333, 1963, 310, 5337, 3628], [79, 2361, 80, 2012, 84, 480], ] -GPT_SLICE_EXPECTED_VALIDATION_SAMPLES = [ +GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS = [ + [], + [], + [[0, 2]], + [], +] + +GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS = [ [2352, 3687, 2311, 4900, 542, 3732], [2551, 5283, 900, 3140, 328, 68], [7979, 2283, 329, 727, 2740, 2818], @@ -315,6 +373,17 @@ def test_gpt_concatenate_data(): [243, 3712, 86, 476, 80, 2547], ] +GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS = [ + [], + [], + [[0, 5]], + [[0, 3]], + [], + [], + [], + [], +] + def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -325,15 +394,27 @@ def test_gpt_slice(): GPTDatasetSliceConfig, ).build() Assert.eq(len(dataset), 9) - sizes = dataset.get_document_sizes() - Assert.all_equal([len(dataset.get(i)) for i in range(9)], sizes[:9]) + doc_sizes = dataset.get_document_sizes() + span_sizes = dataset.get_span_sizes() + Assert.all_equal([len(dataset.get(i).ids) for i in range(9)], doc_sizes[:9]) + Assert.all_equal([len(dataset.get(i).spans) for i in range(9)], span_sizes[:9]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i - 9), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(i - 9).ids, np.array(sample[0], dtype=np.uint16)) + Assert.all_equal(dataset.get(i - 9).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [ + np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) + for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS + ] + ), ) @@ -365,12 +446,20 @@ def test_gpt_slice_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.validation]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.validation]]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), + np.vstack([batch.spans[0] for batch in samples[PhaseType.validation]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS]), + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS]), ) @@ -382,12 +471,20 @@ def test_gpt_slice_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.validation]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.validation]]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.validation]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS]), + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS), ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS]), ) @@ -402,6 +499,17 @@ def test_gpt_slice_data_legacy(): [409, 5091, 328, 1378, 5483, 88], ] +GPT_BLENDED_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [[1, 5]], + [[0, 2]], + [], + [], + [], + [], + [[0, 5]], +] + def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. @@ -420,9 +528,15 @@ def test_gpt_blended(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].ids for i in range(8)]), np.array(GPT_BLENDED_EXPECTED_SAMPLES), ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_BLENDED_EXPECTED_SAMPLES_SPANS] + ), + ) def test_gpt_blended_data(): @@ -445,12 +559,16 @@ def test_gpt_blended_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), np.array(GPT_BLENDED_EXPECTED_SAMPLES), ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_EXPECTED_SAMPLES_SPANS]), + ) -GPT_BLENDED_LEGACY_EXPECTED_SAMPLES = [ +GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_IDS = [ [1725, 74, 207, 1635, 4440, 2774], [328, 80, 263, 890, 1797, 88], [359, 489, 4266, 2052, 5351, 80], @@ -461,6 +579,17 @@ def test_gpt_blended_data(): [409, 5091, 328, 1378, 5483, 88], ] +GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [], + [[0, 2]], + [], + [], + [], + [], + [[0, 5]], +] + def test_gpt_blended_data_legacy(): get_test_dataset() @@ -475,12 +604,16 @@ def test_gpt_blended_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_SPANS]), ) -GPT_BLENDED_MIXED_EXPECTED_SAMPLES = [ +GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS = [ [1725, 74, 207, 1635, 4440, 2774], [916, 6683, 7685, 1277, 5106, 378], [359, 489, 4266, 2052, 5351, 80], @@ -491,6 +624,17 @@ def test_gpt_blended_data_legacy(): [2210, 8179, 73, 2582, 897, 1178], ] +GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [], + [[0, 2]], + [], + [], + [], + [], + [], +] + def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. @@ -508,8 +652,17 @@ def test_gpt_blended_mixed(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [ + np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) + for x in GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS + ] + ), ) @@ -528,12 +681,16 @@ def test_gpt_blended_mixed_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS]), ) -GPT_FIM_EXPECTED_SAMPLES = [ +GPT_FIM_EXPECTED_SAMPLES_IDS = [ [1725, 74, 207, 1635, 4440, 2774], [359, 489, 4266, 2052, 5351, 80], [86, 89, 22255, 1073, 79, 480], @@ -544,6 +701,17 @@ def test_gpt_blended_mixed_data(): [86, 89, 1461, 87, 330, 7876], ] +GPT_FIM_EXPECTED_SAMPLES_SPANS = [ + [[0, 5]], + [[0, 2]], + [], + [], + [], + [[0, 5]], + [], + [], +] + def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. @@ -567,8 +735,12 @@ def test_gpt_fim(): Assert.eq(len(sampled), 8) # TODO: Does this output make sense? Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_FIM_EXPECTED_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack([np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), ) @@ -592,8 +764,12 @@ def test_gpt_fim_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_FIM_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), ) @@ -610,6 +786,10 @@ def test_gpt_fim_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_FIM_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), ) From 1ac50522478c5651f7d97c7ba29601f22b9fe9ad Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 08:18:57 +0000 Subject: [PATCH 13/45] compose tests --- tests/test_dataset.py | 56 +++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6073fd48..b8b822b8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -502,18 +502,19 @@ def test_gpt_slice_data_legacy(): COMPOSED_DATASET_EXPECTED_LENGTH = 24806 COMPOSED_DATASET_EXPECTED_TOKENS = 2033639 +COMPOSED_DATASET_EXPECTED_SPANS = 37008 COMPOSED_DATASET_EXPECTED_SAMPLES = { **MEMMAP_DATASET_EXPECTED_SAMPLES, - 6930: [65, 2327], - 11962: [7078, 2713, 1431], - 15958: [207], - 19362: [69], - 24098: [555, 668, 70], + 6930: ([65, 2327], [[0, 0], [1, 1]]), + 11962: ([7078, 2713, 1431], [[2, 2]]), + 15958: ([207], [[0, 0]]), + 19362: ([69], [[0, 0]]), + 24098: ([555, 668, 70], [[1, 2]]), } -GPT_COMPOSED_EXPECTED_SAMPLES = [ +GPT_COMPOSED_EXPECTED_SAMPLES_IDS = [ [1411, 819, 6791, 7022, 285, 249], [329, 328, 512, 1985, 3069, 7838], [5158, 1023, 8171, 798, 1431, 313], @@ -524,6 +525,17 @@ def test_gpt_slice_data_legacy(): [79, 6042, 577, 225, 207, 207], ] +GPT_COMPOSED_EXPECTED_SAMPLES_SPANS = [ + [], + [], + [[0, 5]], + [], + [], + [], + [], + [], +] + def test_gpt_compose(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -534,17 +546,27 @@ def test_gpt_compose(): GPTConcatenatedMemmapConfig, ).build() Assert.eq(len(dataset), COMPOSED_DATASET_EXPECTED_LENGTH) - sizes = dataset.get_document_sizes() - Assert.eq(sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) - Assert.all_equal([len(dataset.get(i)) for i in range(0, len(dataset), 20)], sizes[::20]) + doc_sizes = dataset.get_document_sizes() + span_sizes = dataset.get_span_sizes() + Assert.eq(doc_sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) + Assert.eq(span_sizes.sum(), COMPOSED_DATASET_EXPECTED_SPANS) + Assert.all_equal([len(dataset.get(i).ids) for i in range(0, len(dataset), 20)], doc_sizes[::20]) + Assert.all_equal([len(dataset.get(i).spans) for i in range(0, len(dataset), 20)], span_sizes[::20]) for i, sample in COMPOSED_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(i).ids, np.array(sample[0], dtype=np.uint16)) + Assert.all_equal(dataset.get(i).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) - print(np.stack([sampled[i] for i in range(8)]).tolist()) + print(np.stack([sampled[i].ids for i in range(8)]).tolist()) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), - np.array(GPT_COMPOSED_EXPECTED_SAMPLES), + np.stack([sampled[i].ids for i in range(8)]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([sampled[i].spans for i in range(8)]), + np.vstack( + [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_COMPOSED_EXPECTED_SAMPLES_SPANS] + ), ) @@ -563,8 +585,12 @@ def test_gpt_composed_data(): sequence_length=5, ) Assert.all_equal( - np.stack(samples[PhaseType.training]), - np.array(GPT_COMPOSED_EXPECTED_SAMPLES), + np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES_IDS), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), + np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_COMPOSED_EXPECTED_SAMPLES_SPANS]), ) From aebb5a0e1d8a8cb01a0991904d0e662d948de6b8 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 18:23:30 +0000 Subject: [PATCH 14/45] handle special tokens from config --- fast_llm/data/config.py | 12 ++++ .../data/preparator/gpt_memmap/prepare.py | 38 +++--------- fast_llm/data/tokenizer.py | 59 ++++++++++++++----- 3 files changed, 66 insertions(+), 43 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370..220ed659 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -16,6 +16,13 @@ class MultiprocessingContext(str, enum.Enum): TokenizerFromFile = "TokenizerFromFile" +class SpecialTokensMode(str, enum.Enum): + tokenizer_default = "tokenizer_default" + bos_only = "bos_only" + eos_only = "eos_only" + bos_eos = "bos_eos" + + @config_class() class TokenizerConfig(Config): """ @@ -34,3 +41,8 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + special_tokens_mode: SpecialTokensMode = Field( + default=SpecialTokensMode.bos_only, + desc="Special tokens configuration.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 687ff519..a3219b56 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -22,41 +22,21 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _tokenizer: Tokenizer _data_type: DataType - def _tokenize_with_spans(self, text: str, char_spans: list[tuple[int, int]]) -> tuple[np.ndarray, np.ndarray]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self._tokenizer.tokenize(curr_text, add_bos_token=beginning_of_text) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - tokenized_text = self._tokenizer.tokenize(curr_text, add_bos_token=beginning_of_text) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self._tokenizer.tokenize(curr_text) - input_ids.extend(tokenized_text) - return np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2) - def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids, token_spans = map( list, zip( *[ - self._tokenize_with_spans(text, char_spans) - for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.spans_field] + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans, dtype=np.int32).reshape(-1, 2), ) + for input_ids, token_spans in [ + self._tokenizer.tokenize_with_spans(text, char_spans) + for text, char_spans in zip( + batch[self._config.dataset.field], batch[self._config.dataset.spans_field] + ) + ] ] ), ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 1a56116b..ced4dcc8 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -2,7 +2,7 @@ import torch from transformers import PreTrainedTokenizerFast -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.config import SpecialTokensMode, TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -16,9 +16,11 @@ def __init__(self, config: TokenizerConfig): self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( pretrained_model_name_or_path=config.path, errors="replace", max_len=None ) + self.special_tokens_mode = config.special_tokens_mode if self.tokenizer.eos_token_id is None: raise ValueError("Tokenizer does not have an EOS token.") self.eod_id = self.tokenizer.eos_token_id + self.bod_id = self.tokenizer.bos_token_id self._inv_vocab = {v: k for k, v in self.vocab.items()} @property @@ -33,23 +35,52 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize( - self, - text: str, - add_special_tokens: bool = True, - add_bos_token: bool | None = None, - add_eos_token: bool | None = None, - ) -> list[int]: - # add_special_tokens will use the default tokenizer behaviour. - # If add_bos_token or add_eos_token is set, we use them and ignore add_special_tokens. - if add_bos_token is not None or add_eos_token is not None: + def tokenize(self, text: str, beginning_of_text=True, end_of_text=True) -> list[int]: + if self.special_tokens_mode == SpecialTokensMode.eos_only: + return self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) + elif self.special_tokens_mode == SpecialTokensMode.bos_only: + return ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode( + text, add_special_tokens=False + ) + elif self.special_tokens_mode == SpecialTokensMode.bos_eos: return ( - ([self.tokenizer.bos_token_id] if add_bos_token and self.tokenizer.bos_token_id else []) + ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.tokenizer.eos_token_id] if add_eos_token else []) + + ([self.eod_id] if end_of_text else []) ) else: - return self.tokenizer.encode(text, add_special_tokens=add_special_tokens) + # TODO: How do we handle when beginning_of_text=False or end_of_text=False? + return self.tokenizer.encode(text) + + def tokenize_with_spans( + self, text: str, char_spans: list[tuple[int, int]] + ) -> tuple[list[int], list[tuple[int, int]]]: + """ + Perform span-aware tokenization and return the tokenized input_ids along with token spans. + """ + input_ids = [] + token_spans = [] + char_pos = 0 + beginning_of_text = True + for start, end in char_spans: + if char_pos < start: + curr_text = text[char_pos:start] + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + beginning_of_text = False + input_ids.extend(tokenized_text) + curr_text = text[start : end + 1] + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + beginning_of_text = False + token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) + input_ids.extend(tokenized_text) + char_pos = end + 1 + if char_pos < len(text): + curr_text = text[char_pos:] + tokenized_text = self.tokenize(curr_text) + input_ids.extend(tokenized_text) + if self.special_tokens_mode in [SpecialTokensMode.eos_only, SpecialTokensMode.bos_eos]: + input_ids.append(self.eod_id) + return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) From d8e3ae1120b9c94a1ab86005625981f916557bef Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 18:58:04 +0000 Subject: [PATCH 15/45] fix fim to handle bos and eos --- fast_llm/data/dataset/gpt/fim.py | 8 +++++--- tests/test_dataset.py | 10 +++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2e9bafe7..e553e03c 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -151,9 +151,11 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, end_of_text=False)], dtype=np.int64) + middle = np.array( + [*self._tokenizer.tokenize(middle, beginning_of_text=False, end_of_text=False)], dtype=np.int64 + ) + suffix = np.array([*self._tokenizer.tokenize(suffix, beginning_of_text=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b8b822b8..c693b6ce 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -824,7 +824,11 @@ def test_gpt_fim(): get_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. sampling_config = get_sampling_config( - 8, sequence_length=5, tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})) + 8, + sequence_length=5, + tokenizer=Tokenizer( + TokenizerConfig.from_dict({"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}) + ), ) sampled = _get_dataset_config( { @@ -864,7 +868,7 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH}, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, }, {PhaseType.training: 8}, sequence_length=5, @@ -885,7 +889,7 @@ def test_gpt_fim_data_legacy(): "format": "list", "path": [str(DATASET_PREFIX)], "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, - "tokenizer": {"path": TOKENIZER_PATH}, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, "split": [1, 0, 0], }, {PhaseType.training: 8}, From a887dd68e1c9bb59647e81f9e97316f46a2b5272 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 20:42:25 +0000 Subject: [PATCH 16/45] address review comments --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/fim.py | 2 ++ fast_llm/data/tokenizer.py | 11 ++++---- fast_llm/functional/cross_entropy.py | 28 ++++++++++----------- fast_llm/functional/triton/cross_entropy.py | 6 +---- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4950f8aa..a4795e59 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -29,7 +29,7 @@ @dataclasses.dataclass class GPTDataBatch: ids: torch.Tensor - spans: torch.Tensor + spans: list[torch.Tensor] def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTDataBatch: diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index e553e03c..266c30ab 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -46,6 +46,8 @@ def name(self) -> str: def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # FIM # TODO: permute segments in sample_list, before concatenating. + if self._config.rate > 0.0 and sample.spans.size > 0: + raise NotImplementedError("FIM is currently not compatible with loss masking.") sample_len = sample.ids.shape[0] eod = self._tokenizer.eod segment_breaks = np.argwhere(sample.ids == eod) # split sample by document diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index ced4dcc8..92b4ba26 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -65,21 +65,22 @@ def tokenize_with_spans( for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) beginning_of_text = False input_ids.extend(tokenized_text) curr_text = text[start : end + 1] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text) + if end >= len(text) - 1: + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) + else: + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) beginning_of_text = False token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) input_ids.extend(tokenized_text) char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text) + tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) input_ids.extend(tokenized_text) - if self.special_tokens_mode in [SpecialTokensMode.eos_only, SpecialTokensMode.bos_eos]: - input_ids.append(self.eod_id) return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0e806dd7..62f120f4 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -11,10 +11,8 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -29,7 +27,7 @@ def torch_cross_entropy_forward_backward( if grad_output is None: loss = None else: - loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean() + loss = torch.nn.functional.cross_entropy(logits_, target).mean() loss.backward(torch.full_like(loss, grad_output)) loss.detach_() return loss.detach(), logits_.grad.detach().to(logits.dtype) @@ -39,10 +37,8 @@ def torch_cross_entropy_forward_backward( def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -61,7 +57,6 @@ def fused_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) @@ -69,7 +64,7 @@ def fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - grad.index_put_((loss_mask,), exp_logits.to(logits.dtype)) + grad = exp_logits.to(logits.dtype) loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() @@ -80,11 +75,9 @@ def fused_cross_entropy_forward_backward( def parallel_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, - ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -113,7 +106,6 @@ def parallel_cross_entropy_forward_backward( if grad_output is None: grad = None else: - grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) exp_logits1 = exp_logits.scatter( 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) ) @@ -121,7 +113,7 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - grad.index_put_((loss_mask,), exp_logits2.to(logits.dtype)) + grad = exp_logits2.to(logits.dtype) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) @@ -157,10 +149,16 @@ def cross_entropy_forward_backward( logits = logits[loss_mask] if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return parallel_cross_entropy_forward_backward( - logits, target, loss_mask, grad_output, group, logits_scale_factor=logits_scale_factor + loss, grad_logits = parallel_cross_entropy_forward_backward( + logits, target, grad_output, group, logits_scale_factor=logits_scale_factor ) else: - return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, loss_mask, grad_output, logits_scale_factor=logits_scale_factor + loss, grad_logits = _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( + logits, target, grad_output, logits_scale_factor=logits_scale_factor ) + if grad_logits is not None: + grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) + grad.index_put_((loss_mask,), grad_logits) + return loss, grad + else: + return loss, grad_logits diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index f78f1cc0..9e6e697f 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -9,7 +9,6 @@ def triton_cross_entropy_forward_backward_kernel( logits_ptr, labels_ptr, - loss_mask_ptr, grad_logits_ptr, losses_ptr, grad_losses, @@ -82,7 +81,6 @@ def triton_cross_entropy_forward_backward( triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, - loss_mask, grad_logits, losses, 1 if grad_output is None else grad_output / n_rows, @@ -93,6 +91,4 @@ def triton_cross_entropy_forward_backward( block_size=block_size, num_warps=num_warps, ) - full_grad_logits = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) - full_grad_logits.index_put_((loss_mask,), grad_logits) - return losses.mean(), None if grad_output is None else full_grad_logits + return losses.mean(), None if grad_output is None else grad_logits From 40a80f6cabacf26c6607a962eeb32f22dc29eb42 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 21:49:01 +0000 Subject: [PATCH 17/45] fix memmap tests --- tests/test_memmap_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index 8c2235bf..6cb22eab 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -22,15 +22,15 @@ def test_gpt_memmap_dataset(dtype): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - memmap_document, memmap_spans = dataset.get(i) + memmap_sample = dataset.get(i) assert np.array_equal( - memmap_document, document.text, equal_nan=True + memmap_sample.ids, document.text, equal_nan=True ), f"Mismatch for document {i}: {document.text} != {dataset.get(i)}." if len(document.spans) > 0: assert np.array_equal( - memmap_spans, document.spans, equal_nan=True + memmap_sample.spans, document.spans, equal_nan=True ), f"Mismatch for non-empty spans {i}: {document.spans} != {dataset.get(i)}." else: assert np.array_equal( - memmap_spans.flatten(), document.spans.flatten(), equal_nan=True + memmap_sample.spans.flatten(), document.spans.flatten(), equal_nan=True ), f"Mismatch for empty spans {i}: {document.spans} != {dataset.get(i)}." From e908303e72148c41d9049e450bbf4631c6e06762 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Tue, 28 Jan 2025 22:33:38 +0000 Subject: [PATCH 18/45] fix fim tests --- tests/test_dataset.py | 109 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c693b6ce..6c5bb586 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -71,6 +71,7 @@ def get_test_data_and_samples( seed: int = 54983, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, + consumed_samples: int = 0, vocab_size=TEST_VOCAB_SIZE, ): distributed_config = DistributedConfig(seed=seed) @@ -82,7 +83,7 @@ def get_test_data_and_samples( batch_config.setup(distributed_config) batch_config.validate() samples = { - phase: list(data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)) + phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0)) for phase, n_samples in samples_per_phase.items() } return data, samples @@ -818,6 +819,8 @@ def test_gpt_blended_mixed_data(): [], ] +GPT_FIM_VALID_IDS = [2, 3, 4, 6, 7] + def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. @@ -845,17 +848,54 @@ def test_gpt_fim(): Assert.eq(len(sampled), 8) # TODO: Does this output make sense? Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([sampled[i].ids for i in GPT_FIM_VALID_IDS]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS]), ) Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack([np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([sampled[i].spans for i in GPT_FIM_VALID_IDS]), + np.vstack( + [ + np.array(x, dtype=sampled[GPT_FIM_VALID_IDS[0]].spans.dtype).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS] + ] + ), ) def test_gpt_fim_data(): - _, samples = get_test_data_and_samples( + _, samples1 = get_test_data_and_samples( + { + "datasets": { + "Training": { + "type": "fim", + "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "rate": 0.5, + "prefix_token": "w", + "middle_token": "x", + "pad_token": "y", + "suffix_token": "z", + } + }, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + }, + {PhaseType.training: 5}, + sequence_length=5, + consumed_samples=2, + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]]), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] + ] + ), + ) + _, samples2 = get_test_data_and_samples( { "datasets": { "Training": { @@ -872,19 +912,52 @@ def test_gpt_fim_data(): }, {PhaseType.training: 8}, sequence_length=5, + consumed_samples=6, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] + ] + ), ) def test_gpt_fim_data_legacy(): - _, samples = get_test_data_and_samples( + _, samples1 = get_test_data_and_samples( + { + "format": "list", + "path": [str(DATASET_PREFIX)], + "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, + "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + "split": [1, 0, 0], + }, + {PhaseType.training: 5}, + sequence_length=5, + consumed_samples=2, + ) + Assert.all_equal( + np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), + np.array( + [GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]], + ), + ) + Assert.all_equal( + np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] + ] + ), + ) + _, samples2 = get_test_data_and_samples( { "format": "list", "path": [str(DATASET_PREFIX)], @@ -894,12 +967,18 @@ def test_gpt_fim_data_legacy(): }, {PhaseType.training: 8}, sequence_length=5, + consumed_samples=6, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_FIM_EXPECTED_SAMPLES_IDS), + np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), + np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_FIM_EXPECTED_SAMPLES_SPANS]), + np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), + np.vstack( + [ + np.array(x, dtype=np.int32).reshape(-1, 2) + for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] + ] + ), ) From 20ffae8e42a2ec8dc485b2eb1d6e7f221fd1333b Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 29 Jan 2025 20:13:21 +0000 Subject: [PATCH 19/45] special tokens mode -> sequence delimiters --- fast_llm/data/config.py | 9 +++++---- fast_llm/data/tokenizer.py | 10 ++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 220ed659..d525d6ea 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -16,11 +16,12 @@ class MultiprocessingContext(str, enum.Enum): TokenizerFromFile = "TokenizerFromFile" -class SpecialTokensMode(str, enum.Enum): +class SequenceDelimiters(str, enum.Enum): tokenizer_default = "tokenizer_default" bos_only = "bos_only" eos_only = "eos_only" bos_eos = "bos_eos" + no_delimiters = "no_delimiters" @config_class() @@ -41,8 +42,8 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) - special_tokens_mode: SpecialTokensMode = Field( - default=SpecialTokensMode.bos_only, - desc="Special tokens configuration.", + sequence_delimiters: SequenceDelimiters = Field( + default=SequenceDelimiters.bos_only, + desc="Boundary tokens (bos/eos) to use for tokenizing sequences", hint=FieldHint.core, ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 92b4ba26..dbbf419a 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -2,7 +2,7 @@ import torch from transformers import PreTrainedTokenizerFast -from fast_llm.data.config import SpecialTokensMode, TokenizerConfig +from fast_llm.data.config import SequenceDelimiters, TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -36,18 +36,20 @@ def inv_vocab(self) -> dict[int, str]: return self._inv_vocab def tokenize(self, text: str, beginning_of_text=True, end_of_text=True) -> list[int]: - if self.special_tokens_mode == SpecialTokensMode.eos_only: + if self.special_tokens_mode == SequenceDelimiters.eos_only: return self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) - elif self.special_tokens_mode == SpecialTokensMode.bos_only: + elif self.special_tokens_mode == SequenceDelimiters.bos_only: return ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode( text, add_special_tokens=False ) - elif self.special_tokens_mode == SpecialTokensMode.bos_eos: + elif self.special_tokens_mode == SequenceDelimiters.bos_eos: return ( ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) ) + elif self.special_tokens_mode == SequenceDelimiters.no_delimiters: + return self.tokenizer.encode(text, add_special_tokens=False) else: # TODO: How do we handle when beginning_of_text=False or end_of_text=False? return self.tokenizer.encode(text) From 753e73104580cacf3b5681a5ea66437d1a91300f Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 29 Jan 2025 20:19:17 +0000 Subject: [PATCH 20/45] GPTDataBatch -> GPTBatch --- fast_llm/data/data/gpt/data.py | 6 +++--- fast_llm/models/gpt/model.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a4795e59..6c664263 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -27,16 +27,16 @@ @dataclasses.dataclass -class GPTDataBatch: +class GPTBatch: ids: torch.Tensor spans: list[torch.Tensor] -def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTDataBatch: +def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch: stacked_ids = np.stack([sample.ids for sample in batch]) # stacked_spans = np.stack([sample.spans for sample in batch]) stacked_spans = [torch.from_numpy(sample.spans) for sample in batch] - return GPTDataBatch(ids=torch.from_numpy(stacked_ids), spans=stacked_spans) + return GPTBatch(ids=torch.from_numpy(stacked_ids), spans=stacked_spans) class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 18a115c9..921216ff 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.data.gpt.data import GPTDataBatch +from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -185,7 +185,7 @@ def preprocess_meta( def preprocess( self, - batch: GPTDataBatch, + batch: GPTBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, From cce070176e39775661e6de6e10af5662def3dbe1 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 29 Jan 2025 22:03:04 +0000 Subject: [PATCH 21/45] GPTMemmapDocument, GPTMemmapSample -> GPTSample --- fast_llm/data/data/gpt/data.py | 11 +++---- fast_llm/data/dataset/gpt/fim.py | 22 ++++++------- fast_llm/data/dataset/gpt/memmap.py | 32 ++++++------------- fast_llm/data/dataset/gpt/random.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 12 +++---- .../data/preparator/gpt_memmap/prepare.py | 5 +-- fast_llm/data/tokenizer.py | 10 +++--- fast_llm/models/gpt/model.py | 14 ++++---- tests/common.py | 5 +-- tests/test_memmap_dataset.py | 19 +++++------ 10 files changed, 61 insertions(+), 71 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6c664263..8dadc3bc 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -28,15 +28,14 @@ @dataclasses.dataclass class GPTBatch: - ids: torch.Tensor - spans: list[torch.Tensor] + token_ids: torch.Tensor + ignore_loss_spans: list[torch.Tensor] def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch: - stacked_ids = np.stack([sample.ids for sample in batch]) - # stacked_spans = np.stack([sample.spans for sample in batch]) - stacked_spans = [torch.from_numpy(sample.spans) for sample in batch] - return GPTBatch(ids=torch.from_numpy(stacked_ids), spans=stacked_spans) + stacked_ids = np.stack([sample.token_ids for sample in batch]) + stacked_spans = [torch.from_numpy(sample.ignore_loss_spans) for sample in batch] + return GPTBatch(token_ids=torch.from_numpy(stacked_ids), ignore_loss_spans=stacked_spans) class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 266c30ab..11513e76 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -46,11 +46,11 @@ def name(self) -> str: def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # FIM # TODO: permute segments in sample_list, before concatenating. - if self._config.rate > 0.0 and sample.spans.size > 0: + if self._config.rate > 0.0 and sample.ignore_loss_spans.size > 0: raise NotImplementedError("FIM is currently not compatible with loss masking.") - sample_len = sample.ids.shape[0] + sample_len = sample.token_ids.shape[0] eod = self._tokenizer.eod - segment_breaks = np.argwhere(sample.ids == eod) # split sample by document + segment_breaks = np.argwhere(sample.token_ids == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example curr_start_position = 0 @@ -60,26 +60,26 @@ def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # Only permute non-empty segments. if loc - curr_start_position > 0: # permute {prefix, suffix, middle} or {suffix, prefix, middle} - permuted = self._fim_split_and_permute_sequence(sample.ids[curr_start_position:loc], np_rng) + permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:loc], np_rng) new_samples += [permuted, [eod]] curr_start_position = loc + 1 # jump over the EOD token # Permute the segment after the last EOD - permuted = self._fim_split_and_permute_sequence(sample.ids[curr_start_position:], np_rng) + permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:], np_rng) new_samples.append(permuted) - sample.ids = np.concatenate(new_samples) + sample.token_ids = np.concatenate(new_samples) else: - sample.ids = self._fim_split_and_permute_sequence(sample.ids, np_rng) + sample.token_ids = self._fim_split_and_permute_sequence(sample.token_ids, np_rng) # Truncate or pad sequence to max-length - diff = sample.ids.shape[0] - sample_len + diff = sample.token_ids.shape[0] - sample_len if diff > 0: # too long - sample.ids = sample.ids[:sample_len] + sample.token_ids = sample.token_ids[:sample_len] elif diff < 0: # too short - sample.ids = np.concatenate([sample.ids, np.full((-1 * diff), self._pad_tok_id)]) + sample.token_ids = np.concatenate([sample.token_ids, np.full((-1 * diff), self._pad_tok_id)]) - assert sample.ids.shape[0] == sample_len + assert sample.token_ids.shape[0] == sample_len return sample def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 30ae5887..fb9823b7 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,4 +1,3 @@ -import dataclasses import pathlib import struct import typing @@ -6,23 +5,12 @@ import numpy as np from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -@dataclasses.dataclass -class GPTMemmapDocument: - text: np.ndarray - spans: np.ndarray - - -@dataclasses.dataclass -class GPTMemmapSample: - ids: np.ndarray - spans: np.ndarray - - class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, @@ -105,7 +93,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get(self, idx, offset=0, length=None) -> GPTMemmapSample: + def get(self, idx, offset=0, length=None) -> GPTSample: ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, @@ -117,7 +105,7 @@ def get(self, idx, offset=0, length=None) -> GPTMemmapSample: if span[0] < offset + len(ids) and span[1] >= offset: spans.append([max(span[0], offset) - offset, min(span[1], offset + len(ids) - 1) - offset]) # return (ids, np.array(spans, dtype=np.int32).reshape(-1, 2)) - return GPTMemmapSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) + return GPTSample(token_ids=ids, ignore_loss_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self) -> str: @@ -147,7 +135,7 @@ def get_span_sizes(self) -> np.ndarray: return self._num_spans @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTMemmapDocument]): + def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata dtype = None num_documents = 0 @@ -166,21 +154,21 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for document in documents: # Infer dtype from the first document if dtype is None: - dtype = document.text.dtype + dtype = document.token_ids.dtype assert dtype is not None, "Document dtype could not be inferred from the data." # Ensure all documents have the same dtype - assert document.text.dtype == dtype, f"Expected dtype {dtype}, got {document.text.dtype}." + assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." # Write document to binary file - bin_stream.write(document.text.tobytes(order="C")) + bin_stream.write(document.token_ids.tobytes(order="C")) # Update metadata - doc_length = len(document.text) + doc_length = len(document.token_ids) lengths.append(doc_length) pointers.append(offset) - num_spans.append(len(document.spans)) - spans.append(document.spans) + num_spans.append(len(document.ignore_loss_spans)) + spans.append(document.ignore_loss_spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 47dd1817..4b821368 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -47,7 +47,7 @@ def __getitem__(self, idx) -> GPTSample: end = np.random.RandomState(np_seed).randint(start, len(ids)) spans.append([start, end]) prev_end = end - return GPTSample(ids=ids, spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) + return GPTSample(token_ids=ids, ignore_loss_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index da0bf8e6..a0c8ccdb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -24,8 +24,8 @@ @dataclasses.dataclass class GPTSample: - ids: np.ndarray - spans: np.ndarray + token_ids: np.ndarray + ignore_loss_spans: np.ndarray class GPTSampledIndexedDataset(SampledDataset): @@ -204,14 +204,14 @@ def __getitem__(self, idx: int) -> typing.Any: sample_spans = [] span_offset = 0 for sample in sample_list: - sample_ids.extend(sample.ids) - for span in sample.spans: + sample_ids.extend(sample.token_ids) + for span in sample.ignore_loss_spans: sample_spans.append([span[0] + span_offset, span[1] + span_offset]) - span_offset += len(sample.ids) + span_offset += len(sample.token_ids) sample_ids = np.array(sample_ids, dtype=np.int64) sample_spans = np.array(sample_spans, dtype=np.int32).reshape(-1, 2) - return GPTSample(ids=sample_ids, spans=sample_spans) + return GPTSample(token_ids=sample_ids, ignore_loss_spans=sample_spans) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a3219b56..76fb3928 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -9,7 +9,8 @@ import tqdm import transformers -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer @@ -54,7 +55,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> dict[str, typing.An def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTMemmapDocument( + yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index dbbf419a..e271c01e 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -16,7 +16,7 @@ def __init__(self, config: TokenizerConfig): self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( pretrained_model_name_or_path=config.path, errors="replace", max_len=None ) - self.special_tokens_mode = config.special_tokens_mode + self.sequence_delimiters = config.sequence_delimiters if self.tokenizer.eos_token_id is None: raise ValueError("Tokenizer does not have an EOS token.") self.eod_id = self.tokenizer.eos_token_id @@ -36,19 +36,19 @@ def inv_vocab(self) -> dict[int, str]: return self._inv_vocab def tokenize(self, text: str, beginning_of_text=True, end_of_text=True) -> list[int]: - if self.special_tokens_mode == SequenceDelimiters.eos_only: + if self.sequence_delimiters == SequenceDelimiters.eos_only: return self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) - elif self.special_tokens_mode == SequenceDelimiters.bos_only: + elif self.sequence_delimiters == SequenceDelimiters.bos_only: return ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode( text, add_special_tokens=False ) - elif self.special_tokens_mode == SequenceDelimiters.bos_eos: + elif self.sequence_delimiters == SequenceDelimiters.bos_eos: return ( ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) ) - elif self.special_tokens_mode == SequenceDelimiters.no_delimiters: + elif self.sequence_delimiters == SequenceDelimiters.no_delimiters: return self.tokenizer.encode(text, add_special_tokens=False) else: # TODO: How do we handle when beginning_of_text=False or end_of_text=False? diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 921216ff..02982402 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -203,14 +203,14 @@ def preprocess( sequence_first = common_kwargs[TransformerKwargs.sequence_first] sequence_length = common_kwargs[TransformerKwargs.sequence_length] - batch.ids = batch.ids.to( + batch.token_ids = batch.token_ids.to( device=self._tensor_space.distributed.device, dtype=torch.int64, non_blocking=True, ) if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.ids = batch.ids.transpose(0, 1).contiguous() + batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.create_tensors(sequence_length) @@ -224,10 +224,10 @@ def preprocess( for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.ids[sequence_k - sequence_q : sequence_k] + tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -240,13 +240,13 @@ def preprocess( } if phase != PhaseType.inference: if sequence_first: - labels = batch.ids[sequence_k - sequence_q + 1 : sequence_k + 1] + labels = batch.token_ids[sequence_k - sequence_q + 1 : sequence_k + 1] else: # TODO: Avoid multiple contiguous calls? - labels = batch.ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config - for i, spans_i in enumerate(batch.spans): + for i, spans_i in enumerate(batch.ignore_loss_spans): mask_indices = ( torch.cat([torch.arange(s - 1, e) for s, e in spans_i]) if len(spans_i) diff --git a/tests/common.py b/tests/common.py index 87a7b978..fb1a0699 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,8 @@ import pytest import torch -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -244,7 +245,7 @@ def get_test_dataset( end = random.Random(doc_seed).randint(start, len(doc) - 1) spans.append([start, end]) prev_end = end - documents[idx] = GPTMemmapDocument(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) + documents[idx] = GPTSample(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) GPTMemmapDataset.write_dataset(prefix, documents) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index 6cb22eab..079a6173 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -4,14 +4,15 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset, GPTMemmapDocument +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_gpt_memmap_dataset(dtype): documents = [ - GPTMemmapDocument(text, spans) + GPTSample(text, spans) for text, spans in zip( [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)], np.array([[]] * 100, dtype=np.int32), @@ -24,13 +25,13 @@ def test_gpt_memmap_dataset(dtype): for i, document in enumerate(documents): memmap_sample = dataset.get(i) assert np.array_equal( - memmap_sample.ids, document.text, equal_nan=True - ), f"Mismatch for document {i}: {document.text} != {dataset.get(i)}." - if len(document.spans) > 0: + memmap_sample.token_ids, document.token_ids, equal_nan=True + ), f"Mismatch for document {i}: {document.token_ids} != {dataset.get(i)}." + if len(document.ignore_loss_spans) > 0: assert np.array_equal( - memmap_sample.spans, document.spans, equal_nan=True - ), f"Mismatch for non-empty spans {i}: {document.spans} != {dataset.get(i)}." + memmap_sample.ignore_loss_spans, document.ignore_loss_spans, equal_nan=True + ), f"Mismatch for non-empty spans {i}: {document.ignore_loss_spans} != {dataset.get(i)}." else: assert np.array_equal( - memmap_sample.spans.flatten(), document.spans.flatten(), equal_nan=True - ), f"Mismatch for empty spans {i}: {document.spans} != {dataset.get(i)}." + memmap_sample.ignore_loss_spans.flatten(), document.ignore_loss_spans.flatten(), equal_nan=True + ), f"Mismatch for empty spans {i}: {document.ignore_loss_spans} != {dataset.get(i)}." From 0583deca315f68cfc65076e7c8c2c97e382865be Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 03:40:32 +0000 Subject: [PATCH 22/45] make loss masking opt-in in cross-entropy --- fast_llm/functional/cross_entropy.py | 55 +++++++++++++++------ fast_llm/functional/triton/cross_entropy.py | 8 +-- fast_llm/layers/language_model/config.py | 5 ++ fast_llm/layers/language_model/head.py | 3 ++ 4 files changed, 52 insertions(+), 19 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 62f120f4..abf0325e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -13,6 +13,8 @@ def torch_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, + ignore_index: int = -100, + apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -27,7 +29,7 @@ def torch_cross_entropy_forward_backward( if grad_output is None: loss = None else: - loss = torch.nn.functional.cross_entropy(logits_, target).mean() + loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean() loss.backward(torch.full_like(loss, grad_output)) loss.detach_() return loss.detach(), logits_.grad.detach().to(logits.dtype) @@ -39,6 +41,8 @@ def fused_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, + ignore_index: int = -100, + apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -47,6 +51,10 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. + if apply_loss_mask: + loss_mask = target != ignore_index + target = target[loss_mask].unsqueeze(1) + logits = logits[loss_mask] target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: @@ -64,7 +72,10 @@ def fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - grad = exp_logits.to(logits.dtype) + if apply_loss_mask: + grad = torch.where(loss_mask, exp_logits.to(logits.dtype), 0) + else: + grad = exp_logits.to(logits.dtype) loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() @@ -78,6 +89,8 @@ def parallel_cross_entropy_forward_backward( grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, + ignore_index: int = -100, + apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -85,6 +98,10 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions + if apply_loss_mask: + loss_mask = target != ignore_index + target = target[loss_mask].unsqueeze(1) + logits = logits[loss_mask] target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0] @@ -113,7 +130,10 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - grad = exp_logits2.to(logits.dtype) + if apply_loss_mask: + grad = torch.where(target_mask, exp_logits2.to(logits.dtype), 0) + else: + grad = exp_logits2.to(logits.dtype) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) @@ -137,6 +157,7 @@ def cross_entropy_forward_backward( implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, ignore_index: int = -100, + apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -144,21 +165,23 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ - loss_mask = target != ignore_index - target = target[loss_mask] - logits = logits[loss_mask] if group: Assert.eq(implementation, CrossEntropyImpl.fused) - loss, grad_logits = parallel_cross_entropy_forward_backward( - logits, target, grad_output, group, logits_scale_factor=logits_scale_factor + return parallel_cross_entropy_forward_backward( + logits, + target, + grad_output, + group, + logits_scale_factor=logits_scale_factor, + ignore_index=ignore_index, + apply_loss_mask=apply_loss_mask, ) else: - loss, grad_logits = _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor + return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( + logits, + target, + grad_output, + logits_scale_factor=logits_scale_factor, + ignore_index=ignore_index, + apply_loss_mask=apply_loss_mask, ) - if grad_logits is not None: - grad = torch.zeros((loss_mask.size(0), *logits.shape[1:]), dtype=logits.dtype, device=logits.device) - grad.index_put_((loss_mask,), grad_logits) - return loss, grad - else: - return loss, grad_logits diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 9e6e697f..5d087d78 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -17,6 +17,7 @@ def triton_cross_entropy_forward_backward_kernel( grad_logits_stride_0, logits_scale_factor: tl.constexpr, block_size: tl.constexpr, + ignore_index: tl.constexpr, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -35,7 +36,7 @@ def triton_cross_entropy_forward_backward_kernel( label_idx = tl.load(labels_ptr + block_idx) label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) - if label_idx < 0: + if label_idx < 0 or label_idx == ignore_index: loss = 0.0 else: loss = tl.log(sum_exp_logits) + max_logits - label_logits @@ -47,7 +48,7 @@ def triton_cross_entropy_forward_backward_kernel( exp_logits = exp_logits / sum_exp_logits if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - if label_idx < 0: + if label_idx < 0 or label_idx == ignore_index: grad_losses = 0.0 grad_logits = grad_losses * tl.where(col_offsets == label_idx, exp_logits - 1.0, exp_logits) tl.store(grad_logits_ptr + col_offsets, grad_logits, mask=mask) @@ -56,9 +57,9 @@ def triton_cross_entropy_forward_backward_kernel( def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, + ignore_index: int = -100, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -90,5 +91,6 @@ def triton_cross_entropy_forward_backward( logits_scale_factor, block_size=block_size, num_warps=num_warps, + ignore_index=ignore_index, ) return losses.mean(), None if grad_output is None else grad_logits diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..1f6f54f0 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -173,6 +173,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + apply_loss_mask: bool = Field( + default=False, + desc="Enable loss-masking for cross-entropy computation", + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.transformer.init_method_std is None: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4c03e393..3eae12d5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -83,6 +83,8 @@ def __init__( else: self._cross_entropy_impl = CrossEntropyImpl.fused + self._apply_loss_mask = config.apply_loss_mask + self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) def forward( @@ -243,6 +245,7 @@ def _logits_cross_entropy_forward_backward( grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, + apply_loss_mask=self._apply_loss_mask, ) # TODO: de-allocate earlier. del logits From 7c40bf22b4df4479848a62baf9ade97ac12a4963 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 05:17:34 +0000 Subject: [PATCH 23/45] make spans opt-in during prepare --- fast_llm/data/dataset/gpt/memmap.py | 15 +++++--- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 38 +++++++++++++------ 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index fb9823b7..c34e1aee 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -104,7 +104,6 @@ def get(self, idx, offset=0, length=None) -> GPTSample: for span in self._spans[idx]: if span[0] < offset + len(ids) and span[1] >= offset: spans.append([max(span[0], offset) - offset, min(span[1], offset + len(ids) - 1) - offset]) - # return (ids, np.array(spans, dtype=np.int32).reshape(-1, 2)) return GPTSample(token_ids=ids, ignore_loss_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property @@ -167,8 +166,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP doc_length = len(document.token_ids) lengths.append(doc_length) pointers.append(offset) - num_spans.append(len(document.ignore_loss_spans)) - spans.append(document.ignore_loss_spans) + if document.ignore_loss_spans is not None: + num_spans.append(len(document.ignore_loss_spans)) + spans.append(document.ignore_loss_spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 @@ -176,14 +176,19 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP lengths = np.array(lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) - spans = np.vstack(spans, dtype=np.int32) + if len(spans) > 0: + spans = np.vstack(spans, dtype=np.int32) + else: + spans = np.array(spans, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 adds number of spans and spans to the index file. + # Version 2 optionally adds loss-masking spans idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" dict[str, list[typing.Any]]: + input_ids = [ + np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + for text in batch[self._config.dataset.field] + ] + num_tokens = [len(x) for x in input_ids] + return { + "input_ids": input_ids, + "num_tokens": num_tokens, + } + + def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids, token_spans = map( list, zip( @@ -54,11 +65,15 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> dict[str, typing.An shard_output_path = self._config.output_path / prefix def _document_generator(): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) + if "token_spans" in shard_dataset.column_names and self._config.dataset.spans_field is not None: + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + ) + else: + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -144,15 +159,16 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if ( - self._config.dataset.spans_field is not None - and self._config.dataset.spans_field not in dataset.column_names - ): - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.") + if self._config.dataset.spans_field is not None: + if self._config.dataset.spans_field not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.") + tokenize_fn = self._tokenize_batch_with_spans + else: + tokenize_fn = self._tokenize_batch # Tokenize the dataset in parallel tokenized_dataset = dataset.map( - self._tokenize_batch, + tokenize_fn, batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", From 1998b9f2f14364580859f1d1b80e23206abeb2f5 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 10:40:10 +0000 Subject: [PATCH 24/45] make spans opt-in for train --- fast_llm/data/data/gpt/config.py | 5 +++ fast_llm/data/data/gpt/data.py | 9 +++-- fast_llm/data/dataset/gpt/config.py | 10 ++++- fast_llm/data/dataset/gpt/fim.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 60 +++++++++++++--------------- fast_llm/data/dataset/gpt/random.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 30 ++++++++------ fast_llm/functional/cross_entropy.py | 8 ++-- fast_llm/models/gpt/model.py | 15 +++---- 9 files changed, 78 insertions(+), 63 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 27a65b31..214b7666 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -42,6 +42,11 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Multiprocessing context. Do not touch.", hint=FieldHint.expert, ) + use_loss_masking_spans: bool = Field( + default=False, + desc="Read and use loss masking spans from the dataset, if present.", + hint=FieldHint.feature, + ) def _validate(self) -> None: if not self.datasets: diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 8dadc3bc..6621c547 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -29,13 +29,15 @@ @dataclasses.dataclass class GPTBatch: token_ids: torch.Tensor - ignore_loss_spans: list[torch.Tensor] + loss_masking_spans: list[torch.Tensor] def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) - stacked_spans = [torch.from_numpy(sample.ignore_loss_spans) for sample in batch] - return GPTBatch(token_ids=torch.from_numpy(stacked_ids), ignore_loss_spans=stacked_spans) + stacked_spans = None + if batch[0].loss_masking_spans is not None: + stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans) class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): @@ -97,6 +99,7 @@ def setup( sequence_length=self._max_sequence_length, vocab_size=self._vocab_size, tokenizer=self._tokenizer, + use_loss_masking_spans=self._config.use_loss_masking_spans, ) dataset = self._config.datasets[phase].build_and_sample(sampling_config) self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index da8eb3ca..a724f5c2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -33,6 +33,7 @@ class GPTSamplingConfig(SamplingConfig): sequence_length: int vocab_size: int tokenizer: "Tokenizer" + use_loss_masking_spans: bool = False @config_class() @@ -128,11 +129,16 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", hint=FieldHint.core, ) + use_loss_masking_spans: bool = Field( + default=False, + desc="Read and use loss masking spans from the dataset, if present.", + hint=FieldHint.feature, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path) + return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.use_loss_masking_spans) @config_class() @@ -382,7 +388,7 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: dataset_configs = [ GPTDatasetSliceConfig( # TODO: this duplicates memmap datasets for each phase. - dataset=GPTMemmapDatasetConfig(path=prefix), + dataset=GPTMemmapDatasetConfig(path=prefix, use_loss_masking_spans=config.use_loss_masking_spans), begin=phase_splits[phase_index], end=phase_splits[phase_index + 1], ) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 11513e76..20f795ea 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -46,7 +46,7 @@ def name(self) -> str: def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # FIM # TODO: permute segments in sample_list, before concatenating. - if self._config.rate > 0.0 and sample.ignore_loss_spans.size > 0: + if self._config.rate > 0.0 and sample.loss_masking_spans is not None: raise NotImplementedError("FIM is currently not compatible with loss masking.") sample_len = sample.token_ids.shape[0] eod = self._tokenizer.eod diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c34e1aee..eea816b2 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -20,17 +20,22 @@ class GPTMemmapDataset(GPTIndexedDataset): See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details. """ - def __init__(self, name: str, prefix: pathlib.Path | str): - self._init(name, prefix) + def __init__(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False): + self._init(name, prefix, use_loss_masking_spans) - def _init(self, name: str, prefix: pathlib.Path | str) -> None: + def _init(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) + self._read_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER) self._version = struct.unpack(" None: offset=offset + self._document_sizes.nbytes, ) - # Spans are introduced in version 2. Datasets tokenized with version 1 do not contain span information and - # compute loss on all tokens by default - if self._version == 1: - self._num_spans = np.zeros(self._num_documents, dtype=np.int32) - self._spans = [np.array([], dtype=np.int32).reshape(-1, 2)] * self._num_documents - elif self._version == 2: + if self._read_spans: self._num_spans = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, ) - spans = [] - offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - for n_spans in self._num_spans: - span = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=n_spans * 2, - offset=offset, - ).reshape(-1, 2) - spans.append(span) - offset += span.nbytes - self._spans = spans - else: - raise ValueError(f"Unsupported version for gpt_memmap dataset: {self._version}.") + self._span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes + self._num_spans_cumsum = np.cumsum(self._num_spans, dtype=np.int64) self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -94,17 +82,25 @@ def __del__(self): del self._index_bin_buffer_mmap def get(self, idx, offset=0, length=None) -> GPTSample: - ids = np.frombuffer( + token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) - spans = [] - for span in self._spans[idx]: - if span[0] < offset + len(ids) and span[1] >= offset: - spans.append([max(span[0], offset) - offset, min(span[1], offset + len(ids) - 1) - offset]) - return GPTSample(token_ids=ids, ignore_loss_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) + spans = None + if self._read_spans: + spans = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._num_spans[idx] * 2, + offset=self._span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + # adjust the spans for the offset and length + spans = spans[(spans[:, 0] < offset + len(token_ids)) & (spans[:, 1] >= offset)] + spans[:, 0] = np.maximum(spans[:, 0], offset) - offset + spans[:, 1] = np.minimum(spans[:, 1], offset + len(token_ids) - 1) - offset + return GPTSample(token_ids=token_ids, loss_masking_spans=spans) @property def name(self) -> str: @@ -166,9 +162,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP doc_length = len(document.token_ids) lengths.append(doc_length) pointers.append(offset) - if document.ignore_loss_spans is not None: - num_spans.append(len(document.ignore_loss_spans)) - spans.append(document.ignore_loss_spans) + if document.loss_masking_spans is not None: + num_spans.append(len(document.loss_masking_spans)) + spans.append(document.loss_masking_spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 4b821368..ea314cb1 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -47,7 +47,7 @@ def __getitem__(self, idx) -> GPTSample: end = np.random.RandomState(np_seed).randint(start, len(ids)) spans.append([start, end]) prev_end = end - return GPTSample(token_ids=ids, ignore_loss_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) + return GPTSample(token_ids=ids, loss_masking_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 9630b8dc..7f8d0398 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -25,7 +25,7 @@ @dataclasses.dataclass class GPTSample: token_ids: np.ndarray - ignore_loss_spans: np.ndarray | None = None + loss_masking_spans: np.ndarray | None = None class GPTSampledIndexedDataset(SampledDataset): @@ -200,18 +200,22 @@ def __getitem__(self, idx: int) -> typing.Any: for doc in range(doc_f, doc_l + 1) ] - sample_ids = [] - sample_spans = [] - span_offset = 0 - for sample in sample_list: - sample_ids.extend(sample.token_ids) - for span in sample.ignore_loss_spans: - sample_spans.append([span[0] + span_offset, span[1] + span_offset]) - span_offset += len(sample.token_ids) - sample_ids = np.array(sample_ids, dtype=np.int64) - sample_spans = np.array(sample_spans, dtype=np.int32).reshape(-1, 2) - - return GPTSample(token_ids=sample_ids, ignore_loss_spans=sample_spans) + if sample_list[0].loss_masking_spans is not None: + sample_ids = [] + sample_spans = [] + span_offset = 0 + for sample in sample_list: + sample_ids.extend(sample.token_ids) + for span in sample.loss_masking_spans: + sample_spans.append([span[0] + span_offset, span[1] + span_offset]) + span_offset += len(sample.token_ids) + sample_ids = np.array(sample_ids, dtype=np.int64) + sample_spans = np.array(sample_spans, dtype=np.int32).reshape(-1, 2) + else: + sample_ids = np.concatenate([sample.token_ids for sample in sample_list]) + sample_spans = None + + return GPTSample(token_ids=sample_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index abf0325e..4809c678 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -53,7 +53,7 @@ def fused_cross_entropy_forward_backward( # Way faster and more memory-efficient than the pytorch version. if apply_loss_mask: loss_mask = target != ignore_index - target = target[loss_mask].unsqueeze(1) + target = target[loss_mask] logits = logits[loss_mask] target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() @@ -73,7 +73,7 @@ def fused_cross_entropy_forward_backward( exp_logits *= logits_scale_factor if apply_loss_mask: - grad = torch.where(loss_mask, exp_logits.to(logits.dtype), 0) + grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0) else: grad = exp_logits.to(logits.dtype) @@ -100,7 +100,7 @@ def parallel_cross_entropy_forward_backward( # TODO: Optimize, overlap/combine reductions if apply_loss_mask: loss_mask = target != ignore_index - target = target[loss_mask].unsqueeze(1) + target = target[loss_mask] logits = logits[loss_mask] target = target.unsqueeze(1) @@ -131,7 +131,7 @@ def parallel_cross_entropy_forward_backward( exp_logits2 *= logits_scale_factor if apply_loss_mask: - grad = torch.where(target_mask, exp_logits2.to(logits.dtype), 0) + grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0) else: grad = exp_logits2.to(logits.dtype) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 02982402..6438d4f1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -246,13 +246,14 @@ def preprocess( labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config - for i, spans_i in enumerate(batch.ignore_loss_spans): - mask_indices = ( - torch.cat([torch.arange(s - 1, e) for s, e in spans_i]) - if len(spans_i) - else torch.tensor([], dtype=torch.int64) - ) - labels[i, mask_indices] = -100 + if batch.loss_masking_spans is not None: + for i, spans_i in enumerate(batch.loss_masking_spans): + mask_indices = ( + torch.cat([torch.arange(s - 1, e) for s, e in spans_i]) + if len(spans_i) + else torch.tensor([], dtype=torch.int64) + ) + labels[i, mask_indices] = -100 kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) From 913a9d3ff1929ce0c78e9ac3248d6e22843c93d7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 17:59:28 +0000 Subject: [PATCH 25/45] revert tests and random dataset --- fast_llm/data/dataset/gpt/random.py | 17 +- tests/common.py | 14 - tests/test_dataset.py | 487 +++++----------------------- tests/test_memmap_dataset.py | 22 +- 4 files changed, 90 insertions(+), 450 deletions(-) diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index ea314cb1..142dca71 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -2,7 +2,6 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.dataset.gpt.sampled import GPTSample class GPTRandomDataset(SamplableDataset): @@ -32,22 +31,10 @@ def __init__(self, config: GPTSamplingConfig, name: str): def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx) -> GPTSample: - np_seed = self._seed + 48576439 + 74593 * idx - ids = np.random.RandomState(np_seed).randint( + def __getitem__(self, idx) -> np.ndarray: + return np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 ) - n_spans = np.random.RandomState(np_seed).randint(0, 3) - spans = [] - prev_end = -1 - for _ in range(n_spans): - if prev_end >= len(ids) - 1: - break - start = np.random.RandomState(np_seed).randint(prev_end + 1, len(ids)) - end = np.random.RandomState(np_seed).randint(start, len(ids)) - spans.append([start, end]) - prev_end = end - return GPTSample(token_ids=ids, loss_masking_spans=np.array(spans, dtype=np.int32).reshape(-1, 2)) @property def name(self) -> str: diff --git a/tests/common.py b/tests/common.py index fb1a0699..69048f8c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -11,7 +11,6 @@ import torch from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -233,19 +232,6 @@ def get_test_dataset( documents = [ np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size for document in documents ] - for idx, doc in enumerate(documents): - doc_seed = seed + idx - n_spans = random.Random(doc_seed).randint(0, 5) - spans = [] - prev_end = -1 - for _ in range(n_spans): - if prev_end >= len(doc) - 1: - break - start = random.Random(doc_seed).randint(prev_end + 1, len(doc) - 1) - end = random.Random(doc_seed).randint(start, len(doc) - 1) - spans.append([start, end]) - prev_end = end - documents[idx] = GPTSample(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) GPTMemmapDataset.write_dataset(prefix, documents) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6c5bb586..394553e5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -71,7 +71,6 @@ def get_test_data_and_samples( seed: int = 54983, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - consumed_samples: int = 0, vocab_size=TEST_VOCAB_SIZE, ): distributed_config = DistributedConfig(seed=seed) @@ -83,8 +82,8 @@ def get_test_data_and_samples( batch_config.setup(distributed_config) batch_config.validate() samples = { - phase: list(data.get_iterator(batch_config, phase, consumed_samples=consumed_samples, num_workers=0)) - for phase, n_samples in samples_per_phase.items() + phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + for phase, samples in samples_per_phase.items() } return data, samples @@ -101,18 +100,12 @@ def _get_test_dataset_concatenated_memmap(): return get_test_concatenated_memmap_dataset(_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, 4) -RANDOM_DATASET_EXPECTED_SAMPLES_IDS = [ +RANDOM_DATASET_EXPECTED_SAMPLES = [ [3954, 4105, 6766, 859, 5494, 1675, 1303, 6913], [1654, 5701, 32, 1662, 7053, 3487, 1861, 1502], [5409, 6240, 5504, 7458, 7667, 3955, 3151, 3912], [5640, 6131, 7750, 2699, 1349, 2585, 7113, 6981], ] -RANDOM_DATASET_EXPECTED_SAMPLES_SPANS = [ - [[2, 4], [7, 7]], - [[6, 6], [7, 7]], - [[1, 2]], - [], -] def test_gpt_random_dataset(): @@ -121,17 +114,9 @@ def test_gpt_random_dataset(): get_sampling_config(4, sequence_length=7) ) Assert.eq(len(sampled), 4) - sampled_ids = np.stack([sampled[i].ids for i in range(4)]) - sampled_spans = np.vstack([sampled[i].spans for i in range(4)]) - Assert.all_equal( - sampled_ids, - np.stack(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), - ) Assert.all_equal( - sampled_spans, - np.vstack( - [np.array(x, dtype=sampled_spans.dtype).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS] - ), + np.stack([sampled[i] for i in range(4)]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), ) @@ -148,36 +133,27 @@ def test_gpt_random_data(): sequence_length=7, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), ) def test_gpt_random_data_legacy(): _, samples = get_test_data_and_samples({"format": "random"}, {PhaseType.training: 4}, sequence_length=7) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(RANDOM_DATASET_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in RANDOM_DATASET_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(RANDOM_DATASET_EXPECTED_SAMPLES), ) # Most documents are too long to write here, we test a few known short ones. MEMMAP_DATASET_EXPECTED_LENGTH = 6153 MEMMAP_DATASET_EXPECTED_TOKENS = 508327 -MEMMAP_DATASET_EXPECTED_SPANS = 9138 MEMMAP_DATASET_EXPECTED_SAMPLES = { - 9: ([], []), - 10: ([80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], [[4, 6], [8, 9]]), - 13: ([78, 727, 74, 317, 1358, 89], []), - 15: ([78], [[0, 0]]), + 9: [], + 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], + 13: [78, 727, 74, 317, 1358, 89], + 15: [78], } @@ -187,19 +163,14 @@ def test_gpt_memmap(cache_directory): get_test_dataset() dataset = _get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() Assert.eq(len(dataset), MEMMAP_DATASET_EXPECTED_LENGTH) - doc_sizes = dataset.get_document_sizes() - span_sizes = dataset.get_span_sizes() - Assert.eq(doc_sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) - Assert.eq(span_sizes.sum(), MEMMAP_DATASET_EXPECTED_SPANS) - Assert.all_equal([len(dataset.get(i).ids) for i in range(100)], doc_sizes[:100]) - Assert.all_equal([len(dataset.get(i).spans) for i in range(100)], span_sizes[:100]) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) + Assert.all_equal([len(dataset.get(i)) for i in range(100)], sizes[:100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - ds_sample = dataset.get(i) - Assert.all_equal(ds_sample.ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal(ds_sample.spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) + Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) -GPT_SAMPLED_EXPECTED_SAMPLES_IDS = [ +GPT_SAMPLED_EXPECTED_SAMPLES = [ [1725, 74, 207, 1635, 4440, 2774], [359, 489, 4266, 2052, 5351, 80], [374, 7534, 87, 1073, 79, 480], @@ -210,17 +181,6 @@ def test_gpt_memmap(cache_directory): [330, 155, 2449, 1136, 1106, 5370], ] -GPT_SAMPLED_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [[0, 2]], - [], - [], - [], - [[0, 5]], - [], - [], -] - def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. @@ -230,14 +190,8 @@ def test_gpt_sampled(): ) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), ) @@ -256,12 +210,8 @@ def test_gpt_sampled_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), ) @@ -272,16 +222,12 @@ def test_gpt_sampled_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_SAMPLED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SAMPLED_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_SAMPLED_EXPECTED_SAMPLES), ) -GPT_CONCATENATED_EXPECTED_SAMPLES_IDS = [ +GPT_CONCATENATED_EXPECTED_SAMPLES = [ [243, 498, 7172, 777, 306, 74], [821, 6042, 89, 977, 4797, 499], [387, 74, 330, 328, 1858, 484], @@ -293,18 +239,6 @@ def test_gpt_sampled_data_legacy(): ] -GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [], - [], - [[0, 5]], - [], - [[0, 0]], - [], - [], -] - - def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. get_test_dataset() @@ -313,28 +247,18 @@ def test_gpt_concatenate(): GPTConcatenatedDatasetConfig, ).build() Assert.eq(len(dataset), 3 * MEMMAP_DATASET_EXPECTED_LENGTH) - doc_sizes = dataset.get_document_sizes() - span_sizes = dataset.get_span_sizes() - Assert.eq(doc_sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) - Assert.eq(span_sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_SPANS) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) for i in range(3): begin = i * MEMMAP_DATASET_EXPECTED_LENGTH - Assert.all_equal([len(dataset.get(begin + i).ids) for i in range(100)], doc_sizes[begin : begin + 100]) - Assert.all_equal([len(dataset.get(begin + i).spans) for i in range(100)], span_sizes[begin : begin + 100]) + Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(begin + i).ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal(dataset.get(begin + i).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) + Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_CONCATENATED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), ) @@ -352,30 +276,19 @@ def test_gpt_concatenate_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_CONCATENATED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_CONCATENATED_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), ) -GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS = [ +GPT_SLICE_EXPECTED_TRAINING_SAMPLES = [ [2625, 76, 2625, 2639, 74, 243], [207, 481, 5546, 74, 414, 498], [74, 333, 1963, 310, 5337, 3628], [79, 2361, 80, 2012, 84, 480], ] -GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS = [ - [], - [], - [[0, 2]], - [], -] - -GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS = [ +GPT_SLICE_EXPECTED_VALIDATION_SAMPLES = [ [2352, 3687, 2311, 4900, 542, 3732], [2551, 5283, 900, 3140, 328, 68], [7979, 2283, 329, 727, 2740, 2818], @@ -386,17 +299,6 @@ def test_gpt_concatenate_data(): [243, 3712, 86, 476, 80, 2547], ] -GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS = [ - [], - [], - [[0, 5]], - [[0, 3]], - [], - [], - [], - [], -] - def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -407,27 +309,15 @@ def test_gpt_slice(): GPTDatasetSliceConfig, ).build() Assert.eq(len(dataset), 9) - doc_sizes = dataset.get_document_sizes() - span_sizes = dataset.get_span_sizes() - Assert.all_equal([len(dataset.get(i).ids) for i in range(9)], doc_sizes[:9]) - Assert.all_equal([len(dataset.get(i).spans) for i in range(9)], span_sizes[:9]) + sizes = dataset.get_document_sizes() + Assert.all_equal([len(dataset.get(i)) for i in range(9)], sizes[:9]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i - 9).ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal(dataset.get(i - 9).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) + Assert.all_equal(dataset.get(i - 9), np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [ - np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) - for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS - ] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), ) @@ -459,20 +349,12 @@ def test_gpt_slice_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.validation]]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.validation]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS]), - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS), + np.stack(samples[PhaseType.validation]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), ) @@ -484,38 +366,29 @@ def test_gpt_slice_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.validation]]), - np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_IDS), + np.stack(samples[PhaseType.validation]), + np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.validation]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_VALIDATION_SAMPLES_SPANS]), - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_SLICE_EXPECTED_TRAINING_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_SLICE_EXPECTED_TRAINING_SAMPLES), ) COMPOSED_DATASET_EXPECTED_LENGTH = 24806 COMPOSED_DATASET_EXPECTED_TOKENS = 2033639 -COMPOSED_DATASET_EXPECTED_SPANS = 37008 COMPOSED_DATASET_EXPECTED_SAMPLES = { **MEMMAP_DATASET_EXPECTED_SAMPLES, - 6930: ([65, 2327], [[0, 0], [1, 1]]), - 11962: ([7078, 2713, 1431], [[2, 2]]), - 15958: ([207], [[0, 0]]), - 19362: ([69], [[0, 0]]), - 24098: ([555, 668, 70], [[1, 2]]), + 6930: [65, 2327], + 11962: [7078, 2713, 1431], + 15958: [207], + 19362: [69], + 24098: [555, 668, 70], } -GPT_COMPOSED_EXPECTED_SAMPLES_IDS = [ +GPT_COMPOSED_EXPECTED_SAMPLES = [ [1411, 819, 6791, 7022, 285, 249], [329, 328, 512, 1985, 3069, 7838], [5158, 1023, 8171, 798, 1431, 313], @@ -526,17 +399,6 @@ def test_gpt_slice_data_legacy(): [79, 6042, 577, 225, 207, 207], ] -GPT_COMPOSED_EXPECTED_SAMPLES_SPANS = [ - [], - [], - [[0, 5]], - [], - [], - [], - [], - [], -] - def test_gpt_compose(): # Make sure dataset splitting works and check for unintended changes in behavior. @@ -547,27 +409,17 @@ def test_gpt_compose(): GPTConcatenatedMemmapConfig, ).build() Assert.eq(len(dataset), COMPOSED_DATASET_EXPECTED_LENGTH) - doc_sizes = dataset.get_document_sizes() - span_sizes = dataset.get_span_sizes() - Assert.eq(doc_sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) - Assert.eq(span_sizes.sum(), COMPOSED_DATASET_EXPECTED_SPANS) - Assert.all_equal([len(dataset.get(i).ids) for i in range(0, len(dataset), 20)], doc_sizes[::20]) - Assert.all_equal([len(dataset.get(i).spans) for i in range(0, len(dataset), 20)], span_sizes[::20]) + sizes = dataset.get_document_sizes() + Assert.eq(sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) + Assert.all_equal([len(dataset.get(i)) for i in range(0, len(dataset), 20)], sizes[::20]) for i, sample in COMPOSED_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i).ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal(dataset.get(i).spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) + Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) - print(np.stack([sampled[i].ids for i in range(8)]).tolist()) + print(np.stack([sampled[i] for i in range(8)]).tolist()) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_COMPOSED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_COMPOSED_EXPECTED_SAMPLES_SPANS] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES), ) @@ -586,12 +438,8 @@ def test_gpt_composed_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_COMPOSED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_COMPOSED_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_COMPOSED_EXPECTED_SAMPLES), ) @@ -606,17 +454,6 @@ def test_gpt_composed_data(): [409, 5091, 328, 1378, 5483, 88], ] -GPT_BLENDED_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [[1, 5]], - [[0, 2]], - [], - [], - [], - [], - [[0, 5]], -] - def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. @@ -635,15 +472,9 @@ def test_gpt_blended(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), + np.stack([sampled[i] for i in range(8)]), np.array(GPT_BLENDED_EXPECTED_SAMPLES), ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) for x in GPT_BLENDED_EXPECTED_SAMPLES_SPANS] - ), - ) def test_gpt_blended_data(): @@ -666,16 +497,12 @@ def test_gpt_blended_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), + np.stack(samples[PhaseType.training]), np.array(GPT_BLENDED_EXPECTED_SAMPLES), ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_EXPECTED_SAMPLES_SPANS]), - ) -GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_IDS = [ +GPT_BLENDED_LEGACY_EXPECTED_SAMPLES = [ [1725, 74, 207, 1635, 4440, 2774], [328, 80, 263, 890, 1797, 88], [359, 489, 4266, 2052, 5351, 80], @@ -686,17 +513,6 @@ def test_gpt_blended_data(): [409, 5091, 328, 1378, 5483, 88], ] -GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [], - [[0, 2]], - [], - [], - [], - [], - [[0, 5]], -] - def test_gpt_blended_data_legacy(): get_test_dataset() @@ -711,16 +527,12 @@ def test_gpt_blended_data_legacy(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_LEGACY_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_BLENDED_LEGACY_EXPECTED_SAMPLES), ) -GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS = [ +GPT_BLENDED_MIXED_EXPECTED_SAMPLES = [ [1725, 74, 207, 1635, 4440, 2774], [916, 6683, 7685, 1277, 5106, 378], [359, 489, 4266, 2052, 5351, 80], @@ -731,17 +543,6 @@ def test_gpt_blended_data_legacy(): [2210, 8179, 73, 2582, 897, 1178], ] -GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [], - [[0, 2]], - [], - [], - [], - [], - [], -] - def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. @@ -759,17 +560,8 @@ def test_gpt_blended_mixed(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].ids for i in range(8)]), - np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in range(8)]), - np.vstack( - [ - np.array(x, dtype=sampled[0].spans.dtype).reshape(-1, 2) - for x in GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS - ] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), ) @@ -788,16 +580,12 @@ def test_gpt_blended_mixed_data(): sequence_length=5, ) Assert.all_equal( - np.stack([batch.ids[0] for batch in samples[PhaseType.training]]), - np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES_IDS), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples[PhaseType.training]]), - np.vstack([np.array(x, dtype=np.int32).reshape(-1, 2) for x in GPT_BLENDED_MIXED_EXPECTED_SAMPLES_SPANS]), + np.stack(samples[PhaseType.training]), + np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), ) -GPT_FIM_EXPECTED_SAMPLES_IDS = [ +GPT_FIM_EXPECTED_SAMPLES = [ [1725, 74, 207, 1635, 4440, 2774], [359, 489, 4266, 2052, 5351, 80], [86, 89, 22255, 1073, 79, 480], @@ -808,30 +596,13 @@ def test_gpt_blended_mixed_data(): [86, 89, 1461, 87, 330, 7876], ] -GPT_FIM_EXPECTED_SAMPLES_SPANS = [ - [[0, 5]], - [[0, 2]], - [], - [], - [], - [[0, 5]], - [], - [], -] - -GPT_FIM_VALID_IDS = [2, 3, 4, 6, 7] - def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. get_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. sampling_config = get_sampling_config( - 8, - sequence_length=5, - tokenizer=Tokenizer( - TokenizerConfig.from_dict({"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}) - ), + 8, sequence_length=5, tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})) ) sampled = _get_dataset_config( { @@ -848,54 +619,13 @@ def test_gpt_fim(): Assert.eq(len(sampled), 8) # TODO: Does this output make sense? Assert.all_equal( - np.stack([sampled[i].ids for i in GPT_FIM_VALID_IDS]), - np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS]), - ) - Assert.all_equal( - np.vstack([sampled[i].spans for i in GPT_FIM_VALID_IDS]), - np.vstack( - [ - np.array(x, dtype=sampled[GPT_FIM_VALID_IDS[0]].spans.dtype).reshape(-1, 2) - for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS] - ] - ), + np.stack([sampled[i] for i in range(8)]), + np.array(GPT_FIM_EXPECTED_SAMPLES), ) def test_gpt_fim_data(): - _, samples1 = get_test_data_and_samples( - { - "datasets": { - "Training": { - "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, - "rate": 0.5, - "prefix_token": "w", - "middle_token": "x", - "pad_token": "y", - "suffix_token": "z", - } - }, - "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, - }, - {PhaseType.training: 5}, - sequence_length=5, - consumed_samples=2, - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), - np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]]), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), - np.vstack( - [ - np.array(x, dtype=np.int32).reshape(-1, 2) - for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] - ] - ), - ) - _, samples2 = get_test_data_and_samples( + _, samples = get_test_data_and_samples( { "datasets": { "Training": { @@ -908,77 +638,30 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + "tokenizer": {"path": TOKENIZER_PATH}, }, {PhaseType.training: 8}, sequence_length=5, - consumed_samples=6, - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), - np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), - np.vstack( - [ - np.array(x, dtype=np.int32).reshape(-1, 2) - for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] - ] - ), + np.stack(samples[PhaseType.training]), + np.array(GPT_FIM_EXPECTED_SAMPLES), ) def test_gpt_fim_data_legacy(): - _, samples1 = get_test_data_and_samples( - { - "format": "list", - "path": [str(DATASET_PREFIX)], - "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, - "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, - "split": [1, 0, 0], - }, - {PhaseType.training: 5}, - sequence_length=5, - consumed_samples=2, - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples1[PhaseType.training]]), - np.array( - [GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[:3]], - ), - ) - Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples1[PhaseType.training]]), - np.vstack( - [ - np.array(x, dtype=np.int32).reshape(-1, 2) - for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[:3]] - ] - ), - ) - _, samples2 = get_test_data_and_samples( + _, samples = get_test_data_and_samples( { "format": "list", "path": [str(DATASET_PREFIX)], "fim": {"rate": 0.5, "prefix_token": "w", "middle_token": "x", "pad_token": "y", "suffix_token": "z"}, - "tokenizer": {"path": TOKENIZER_PATH, "special_tokens_mode": "tokenizer_default"}, + "tokenizer": {"path": TOKENIZER_PATH}, "split": [1, 0, 0], }, {PhaseType.training: 8}, sequence_length=5, - consumed_samples=6, - ) - Assert.all_equal( - np.stack([batch.ids[0] for batch in samples2[PhaseType.training]]), - np.array([GPT_FIM_EXPECTED_SAMPLES_IDS[i] for i in GPT_FIM_VALID_IDS[3:]]), ) Assert.all_equal( - np.vstack([batch.spans[0] for batch in samples2[PhaseType.training]]), - np.vstack( - [ - np.array(x, dtype=np.int32).reshape(-1, 2) - for x in [GPT_FIM_EXPECTED_SAMPLES_SPANS[i] for i in GPT_FIM_VALID_IDS[3:]] - ] - ), + np.stack(samples[PhaseType.training]), + np.array(GPT_FIM_EXPECTED_SAMPLES), ) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index 079a6173..261f87e4 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -5,33 +5,17 @@ import pytest from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_gpt_memmap_dataset(dtype): - documents = [ - GPTSample(text, spans) - for text, spans in zip( - [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)], - np.array([[]] * 100, dtype=np.int32), - ) - ] + documents = [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - memmap_sample = dataset.get(i) assert np.array_equal( - memmap_sample.token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document.token_ids} != {dataset.get(i)}." - if len(document.ignore_loss_spans) > 0: - assert np.array_equal( - memmap_sample.ignore_loss_spans, document.ignore_loss_spans, equal_nan=True - ), f"Mismatch for non-empty spans {i}: {document.ignore_loss_spans} != {dataset.get(i)}." - else: - assert np.array_equal( - memmap_sample.ignore_loss_spans.flatten(), document.ignore_loss_spans.flatten(), equal_nan=True - ), f"Mismatch for empty spans {i}: {document.ignore_loss_spans} != {dataset.get(i)}." + dataset.get(i), document, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." From 23dc7eb299282151a59b348492e2ccfde108715a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 19:05:29 +0000 Subject: [PATCH 26/45] partially fix existing tests --- fast_llm/data/dataset/gpt/random.py | 7 +++++-- tests/common.py | 4 +++- tests/test_dataset.py | 26 ++++++++++++++------------ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 142dca71..1ca5f073 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -2,6 +2,7 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.gpt.sampled import GPTSample class GPTRandomDataset(SamplableDataset): @@ -32,8 +33,10 @@ def __len__(self) -> int: return self._num_samples def __getitem__(self, idx) -> np.ndarray: - return np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + return GPTSample( + np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( + 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + ) ) @property diff --git a/tests/common.py b/tests/common.py index 69048f8c..822f2bcb 100644 --- a/tests/common.py +++ b/tests/common.py @@ -11,6 +11,7 @@ import torch from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.models.gpt.config import ( LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, @@ -230,7 +231,8 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) documents = [ - np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size for document in documents + GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) + for document in documents ] GPTMemmapDataset.write_dataset(prefix, documents) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 394553e5..493dd58a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -82,7 +82,9 @@ def get_test_data_and_samples( batch_config.setup(distributed_config) batch_config.validate() samples = { - phase: [batch[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + phase: [ + batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + ] for phase, samples in samples_per_phase.items() } return data, samples @@ -115,7 +117,7 @@ def test_gpt_random_dataset(): ) Assert.eq(len(sampled), 4) Assert.all_equal( - np.stack([sampled[i] for i in range(4)]), + np.stack([sampled[i].token_ids for i in range(4)]), np.array(RANDOM_DATASET_EXPECTED_SAMPLES), ) @@ -165,9 +167,9 @@ def test_gpt_memmap(cache_directory): Assert.eq(len(dataset), MEMMAP_DATASET_EXPECTED_LENGTH) sizes = dataset.get_document_sizes() Assert.eq(sizes.sum(), MEMMAP_DATASET_EXPECTED_TOKENS) - Assert.all_equal([len(dataset.get(i)) for i in range(100)], sizes[:100]) + Assert.all_equal([len(dataset.get(i).token_ids) for i in range(100)], sizes[:100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(i).token_ids, np.array(sample, dtype=np.uint16)) GPT_SAMPLED_EXPECTED_SAMPLES = [ @@ -190,7 +192,7 @@ def test_gpt_sampled(): ) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_SAMPLED_EXPECTED_SAMPLES), ) @@ -257,7 +259,7 @@ def test_gpt_concatenate(): sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_Ids for i in range(8)]), np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), ) @@ -316,7 +318,7 @@ def test_gpt_slice(): sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_SLICE_EXPECTED_VALIDATION_SAMPLES), ) @@ -416,9 +418,9 @@ def test_gpt_compose(): Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) - print(np.stack([sampled[i] for i in range(8)]).tolist()) + print(np.stack([sampled[i].token_ids for i in range(8)]).tolist()) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_COMPOSED_EXPECTED_SAMPLES), ) @@ -472,7 +474,7 @@ def test_gpt_blended(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_BLENDED_EXPECTED_SAMPLES), ) @@ -560,7 +562,7 @@ def test_gpt_blended_mixed(): ).build_and_sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_BLENDED_MIXED_EXPECTED_SAMPLES), ) @@ -619,7 +621,7 @@ def test_gpt_fim(): Assert.eq(len(sampled), 8) # TODO: Does this output make sense? Assert.all_equal( - np.stack([sampled[i] for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_FIM_EXPECTED_SAMPLES), ) From 6712d5e28300ca96043e26e40c6bd8ba4feccfcf Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 19:15:52 +0000 Subject: [PATCH 27/45] fix existing tests --- fast_llm/data/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- tests/test_dataset.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index d525d6ea..4f1bdb66 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -43,7 +43,7 @@ class TokenizerConfig(Config): hint=FieldHint.core, ) sequence_delimiters: SequenceDelimiters = Field( - default=SequenceDelimiters.bos_only, + default=SequenceDelimiters.tokenizer_default, desc="Boundary tokens (bos/eos) to use for tokenizing sequences", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 7f8d0398..c1b58813 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -212,7 +212,7 @@ def __getitem__(self, idx: int) -> typing.Any: sample_ids = np.array(sample_ids, dtype=np.int64) sample_spans = np.array(sample_spans, dtype=np.int32).reshape(-1, 2) else: - sample_ids = np.concatenate([sample.token_ids for sample in sample_list]) + sample_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) sample_spans = None return GPTSample(token_ids=sample_ids, loss_masking_spans=sample_spans) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 493dd58a..e82a2ae4 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -253,13 +253,13 @@ def test_gpt_concatenate(): Assert.eq(sizes.sum(), 3 * MEMMAP_DATASET_EXPECTED_TOKENS) for i in range(3): begin = i * MEMMAP_DATASET_EXPECTED_LENGTH - Assert.all_equal([len(dataset.get(begin + i)) for i in range(100)], sizes[begin : begin + 100]) + Assert.all_equal([len(dataset.get(begin + i).token_ids) for i in range(100)], sizes[begin : begin + 100]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(begin + i), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(begin + i).token_ids, np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( - np.stack([sampled[i].token_Ids for i in range(8)]), + np.stack([sampled[i].token_ids for i in range(8)]), np.array(GPT_CONCATENATED_EXPECTED_SAMPLES), ) @@ -312,9 +312,9 @@ def test_gpt_slice(): ).build() Assert.eq(len(dataset), 9) sizes = dataset.get_document_sizes() - Assert.all_equal([len(dataset.get(i)) for i in range(9)], sizes[:9]) + Assert.all_equal([len(dataset.get(i).token_ids) for i in range(9)], sizes[:9]) for i, sample in MEMMAP_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i - 9), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(i - 9).token_ids, np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) Assert.all_equal( @@ -413,9 +413,9 @@ def test_gpt_compose(): Assert.eq(len(dataset), COMPOSED_DATASET_EXPECTED_LENGTH) sizes = dataset.get_document_sizes() Assert.eq(sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS) - Assert.all_equal([len(dataset.get(i)) for i in range(0, len(dataset), 20)], sizes[::20]) + Assert.all_equal([len(dataset.get(i).token_ids) for i in range(0, len(dataset), 20)], sizes[::20]) for i, sample in COMPOSED_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16)) + Assert.all_equal(dataset.get(i).token_ids, np.array(sample, dtype=np.uint16)) sampled = dataset.sample(get_sampling_config(8, sequence_length=5)) Assert.eq(len(sampled), 8) print(np.stack([sampled[i].token_ids for i in range(8)]).tolist()) From 6802627dbe35ac13f71fba88ce901d686776d396 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 19:17:40 +0000 Subject: [PATCH 28/45] remove get_span_sizes --- fast_llm/data/dataset/gpt/indexed.py | 6 ------ fast_llm/data/dataset/gpt/memmap.py | 8 -------- 2 files changed, 14 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index e3ddbeda..2c158bff 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -36,9 +36,6 @@ def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return self._dataset.get_document_sizes()[self._begin : self._end] - def get_span_sizes(self) -> np.ndarray: - return self._dataset.get_span_sizes()[self._begin : self._end] - class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset @@ -48,6 +45,3 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get_span_sizes(self) -> np.ndarray: - return np.concatenate([dataset.get_span_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index eea816b2..a0b950e5 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -121,14 +121,6 @@ def get_document_sizes(self) -> np.ndarray: """ return self._document_sizes - def get_span_sizes(self) -> np.ndarray: - """ - The number of spans in each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - return self._num_spans - @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata From fbf51575c3c9efce1e051906c8f1c3276a063b2a Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 19:21:43 +0000 Subject: [PATCH 29/45] typing for custom model --- fast_llm/models/custom/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index e732e41d..c206ef40 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -2,6 +2,7 @@ import torch +from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import Layer, LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig @@ -48,7 +49,7 @@ def preprocess_meta( def preprocess( self, - batch: torch.Tensor, + batch: GPTBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, From 4bcb488840cd08260aa392e668a26c7ba511c185 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 19:25:02 +0000 Subject: [PATCH 30/45] fix memmap tests --- tests/test_memmap_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_memmap_dataset.py b/tests/test_memmap_dataset.py index 261f87e4..1ea8fd91 100644 --- a/tests/test_memmap_dataset.py +++ b/tests/test_memmap_dataset.py @@ -5,17 +5,18 @@ import pytest from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_gpt_memmap_dataset(dtype): - documents = [np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype) for _ in range(100)] + documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): assert np.array_equal( - dataset.get(i), document, equal_nan=True + dataset.get(i).token_ids, document.token_ids, equal_nan=True ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." From 8494b6a16e417f6f747636fc8da101fcf73da4a0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 21:17:25 +0000 Subject: [PATCH 31/45] test for spans --- fast_llm/data/dataset/gpt/memmap.py | 2 +- tests/common.py | 35 +++++++++++++++++++++++++++++ tests/test_dataset.py | 28 +++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a0b950e5..ac2574a6 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -62,7 +62,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: b offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, ) self._span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - self._num_spans_cumsum = np.cumsum(self._num_spans, dtype=np.int64) + self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) diff --git a/tests/common.py b/tests/common.py index 822f2bcb..0273d4e9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -259,6 +259,41 @@ def get_test_concatenated_memmap_dataset( index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) +def get_test_dataset_with_spans( + prefix: pathlib.Path = DATASET_PREFIX, + seed: int = 1234, + num_tokens: int = TEST_DATASET_TOKENS, + characters: str = TEST_CHARACTERS, + vocab_size: int = TEST_VOCAB_SIZE, +): + if not TOKENIZER_FILE.is_file(): + import transformers + + transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + + if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): + import transformers + + documents = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + for idx, doc in enumerate(documents): + doc = np.array(tokenizer(doc)["input_ids"], dtype=np.uint16) % vocab_size + doc_seed = seed + idx + n_spans = random.Random(doc_seed).randint(0, 5) + spans = [] + prev_end = -1 + for _ in range(n_spans): + if prev_end >= len(doc) - 1: + break + start = random.Random(doc_seed).randint(prev_end + 1, len(doc) - 1) + end = random.Random(doc_seed).randint(start, len(doc) - 1) + spans.append([start, end]) + prev_end = end + documents[idx] = GPTSample(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) + + GPTMemmapDataset.write_dataset(prefix, documents) + + def run_test_script( name: str, script: list[str], diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e82a2ae4..aaef2fcb 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -32,6 +32,7 @@ TOKENIZER_PATH, get_test_concatenated_memmap_dataset, get_test_dataset, + get_test_dataset_with_spans, ) @@ -45,6 +46,7 @@ def get_sampling_config( sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, tokenizer: Tokenizer | None = None, + use_loss_masking_spans=False, ) -> GPTSamplingConfig: # Config with convenient defaults. return GPTSamplingConfig( @@ -56,6 +58,7 @@ def get_sampling_config( sequence_length=sequence_length, vocab_size=vocab_size, tokenizer=tokenizer, + use_loss_masking_spans=use_loss_masking_spans, ) @@ -667,3 +670,28 @@ def test_gpt_fim_data_legacy(): np.stack(samples[PhaseType.training]), np.array(GPT_FIM_EXPECTED_SAMPLES), ) + + +_DATASET_PREFIX_SPANS = DATASET_PREFIX.with_name("with_spans") + +SPANS_DATASET_EXPECTED_SAMPLES = { + 9: ([], []), + 10: ([80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], [[4, 6], [8, 9]]), + 13: ([78, 727, 74, 317, 1358, 89], []), + 15: ([78], [[0, 0]]), +} + + +def test_gpt_data_with_spans(): + get_test_dataset_with_spans(prefix=_DATASET_PREFIX_SPANS) + dataset = _get_dataset_config( + { + "type": "memmap", + "path": _DATASET_PREFIX_SPANS, + "use_loss_masking_spans": True, + }, + GPTMemmapDatasetConfig, + ).build() + for i, sample in SPANS_DATASET_EXPECTED_SAMPLES.items(): + Assert.all_equal(np.array(sample[0], dtype=np.uint16), dataset.get(i).token_ids) + Assert.all_equal(np.array(sample[1]).reshape(-1, 2), dataset.get(i).loss_masking_spans) From 769d466fa90b77c4b1fa6cf72dad578716688bbe Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 30 Jan 2025 21:42:30 +0000 Subject: [PATCH 32/45] fix triton cross-entropy --- fast_llm/functional/triton/cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 5d087d78..664a038a 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -60,6 +60,7 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float = 1.0, ignore_index: int = -100, + apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, From 5cbb342327143e6207b495fc84ffd3b40491d24f Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Fri, 31 Jan 2025 09:01:12 -0800 Subject: [PATCH 33/45] Update data.py Co-authored-by: Torsten Scholak --- fast_llm/data/data/gpt/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6621c547..a6aca6fc 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -29,7 +29,7 @@ @dataclasses.dataclass class GPTBatch: token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] + loss_masking_spans: list[torch.Tensor] | None def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch: From a04cc94e6864114b508cbe20f79c2ba4244e5177 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 5 Feb 2025 08:03:21 +0000 Subject: [PATCH 34/45] fix loss mask in cross entropy --- fast_llm/functional/cross_entropy.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 4809c678..d92835fe 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -53,8 +53,6 @@ def fused_cross_entropy_forward_backward( # Way faster and more memory-efficient than the pytorch version. if apply_loss_mask: loss_mask = target != ignore_index - target = target[loss_mask] - logits = logits[loss_mask] target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: @@ -77,9 +75,11 @@ def fused_cross_entropy_forward_backward( else: grad = exp_logits.to(logits.dtype) - loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)).mean() + per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) + if apply_loss_mask: + per_sample_loss *= loss_mask - return loss, grad + return per_sample_loss.mean(), grad @torch.compile @@ -100,8 +100,6 @@ def parallel_cross_entropy_forward_backward( # TODO: Optimize, overlap/combine reductions if apply_loss_mask: loss_mask = target != ignore_index - target = target[loss_mask] - logits = logits[loss_mask] target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0] @@ -137,9 +135,11 @@ def parallel_cross_entropy_forward_backward( predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - loss = sum_exp_logits.log().sub(predicted_logits).mean() + per_sample_loss = sum_exp_logits.log().sub(predicted_logits) + if apply_loss_mask: + per_sample_loss *= loss_mask - return loss, grad + return per_sample_loss.mean(), grad _CROSS_ENTROPY_IMPLEMENTATIONS = { From 2f2495dc51cbf4ba647587991973920f166bc127 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 5 Feb 2025 20:58:17 +0000 Subject: [PATCH 35/45] review comments --- fast_llm/data/config.py | 13 ----- fast_llm/data/dataset/gpt/config.py | 9 +--- fast_llm/data/dataset/gpt/fim.py | 40 +++++++-------- fast_llm/data/dataset/gpt/memmap.py | 49 +++++++++++-------- fast_llm/data/dataset/gpt/sampled.py | 14 ++++-- fast_llm/data/dataset/indexed.py | 6 ++- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 10 ++-- fast_llm/data/tokenizer.py | 38 +++++--------- tests/test_dataset.py | 7 +-- 10 files changed, 89 insertions(+), 99 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 4f1bdb66..1586d370 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -16,14 +16,6 @@ class MultiprocessingContext(str, enum.Enum): TokenizerFromFile = "TokenizerFromFile" -class SequenceDelimiters(str, enum.Enum): - tokenizer_default = "tokenizer_default" - bos_only = "bos_only" - eos_only = "eos_only" - bos_eos = "bos_eos" - no_delimiters = "no_delimiters" - - @config_class() class TokenizerConfig(Config): """ @@ -42,8 +34,3 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) - sequence_delimiters: SequenceDelimiters = Field( - default=SequenceDelimiters.tokenizer_default, - desc="Boundary tokens (bos/eos) to use for tokenizing sequences", - hint=FieldHint.core, - ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index a724f5c2..e337840a 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -129,16 +129,11 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", hint=FieldHint.core, ) - use_loss_masking_spans: bool = Field( - default=False, - desc="Read and use loss masking spans from the dataset, if present.", - hint=FieldHint.feature, - ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.use_loss_masking_spans) + return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path) @config_class() @@ -388,7 +383,7 @@ def build_and_sample(self, config: GPTSamplingConfig) -> SampledDataset: dataset_configs = [ GPTDatasetSliceConfig( # TODO: this duplicates memmap datasets for each phase. - dataset=GPTMemmapDatasetConfig(path=prefix, use_loss_masking_spans=config.use_loss_masking_spans), + dataset=GPTMemmapDatasetConfig(path=prefix), begin=phase_splits[phase_index], end=phase_splits[phase_index + 1], ) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 20f795ea..998ad88c 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -18,6 +18,8 @@ def __init__( dataset: SampledDataset, sampling_config: GPTSamplingConfig, ): + if sampling_config.use_loss_masking_spans: + raise NotImplementedError("FIM is currently not compatible with loss masking.") self._config = config self._dataset = dataset self._seed = sampling_config.seed @@ -36,21 +38,21 @@ def __len__(self) -> int: return len(self._dataset) def __getitem__(self, idx: int) -> np.ndarray: - sample = self._fim(self._dataset[idx], np.random.RandomState(seed=(self._seed + idx) % MAX_SEED)) - return sample + fim_token_ids = self._fim( + self._dataset[idx].token_ids, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED) + ) + return GPTSample(fim_token_ids) @property def name(self) -> str: return f"{self._dataset.name}_fim" - def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: + def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # FIM # TODO: permute segments in sample_list, before concatenating. - if self._config.rate > 0.0 and sample.loss_masking_spans is not None: - raise NotImplementedError("FIM is currently not compatible with loss masking.") - sample_len = sample.token_ids.shape[0] + sample_len = sample.shape[0] eod = self._tokenizer.eod - segment_breaks = np.argwhere(sample.token_ids == eod) # split sample by document + segment_breaks = np.argwhere(sample == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example curr_start_position = 0 @@ -60,26 +62,26 @@ def _fim(self, sample: GPTSample, np_rng: np.random.RandomState) -> GPTSample: # Only permute non-empty segments. if loc - curr_start_position > 0: # permute {prefix, suffix, middle} or {suffix, prefix, middle} - permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:loc], np_rng) + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc], np_rng) new_samples += [permuted, [eod]] curr_start_position = loc + 1 # jump over the EOD token # Permute the segment after the last EOD - permuted = self._fim_split_and_permute_sequence(sample.token_ids[curr_start_position:], np_rng) + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) new_samples.append(permuted) - sample.token_ids = np.concatenate(new_samples) + sample = np.concatenate(new_samples) else: - sample.token_ids = self._fim_split_and_permute_sequence(sample.token_ids, np_rng) + sample = self._fim_split_and_permute_sequence(sample, np_rng) # Truncate or pad sequence to max-length - diff = sample.token_ids.shape[0] - sample_len + diff = sample.shape[0] - sample_len if diff > 0: # too long - sample.token_ids = sample.token_ids[:sample_len] + sample = sample[:sample_len] elif diff < 0: # too short - sample.token_ids = np.concatenate([sample.token_ids, np.full((-1 * diff), self._pad_tok_id)]) + sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) - assert sample.token_ids.shape[0] == sample_len + assert sample.shape[0] == sample_len return sample def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: @@ -153,11 +155,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end_of_text=False)], dtype=np.int64) - middle = np.array( - [*self._tokenizer.tokenize(middle, beginning_of_text=False, end_of_text=False)], dtype=np.int64 - ) - suffix = np.array([*self._tokenizer.tokenize(suffix, beginning_of_text=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) + middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ac2574a6..6ecb8667 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -20,14 +20,14 @@ class GPTMemmapDataset(GPTIndexedDataset): See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details. """ - def __init__(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False): - self._init(name, prefix, use_loss_masking_spans) + def __init__(self, name: str, prefix: pathlib.Path | str): + self._init(name, prefix) - def _init(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: bool = False) -> None: + def _init(self, name: str, prefix: pathlib.Path | str) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._read_spans = False + self._has_spans = 0 with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER) @@ -35,7 +35,6 @@ def _init(self, name: str, prefix: pathlib.Path | str, use_loss_masking_spans: b assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}." if self._version == 2: self._has_spans = struct.unpack(" GPTSample: + def get( + self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) - spans = None - if self._read_spans: - spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=self._span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) + sample_spans = None + if use_loss_masking_spans and self._spans is not None: + sample_spans = self._spans[idx] # adjust the spans for the offset and length - spans = spans[(spans[:, 0] < offset + len(token_ids)) & (spans[:, 1] >= offset)] - spans[:, 0] = np.maximum(spans[:, 0], offset) - offset - spans[:, 1] = np.minimum(spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=spans) + sample_spans = sample_spans[ + (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + ] + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index c1b58813..00a67874 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -45,6 +45,7 @@ def __init__( self._num_samples = sampling_config.num_samples self._sequence_length = sampling_config.sequence_length self._seed = sampling_config.seed + self._use_loss_masking_spans = sampling_config.use_loss_masking_spans if sampling_config.cache_directory is None: log_main_rank( @@ -136,13 +137,16 @@ def _sample(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: def __getstate__( self, - ) -> tuple[GPTIndexedDataset, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray]: + ) -> tuple[ + GPTIndexedDataset, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray, pathlib.Path | np.ndarray | bool + ]: if hasattr(self, "_doc_idx_filename"): return ( self._indexed_dataset, self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, + self._use_loss_masking_spans, ) else: return ( @@ -150,15 +154,17 @@ def __getstate__( self._doc_idx, self._sample_idx, self._shuffle_idx, + self._use_loss_masking_spans, ) - def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path]) -> None: + def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Path, pathlib.Path, bool]) -> None: if isinstance(state[1], pathlib.Path): ( self._indexed_dataset, self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, + self._use_loss_masking_spans, ) = state else: ( @@ -166,6 +172,7 @@ def __setstate__(self, state: tuple[GPTIndexedDataset, pathlib.Path, pathlib.Pat self._doc_idx, self._sample_idx, self._shuffle_idx, + self._use_loss_masking_spans, ) = state def _load_mappings(self) -> None: @@ -196,11 +203,12 @@ def __getitem__(self, idx: int) -> typing.Any: self._doc_idx[doc].item(), offset=(doc == doc_f) * offset_f, length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + use_loss_masking_spans=self._use_loss_masking_spans, ) for doc in range(doc_f, doc_l + 1) ] - if sample_list[0].loss_masking_spans is not None: + if self._use_loss_masking_spans: sample_ids = [] sample_spans = [] span_offset = 0 diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 8a652dda..09ed5277 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -46,13 +46,15 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get(self, document: int, offset: int = 0, length: int | None = None) -> typing.Any: + def get( + self, document: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + ) -> typing.Any: """ Get the sample (document) with the given index (in the dataset slice), optionally sub-sampled to a specific offset (starting point) and maximum length (end = min(offset + length, sample_length). """ - return self._dataset.get(document + self._begin, offset, length) + return self._dataset.get(document + self._begin, offset, length, use_loss_masking_spans) def __len__(self) -> int: return self._end - self._begin diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index e331ed8f..e02c51b5 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -46,7 +46,7 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Field of the dataset to use.", hint=FieldHint.optional, ) - spans_field: None | str = Field( + loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) data_type: DataType | None = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 4ebdcee2..f27d7147 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -46,7 +46,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict for input_ids, token_spans in [ self._tokenizer.tokenize_with_spans(text, char_spans) for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.spans_field] + batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans] ) ] ] @@ -65,7 +65,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> dict[str, typing.An shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._config.dataset.spans_field is not None: + if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -159,9 +159,9 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.spans_field is not None: - if self._config.dataset.spans_field not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.spans_field}'.") + if self._config.dataset.loss_masking_spans is not None: + if self._config.dataset.loss_masking_spans not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") tokenize_fn = self._tokenize_batch_with_spans else: tokenize_fn = self._tokenize_batch diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index e271c01e..28e105ee 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -2,7 +2,7 @@ import torch from transformers import PreTrainedTokenizerFast -from fast_llm.data.config import SequenceDelimiters, TokenizerConfig +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -16,12 +16,12 @@ def __init__(self, config: TokenizerConfig): self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( pretrained_model_name_or_path=config.path, errors="replace", max_len=None ) - self.sequence_delimiters = config.sequence_delimiters if self.tokenizer.eos_token_id is None: raise ValueError("Tokenizer does not have an EOS token.") + if self.tokenizer.bos_token_id is None: + raise ValueError("Tokenizer does not have an BOS token.") self.eod_id = self.tokenizer.eos_token_id self.bod_id = self.tokenizer.bos_token_id - self._inv_vocab = {v: k for k, v in self.vocab.items()} @property def vocab_size(self) -> int: @@ -35,24 +35,12 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, beginning_of_text=True, end_of_text=True) -> list[int]: - if self.sequence_delimiters == SequenceDelimiters.eos_only: - return self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end_of_text else []) - elif self.sequence_delimiters == SequenceDelimiters.bos_only: - return ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) + self.tokenizer.encode( - text, add_special_tokens=False - ) - elif self.sequence_delimiters == SequenceDelimiters.bos_eos: - return ( - ([self.bod_id] if (self.bod_id is not None and beginning_of_text) else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end_of_text else []) - ) - elif self.sequence_delimiters == SequenceDelimiters.no_delimiters: - return self.tokenizer.encode(text, add_special_tokens=False) - else: - # TODO: How do we handle when beginning_of_text=False or end_of_text=False? - return self.tokenizer.encode(text) + def tokenize(self, text: str, begin=True, end=True) -> list[int]: + return ( + ([self.bod_id] if begin else []) + + self.tokenizer.encode(text, add_special_tokens=False) + + ([self.eod_id] if end else []) + ) def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] @@ -67,21 +55,21 @@ def tokenize_with_spans( for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) beginning_of_text = False input_ids.extend(tokenized_text) curr_text = text[start : end + 1] if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) else: - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=False) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) beginning_of_text = False token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) input_ids.extend(tokenized_text) char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, beginning_of_text=beginning_of_text, end_of_text=True) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) input_ids.extend(tokenized_text) return input_ids, token_spans diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 80b025ec..28e8b921 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -754,10 +754,11 @@ def test_gpt_data_with_spans(): { "type": "memmap", "path": _DATASET_PREFIX_SPANS, - "use_loss_masking_spans": True, }, GPTMemmapDatasetConfig, ).build() for i, sample in SPANS_DATASET_EXPECTED_SAMPLES.items(): - Assert.all_equal(np.array(sample[0], dtype=np.uint16), dataset.get(i).token_ids) - Assert.all_equal(np.array(sample[1]).reshape(-1, 2), dataset.get(i).loss_masking_spans) + Assert.all_equal(np.array(sample[0], dtype=np.uint16), dataset.get(i, use_loss_masking_spans=True).token_ids) + Assert.all_equal( + np.array(sample[1]).reshape(-1, 2), dataset.get(i, use_loss_masking_spans=True).loss_masking_spans + ) From 348a17c3887ae9ff1a9d552377eeba4fb46418e2 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 5 Feb 2025 21:40:25 +0000 Subject: [PATCH 36/45] fix fused cross-entropy --- fast_llm/functional/cross_entropy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d92835fe..38453c3c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -53,6 +53,8 @@ def fused_cross_entropy_forward_backward( # Way faster and more memory-efficient than the pytorch version. if apply_loss_mask: loss_mask = target != ignore_index + # ignore_index can go out of bounds, so clamp targets after getting the mask + target = target.clamp(min=0, max=logits.size(-1) - 1) target = target.unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: From 06f81c79ea96625ddf2b11cca62f53efd908f8bd Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 5 Feb 2025 21:47:30 +0000 Subject: [PATCH 37/45] cleaner collating --- fast_llm/data/data/gpt/data.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a6aca6fc..486219d8 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -3,6 +3,7 @@ import pathlib import typing import warnings +from functools import partial import numpy as np import torch @@ -32,12 +33,12 @@ class GPTBatch: loss_masking_spans: list[torch.Tensor] | None -def gpt_data_collate_fn(batch: list[GPTSample]) -> GPTBatch: +def gpt_data_collate_fn(batch: list[GPTSample], use_loss_masking_spans: bool) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None - if batch[0].loss_masking_spans is not None: + if use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans) + return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=None) class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): @@ -138,7 +139,7 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=gpt_data_collate_fn, + collate_fn=partial(gpt_data_collate_fn, use_loss_masking_spans=self._config.use_loss_masking_spans), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) From 15b903332814a8c6dd74282adc45800a0dfff13a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 5 Feb 2025 23:24:04 +0000 Subject: [PATCH 38/45] fix collate --- fast_llm/data/data/gpt/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 486219d8..ccfbf92a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -38,7 +38,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], use_loss_masking_spans: bool) -> stacked_spans = None if use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=None) + return GPTBatch(token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans) class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): From b277c8d67858653719a66838e83f81a2360afc28 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Feb 2025 07:35:24 +0000 Subject: [PATCH 39/45] run pre-commit on all --- tests/data/common.py | 18 ++++++++++++++---- tests/data/test_loss_masking_spans.py | 10 ++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/data/common.py b/tests/data/common.py index 3d88c8cf..de566f63 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -83,7 +83,9 @@ def compare_indexed_dataset( Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() Assert.eq(sizes.sum(), num_tokens) - Assert.all_equal([len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)]) + Assert.all_equal( + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + ) for i, sample in samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(sample, dtype=np.uint16)) @@ -92,11 +94,19 @@ def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list Assert.eq(len(sampled), len(expected_samples)) Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) -def compare_indexed_dataset_with_spans(dataset: GPTIndexedDataset, length: int, num_tokens: int, samples: dict[int, tuple[list[int], list[list[int]]]]): + +def compare_indexed_dataset_with_spans( + dataset: GPTIndexedDataset, length: int, num_tokens: int, samples: dict[int, tuple[list[int], list[list[int]]]] +): Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() Assert.eq(sizes.sum(), num_tokens) - Assert.all_equal([len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)]) + Assert.all_equal( + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + ) for i, sample in samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal(dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, np.array(sample[1], dtype=np.int32).reshape(-1, 2)) + Assert.all_equal( + dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, + np.array(sample[1], dtype=np.int32).reshape(-1, 2), + ) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 7492fdd8..a8c63463 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -1,7 +1,7 @@ +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from tests.common import DATASET_PREFIX, get_test_dataset_with_spans -from tests.data.common import get_dataset_config, compare_indexed_dataset_with_spans +from tests.data.common import compare_indexed_dataset_with_spans, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig _DATASET_PREFIX_SPANS = DATASET_PREFIX.with_name("with_spans") @@ -22,9 +22,11 @@ def test_gpt_data_with_spans(): }, GPTMemmapDatasetConfig, ).build() - compare_indexed_dataset_with_spans(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, SPANS_DATASET_EXPECTED_SAMPLES) + compare_indexed_dataset_with_spans( + dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, SPANS_DATASET_EXPECTED_SAMPLES + ) # for i, sample in SPANS_DATASET_EXPECTED_SAMPLES.items(): # Assert.all_equal(np.array(sample[0], dtype=np.uint16), dataset.get(i, use_loss_masking_spans=True).token_ids) # Assert.all_equal( # np.array(sample[1]).reshape(-1, 2), dataset.get(i, use_loss_masking_spans=True).loss_masking_spans - # ) \ No newline at end of file + # ) From a72c8130c6b6e87daeaae6e6eb8fe509f8979805 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Feb 2025 00:38:16 -0500 Subject: [PATCH 40/45] misc --- fast_llm/functional/cross_entropy.py | 44 +++++---------------- fast_llm/functional/triton/cross_entropy.py | 8 +--- 2 files changed, 12 insertions(+), 40 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 38453c3c..55e996bf 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -13,13 +13,12 @@ def torch_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, - apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. + TODO: loss masking only works for this method if the masking index is set to -100. """ # Torch compile doesn't understand this. with torch.enable_grad(): @@ -29,7 +28,7 @@ def torch_cross_entropy_forward_backward( if grad_output is None: loss = None else: - loss = torch.nn.functional.cross_entropy(logits_, target, ignore_index=ignore_index).mean() + loss = torch.nn.functional.cross_entropy(logits_, target).mean() loss.backward(torch.full_like(loss, grad_output)) loss.detach_() return loss.detach(), logits_.grad.detach().to(logits.dtype) @@ -41,8 +40,6 @@ def fused_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, - apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -51,11 +48,9 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. - if apply_loss_mask: - loss_mask = target != ignore_index - # ignore_index can go out of bounds, so clamp targets after getting the mask - target = target.clamp(min=0, max=logits.size(-1) - 1) - target = target.unsqueeze(1) + loss_mask = target < 0 + # Ignore_index can go out of bounds, so set masked values to zero. + target = target.unsqueeze(1) * loss_mask logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: logits_norm *= logits_scale_factor @@ -72,14 +67,9 @@ def fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - if apply_loss_mask: - grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0) - else: - grad = exp_logits.to(logits.dtype) + grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0) - per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) - if apply_loss_mask: - per_sample_loss *= loss_mask + per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask return per_sample_loss.mean(), grad @@ -91,8 +81,6 @@ def parallel_cross_entropy_forward_backward( grad_output: float | None, group: ProcessGroup, logits_scale_factor: float = 1.0, - ignore_index: int = -100, - apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. @@ -100,8 +88,7 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions - if apply_loss_mask: - loss_mask = target != ignore_index + loss_mask = target < 0 target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0] @@ -130,16 +117,11 @@ def parallel_cross_entropy_forward_backward( if logits_scale_factor != 1.0: exp_logits2 *= logits_scale_factor - if apply_loss_mask: - grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0) - else: - grad = exp_logits2.to(logits.dtype) + grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0) predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - per_sample_loss = sum_exp_logits.log().sub(predicted_logits) - if apply_loss_mask: - per_sample_loss *= loss_mask + per_sample_loss = sum_exp_logits.log().sub(predicted_logits) * loss_mask return per_sample_loss.mean(), grad @@ -158,8 +140,6 @@ def cross_entropy_forward_backward( group: ProcessGroup | None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, - ignore_index: int = -100, - apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -175,8 +155,6 @@ def cross_entropy_forward_backward( grad_output, group, logits_scale_factor=logits_scale_factor, - ignore_index=ignore_index, - apply_loss_mask=apply_loss_mask, ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( @@ -184,6 +162,4 @@ def cross_entropy_forward_backward( target, grad_output, logits_scale_factor=logits_scale_factor, - ignore_index=ignore_index, - apply_loss_mask=apply_loss_mask, ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 664a038a..9835cb0e 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -17,7 +17,6 @@ def triton_cross_entropy_forward_backward_kernel( grad_logits_stride_0, logits_scale_factor: tl.constexpr, block_size: tl.constexpr, - ignore_index: tl.constexpr, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -36,7 +35,7 @@ def triton_cross_entropy_forward_backward_kernel( label_idx = tl.load(labels_ptr + block_idx) label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) - if label_idx < 0 or label_idx == ignore_index: + if label_idx < 0: loss = 0.0 else: loss = tl.log(sum_exp_logits) + max_logits - label_logits @@ -48,7 +47,7 @@ def triton_cross_entropy_forward_backward_kernel( exp_logits = exp_logits / sum_exp_logits if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - if label_idx < 0 or label_idx == ignore_index: + if label_idx < 0: grad_losses = 0.0 grad_logits = grad_losses * tl.where(col_offsets == label_idx, exp_logits - 1.0, exp_logits) tl.store(grad_logits_ptr + col_offsets, grad_logits, mask=mask) @@ -59,8 +58,6 @@ def triton_cross_entropy_forward_backward( target: torch.Tensor, grad_output: float | None, logits_scale_factor: float = 1.0, - ignore_index: int = -100, - apply_loss_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -92,6 +89,5 @@ def triton_cross_entropy_forward_backward( logits_scale_factor, block_size=block_size, num_warps=num_warps, - ignore_index=ignore_index, ) return losses.mean(), None if grad_output is None else grad_logits From a13bf2da7aa6f1d1ab5bce37dcdd7d01d46c0978 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Feb 2025 00:40:05 -0500 Subject: [PATCH 41/45] misc --- fast_llm/functional/cross_entropy.py | 6 +----- fast_llm/functional/triton/cross_entropy.py | 5 +---- fast_llm/layers/language_model/config.py | 5 ----- fast_llm/layers/language_model/head.py | 3 --- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 55e996bf..3db1f99d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -150,11 +150,7 @@ def cross_entropy_forward_backward( if group: Assert.eq(implementation, CrossEntropyImpl.fused) return parallel_cross_entropy_forward_backward( - logits, - target, - grad_output, - group, - logits_scale_factor=logits_scale_factor, + logits, target, grad_output, group, logits_scale_factor=logits_scale_factor ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 9835cb0e..7a00df15 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -54,10 +54,7 @@ def triton_cross_entropy_forward_backward_kernel( def triton_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - grad_output: float | None, - logits_scale_factor: float = 1.0, + logits, target, grad_output: float | None, logits_scale_factor: float = 1.0 ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 1f6f54f0..8e3a467c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -173,11 +173,6 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - apply_loss_mask: bool = Field( - default=False, - desc="Enable loss-masking for cross-entropy computation", - hint=FieldHint.feature, - ) def _validate(self) -> None: if self.transformer.init_method_std is None: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 3eae12d5..4c03e393 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -83,8 +83,6 @@ def __init__( else: self._cross_entropy_impl = CrossEntropyImpl.fused - self._apply_loss_mask = config.apply_loss_mask - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) def forward( @@ -245,7 +243,6 @@ def _logits_cross_entropy_forward_backward( grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, - apply_loss_mask=self._apply_loss_mask, ) # TODO: de-allocate earlier. del logits From c1bbadf408a07dd2de50171bdf0de91fae44b5dd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Feb 2025 00:40:54 -0500 Subject: [PATCH 42/45] misc --- fast_llm/functional/cross_entropy.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 3db1f99d..311aa6bb 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -154,8 +154,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, - target, - grad_output, - logits_scale_factor=logits_scale_factor, + logits, target, grad_output, logits_scale_factor=logits_scale_factor ) From 69a59c4ff293bcf3fc3c0e3b9c2d14e3496e5fd1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Feb 2025 01:41:56 -0500 Subject: [PATCH 43/45] Simplfy tests --- tests/common.py | 52 +++++++-------------------- tests/data/common.py | 30 ++++++---------- tests/data/test_loss_masking_spans.py | 32 ----------------- tests/data/test_memmap.py | 24 +++++++++++++ 4 files changed, 47 insertions(+), 91 deletions(-) delete mode 100644 tests/data/test_loss_masking_spans.py diff --git a/tests/common.py b/tests/common.py index 0273d4e9..9e82ab54 100644 --- a/tests/common.py +++ b/tests/common.py @@ -218,6 +218,7 @@ def get_test_dataset( num_tokens: int = TEST_DATASET_TOKENS, characters: str = TEST_CHARACTERS, vocab_size: int = TEST_VOCAB_SIZE, + max_spans: int = 0, ): if not TOKENIZER_FILE.is_file(): import transformers @@ -227,14 +228,20 @@ def get_test_dataset( if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): import transformers - documents = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - documents = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) - for document in documents + samples = [ + GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts ] - GPTMemmapDataset.write_dataset(prefix, documents) + if max_spans > 0: + lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) + for sample, span in zip(samples, spans): + span = np.unique(span) + sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + + GPTMemmapDataset.write_dataset(prefix, samples) def get_test_concatenated_memmap_dataset( @@ -259,41 +266,6 @@ def get_test_concatenated_memmap_dataset( index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)]) -def get_test_dataset_with_spans( - prefix: pathlib.Path = DATASET_PREFIX, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, -): - if not TOKENIZER_FILE.is_file(): - import transformers - - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) - - if not (prefix.with_suffix(".idx").is_file() and prefix.with_suffix(".bin").is_file()): - import transformers - - documents = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - for idx, doc in enumerate(documents): - doc = np.array(tokenizer(doc)["input_ids"], dtype=np.uint16) % vocab_size - doc_seed = seed + idx - n_spans = random.Random(doc_seed).randint(0, 5) - spans = [] - prev_end = -1 - for _ in range(n_spans): - if prev_end >= len(doc) - 1: - break - start = random.Random(doc_seed).randint(prev_end + 1, len(doc) - 1) - end = random.Random(doc_seed).randint(start, len(doc) - 1) - spans.append([start, end]) - prev_end = end - documents[idx] = GPTSample(doc, np.array(spans, dtype=np.int32).reshape(-1, 2)) - - GPTMemmapDataset.write_dataset(prefix, documents) - - def run_test_script( name: str, script: list[str], diff --git a/tests/data/common.py b/tests/data/common.py index de566f63..29800437 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -78,7 +78,11 @@ def get_test_data_and_compare_samples( def compare_indexed_dataset( - dataset: GPTIndexedDataset, length: int, num_tokens: int, samples: dict[int, list[int]] + dataset: GPTIndexedDataset, + length: int, + num_tokens: int, + samples: dict[int, list[int]], + loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() @@ -87,26 +91,14 @@ def compare_indexed_dataset( [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] ) for i, sample in samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(sample, dtype=np.uint16)) + dataset_sample = dataset.get(i, use_loss_masking_spans=loss_masking_spans is not None) + Assert.all_equal(dataset_sample.token_ids, np.array(sample, dtype=np.uint16)) + if loss_masking_spans: + Assert.all_equal( + dataset_sample.loss_masking_spans, np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2) + ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int]]) -> None: Assert.eq(len(sampled), len(expected_samples)) Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) - - -def compare_indexed_dataset_with_spans( - dataset: GPTIndexedDataset, length: int, num_tokens: int, samples: dict[int, tuple[list[int], list[list[int]]]] -): - Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() - Assert.eq(sizes.sum(), num_tokens) - Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] - ) - for i, sample in samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(sample[0], dtype=np.uint16)) - Assert.all_equal( - dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, - np.array(sample[1], dtype=np.int32).reshape(-1, 2), - ) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py deleted file mode 100644 index a8c63463..00000000 --- a/tests/data/test_loss_masking_spans.py +++ /dev/null @@ -1,32 +0,0 @@ -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from tests.common import DATASET_PREFIX, get_test_dataset_with_spans -from tests.data.common import compare_indexed_dataset_with_spans, get_dataset_config -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS - -_DATASET_PREFIX_SPANS = DATASET_PREFIX.with_name("with_spans") - -SPANS_DATASET_EXPECTED_SAMPLES = { - 9: ([], []), - 10: ([80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], [[4, 6], [8, 9]]), - 13: ([78, 727, 74, 317, 1358, 89], []), - 15: ([78], [[0, 0]]), -} - - -def test_gpt_data_with_spans(): - get_test_dataset_with_spans(prefix=_DATASET_PREFIX_SPANS) - dataset = get_dataset_config( - { - "type": "memmap", - "path": _DATASET_PREFIX_SPANS, - }, - GPTMemmapDatasetConfig, - ).build() - compare_indexed_dataset_with_spans( - dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, SPANS_DATASET_EXPECTED_SAMPLES - ) - # for i, sample in SPANS_DATASET_EXPECTED_SAMPLES.items(): - # Assert.all_equal(np.array(sample[0], dtype=np.uint16), dataset.get(i, use_loss_masking_spans=True).token_ids) - # Assert.all_equal( - # np.array(sample[1]).reshape(-1, 2), dataset.get(i, use_loss_masking_spans=True).loss_masking_spans - # ) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 67f4eb91..c6af54bb 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -41,3 +41,27 @@ def test_gpt_memmap(cache_directory): get_test_dataset() dataset = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) + + +MEMMAP_DATASET_SPANS = { + 9: [], + 10: [[0, 4], [6, 8]], + 13: [[1, 2]], + 15: [], +} + +_DATASET_PREFIX_SPANS = DATASET_PREFIX.with_name("with_spans") + + +def test_gpt_data_with_spans(): + get_test_dataset(prefix=DATASET_PREFIX.with_name("with_spans"), max_spans=5) + dataset = get_dataset_config( + { + "type": "memmap", + "path": _DATASET_PREFIX_SPANS, + }, + GPTMemmapDatasetConfig, + ).build() + compare_indexed_dataset( + dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS + ) From 3aa735beaae16e9f34bb36c5421516c6821150df Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 7 Feb 2025 01:51:35 -0500 Subject: [PATCH 44/45] fix --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/models/gpt/huggingface.py | 5 ++++- fast_llm/models/gpt/model.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index ccfbf92a..eb486dd6 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -30,7 +30,7 @@ @dataclasses.dataclass class GPTBatch: token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] | None + loss_masking_spans: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], use_loss_masking_spans: bool) -> GPTBatch: diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index fd61abcd..e4db9b07 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -4,6 +4,7 @@ import torch import transformers.modeling_outputs +from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.huggingface.config import HuggingfaceModelConfig from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel @@ -66,7 +67,9 @@ def forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self._fast_llm_model.base_model.preprocess(input_ids, phase=PhaseType.inference, iteration=iteration) + batch = self._fast_llm_model.base_model.preprocess( + GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration + ) ((_, kwargs),) = batch if past_key_values is not None: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6438d4f1..8aa68333 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -196,7 +196,7 @@ def preprocess( assert self._is_setup if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch, phase) + preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size From 0821abe8bce6fecae342f49b5f316e047d3803c9 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 7 Feb 2025 19:43:25 +0000 Subject: [PATCH 45/45] fix loss mask --- fast_llm/functional/cross_entropy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 311aa6bb..e87581f1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -48,9 +48,9 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. - loss_mask = target < 0 + loss_mask = target >= 0 # Ignore_index can go out of bounds, so set masked values to zero. - target = target.unsqueeze(1) * loss_mask + target = (target * loss_mask).unsqueeze(1) logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() if logits_scale_factor != 1.0: logits_norm *= logits_scale_factor @@ -88,7 +88,7 @@ def parallel_cross_entropy_forward_backward( """ # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). # TODO: Optimize, overlap/combine reductions - loss_mask = target < 0 + loss_mask = target >= 0 target = target.unsqueeze(1) logits_max = torch.max(logits, dim=-1)[0]