Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4b606b0
Generalize config classes
jlamypoirier Apr 30, 2025
4a67660
cli
jlamypoirier Apr 30, 2025
531f67d
Merge branch 'main' into generalize_dynamic_classes
jlamypoirier May 2, 2025
1823407
misc
jlamypoirier May 5, 2025
fe7acd9
stuff
jlamypoirier May 7, 2025
94e56e1
combine data source inputs to data_source
nitsanluke May 7, 2025
bee7a4b
Merge remote-tracking branch 'origin/main' into generalize_dynamic_cl…
jlamypoirier May 7, 2025
d41be60
stuff
jlamypoirier May 7, 2025
6a30d76
Merge branch 'generalize_dynamic_classes' into restructure_dataset_co…
nitsanluke May 8, 2025
ec35a50
fixes
jlamypoirier May 8, 2025
1dab7de
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
36b42b9
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
c6876ac
merge
nitsanluke May 8, 2025
eadd49a
Merge branch 'restructure_dataset_config_for_multi_source' of github.…
nitsanluke May 8, 2025
a5b06d8
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
272c63f
Merge branch 'restructure_dataset_config_for_multi_source' of github.…
nitsanluke May 8, 2025
cbcde98
remove duplicate
nitsanluke May 8, 2025
49f9929
adding prompt completion config
nitsanluke May 8, 2025
694181f
name change
nitsanluke May 8, 2025
3e4b746
merge
nitsanluke May 8, 2025
f2ec355
adding concat logic
nitsanluke May 8, 2025
fdf44d3
adding ClassVar type
nitsanluke May 8, 2025
2c790a5
Merge branch 'restructure_dataset_config_for_multi_source' into conca…
nitsanluke May 8, 2025
4c84d20
updates to comments
nitsanluke May 8, 2025
bc0cb30
include loss masking span always available
nitsanluke May 9, 2025
903488e
Update fast_llm/data/preparator/gpt_memmap/prepare.py
nitsanluke May 9, 2025
c8e20ea
address comments
nitsanluke May 9, 2025
0909768
rename to _text_column
nitsanluke May 14, 2025
1a6b78b
remove default_factory for source_schema
nitsanluke May 14, 2025
662f318
minor comment
nitsanluke May 14, 2025
26eef54
merge update
nitsanluke May 14, 2025
8457540
Merge branch 'main' into restructure_dataset_config_for_multi_source
nitsanluke Jun 3, 2025
0ce7571
reset to main
nitsanluke Jun 3, 2025
bc09402
Megatorn-LM reset to main
nitsanluke Jun 3, 2025
62bdeee
remvoe comment
nitsanluke Jun 3, 2025
10d4ccb
Merge branch 'restructure_dataset_config_for_multi_source' into conca…
nitsanluke Jun 3, 2025
66edf33
update error
nitsanluke Jun 3, 2025
3b113d8
Merge branch 'main' into concat_inputs
nitsanluke Jun 16, 2025
236b908
Merge branch 'main' into concat_inputs
nitsanluke Jun 16, 2025
d8cb9f3
Merge branch 'main' into concat_inputs
nitsanluke Jun 26, 2025
dc6f7f2
update masking spans col
nitsanluke Jun 27, 2025
ee8ac01
Merge branch 'main' into concat_inputs
tscholak Jul 11, 2025
02295fe
Merge branch 'main' into concat_inputs
nitsanluke Jul 23, 2025
a300deb
address comments
nitsanluke Jul 23, 2025
2ec9ef8
Merge branch 'concat_inputs' of github.com:ServiceNow/Fast-LLM into c…
nitsanluke Jul 23, 2025
6fe76a0
Merge branch 'main' into concat_inputs
nitsanluke Jul 24, 2025
e9e69cf
update comment
nitsanluke Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ class SourceSchemaConfig(Config):
pass


@config_class(dynamic_type={SourceSchemaConfig: "prompt_completion"})
class PromptCompletionConfig(SourceSchemaConfig):
prompt_column: str = Field(
default="prompt",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
completion_column: str = Field(
default="completion",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
delimiter: str = Field(
default="",
desc="Delimiter between prompt and completion.",
hint=FieldHint.optional,
)


@config_class(dynamic_type={SourceSchemaConfig: "text_column"})
class TextColumnConfig(SourceSchemaConfig):
input_column: str = Field(
Expand Down
92 changes: 64 additions & 28 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.config import DatasetPreparator
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig
from fast_llm.data.preparator.gpt_memmap.config import (
GPTMemmapDatasetPreparatorConfig,
PromptCompletionConfig,
TextColumnConfig,
)
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
Expand All @@ -50,6 +54,30 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[
"num_tokens": num_tokens,
}

def _tokenize_prompt_completion_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
"""
Tokenize prompt and completion columns separately, then concatenate.
Returns input_ids, token_spans (prompt len), and num_tokens.
"""
prompt_col = self._config.dataset.source_schema.prompt_column
completion_col = self._config.dataset.source_schema.completion_column
delimiter = self._config.dataset.source_schema.delimiter
input_ids = []
token_spans = []
for prompt, completion in zip(batch[prompt_col], batch[completion_col]):
prompt_tokens = self._tokenizer.tokenize(prompt, begin=True, end=False)
completion_tokens = self._tokenizer.tokenize(f"{delimiter}{completion}", begin=False, end=True)
combined = prompt_tokens + completion_tokens
input_ids.append(np.array(combined, dtype=self._data_type.numpy))
token_spans.append(np.array((0, len(prompt_tokens) - 1), dtype=np.int32).reshape(-1, 2))

num_tokens = [len(x) for x in input_ids]
return {
"input_ids": input_ids,
"token_spans": token_spans,
"num_tokens": num_tokens,
}

def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
input_ids, token_spans = map(
list,
Expand Down Expand Up @@ -143,7 +171,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
shard_output_path = self._config.output_path / prefix

def _document_generator():
if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None:
if "token_spans" in shard_dataset.column_names:
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield GPTSample(
np.array(item["input_ids"], dtype=self._data_type.numpy),
Expand Down Expand Up @@ -289,37 +317,46 @@ def run(self) -> None:
)

# Set data column and loss masking spans column based on source schema
if isinstance(self._config.dataset.source_schema, TextColumnConfig):
self._text_column = self._config.dataset.source_schema.input_column
self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column
source_schema = self._config.dataset.source_schema
if isinstance(source_schema, TextColumnConfig):
self._text_column = source_schema.input_column
self._loss_masking_spans_column = source_schema.loss_masking_spans_column
elif isinstance(source_schema, PromptCompletionConfig):
Assert.incl(source_schema.prompt_column, dataset.column_names)
Assert.incl(source_schema.completion_column, dataset.column_names)
tokenize_fn = self._tokenize_prompt_completion_batch
else:
raise ValueError(
f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'."
)

if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")
# TODO: Add a new schema for preference datasets then drop class vars _loss_masking_spans_column & _text_column
if isinstance(source_schema, TextColumnConfig):
if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")

if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.")
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch
if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(
f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'."
)
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
Expand All @@ -331,7 +368,6 @@ def run(self) -> None:

# Calculate total number of tokens
total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens"))

# Split dataset into shards based on number of tokens
num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard))
shards = [
Expand Down