Skip to content

Auto dataset concatenation prototype #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import time
import typing
import warnings

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.data.dataset.abstract import SampledDataset
Expand Down Expand Up @@ -66,6 +67,10 @@ def _from_dict(
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
)
actual_cls = cls._registry[type_]
Assert.custom(issubclass, actual_cls, cls)
if actual_cls == cls:
Expand Down Expand Up @@ -161,6 +166,41 @@ class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig):
datasets: list[GPTSampledDatasetConfig] = FieldUpdate()


@config_class()
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "concatenated_memmap"
path: pathlib.Path = Field(
default=None,
desc="The path to a dataset directory.",
hint=FieldHint.core,
)

def build(self) -> "GPTConcatenatedDataset":
pass

assert self.path.is_dir()
index_path = self.path / "index.txt"

if index_path.is_file():
prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()]
else:
warnings.warn(
f"The dataset path {self.path} points to a directory."
" The dataset will be indexed automatically, which may be unsafe."
" We recommend using an index file instead."
)
prefixes = [
path.with_suffix("")
for path in self.path.iterdir()
if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file()
]
dataset_config = GPTConcatenatedDatasetConfig.from_dict(
{"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]}
)
return dataset_config.build()


@config_class()
class FimConfig(Config):
"""
Expand Down
10 changes: 6 additions & 4 deletions fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def __setstate__(self, state: tuple[str, pathlib.Path]):
self._init(*state)

def __del__(self):
self._bin_buffer_mmap._mmap.close() # noqa
del self._bin_buffer_mmap
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap
if hasattr(self, "_bin_buffer_mmap"):
self._bin_buffer_mmap._mmap.close() # noqa
del self._bin_buffer_mmap
if hasattr(self, "_index_bin_buffer"):
self._index_bin_buffer_mmap._mmap.close() # noqa
del self._index_bin_buffer_mmap

def get(self, idx, offset=0, length=None) -> np.ndarray:
return np.frombuffer(
Expand Down
4 changes: 4 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def run(self) -> None:
output_file = self._config.output_path / "fast_llm_dataset.json"
json.dump({"datasets": dataset_dicts}, output_file.open("w"))

# Create an index file on rank 0
index_file = self._config.output_path / "index.txt"
index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts])

# Finalize distributed processing
if self._config.distributed.world_size > 1:
torch.distributed.barrier()
Expand Down
10 changes: 10 additions & 0 deletions fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def not_custom(fn, *args, **kwargs):


class Registry[KeyType, ValueType]:
# TODO: Inherit from dict instead?
def __init__(self, name: str, data: dict[KeyType, ValueType]):
self._name = name
self._data = data.copy()
Expand All @@ -206,6 +207,15 @@ def keys(self) -> list[KeyType]:
def __contains__(self, key: KeyType) -> bool:
return key in self._data

def __iter__(self) -> typing.Iterator[KeyType]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

def items(self):
return self._data.items()

@property
def name(self) -> str:
return self._name
Expand Down
36 changes: 30 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

TOKENIZER_PATH = TEST_RESULTS_PATH / "tokenizer" / "common"
TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json"
DATASET_PREFIX = TEST_RESULTS_PATH / "dataset" / "common"
DATASET_CACHE = TEST_RESULTS_PATH / "dataset"
DATASET_PREFIX = DATASET_CACHE / "common"
DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset" / "cache"

TEST_VOCAB_SIZE = 8192
# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6%
Expand Down Expand Up @@ -210,11 +212,11 @@


def get_test_dataset(
prefix=DATASET_PREFIX,
seed=1234,
num_tokens=TEST_DATASET_TOKENS,
characters=TEST_CHARACTERS,
vocab_size=TEST_VOCAB_SIZE,
prefix: pathlib.Path = DATASET_PREFIX,
seed: int = 1234,
num_tokens: int = TEST_DATASET_TOKENS,
characters: str = TEST_CHARACTERS,
vocab_size: int = TEST_VOCAB_SIZE,
):
if not TOKENIZER_FILE.is_file():
import transformers
Expand All @@ -233,6 +235,28 @@ def get_test_dataset(
GPTMemmapDataset.write_dataset(prefix, documents)


def get_test_concatenated_memmap_dataset(
path: pathlib.Path,
num_files: int,
seed: int = 1234,
num_tokens: int = TEST_DATASET_TOKENS,
characters: str = TEST_CHARACTERS,
vocab_size: int = TEST_VOCAB_SIZE,
seed_shift: int = 55,
):
index_file = path / "index.txt"
if not index_file.is_file():
for i in range(num_files):
get_test_dataset(
prefix=path / f"dataset_{i}",
seed=seed + i * seed_shift,
num_tokens=num_tokens,
characters=characters,
vocab_size=vocab_size,
)
index_file.open("w").writelines([str(path / f"dataset_{i}") + "\n" for i in range(num_files)])


def run_test_script(
name: str,
script: list[str],
Expand Down
106 changes: 93 additions & 13 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fast_llm.data.dataset.gpt.config import (
GPTBlendedDatasetConfig,
GPTConcatenatedDatasetConfig,
GPTConcatenatedMemmapConfig,
GPTDatasetSliceConfig,
GPTFimSampledDatasetConfig,
GPTMemmapDatasetConfig,
Expand All @@ -23,9 +24,15 @@
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.utils import Assert
from tests.common import DATASET_PREFIX, TEST_RESULTS_PATH, TEST_VOCAB_SIZE, TOKENIZER_PATH, get_test_dataset

DATASET_CACHE = TEST_RESULTS_PATH / "dataset" / "cache"
from tests.common import (
DATASET_CACHE,
DATASET_PREFIX,
DATASET_SAMPLING_CACHE,
TEST_VOCAB_SIZE,
TOKENIZER_PATH,
get_test_concatenated_memmap_dataset,
get_test_dataset,
)


def get_sampling_config(
Expand Down Expand Up @@ -81,11 +88,16 @@ def get_test_data_and_samples(
return data, samples


DATASET_PREFIX_MIX_1 = DATASET_PREFIX.with_name("blended_mix_1")
_DATASET_PREFIX_MIX_1 = DATASET_PREFIX.with_name("blended_mix_1")
_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap"


def _get_test_dataset_mix_1():
return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345)

def get_test_dataset_1():
return get_test_dataset(prefix=DATASET_PREFIX_MIX_1, seed=2345)

def _get_test_dataset_concatenated_memmap():
return get_test_concatenated_memmap_dataset(_DATASET_PREFIX_MIX_CONCATENATED_MEMMAP, 4)


RANDOM_DATASET_EXPECTED_SAMPLES = [
Expand Down Expand Up @@ -145,7 +157,7 @@ def test_gpt_random_data_legacy():
}


@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_CACHE) / "test_memmap"))
@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_SAMPLING_CACHE) / "test_memmap"))
def test_gpt_memmap(cache_directory):
# Make sure the memmap dataset works and check for unintended changes in behavior.
get_test_dataset()
Expand Down Expand Up @@ -363,6 +375,74 @@ def test_gpt_slice_data_legacy():
)


COMPOSED_DATASET_EXPECTED_LENGTH = 24806
COMPOSED_DATASET_EXPECTED_TOKENS = 2033639

COMPOSED_DATASET_EXPECTED_SAMPLES = {
**MEMMAP_DATASET_EXPECTED_SAMPLES,
6930: [65, 2327],
11962: [7078, 2713, 1431],
15958: [207],
19362: [69],
24098: [555, 668, 70],
}


GPT_COMPOSED_EXPECTED_SAMPLES = [
[1411, 819, 6791, 7022, 285, 249],
[329, 328, 512, 1985, 3069, 7838],
[5158, 1023, 8171, 798, 1431, 313],
[1073, 3917, 275, 480, 74, 1752],
[207, 317, 269, 6662, 4357, 498],
[74, 310, 277, 7091, 668, 367],
[7828, 480, 89, 116, 4604, 69],
[79, 6042, 577, 225, 207, 207],
]


def test_gpt_compose():
# Make sure dataset splitting works and check for unintended changes in behavior.
_get_test_dataset_concatenated_memmap()
# samples[9:18]
dataset = _get_dataset_config(
{"type": "concatenated_memmap", "path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP},
GPTConcatenatedMemmapConfig,
).build()
Assert.eq(len(dataset), COMPOSED_DATASET_EXPECTED_LENGTH)
sizes = dataset.get_document_sizes()
Assert.eq(sizes.sum(), COMPOSED_DATASET_EXPECTED_TOKENS)
Assert.all_equal([len(dataset.get(i)) for i in range(0, len(dataset), 20)], sizes[::20])
for i, sample in COMPOSED_DATASET_EXPECTED_SAMPLES.items():
Assert.all_equal(dataset.get(i), np.array(sample, dtype=np.uint16))
sampled = dataset.sample(get_sampling_config(8, sequence_length=5))
Assert.eq(len(sampled), 8)
print(np.stack([sampled[i] for i in range(8)]).tolist())
Assert.all_equal(
np.stack([sampled[i] for i in range(8)]),
np.array(GPT_COMPOSED_EXPECTED_SAMPLES),
)


def test_gpt_composed_data():
_get_test_dataset_concatenated_memmap()
_, samples = get_test_data_and_samples(
{
"datasets": {
"Training": {
"type": "composed",
"path": _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP,
}
}
},
{PhaseType.training: 8},
sequence_length=5,
)
Assert.all_equal(
np.stack(samples[PhaseType.training]),
np.array(GPT_COMPOSED_EXPECTED_SAMPLES),
)


GPT_BLENDED_EXPECTED_SAMPLES = [
[1725, 74, 207, 1635, 4440, 2774],
[2066, 207, 6436, 2360, 2210, 6633],
Expand All @@ -378,13 +458,13 @@ def test_gpt_slice_data_legacy():
def test_gpt_blended():
# Make sure dataset blending works and check for unintended changes in behavior.
get_test_dataset()
get_test_dataset_1()
_get_test_dataset_mix_1()
sampled = _get_dataset_config(
{
"type": "blended",
"datasets": [
{"type": "memmap", "path": DATASET_PREFIX},
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
{"type": "memmap", "path": _DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
},
Expand All @@ -399,15 +479,15 @@ def test_gpt_blended():

def test_gpt_blended_data():
get_test_dataset()
get_test_dataset_1()
_get_test_dataset_mix_1()
_, samples = get_test_data_and_samples(
{
"datasets": {
"Training": {
"type": "blended",
"datasets": [
{"type": "memmap", "path": DATASET_PREFIX},
{"type": "memmap", "path": DATASET_PREFIX_MIX_1},
{"type": "memmap", "path": _DATASET_PREFIX_MIX_1},
],
"weights": [0.75, 0.25],
}
Expand Down Expand Up @@ -436,11 +516,11 @@ def test_gpt_blended_data():

def test_gpt_blended_data_legacy():
get_test_dataset()
get_test_dataset_1()
_get_test_dataset_mix_1()
_, samples = get_test_data_and_samples(
{
"format": "list",
"path": ["0.75", str(DATASET_PREFIX), "0.25", str(DATASET_PREFIX_MIX_1)],
"path": ["0.75", str(DATASET_PREFIX), "0.25", str(_DATASET_PREFIX_MIX_1)],
"split": [1, 0, 0],
},
{PhaseType.training: 8},
Expand Down
Loading