Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4b606b0
Generalize config classes
jlamypoirier Apr 30, 2025
4a67660
cli
jlamypoirier Apr 30, 2025
531f67d
Merge branch 'main' into generalize_dynamic_classes
jlamypoirier May 2, 2025
1823407
misc
jlamypoirier May 5, 2025
fe7acd9
stuff
jlamypoirier May 7, 2025
94e56e1
combine data source inputs to data_source
nitsanluke May 7, 2025
bee7a4b
Merge remote-tracking branch 'origin/main' into generalize_dynamic_cl…
jlamypoirier May 7, 2025
d41be60
stuff
jlamypoirier May 7, 2025
6a30d76
Merge branch 'generalize_dynamic_classes' into restructure_dataset_co…
nitsanluke May 8, 2025
ec35a50
fixes
jlamypoirier May 8, 2025
1dab7de
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
36b42b9
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
c6876ac
merge
nitsanluke May 8, 2025
eadd49a
Merge branch 'restructure_dataset_config_for_multi_source' of github.…
nitsanluke May 8, 2025
a5b06d8
Update fast_llm/data/preparator/gpt_memmap/config.py
nitsanluke May 8, 2025
272c63f
Merge branch 'restructure_dataset_config_for_multi_source' of github.…
nitsanluke May 8, 2025
cbcde98
remove duplicate
nitsanluke May 8, 2025
694181f
name change
nitsanluke May 8, 2025
fdf44d3
adding ClassVar type
nitsanluke May 8, 2025
0909768
rename to _text_column
nitsanluke May 14, 2025
1a6b78b
remove default_factory for source_schema
nitsanluke May 14, 2025
662f318
minor comment
nitsanluke May 14, 2025
8457540
Merge branch 'main' into restructure_dataset_config_for_multi_source
nitsanluke Jun 3, 2025
0ce7571
reset to main
nitsanluke Jun 3, 2025
bc09402
Megatorn-LM reset to main
nitsanluke Jun 3, 2025
62bdeee
remvoe comment
nitsanluke Jun 3, 2025
28f48e1
update error msg
nitsanluke Jun 3, 2025
7d5bb2f
Merge branch 'main' into restructure_dataset_config_for_multi_source
nitsanluke Jun 16, 2025
51f88be
include checks for error msgs
nitsanluke Jun 16, 2025
48884e6
Merge branch 'main' into restructure_dataset_config_for_multi_source
nitsanluke Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import yaml

from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log
from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -257,7 +257,7 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class(cls=None):
def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -283,13 +283,7 @@ def __init__(self, **kwargs):
cls.__init__ = __init__
return wrapped

# See if we're being called as @config_class or @config_class().
if cls is None:
# We're called with parens.
return wrap

# We're called as @config_class without parens.
return wrap(cls)
return wrap


@dataclasses.dataclass()
Expand All @@ -316,6 +310,9 @@ class Config:
# without them being automatically added to `_explicit_fields`.
_setting_implicit_default: bool | None = Field(init=False, repr=False)

# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]]]

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Make the class read-only after validation.
Expand Down Expand Up @@ -358,7 +355,7 @@ def __delattr__(self, key: str) -> None:
super().__delattr__(key)

@contextlib.contextmanager
def _set_implicit_default(self, _value: bool | int = True):
def _set_implicit_default(self, _value: bool | None = True):
assert self._setting_implicit_default is False
self._setting_implicit_default = _value
yield
Expand Down Expand Up @@ -391,6 +388,10 @@ def _validate(self) -> None:
self._check_abstract()
errors = []
with self._set_implicit_default(None):
# Set the type field, or override it to the provided type with the actual class for clarity and safety.
self.type = self.__class__.__name__
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self._registry[self.type], self.__class__)
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
continue
Expand Down Expand Up @@ -468,6 +469,13 @@ def _validate_element(cls, value, type_, name: str):
raise FieldTypeError(f"Not a type.")
elif issubclass(type_, Config):
cls._validate_element_type(value, type_, strict=False)
# If the value belongs to a proper subclass of `type_`,
# we need an explicitly set `type` field for serialization to remember the actual config class.
if type(value) != type_:
if value.type is None:
value.type = value.__class__.__name__
value._explicit_fields.add("type")

value.validate(_is_validating=True)
else:
value = cls._validate_simple(value, type_)
Expand Down Expand Up @@ -720,7 +728,18 @@ def from_dict(
for keys, value in update.items():
set_nested_dict_value(default, keys, value, update_type)

return cls._from_dict(default, strict)
type_ = default.get("type")
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(f"Unknown config type {type_}.")
actual_cls = cls._registry[type_]
if not issubclass(actual_cls, cls):
raise ValueError(
f"Config class {actual_cls.__name__} (from type {type_}) is not a subclass of {cls.__name__}"
)
return actual_cls._from_dict(default, strict=strict)

@classmethod
def from_flat_dict(
Expand Down Expand Up @@ -879,10 +898,40 @@ def _check_abstract(cls) -> None:
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)

@classmethod
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
cls._registry[cls.__name__] = cls

@classmethod
def get_subclass(cls, name):
# TODO: Make it case-insensitive?
cls_ = None
for base_class in cls.__mro__:
if issubclass(base_class, Config) and name in base_class._registry:
if cls_ is None:
cls_ = base_class._registry[name]
if not issubclass(cls_, cls):
raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})")
elif base_class._registry[name] is not cls_:
# We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization.
# TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit?
raise RuntimeError(
f"Ambiguous type `{name}` for base class {cls.__name__}."
f" ({cls_.__name__} vs {base_class._registry[name]})"
)
if cls_ is None:
raise KeyError(f"Unknown type {name} for base class {cls.__name__}")
return cls_

def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
cls._registry = Registry[str, type[cls]](cls.__name__, {})
Config._registry[cls.__name__] = cls
short_name = cls.__name__.strip("Config")
if short_name != cls.__name__:
Config._registry[short_name] = cls
for base_class in cls.__mro__:
if issubclass(base_class, Config):
assert cls.__class_validated__, (
Expand Down Expand Up @@ -928,6 +977,13 @@ def __init_subclass__(cls):
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type

# Type for the field. At the end of class definition to avoid shadowing builtin.
type: str | None = Field(
default=None,
desc="The config class name.",
hint=FieldHint.core,
)


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config
Expand Down
70 changes: 16 additions & 54 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SamplingParameters,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
Expand Down Expand Up @@ -94,59 +94,7 @@ class GPTSamplingData(SamplingData):

@config_class()
class GPTSampledDatasetConfig(SampledDatasetConfig):

# TODO: Generalize dynamic types?
_registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[
str, type["GPTDatasetConfig"]
]("gpt_dataset_class", {})
type_: typing.ClassVar[str | None] = None
type: str | None = Field(
default=None,
desc="The type of dataset.",
hint=FieldHint.core,
)

def _validate(self) -> None:
if self.type is None:
self.type = self.type_
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.eq(self.type, self.__class__.type_)
super()._validate()

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
type_ = default.get("type")
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
)
actual_cls = cls._registry[type_]
Assert.custom(issubclass, actual_cls, cls)
if actual_cls == cls:
return super()._from_dict(default, strict=strict, flat=flat)
else:
return actual_cls._from_dict(default, strict=strict, flat=flat)

def __init_subclass__(cls) -> None:
if cls._abstract and cls.type_ is not None:
# Abstract classes should not have a `type_`
raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.")
if cls.type_ is not None:
if cls.type_ in cls._registry:
raise ValueError(
f"Registry {cls._registry.name} already contains type {cls.type_}."
f" Make sure all classes either have a unique or `None` type."
)
GPTSampledDatasetConfig._registry[cls.type_] = cls
super().__init_subclass__()
pass


@config_class()
Expand Down Expand Up @@ -558,3 +506,17 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset:
if sampling.distributed.config.rank == 0:
time.sleep(self.sleep)
return GPTRandomDatasetConfig().build_and_sample(sampling)


# Add user-friendly names for the configs.
GPTSampledDatasetConfig.register_subclass("dummy", GPTRandomDatasetConfig)
GPTSampledDatasetConfig.register_subclass("memmap", GPTMemmapDatasetConfig)
GPTSampledDatasetConfig.register_subclass("concatenated", GPTConcatenatedDatasetConfig)
GPTSampledDatasetConfig.register_subclass("slice", GPTDatasetSliceConfig)
GPTSampledDatasetConfig.register_subclass("sampled", GPTSampledDatasetUpdateConfig)
GPTSampledDatasetConfig.register_subclass("blended", GPTBlendedDatasetConfig)
GPTSampledDatasetConfig.register_subclass("file", GPTDatasetFromFileConfig)
GPTSampledDatasetConfig.register_subclass("concatenated_memmap", GPTConcatenatedMemmapConfig)
GPTSampledDatasetConfig.register_subclass("fim", GPTFimSampledDatasetConfig)
GPTSampledDatasetConfig.register_subclass("legacy", GPTLegacyDatasetConfig)
GPTSampledDatasetConfig.register_subclass("test_slow", GPTTestSlowDatasetConfig)
23 changes: 17 additions & 6 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"


class SourceSchemaConfig(Config):
pass

class TextColumnConfig(SourceSchemaConfig):
input_column: str = Field(
default="text",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
loss_masking_spans_column: None | str = Field(
default=None, desc="Field containing character spans to mask for loss computation",
hint=FieldHint.optional
)

@config_class
class GPTHuggingfaceDatasetConfig(Config):
path: str = Field(
Expand Down Expand Up @@ -51,14 +65,11 @@ class GPTHuggingfaceDatasetConfig(Config):
desc="Split of the dataset to use.",
hint=FieldHint.optional,
)
field: str = Field(
default="text",
desc="Field of the dataset to use.",
data_source: SourceSchemaConfig = Field(
default_factory=TextColumnConfig,
desc="Configuration for the data source.",
hint=FieldHint.optional,
)
loss_masking_spans: None | str = Field(
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
)
data_type: DataType | None = Field(
default=None,
desc="Data type of the dataset field."
Expand Down
26 changes: 17 additions & 9 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.config import DatasetPreparator
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
Expand All @@ -37,11 +37,13 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D

_tokenizer: Tokenizer
_data_type: DataType
_data_column: str
_loss_masking_spans_column: str | None

def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
input_ids = [
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy)
for text in batch[self._config.dataset.field]
for text in batch[self._data_column]
]
num_tokens = [len(x) for x in input_ids]
return {
Expand All @@ -61,7 +63,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict
for input_ids, token_spans in [
self._tokenizer.tokenize_with_spans(text, char_spans)
for text, char_spans in zip(
batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans]
batch[self._data_column], batch[self._loss_masking_spans_column]
)
]
]
Expand All @@ -80,7 +82,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
shard_output_path = self._config.output_path / prefix

def _document_generator():
if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None:
if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None:
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield GPTSample(
np.array(item["input_ids"], dtype=self._data_type.numpy),
Expand Down Expand Up @@ -176,6 +178,12 @@ def run(self) -> None:
else self._config.dataset.data_type
)

if isinstance(self._config.dataset.data_source, TextColumnConfig):
self._data_column = self._config.dataset.data_source.input_column
self._loss_masking_spans_column = self._config.dataset.data_source.loss_masking_spans_column
else:
raise ValueError(f"Dataset data_source set incorrectly. data_source: '{self._config.dataset.data_source}'.")

# Initialize distributed processing
if self._config.distributed.world_size > 1:
torch.distributed.init_process_group(
Expand Down Expand Up @@ -212,11 +220,11 @@ def run(self) -> None:
num_shards=self._config.distributed.world_size,
index=self._config.distributed.rank,
)
if self._config.dataset.field not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.")
if self._config.dataset.loss_masking_spans is not None:
if self._config.dataset.loss_masking_spans not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.")
if self._data_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._data_column}'.")
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
else:
tokenize_fn = self._tokenize_batch
Expand Down
2 changes: 2 additions & 0 deletions fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

logger = logging.getLogger(__name__)

torch.distributed.gather


class StateDictCheckpointHandler(CheckpointHandler):
base_file_name: typing.ClassVar[str] = "model"
Expand Down
Loading