Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
6 changes: 4 additions & 2 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ model:
add_linear_biases: false
init_method_std: 0.009021
hidden_dropout: 0.0
vocab_size: 32000
tie_word_embeddings: false
embeddings_layer:
vocab_size: 32000
output_layer:
tied_weight: false
multi_stage:
zero_stage: 2
distributed:
Expand Down
124 changes: 123 additions & 1 deletion fast_llm/engine/config_utils/initialization.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,139 @@
import abc
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
import torch

from fast_llm.tensor import ParameterMeta


class Initializer(abc.ABC):
class Initialization(abc.ABC):
"""
A common base class for initializations and initialization configs so both may be used interchangeably.
"""

@abc.abstractmethod
def get_initializer(self) -> "Initializer":
pass


@config_class(registry=True)
class InitializationConfig(Config, Initialization):
_abstract = True
is_default: typing.ClassVar[bool] = False

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None:
# Default subclass.
return DefaultInitializationConfig._from_dict(default, strict, flat)
return super()._from_dict(default, strict=strict, flat=flat)


@config_class()
class DefaultInitializationConfig(InitializationConfig):
# A placeholder indicating that the class default should be used instead.
_abstract = False
is_default = True

def get_initializer(self) -> "Initializer":
raise NotImplementedError()


@config_class(dynamic_type={InitializationConfig: "fill"})
class FillInitializationConfig(InitializationConfig):
"""
Normal initialization: normal(mean, std).clamp(min,max)
"""

_abstract = False

value: float = Field(
default=1,
desc="Initialization value.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

def get_initializer(self) -> "Initializer":
return init_fill_(self.value)


@config_class(dynamic_type={InitializationConfig: "normal"})
class NormalInitializationConfig(InitializationConfig):
"""
Normal initialization: normal(mean, std).clamp(min,max)
"""

_abstract = False

std: float = Field(
default=1,
desc="Standard deviation for normal initialization.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
mean: float = Field(
default=0,
desc="Mean for normal initialization.",
hint=FieldHint.optional,
)
min: float | None = Field(
default=None,
desc="Min value for initialization clamping.",
hint=FieldHint.optional,
)
max: float | None = Field(
default=None,
desc="Min value for initialization clamping.",
hint=FieldHint.optional,
)

def get_initializer(self) -> "Initializer":
return init_normal_(self.mean, self.std, self.min, self.max)


@config_class(dynamic_type={InitializationConfig: "uniform"})
class UniformInitializationConfig(InitializationConfig):
"""
Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max)
"""

_abstract = False

scale: float = Field(
default=None,
desc="Initialization scale.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)
mean: float = Field(
default=None,
desc="Initialization mean.",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0),
)

def get_initializer(self) -> "Initializer":
return init_uniform_centered_(self.scale, self.mean)


class Initializer(Initialization):
@abc.abstractmethod
def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None:
pass

def get_initializer(self) -> "Initializer":
return self

requires_global_initialization = False


Expand Down
66 changes: 55 additions & 11 deletions fast_llm/engine/config_utils/parameter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,52 @@
import math
import typing

from fast_llm.config import Config, Field, config_class
from fast_llm.engine.config_utils.initialization import Initializer
from fast_llm.config import Config, Field, FieldHint, config_class
from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.layers.common.peft.config import PeftConfig

if typing.TYPE_CHECKING:
from fast_llm.tensor import ParameterMeta


def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]):
# Remove `None` entries.
lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None)
if not lr_scales:
# Everything is None
return None
tuple_length = None
# Check if we have tuples, and determine the length.
for lr_scale in lr_scales:
if isinstance(lr_scale, tuple):
if tuple_length is None:
tuple_length = len(lr_scale)
else:
assert len(lr_scale) == tuple_length
if tuple_length is None:
# No tuple: simple product.
return math.prod(lr_scales)
else:
# Tuple(s): use recursion.
return tuple(
combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales])
for i in range(tuple_length)
)


@config_class()
class ParameterConfig(Config):
initialization: InitializationConfig = Field(
desc="If provided, override the default initialization method set by the parent layer.",
hint=FieldHint.feature,
)
lr_scale: float | None = Field(
default=None,
desc="Scaling factor for the parameter learning rate."
" Combines multiplicatively with the scale set by the parent layer, if applicable.",
hint=FieldHint.feature,
)
# TODO: Initialization, lr_scale

def _validate(self) -> None:
Expand All @@ -18,50 +55,57 @@ def _validate(self) -> None:
def get_parameter(
self,
dims: tuple[TensorDim, ...],
default_initializer: Initializer,
*,
default_initialization: Initialization,
lr_scale: float | None,
weight_decay: bool = True,
allow_sequence_tensor_parallel: bool = True,
peft: PeftConfig | None,
) -> "ParameterMeta":
from fast_llm.tensor import ParameterMeta

return ParameterMeta.from_dims(
out = ParameterMeta.from_dims(
dims,
init_method=default_initializer,
lr_scale=lr_scale,
init_method=default_initialization if self.initialization.is_default else self.initialization,
lr_scale=combine_lr_scales(lr_scale, self.lr_scale),
weight_decay=weight_decay,
allow_sequence_tensor_parallel=allow_sequence_tensor_parallel,
)
if peft is not None:
out = peft.apply_weight(out)
return out


@config_class()
class OptionalParameterConfig(ParameterConfig):
enabled: bool | None = Field(
default=None,
)
# TODO: Initialization, lr_scale

def _validate(self) -> None:
pass

def get_parameter(
self,
dims: tuple[TensorDim, ...],
default_initializer: Initializer,
*,
default_initialization: Initialization,
lr_scale: float | None,
weight_decay: bool = True,
allow_sequence_tensor_parallel: bool = True,
default_enabled: bool = False,
peft: PeftConfig | None,
) -> "ParameterMeta|None":
from fast_llm.tensor import ParameterMeta
pass

if (self.enabled is None and default_enabled) or self.enabled:
return ParameterMeta.from_dims(
return super().get_parameter(
dims,
init_method=default_initializer,
default_initialization=default_initialization,
lr_scale=lr_scale,
weight_decay=weight_decay,
allow_sequence_tensor_parallel=allow_sequence_tensor_parallel,
peft=peft,
)
else:
return None
6 changes: 5 additions & 1 deletion fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def forward(
# Last layer does not provide output
if output is not None:
meta = self._meta_outputs[i]
output_global, _ = meta.local_to_global(output.detach())
if output.shape == meta.shape:
output_global, _ = meta.local_to_global(output.detach())
else:
# TODO: Handle variable shape.
output_global = output
kwargs["hidden_states"][self._layer_range[i]] = {
"layer_type": type(layer).__name__,
"tensor": output_global,
Expand Down
53 changes: 33 additions & 20 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.config import BlockConfig, BlockDimNames
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.utils import combine_lr_scales, div
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import div

try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa
Expand Down Expand Up @@ -57,12 +57,24 @@ def __init__(
config: ConfigType,
block_config: BlockConfig,
distributed_config: DistributedConfig,
*,
# TODO: Review `hidden_dim` and `block_index`
hidden_dim: TensorDim,
block_index: int,
name: str,
lr_scale: float | None,
peft: PeftConfig | None,
):
super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale)
super().__init__(
config,
block_config,
distributed_config,
hidden_dim=hidden_dim,
block_index=block_index,
name=name,
lr_scale=lr_scale,
peft=peft,
)
self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config)

self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor)
Expand Down Expand Up @@ -94,29 +106,34 @@ def __init__(

self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power)

lr_scale = combine_lr_scales(
self._lr_scale,
self._config.attention_lr_scale,
)

# TODO: Merge the query and key-value computations? (harder with sequence parallel.)
self.query = self._config.query_layer.get_layer(
hidden_dim,
query_dim,
default_weight_initializer=init_normal_(std=self._block_config.init_method_std),
default_weight_initialization=init_normal_(std=self._block_config.init_method_std),
default_add_bias=self._block_config.add_linear_biases,
default_apply_peft=True,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=self._peft,
)
# TODO: Use value config.
self.key_value = self._config.query_layer.get_layer(
self.key_value = self._config.key_layer.get_layer(
hidden_dim,
key_value_dim,
default_weight_initializer=init_normal_(std=self._block_config.init_method_std),
default_weight_initialization=init_normal_(std=self._block_config.init_method_std),
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=None if self._config.key_layer.apply_peft is None else self._peft,
)
if self._peft is not None and self._config.key_layer.apply_peft is None:
# Default: Apply to value only.
# TODO: Avoid this hack.
self.key_value = self._peft.apply_linear(
self.key_value, True, out_channel_begin=div(key_value_dim.global_size, 2)
)

self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

# Rotary embeddings.
Expand All @@ -126,19 +143,15 @@ def __init__(
self.dense = self._config.dense_layer.get_layer(
dense_dim,
hidden_dim,
default_weight_initializer=init_normal_(
default_weight_initialization=init_normal_(
std=self._block_config.init_method_std / max(2 * self._block_config.num_layers, 1) ** 0.5,
),
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
lr_scale=self._lr_scale,
peft=self._peft,
)

# PEFT.
self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query)
self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value)
self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense)

if self._debug.enabled:
self._query_dims = (
BlockDimNames.batch,
Expand Down
Loading