diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 8c3b68a621..42813c1a49 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -250,6 +250,8 @@ gradient_checkpointing: true - 🔥neftune_noise_alpha: neftune添加的噪声系数。默认为0,通常可以设置为5、10、15。 - 🔥use_liger_kernel: 是否启用[Liger](https://github.com/linkedin/Liger-Kernel)内核加速训练并减少显存消耗。默认为False。示例shell参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger)。 - 注意:liger_kernel不支持device_map,请使用DDP/DeepSpeed进行多卡训练。liger_kernel目前只支持`task_type='causal_lm'`。 +- use_tiled_mlp: 是否启用Tiled MLP进行内存高效的长序列训练。启用后,MLP层会被替换为分块实现,将序列分成多个shard进行计算以减少显存占用。默认为False。 +- tiled_mlp_num_shards: Tiled MLP计算时将序列分成的shard数量。默认为None,即设置为4。较大的值可以减少显存但可能增加计算时间。 - average_tokens_across_devices: 是否在设备之间进行token数平均。如果设置为True,将使用all_reduce同步`num_tokens_in_batch`以进行精确的损失计算。默认为False。 - max_grad_norm: 梯度裁剪。默认为1.。 - 注意:日志中的grad_norm记录的是裁剪前的值。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 9316c7ecd5..53812b4eaa 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -252,6 +252,8 @@ Other important parameters: - 🔥neftune_noise_alpha: Noise magnitude for NEFTune. Default is 0. Common values: 5, 10, 15. - 🔥use_liger_kernel: Whether to enable the [Liger](https://github.com/linkedin/Liger-Kernel) kernel to accelerate training and reduce GPU memory consumption. Defaults to False. Example shell script can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/liger). - Note: Liger kernel does not support `device_map`. Use DDP or DeepSpeed for multi-GPU training. Currently, liger_kernel only supports `task_type='causal_lm'`. +- use_tiled_mlp: Whether to enable Tiled MLP for memory-efficient long sequence training. When enabled, MLP layers are replaced with a tiled implementation that processes sequences in chunks to reduce memory usage. Defaults to False. +- tiled_mlp_num_shards: Number of shards to split the sequence for tiled MLP computation. Defaults to None, which sets it to 4. Larger values reduce memory but may increase computation time. - average_tokens_across_devices: Whether to average token counts across devices. If `True`, `num_tokens_in_batch` is synchronized via `all_reduce` for accurate loss computation. Default is `False`. - max_grad_norm: Gradient clipping. Default is 1. - Note: The logged `grad_norm` reflects the value **before** clipping. diff --git a/examples/train/tiled_mlp/fsdp2.json b/examples/train/tiled_mlp/fsdp2.json new file mode 100644 index 0000000000..18cce13780 --- /dev/null +++ b/examples/train/tiled_mlp/fsdp2.json @@ -0,0 +1,25 @@ +{ + "compute_environment": "LOCAL_MACHINE", + "debug": false, + "distributed_type": "FSDP", + "downcast_bf16": "no", + "fsdp_config": { + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_reshard_after_forward": true, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_activation_checkpointing": true, + "fsdp_version": 2 + }, + "machine_rank": 0, + "main_training_function": "main", + "mixed_precision": "bf16", + "num_machines": 1, + "num_processes": 2, + "rdzv_backend": "static", + "same_network": true, + "tpu_env": [], + "tpu_use_cluster": false, + "tpu_use_sudo": false, + "use_cpu": false +} diff --git a/examples/train/tiled_mlp/train_deepspeed.sh b/examples/train/tiled_mlp/train_deepspeed.sh new file mode 100644 index 0000000000..244677b3ac --- /dev/null +++ b/examples/train/tiled_mlp/train_deepspeed.sh @@ -0,0 +1,24 @@ +CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=2 \ +swift sft \ + --model Qwen/Qwen3-4B \ + --dataset swift/self-cognition#200 \ + --train_type full \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-5 \ + --weight_decay 0.1 \ + --gradient_accumulation_steps 1 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --use_tiled_mlp true \ + --tiled_mlp_num_shards 4 \ + --deepspeed zero3 diff --git a/examples/train/tiled_mlp/train_fsdp2.sh b/examples/train/tiled_mlp/train_fsdp2.sh new file mode 100644 index 0000000000..5d2372602d --- /dev/null +++ b/examples/train/tiled_mlp/train_fsdp2.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# FSDP2 training with tiled MLP +# Requires accelerate config with fsdp_version: 2 + +# First, create the accelerate config (fsdp2.json) or use the one in examples/train/multi-gpu/fsdp2_lora/ + +# FSDP2 with tiled MLP +accelerate launch --config_file fsdp2.json \ + -m swift sft \ + --model Qwen/Qwen3-4B \ + --dataset swift/self-cognition#200 \ + --train_type full \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-5 \ + --gradient_checkpointing false \ + --weight_decay 0.1 \ + --gradient_accumulation_steps 1 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 2048 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --use_tiled_mlp true \ + --tiled_mlp_num_shards 4 diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 841bdb9ffa..5e15c73d88 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -51,6 +51,10 @@ def _prepare_generation_config(self): @RayHelper.function(group='default') def _prepare_model_tokenizer(self, **kwargs): args = self.args + # Apply tiled MLP before model instantiation + if getattr(args, 'use_tiled_mlp', False): + from swift.plugin.tiled_mlp import apply_tiled_mlp + apply_tiled_mlp(args.model_type, num_shards=getattr(args, 'tiled_mlp_num_shards', None)) self.model, self.processor = args.get_model_processor(**kwargs) if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index a612651ea0..23fb43b72f 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -17,6 +17,7 @@ from .rm_plugin import rm_plugins from .env import envs, Env from .context_manager import context_managers, ContextManager + from .tiled_mlp import (TiledSwiGLUMLP, apply_tiled_mlp, is_fsdp2_enabled, is_fsdp1_enabled, get_tiled_mlp_mode) else: _import_structure = { @@ -34,6 +35,8 @@ 'rm_plugin': ['rm_plugins'], 'env': ['envs', 'Env'], 'context_manager': ['context_managers', 'ContextManager'], + 'tiled_mlp': + ['TiledSwiGLUMLP', 'apply_tiled_mlp', 'is_fsdp2_enabled', 'is_fsdp1_enabled', 'get_tiled_mlp_mode'], } import sys diff --git a/swift/plugin/tiled_mlp.py b/swift/plugin/tiled_mlp.py new file mode 100644 index 0000000000..bc0b7ea2ba --- /dev/null +++ b/swift/plugin/tiled_mlp.py @@ -0,0 +1,452 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +""" +Tiled MLP implementation for memory-efficient training. + +- FSDP2: Uses custom TiledMLP implementation (this file) +- DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP +- DeepSpeed/Single NPU: Uses LigerTiledSwiGLUMLP with native PyTorch _mlp_forward +- FSDP1: Raises error (not compatible) +""" +import os +import threading +from typing import List, Optional + +import torch +import torch.nn as nn + +from swift.utils import get_logger + +logger = get_logger() + +# ============================================================================ +# NPU Detection +# ============================================================================ + +IS_NPU = False +try: + import torch_npu # noqa: F401 + IS_NPU = torch.npu.is_available() +except ImportError: + pass + + +def is_npu_available() -> bool: + """Check if NPU is available.""" + return IS_NPU + + +# ============================================================================ +# FSDP2 Compatible TiledMLP Implementation +# ============================================================================ + + +class GradientAccumulator: + """Gradient accumulator for TiledMLP (FSDP2 compatible)""" + + def __init__(self, params: List[torch.nn.Parameter], total_shards: int, dtype: torch.dtype = None): + self.params = params + self.total_shards = total_shards + self.grad_accumulation_dtype = dtype or torch.float32 + self.accumulated_grads = {} + self.hooks = [] + self.lock = threading.Lock() + + for param in self.params: + if param.grad is not None: + self.accumulated_grads[param] = param.grad.to(self.grad_accumulation_dtype) + param.grad = None + else: + self.accumulated_grads[param] = torch.zeros_like(param, dtype=self.grad_accumulation_dtype) + + def install_hooks(self, is_last_shard: bool): + self._remove_hooks() + + def create_hook(param): + + def hook(grad): + with self.lock: + grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) + self.accumulated_grads[param] += grad_to_accum_dtype + + if is_last_shard: + param.grad = None # Critical: prevent double accumulation + final_grad = self.accumulated_grads[param].to(param.dtype) + return final_grad + return None + + return hook + + for param in self.params: + if param.requires_grad: + hook = param.register_hook(create_hook(param)) + self.hooks.append(hook) + + def _remove_hooks(self): + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + def cleanup(self): + self._remove_hooks() + + +class TiledMLPFunction(torch.autograd.Function): + """TiledMLP autograd function for FSDP2 compatibility""" + + @staticmethod + def forward(ctx, fn, self, x, shards, compute_params): + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + # Split on dim=-2 (seqlen dimension) + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + return output_unsharded + + @staticmethod + def backward(ctx, *grads): + fn = ctx.fn + (x, ) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + # Flatten to [bs*seqlen, hidden_size] + hidden_size = x.shape[-1] + x_shape_orig = x.shape + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + + # Pre-allocate input gradient + x_grad = torch.zeros_like(x) + + # Split on dim=0 + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) + + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + + # narrow(0, ...) creates a view that can correctly receive gradients + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step) + + is_last_shard = i + 1 == shards + grad_accumulator.install_hooks(is_last_shard) + + with torch.enable_grad(): + output = fn(self, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + grad_accumulator.cleanup() + del grad_accumulator + + # Restore original shape + x_grad = x_grad.view(x_shape_orig) if x_requires_grad else None + return (None, None, x_grad, None, None) + + +class TiledSwiGLUMLP(nn.Module): + """ + Memory-efficient SwiGLU MLP using tiled computation for FSDP2. + + This module combines SwiGLU activation with tiled processing to handle + very long sequences efficiently. The forward pass is recomputed during + backward to save memory. + + Args: + config: Model configuration with hidden_size and intermediate_size attributes + num_shards: Number of shards to split the sequence. If None, automatically + calculated as ceil(seqlen / hidden_size) + """ + + def __init__(self, config, num_shards: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_shards = num_shards or 4 # Default to 4 shards + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act = nn.SiLU() + + def _mlp_forward(self, module, x): + """Internal MLP forward function for tiled computation.""" + gate = module.gate_proj(x) + up = module.up_proj(x) + return module.down_proj(module.act(gate) * up) + + def forward(self, x): + """ + Forward pass with tiled computation. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_size] + or [seq_len, hidden_size] + Returns: + Output tensor of the same shape as input + """ + compute_params = [ + self.gate_proj.weight, + self.up_proj.weight, + self.down_proj.weight, + ] + return TiledMLPFunction.apply( + self._mlp_forward, + self, + x, + self.num_shards, + compute_params, + ) + + +# ============================================================================ +# Environment Detection Functions +# ============================================================================ + + +def is_fsdp2_enabled() -> bool: + """Check if FSDP2 is enabled via accelerate config.""" + # Check environment variable set by accelerate + if os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() == 'true': + # Check fsdp_version from accelerate config + # FSDP_VERSION is set by accelerate when fsdp_version is specified in config + fsdp_version = os.environ.get('FSDP_VERSION', '1') + if fsdp_version == '2': + return True + # Also check accelerate state if available + try: + from accelerate import PartialState + state = PartialState() + if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: + # Check if fsdp_version is 2 in the plugin + if hasattr(state.fsdp_plugin, 'fsdp_version'): + return state.fsdp_plugin.fsdp_version == 2 + except Exception: + pass + return False + + +def is_fsdp1_enabled() -> bool: + """Check if FSDP1 is enabled via accelerate config.""" + if os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() == 'true': + fsdp_version = os.environ.get('FSDP_VERSION', '1') + if fsdp_version == '2': + return False + # Also check accelerate state if available + try: + from accelerate import PartialState + state = PartialState() + if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: + if hasattr(state.fsdp_plugin, 'fsdp_version'): + return state.fsdp_plugin.fsdp_version != 2 + except Exception: + pass + return True + return False + + +def is_deepspeed_enabled() -> bool: + """Check if DeepSpeed is enabled.""" + from swift.utils import is_deepspeed_enabled as _is_deepspeed_enabled + return _is_deepspeed_enabled() + + +def get_tiled_mlp_mode() -> str: + """ + Determine which tiled MLP implementation to use. + + Returns: + 'fsdp2': Use custom TiledSwiGLUMLP implementation + 'liger': Use liger_kernel's LigerTiledSwiGLUMLP + 'error': FSDP1 detected, should raise error + """ + if is_fsdp2_enabled(): + return 'fsdp2' + elif is_fsdp1_enabled(): + return 'error' + else: + # DeepSpeed, Single GPU, or DDP - use liger kernel + return 'liger' + + +# ============================================================================ +# MLP Replacement Functions +# ============================================================================ + +# Supported model types for tiled MLP +SUPPORTED_MODEL_TYPES = { + 'qwen2', + 'qwen2_5', + 'qwen3', + 'qwen3_vl', +} + + +def _get_mlp_class_for_model(model_type: str) -> str: + """Get the MLP class name for different model architectures.""" + # Map model types to their MLP class names + mlp_class_mapping = { + 'qwen2': 'Qwen2MLP', + 'qwen2_5': 'Qwen2MLP', + 'qwen3': 'Qwen3MLP', + 'qwen3_vl': 'Qwen3VLTextMLP', + } + + if model_type in mlp_class_mapping: + return mlp_class_mapping[model_type] + + # Fallback: capitalize model_type and append 'MLP' + # e.g., 'mistral' -> 'MistralMLP' + return model_type.capitalize() + 'MLP' + + +def apply_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """ + Apply tiled MLP replacement before model instantiation. + + This function should be called BEFORE loading the model to replace + the MLP class in the transformers module. + + Args: + model_type: The model type (e.g., 'llama', 'qwen2') + num_shards: Number of shards for tiled computation + + Raises: + ValueError: If FSDP1 is detected (not compatible) + """ + mode = get_tiled_mlp_mode() + + if mode == 'error': + raise ValueError('Tiled MLP is not compatible with FSDP1. ' + 'Please use FSDP2 (set fsdp_version: 2 in accelerate config) or DeepSpeed.') + + if mode == 'fsdp2': + _apply_custom_tiled_mlp(model_type, num_shards) + elif mode == 'liger': + _apply_liger_tiled_mlp(model_type, num_shards) + + +def _apply_custom_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """Apply custom FSDP2-compatible tiled MLP.""" + num_shards = num_shards or 4 + mlp_class_name = _get_mlp_class_for_model(model_type) + + # Get the transformers module for this model + model_module = _get_transformers_module(model_type) + if model_module is None: + raise ValueError(f'Tiled MLP: Could not find transformers module for model_type={model_type}. ' + f'Supported model types: {SUPPORTED_MODEL_TYPES}') + + # Check if MLP class exists in the module + original_mlp_class = getattr(model_module, mlp_class_name, None) + if original_mlp_class is None: + raise ValueError(f'Tiled MLP: Could not find {mlp_class_name} in {model_module.__name__}. ' + f'model_type={model_type} may not be supported.') + + # Create a wrapper class that uses TiledSwiGLUMLP + class TiledMLPWrapper(TiledSwiGLUMLP): + + def __init__(self, config, **kwargs): + super().__init__(config, num_shards=num_shards) + + # Replace the MLP class + setattr(model_module, mlp_class_name, TiledMLPWrapper) + logger.info(f'Tiled MLP: Replaced {mlp_class_name} with TiledSwiGLUMLP (FSDP2 mode, num_shards={num_shards})') + + +def _apply_liger_tiled_mlp(model_type: str, num_shards: Optional[int] = None): + """ + Apply liger_kernel's tiled MLP implementation. + + For NPU: Uses a subclass with native PyTorch _mlp_forward to replace + LigerSiLUMulFunction (Triton kernel) which is not compatible with NPU. + """ + try: + from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP + except ImportError: + raise ImportError('Tiled MLP: liger_kernel not installed or LigerTiledSwiGLUMLP not available. ' + 'Please install liger-kernel: pip install liger-kernel') + + num_shards = num_shards or 4 + mlp_class_name = _get_mlp_class_for_model(model_type) + + model_module = _get_transformers_module(model_type) + if model_module is None: + raise ValueError(f'Tiled MLP: Could not find transformers module for model_type={model_type}. ' + f'Supported model types: {SUPPORTED_MODEL_TYPES}') + + # Check if MLP class exists in the module + original_mlp_class = getattr(model_module, mlp_class_name, None) + if original_mlp_class is None: + raise ValueError(f'Tiled MLP: Could not find {mlp_class_name} in {model_module.__name__}. ' + f'model_type={model_type} may not be supported.') + + if is_npu_available(): + # NPU: Use subclass with native PyTorch _mlp_forward + # LigerSiLUMulFunction (Triton kernel) is not compatible with NPU + class NPULigerTiledMLPWrapper(LigerTiledSwiGLUMLP): + """LigerTiledSwiGLUMLP with native PyTorch _mlp_forward for NPU.""" + + def __init__(self, config, **kwargs): + super().__init__(config, num_shards=num_shards) + self.act = nn.SiLU() # Add activation for native implementation + + def _mlp_forward(self, module, x): + """Native PyTorch implementation replacing LigerSiLUMulFunction.""" + gate = module.gate_proj(x) + up = module.up_proj(x) + return module.down_proj(self.act(gate) * up) + + setattr(model_module, mlp_class_name, NPULigerTiledMLPWrapper) + logger.info(f'Tiled MLP: Replaced {mlp_class_name} with NPULigerTiledSwiGLUMLP ' + f'(liger mode + NPU native PyTorch, num_shards={num_shards})') + else: + # GPU: Use original Liger kernel + class LigerTiledMLPWrapper(LigerTiledSwiGLUMLP): + + def __init__(self, config, **kwargs): + super().__init__(config, num_shards=num_shards) + + setattr(model_module, mlp_class_name, LigerTiledMLPWrapper) + logger.info(f'Tiled MLP: Replaced {mlp_class_name} with LigerTiledSwiGLUMLP ' + f'(liger mode, num_shards={num_shards})') + + +def _get_transformers_module(model_type: str): + """Get the transformers modeling module for a given model type.""" + import importlib + + module_mapping = { + 'qwen2': 'transformers.models.qwen2.modeling_qwen2', + 'qwen2_5': 'transformers.models.qwen2.modeling_qwen2', + 'qwen3': 'transformers.models.qwen3.modeling_qwen3', + 'qwen3_vl': 'transformers.models.qwen3_vl.modeling_qwen3_vl', + } + + module_name = module_mapping.get(model_type) + + # Fallback: try to construct module name from model_type + if module_name is None: + base_type = model_type + module_name = f'transformers.models.{base_type}.modeling_{base_type}' + + try: + return importlib.import_module(module_name) + except ImportError: + return None diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 5640bc6a6c..336cc67bae 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -53,6 +53,14 @@ class TrainArgumentsMixin: dataloader_prefetch_factor (Optional[int]): The number of batches loaded in advance by each worker. Defaults to None. use_liger_kernel (bool): Whether to use the Liger kernel for optimization. Defaults to False. + use_tiled_mlp (bool): Whether to use tiled MLP for memory-efficient training. When enabled, the MLP layers + are replaced with a tiled implementation that processes sequences in chunks to reduce memory usage. + - FSDP2: Uses custom TiledSwiGLUMLP implementation (compatible) + - DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP + - FSDP1: Raises error (not compatible) + Defaults to False. + tiled_mlp_num_shards (Optional[int]): Number of shards to split the sequence for tiled MLP computation. + If None, defaults to 4. Larger values reduce memory but may increase computation time. Defaults to None. check_model (bool): If True, checks local model files for corruption or modification and provides a warning. Should be set to False in an offline environment. Defaults to True. acc_strategy (Literal['token', 'seq']): The strategy for calculating accuracy during training and validation. @@ -115,6 +123,8 @@ class TrainArgumentsMixin: dataloader_persistent_workers: bool = False dataloader_prefetch_factor: Optional[int] = None use_liger_kernel: bool = False + use_tiled_mlp: bool = False + tiled_mlp_num_shards: Optional[int] = None # extra check_model: bool = True