Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
168 changes: 114 additions & 54 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import contextlib
import copy
import dataclasses
Expand All @@ -11,7 +12,7 @@

import yaml

from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log
from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,7 +138,6 @@ def __init__(
default=dataclasses.MISSING,
default_factory=dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash=None,
compare: bool = True,
metadata=None,
Expand All @@ -146,12 +146,12 @@ def __init__(
if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
raise ValueError("cannot specify both default and default_factory")
if isinstance(default_factory, type) and issubclass(default_factory, Config):
default_factory = _ConfigFactory(default_factory)
raise ValueError("Config classes should not be used as `default_factory`")
super().__init__(
default=default,
default_factory=default_factory,
init=init,
repr=repr,
repr=False,
hash=hash,
compare=compare,
metadata=metadata,
Expand Down Expand Up @@ -223,20 +223,6 @@ def valid(x):
return valid


class _ConfigFactory:
"""
A dataclass default factory that prevents early validation.
Validation is still done through the parent config if needed.
"""

def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]):
self._factory = factory

def __call__(self):
with NoAutoValidate():
return self._factory()


class ValidationError(ValueError):
pass

Expand All @@ -257,7 +243,9 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class(cls=None):
def config_class[
T: Config
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -267,7 +255,7 @@ def wrap(cls):
if hasattr(cls, "__post_init__"):
raise TypeError(f"`__post_init__` should not be implemented for `Config` classes")

wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True))
wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False))

wrapped_init = cls.__init__

Expand All @@ -280,20 +268,31 @@ def __init__(self, **kwargs):
if _AUTO_VALIDATE:
self.validate()

cls.__init__ = __init__
wrapped.__init__ = __init__

wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None

if dynamic_type is not None:
for cls_, name in dynamic_type.items():
print(cls_, name, wrapped)
cls_.register_subclass(name, wrapped)

return wrapped

# See if we're being called as @config_class or @config_class().
if cls is None:
# We're called with parens.
return wrap
return wrap

# We're called as @config_class without parens.
return wrap(cls)

class ConfigMeta(abc.ABCMeta):
def __call__(cls: "type[Config]", **kwargs):
# Always go through `_from_dict` for correct dynamic class selection and nested config instantiation.
if not kwargs.pop("_from_dict_check", False):
# with NoAutoValidate():
return cls._from_dict(kwargs)
return super().__call__(**kwargs)

@dataclasses.dataclass()
class Config:

@dataclasses.dataclass(kw_only=True, repr=False)
class Config(metaclass=ConfigMeta):
"""
An advanced `dataclass` with basic type checking, validation and argparse support.
Typically, a subclass will:
Expand All @@ -307,14 +306,17 @@ class Config:
# Set to true to prevent instantiation.
_abstract: typing.ClassVar[bool] = False
# Keep track of whether an instance has been validated
_validated: bool = Field(init=False, repr=False)
_validated: bool = Field(init=False)
# Keep track of unknown fields so they can be reported during validation.
_unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False)
_unknown_fields: dict[str, typing.Any] = Field(init=False)
# Keep track of explicitly set fields to ensure they get serialized and used as config updates.
_explicit_fields: set[str] = Field(init=False, repr=False)
_explicit_fields: set[str] = Field(init=False)
# Used within `_set_implicit_default` to set implicit defaults for fields
# without them being automatically added to `_explicit_fields`.
_setting_implicit_default: bool | None = Field(init=False, repr=False)
_setting_implicit_default: bool | None = Field(init=False)

# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Expand All @@ -339,7 +341,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None:
)
else:
field = self.get_field(key)
if field.init and field._field_type != dataclasses._FIELD_CLASSVAR:
if field.init and field._field_type == dataclasses._FIELD:
# Adding to explicit field list except within `_set_implicit_default` context,
# during dataclass initialization (`_setting_implicit_default` not yet set)
# and during automated config validation (`_setting_implicit_default=None`)
Expand All @@ -358,17 +360,28 @@ def __delattr__(self, key: str) -> None:
super().__delattr__(key)

@contextlib.contextmanager
def _set_implicit_default(self, _value: bool | int = True):
def _set_implicit_default(self, _value: bool | None = True):
assert self._setting_implicit_default is False
self._setting_implicit_default = _value
yield
self._setting_implicit_default = False

def validate[T](self: T, *, _is_validating: bool = False) -> T:
def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
"""
Validate a class and mark it as read-only
This should not be overridden in derived classes.
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

if not self._validated:
try:
self._validate()
Expand All @@ -388,11 +401,16 @@ def _validate(self) -> None:
Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
self._check_abstract()
if self._abstract:
raise ValidationError(f"{type(self).__name__} is abstract")
if not self.__class_validated__:
raise ValidationError(
f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
errors = []
with self._set_implicit_default(None):
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
value = getattr(self, name)
if isinstance(value, Tag):
Expand Down Expand Up @@ -610,11 +628,7 @@ def _add_field_to_args(
all_fields: bool = False,
serializable: bool = True,
) -> None:
if (
field is not None
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
and not all_fields
):
if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields:
# Exclude class variables and derived fields unless requested explicitly.
return
explicit_field = (
Expand Down Expand Up @@ -677,6 +691,9 @@ def to_copy[
) -> T:
return self.from_dict(self, *updates, strict=strict, update_type=update_type)

def __repr__(self):
return self.to_logs(log_fn=str)

def to_logs[
T
](
Expand Down Expand Up @@ -739,16 +756,24 @@ def _from_dict(
flat: bool = False,
) -> typing.Self:
# TODO v0.3: Remove flat format
out_arg_dict = {}
out_arg_dict = {"_from_dict_check": True}

# TODO v0.3: Remove backward compatibility fix
if "__class__" in default:
del default["__class__"]

try:
actual_cls = cls.get_subclass(default.get("type"))
if actual_cls is not None and actual_cls is not cls:
return actual_cls._from_dict(default, strict=strict, flat=flat)
except KeyError:
# Postpone error to validation.
pass

# Do not validate yet in case the root class sets cross-dependencies in validation.
with NoAutoValidate():
for name, field in cls.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
Expand Down Expand Up @@ -869,22 +894,51 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
f"Config comparison errors:\n " + "\n".join(errors),
log_fn=log_fn,
)
return None

@classmethod
def _check_abstract(cls) -> None:
if cls._abstract:
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
raise ValidationError(
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
Assert.custom(issubclass, cls_, cls)
if cls._registry is None:
raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..")
if name in cls._registry:
old_cls = cls._registry[name]
if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__:
del cls._registry[name]
else:
raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.")
cls._registry[name] = cls_

@classmethod
def get_subclass(cls, name: str | None):
# TODO: Make it case-insensitive?
if name is None:
return None
cls_ = None
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry:
if cls_ is None:
cls_ = base_class._registry[name]
if not issubclass(cls_, cls):
raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})")
elif base_class._registry[name] is not cls_:
# We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization.
# TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit?
raise KeyError(
f"Ambiguous type `{name}` for base class {cls.__name__}."
f" ({cls_.__name__} vs {base_class._registry[name]})"
)
if cls_ is None:
raise KeyError(f"Unknown type {name} for base class {cls.__name__}")
return cls_

def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
for base_class in cls.__mro__:
if issubclass(base_class, Config):
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
f" Make sure to use the @config_class decorator."
Expand Down Expand Up @@ -913,7 +967,6 @@ def __init_subclass__(cls):
valid=value.pop("valid", base_class_field.valid),
default=value.pop("default", base_class_field.default),
default_factory=value.pop("default_factory", base_class_field.default_factory),
repr=value.pop("repr", base_class_field.repr),
hash=value.pop("hash", base_class_field.hash),
compare=value.pop("compare", base_class_field.compare),
metadata=value.pop("metadata", base_class_field.metadata),
Expand All @@ -928,6 +981,13 @@ def __init_subclass__(cls):
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type

# Type for the field. At the end of class definition to avoid shadowing builtin.
type: str | None = Field(
default=None,
desc="The config class name.",
hint=FieldHint.feature,
)


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config
Expand Down
4 changes: 1 addition & 3 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingData]]

sampling: SamplingConfig = Field(
default_factory=SamplingConfig, desc="Default configuration for dataset sampling."
)
sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.")
3 changes: 1 addition & 2 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
_abstract = False

tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
Expand All @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
sampling: GPTSamplingConfig = FieldUpdate()
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):

_abstract = True
sampling: SamplingConfig = Field(
default_factory=SamplingConfig,
desc="Optional override to sampling configuration parameters.",
hint=FieldHint.core,
)
dataset: SampledDatasetConfig = Field(
default_factory=SampledDatasetConfig,
desc="The dataset to sample from.",
hint=FieldHint.core,
)
Expand Down
Loading