diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4885e516..0f83c862 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -85,6 +85,10 @@ def __init__( # TODO: Separate fsdp for tied weights? self._fsdp_index = {name: i for i, fsdp in enumerate(self._fsdps) for name in fsdp.parameter_names} + @property + def requires_grad(self): + return any(fsdp.requires_grad for fsdp in self._fsdps) + @property def mode(self) -> StageMode: assert self._is_setup diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 0399116a..1d4b04c1 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -406,7 +406,7 @@ def _forward(self, context: BatchContext, step: Step) -> None: losses=context.losses, metrics=context.metrics, ) - if context.is_training: + if step.backward_step is not None: context.contexts[step.backward_step.global_index] = grad_context self._record_compute(context, step) return output diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 87a12bfe..4c0e4371 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -141,7 +141,7 @@ def __init__( phase=self._phase, ) - self._steps = self._create_steps() + self._steps, self._first_grad_stage = self._create_steps() self._create_index() @@ -214,8 +214,8 @@ def _create_index(self) -> None: # Consistency checks step_map = self._step_map.copy() for data_index in range(self._batch_config.num_inputs): - for type_ in (StepType.forward, StepType.backward) if self._is_training else (StepType.forward,): - for stage in range(self._num_stages): + for type_ in (StepType.forward, StepType.backward): + for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( step_map.pop((type_, stage, data_index), None) is not None ), f"Missing {type_.value} step with stage={stage}, data_index={data_index}" @@ -225,7 +225,8 @@ def _create_index(self) -> None: for i, step in enumerate(self._steps): if self._is_training: if step.type_ == StepType.forward: - step.backward_step = self.get_step(StepType.backward, *step.map_index[1:]) + if step.stage >= self._first_grad_stage: + step.backward_step = self.get_step(StepType.backward, *step.map_index[1:]) else: step.forward_step = self.get_step(StepType.forward, *step.map_index[1:]) if step.type_ == StepType.forward and step.stage == 0: @@ -236,7 +237,8 @@ def _create_index(self) -> None: step.prev_step = self.get_step( step.type_, step.stage + (1 if step.type_ == StepType.backward else -1), *step.map_index[2:] ) - if step.type_ == StepType.backward and step.stage == 0: + + if step.type_ == StepType.backward and step.stage == self._first_grad_stage: step.next_step = None elif step.type_ == StepType.forward and step.stage == self._num_stages - 1: step.next_step = self.get_step(StepType.backward, *step.map_index[1:]) if self._is_training else None @@ -249,11 +251,15 @@ def _create_index(self) -> None: for step in self._steps: if self._is_training: if step.type_ == StepType.forward: - Assert.gt(step.backward_step.global_index, step.global_index) - Assert.is_(step.backward_step.forward_step, step) + if step.stage >= self._first_grad_stage: + Assert.gt(step.backward_step.global_index, step.global_index) + Assert.is_(step.backward_step.forward_step, step) + else: + assert step.backward_step is None else: Assert.lt(step.forward_step.global_index, step.global_index) - Assert.is_(step.forward_step.backward_step, step) + if step.stage >= self._first_grad_stage: + Assert.is_(step.forward_step.backward_step, step) if step.next_step is not None: Assert.gt(step.next_step.global_index, step.global_index) Assert.is_(step.next_step.prev_step, step) @@ -303,7 +309,10 @@ def _setup_reduce_steps(self, grad_buffer_indices: dict[int, int]) -> None: reduce_step.reduce_accumulate = reduction_count[reduce_step.stage] > 0 reduction_count[reduce_step.stage] += 1 for stage, count in enumerate(reduction_count): - assert (count > 0) == (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank) + assert (count > 0) == ( + stage >= self._first_grad_stage + and (stage % self._distributed.pipeline_parallel == self._distributed.pipeline_rank) + ) def _setup_timeline(self) -> None: # TODO: Include network time @@ -468,8 +477,16 @@ def get_data_index_split( micro_sequence, ) - def _create_steps(self) -> list[Step]: + def _create_steps(self) -> tuple[list[Step], int]: steps = [] + if self._is_training: + # The first stage(s) may not have any trainable parameters, + # in which case we shouldn't run the backward pass. + first_grad_stage = 0 + while first_grad_stage < self._num_stages and not self._multi_stage.stages[first_grad_stage].requires_grad: + first_grad_stage += 1 + else: + first_grad_stage = self._num_stages for depth_first_micro_batch in range(self._batch_config.depth_first_micro_batches): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches): @@ -485,7 +502,7 @@ def _create_steps(self) -> list[Step]: ) ) if self._is_training: - for stage in reversed(range(self._num_stages)): + for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._batch_config.breadth_first_micro_batches): for micro_sequence in reversed(range(self._batch_config.num_micro_sequences)): steps.append( @@ -498,4 +515,4 @@ def _create_steps(self) -> list[Step]: type_=StepType.backward, ) ) - return steps + return steps, first_grad_stage diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 6c7ad1df..45daa002 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -68,6 +68,7 @@ def triton_normalization_backward_kernel_1( n_cols, n_rows, has_bias: tl_constexpr, + parameter_grad: tl_constexpr, zero_centered: tl_constexpr, block_size: tl_constexpr, block_size_row: tl_constexpr, @@ -108,18 +109,19 @@ def triton_normalization_backward_kernel_1( tl.store(grad_input_ptr + offsets, grad_input, mask=mask) # Parameter grad partial sums - parameter_offsets = tl.program_id(0) * n_cols + cols - grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets - grad_weight_partial = (grad_output * input_normalized).to(weight.dtype) - grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :] + if parameter_grad: + parameter_offsets = tl.program_id(0) * n_cols + cols + grad_weight_partial_ptr = grad_weight_partial_ptr + parameter_offsets + grad_weight_partial = (grad_output * input_normalized).to(weight.dtype) + grad_weight_partial = tl.sum(grad_weight_partial, axis=0)[None, :] - if has_bias: - grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets - grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :] + if has_bias: + grad_bias_partial_ptr = grad_bias_partial_ptr + parameter_offsets + grad_bias_partial = tl.sum(grad_output.to(weight.dtype), axis=0)[None, :] - tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask) - if has_bias: - tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa + tl.store(grad_weight_partial_ptr, grad_weight_partial, mask=col_mask) + if has_bias: + tl.store(grad_bias_partial_ptr, grad_bias_partial, mask=col_mask) # noqa @triton_jit() @@ -211,6 +213,11 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin context.clear() has_bias = bias is not None + parameter_grad = weight.requires_grad + assert parameter_grad == hasattr(weight, "grad_buffer") + if has_bias: + assert parameter_grad == bias.requires_grad + grad_output = grad_output.contiguous() n_rows = grad_output.shape[:-1].numel() @@ -232,12 +239,17 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin grad_input = torch.empty_like(grad_output) - grad_is_zero = param_get_and_unset_is_zero(weight) - grad_weight = weight.grad_buffer - # TODO: Any point in making it full precision? - grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols) + if parameter_grad: + grad_is_zero = param_get_and_unset_is_zero(weight) + grad_weight = weight.grad_buffer + # TODO: Any point in making it full precision? + grad_weight_partial = grad_output.new_empty(num_blocks_row, n_cols) + else: + grad_is_zero = True + grad_weight = None + grad_weight_partial = None - if has_bias: + if has_bias and parameter_grad: assert param_get_and_unset_is_zero(bias) == grad_is_zero grad_bias = bias.grad_buffer grad_bias_partial = grad_output.new_empty(num_blocks_row, n_cols) @@ -256,24 +268,26 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin n_cols, n_rows, has_bias, + parameter_grad, zero_centered, block_size, block_size_row, num_warps=num_warps, ) - triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)]( - grad_weight_partial, - grad_bias_partial, - grad_weight, - grad_bias, - num_blocks_row, - n_cols, - has_bias, - not grad_is_zero, - block_size_m, - block_size_n, - num_ctas=1, - ) + if parameter_grad: + triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)]( + grad_weight_partial, + grad_bias_partial, + grad_weight, + grad_bias, + num_blocks_row, + n_cols, + has_bias, + not grad_is_zero, + block_size_m, + block_size_n, + num_ctas=1, + ) return grad_input diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index fff0548c..71c15c9b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -7,6 +7,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorDim + from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm @@ -115,3 +116,59 @@ def _from_dict( cls._handle_renamed_field(default, "normalization_implementation", "implementation") cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") return super()._from_dict(default, strict, flat) + + +class PeftType(str, enum.Enum): + # TODO : Use a dynamic config type instead. + none = "none" + lora = "lora" + + +@config_class() +class PeftArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + +@config_class() +class PeftConfig(PeftArchitectureConfig, BaseModelConfig): + # TODO: Architecture/non-architecture split might not make much sense here. + + type: PeftType = Field( + default=PeftType.none, + desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", + hint=FieldHint.core, + ) + rank: int = Field( + default=8, + desc="The LoRA rank, i.e. the size of the intermediate dimension.", + hint=FieldHint.stability, + ) + alpha: float = Field( + default=8.0, + desc="The LoRA scaling parameter.", + hint=FieldHint.stability, + ) + dropout: float = Field( + default=0.0, + desc="Dropout rate for LoRA.", + hint=FieldHint.stability, + ) + + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + if self.type == PeftType.none: + return linear + elif self.type == PeftType.lora: + from fast_llm.layers.common.peft import lora_linear + + # TODO: Init method? + return lora_linear( + linear, + linear.weight.param_init_method, + linear.weight.param_init_method, + self.rank, + self.alpha, + self.dropout, + **kwargs, + ) + else: + raise NotImplementedError(self.type) diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index afd0d96d..cd19a47a 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -4,14 +4,13 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, input_parallel_linear_backward, input_parallel_linear_forward, - linear_autograd, linear_backward, linear_forward, - output_parallel_linear_autograd, output_parallel_linear_backward, output_parallel_linear_forward, ) @@ -20,7 +19,22 @@ logger = logging.getLogger(__name__) -class LinearBase(torch.nn.Module): +class LinearLike(torch.nn.Module): + def __init__(self): + super().__init__() + self._forward = wrap_forward_backward(self.forward_only, self.backward) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_) + + def forward_only(self, input_: torch.Tensor): + raise NotImplementedError() + + def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tensor: + raise NotImplementedError() + + +class LinearBase(LinearLike): """ A base module for linear layers holding weights and biases. """ @@ -41,6 +55,7 @@ def __init__( self._transposed_weight = transposed_weight self._in_dim = in_dim self._out_dim = out_dim + self._weight_init_method = weight_init_method self.weight = ParameterMeta.from_dims( (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), init_method=weight_init_method, @@ -91,9 +106,6 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return linear_autograd(input_, weight=self.weight, bias=self.bias, transposed_weight=self._transposed_weight) - def forward_only( self, input_: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: @@ -133,16 +145,6 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return output_parallel_linear_autograd( - input_, - weight=self.weight, - bias=self.bias, - group=self._out_dim.parallel_group, - sequence_parallel=self._sequence_parallel, - transposed_weight=self._transposed_weight, - ) - def forward_only(self, input_) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: return output_parallel_linear_forward( input_, @@ -190,6 +192,7 @@ def __init__( ) def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: Use self._forward instead (broken). return input_parallel_linear_autograd( input_, weight=self.weight, @@ -200,15 +203,17 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | No ) def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None, tuple[typing.Any, ...]]: + group = self._in_dim.parallel_group output, context = input_parallel_linear_forward( input_, weight=self.weight, - bias=None if self._group else self.bias, - group=self._in_dim.parallel_group, + bias=None if group else self.bias, + group=group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, ) - return output, self.bias if self._group else None, context + return output, self.bias if group else None, context def backward(self, grad_output: torch.Tensor, context: tuple[typing.Any, ...]) -> torch.Tensor: # noqa + # TODO: Needs grad_bias as input too? return input_parallel_linear_backward(grad_output, context) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 25e8090c..04123014 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -71,11 +71,14 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa output, weight, bias, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. grad_input, grad_weight, grad_bias, _, _ = fast_layer_norm.ln_bwd( grad_output, output, None, inv_var, weight, bias, True ) - accumulate_gradient(weight, grad_weight) - accumulate_gradient(bias, grad_bias) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) + if bias.requires_grad: + accumulate_gradient(bias, grad_bias) return grad_input, None, None, None, None @@ -100,11 +103,14 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa output, weight, bias, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( grad_output, None, inv_var, output, ctx.normalized_shape, weight, bias, ctx.eps, True ) - accumulate_gradient(weight, grad_weight) - accumulate_gradient(bias, grad_bias) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) + if bias.requires_grad: + accumulate_gradient(bias, grad_bias) return grad_input, None, None, None, None @@ -123,10 +129,12 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: # noqa output, weight, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( grad_output.contiguous(), inv_var, output, ctx.normalized_shape, weight, ctx.eps, True ) - accumulate_gradient(weight, grad_weight) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) return grad_input, None, None, None diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py new file mode 100644 index 00000000..3a1966e5 --- /dev/null +++ b/fast_llm/layers/common/peft.py @@ -0,0 +1,90 @@ +import typing + +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.common.linear import Linear, LinearBase + + +def lora_linear( + layer: LinearBase, + init_method_0, + init_method_1, + rank: int, + alpha: float, + dropout: float = 0.0, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, +): + layer.weight.requires_grad = False + in_dim = layer._in_dim + if in_dim.parallel_dim is not None: + assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." + in_dim = TensorDim(in_dim.name, in_dim.global_size) + out_dim = layer._out_dim + if out_dim.parallel_dim is not None: + assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." + out_dim = TensorDim(out_dim.name, out_dim.global_size) + if out_channel_begin is not None or out_channel_end is not None: + if out_channel_begin is None: + out_channel_begin = 0 + if out_channel_end is None: + out_channel_end = out_dim.global_size + # TODO: This won't work with TP. Use Composite dim structure for proper split? + out_dim = TensorDim(out_dim.name, out_channel_end - out_channel_begin) + + middle_dim = TensorDim("lora_middle", rank) + + layer.lora_0 = Linear( + in_dim, + middle_dim, + bias=False, + weight_init_method=init_method_0, + transposed_weight=layer.transposed_weight, + lr_scale=layer.weight.lr_scale, + ) + layer.lora_1 = Linear( + middle_dim, + out_dim, + bias=False, + weight_init_method=init_method_1, + transposed_weight=layer.transposed_weight, + lr_scale=layer.weight.lr_scale, + ) + # TODO: Implement proper backward pass. + layer.lora_0.weight.auto_grad_accumulation = True + layer.lora_1.weight.auto_grad_accumulation = True + + old_forward = layer._forward + + def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # TODO: torch compile? + input_ = input_.detach().requires_grad_() + with torch.enable_grad(): + output = old_forward(input_) + if isinstance(output, tuple): + layer_out, tp_bias = output[0] + assert tp_bias is None + lora_out = (alpha / rank) * layer.lora_1( + layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + ) + if out_channel_begin is None: + output = output + lora_out + else: + output.view(-1, layer_out.size(-1))[:, out_channel_begin:out_channel_end] += lora_out + return output.detach(), (input_, output) + + def backward( + grad_output: torch.Tensor, context: torch.Tensor + ) -> tuple[torch.Tensor, typing.Callable[[], None] | None]: + # TODO: Implement proper backward pass. + input_, output = context + output.backward(grad_output) + return input_.grad + + layer._forward = wrap_forward_backward(forward_only, backward) + layer.forward_only = forward_only + layer.backward = backward + + return layer diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 67e7eb53..1d9406ed 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -74,6 +74,13 @@ def __init__( allow_sequence_tensor_parallel=not config.parallel_embeddings, ) + # PEFT. + self.word_embeddings_weight = self._config.transformer.peft.apply_weight(self.word_embeddings_weight) + if hasattr(self, "position_embeddings_weight"): + self.position_embeddings_weight = self._config.transformer.peft.apply_weight( + self.position_embeddings_weight + ) + @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4c03e393..efca95b4 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -85,6 +85,11 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) + # PEFT. + self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) + if hasattr(self, "output_weights"): + self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index bfa17d1a..c7ae55c5 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,7 +9,12 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerName, +) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -137,6 +142,11 @@ def __init__( lr_scale=self._config.attention_lr_scale, ) + # PEFT. + self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) + self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) + self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor ) -> torch.Tensor: @@ -253,14 +263,14 @@ def _query_key_value_forward( return query, key_value, context def _query_key_value_backward( - self, query_grad: torch.Tensor, key_grad: torch.Tensor, context: dict + self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: # TODO: De-allocate qkv grads quicker. handle = None if self._tensor_space.distributed.sequence_data_group: - key_grad, handle = reduce_scatter_op( - key_grad, + key_value_grad, handle = reduce_scatter_op( + key_value_grad, group=self._tensor_space.distributed.sequence_data_group, dim=1 - context["sequence_first"], async_op=True, @@ -274,11 +284,11 @@ def _query_key_value_backward( if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): if self._sequence_parallel: - key_grad = reduce_scatter_op(key_grad, group=group, dim=0) + key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: - key_grad = reduce_op(key_grad, group=group) + key_value_grad = reduce_op(key_value_grad, group=group) - input_grad.add_(self.key_value.backward(key_grad, context.pop("key_value"))) + input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad def _decide_window_size(self) -> int | None: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1352c7f0..4806e37e 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -10,9 +10,21 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig +from fast_llm.layers.common.config import ( + NormalizationArchitectureConfig, + NormalizationConfig, + PeftArchitectureConfig, + PeftConfig, + PeftType, +) from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.tensor import ParameterMeta + logger = logging.getLogger(__name__) @@ -160,6 +172,77 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" +class TransformerSubLayerName(str, enum.Enum): + # TODO: Use this to replace AddLinearBiasChoices. + query = "query" + key = "key" + value_ = "value" + key_value = "key_value" + dense = "dense" + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class() +class TransformerPeftConfig(PeftConfig): + layers: list[TransformerSubLayerName] = Field( + default_factory=lambda: [TransformerSubLayerName.query, TransformerSubLayerName.value_], + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + if self.type != PeftType.none: + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False + return linear + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + if self.type != PeftType.none and self.freeze_others: + for parameter in module.parameters(): + parameter.requires_grad = False + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.type != PeftType.none and self.freeze_others: + parameter.requires_grad = False + return parameter + + def _validate(self) -> None: + if self.type != PeftType.none: + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, + ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) + + @config_class() class TransformerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -168,6 +251,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.core, ) + peft: PeftArchitectureConfig = Field( + default_factory=PeftArchitectureConfig, + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.core, + ) num_layers: int = Field( default=12, desc="Number of layers in the transformer.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -375,6 +463,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig) + peft: TransformerPeftConfig = FieldUpdate(default_factory=TransformerPeftConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( @@ -613,6 +702,7 @@ def _validate(self) -> None: Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) + for scale in self.mlp_lr_scale: if scale is not None: Assert.geq(scale, 0) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..9b90beff 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -58,6 +58,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s lr_scale=tuple(config.mlp_lr_scale), ) + # PEFT. + self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + class MLP(MLPBase): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 4780dd3a..b65be23f 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -46,6 +46,10 @@ def __init__( self._config, self._tensor_space, f"{self.name} mlp" ) + # PEFT. + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor