Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@ def apply_liger(model_type: str):
'by running `pip install -U liger-kernel`')


def apply_cce(model_type: str):
try:
from cut_cross_entropy.transformers import cce_patch
from swift.llm import ModelType
except ImportError:
raise ImportError('Please upgrade cut-cross-entropy to apply cce kernels to this model '
'by running `pip install -U cut-cross-entropy`')

model_type_map = {
ModelType.llama: 'llama',
ModelType.llama3: 'llama',
ModelType.llama3_1: 'llama',
ModelType.llama3_2: 'llama',
ModelType.mistral: 'mistral',
ModelType.phi3: 'phi3',
ModelType.gemma2: 'gemma2',
ModelType.qwen2: 'qwen2',
ModelType.qwen2_5: 'qwen2',
}

cce_model_type = model_type_map.get(model_type)
if cce_model_type:
cce_patch(cce_model_type)
return

supported_models = ', '.join(sorted(set(model_type_map.values())))
raise ValueError(f'Unsupported cce model_type: {model_type}. Supported types: {supported_models}')


def get_multimodal_target_regex(
model,
*,
Expand Down Expand Up @@ -375,6 +404,9 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t
# Apply liger
apply_liger(args.model_type)

if args.use_cce and 'use_cce' not in inspect.signature(TrainingArguments).parameters:
apply_cce(args.model_type)

if args.is_adapter:
if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners:
# Fix the name of the layer in xcomposer that contains Plora.
Expand Down
12 changes: 11 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments

from swift.plugin import loss_mapping
from swift.utils import get_dist_setting, get_logger, is_liger_available, is_mp, json_parse_to_dict
from swift.utils import get_dist_setting, get_logger, is_cce_available, is_liger_available, is_mp, json_parse_to_dict
from .optimizers.galore import GaLoreConfig

logger = get_logger()
Expand Down Expand Up @@ -53,6 +53,7 @@ class TrainArgumentsMixin:
dataloader_prefetch_factor (Optional[int]): The number of batches loaded in advance by each worker. Defaults
to None.
use_liger_kernel (bool): Whether to use the Liger kernel for optimization. Defaults to False.
use_cce (bool): Whether to use ml-cross-entropy fused kernels for optimization. Defaults to False.
check_model (bool): If True, checks local model files for corruption or modification and provides a warning.
Should be set to False in an offline environment. Defaults to True.
acc_strategy (Literal['token', 'seq']): The strategy for calculating accuracy during training and validation.
Expand Down Expand Up @@ -115,6 +116,7 @@ class TrainArgumentsMixin:
dataloader_persistent_workers: bool = False
dataloader_prefetch_factor: Optional[int] = None
use_liger_kernel: bool = False
use_cce: bool = False

# extra
check_model: bool = True
Expand Down Expand Up @@ -163,11 +165,18 @@ def _init_liger(self):
except Exception:
pass

def _init_cce(self):
if self.use_cce:
assert is_cce_available(), 'use_cce requires cut-cross-entropy, try `pip install cut-cross-entropy`'

def __post_init__(self):
if is_mp() and self.use_liger_kernel:
raise ValueError('liger_kernel does not support device_map. '
'Please use DDP/DeepSpeed for multi-GPU training.')

if self.use_cce and self.use_liger_kernel:
logger.warning('Enabling both use_cce and use_liger_kernel may lead to duplicated kernel patches.')

if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None):
self.optimizer = 'multimodal'
if self.gradient_accumulation_steps is None:
Expand All @@ -181,6 +190,7 @@ def __post_init__(self):
if self.gradient_checkpointing_kwargs:
self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs)
self._init_liger()
self._init_cce()
if self.dataloader_num_workers is None:
if platform.system() == 'Windows':
self.dataloader_num_workers = 0
Expand Down
6 changes: 3 additions & 3 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if (self.label_smoother is not None or compute_loss_func is not None or loss_scale is not None
or self.args.enable_dft_loss or self.args.enable_channel_loss
or self.template.sequence_parallel_size > 1) and 'labels' in inputs:
if self.args.use_liger_kernel:
logger.warning_once('The cross_entropy loss function defined in Liger Kernel will not '
'take effect, potentially leading to increased GPU memory consumption.')
if self.args.use_liger_kernel or getattr(self.args, 'use_cce', False):
logger.warning_once('The cross_entropy loss function defined in Liger Kernel or ml-cross-entropy will '
'not take effect, potentially leading to increased GPU memory consumption.')
labels = inputs.pop('labels')
outputs = model(**inputs)
if getattr(outputs, 'aux_loss', None) is not None:
Expand Down
11 changes: 11 additions & 0 deletions swift/ui/llm_grpo/llm_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ class LLMGRPO(LLMTrain):
'en': 'Liger kernel can reduce memory usage'
}
},
'use_cce': {
'label': {
'zh': '使用CCE加速',
'en': 'Use CCE acceleration'
},
'info': {
'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子',
'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels'
}
},
'sequence_parallel_size': {
'label': {
'zh': '序列并行大小',
Expand Down Expand Up @@ -233,6 +243,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
gr.Textbox(elem_id='seed', scale=4)
gr.Dropdown(elem_id='torch_dtype', scale=4)
gr.Checkbox(elem_id='use_liger_kernel', scale=4)
gr.Checkbox(elem_id='use_cce', scale=4)
gr.Textbox(elem_id='sequence_parallel_size', lines=1, scale=4)
with gr.Row():
gr.Dropdown(
Expand Down
11 changes: 11 additions & 0 deletions swift/ui/llm_rlhf/llm_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ class LLMRLHF(LLMTrain):
'en': 'Liger kernel can reduce memory usage'
}
},
'use_cce': {
'label': {
'zh': '使用CCE加速',
'en': 'Use CCE acceleration'
},
'info': {
'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子',
'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels'
}
},
'sequence_parallel_size': {
'label': {
'zh': '序列并行大小',
Expand Down Expand Up @@ -246,6 +256,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
gr.Textbox(elem_id='seed', scale=2)
gr.Dropdown(elem_id='torch_dtype', scale=2)
gr.Checkbox(elem_id='use_liger_kernel', scale=2)
gr.Checkbox(elem_id='use_cce', scale=2)
with gr.Row():
gr.Dropdown(
elem_id='gpu_id',
Expand Down
17 changes: 17 additions & 0 deletions swift/ui/llm_train/llm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ class LLMTrain(BaseUI):
'en': 'Liger kernel can reduce memory usage'
}
},
'use_cce': {
'label': {
'zh': '使用CCE加速',
'en': 'Use CCE acceleration'
},
'info': {
'zh': 'CCE (ml-cross-entropy) 提供融合的交叉熵算子',
'en': 'CCE (ml-cross-entropy) provides fused cross-entropy kernels'
}
},
'sequence_parallel_size': {
'label': {
'zh': '序列并行大小',
Expand Down Expand Up @@ -257,6 +267,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
gr.Textbox(elem_id='seed', scale=4)
gr.Dropdown(elem_id='torch_dtype', scale=4)
gr.Checkbox(elem_id='use_liger_kernel', scale=4)
gr.Checkbox(elem_id='use_cce', scale=4)
with gr.Row():
gr.Dropdown(
elem_id='gpu_id',
Expand Down Expand Up @@ -390,6 +401,9 @@ def train(cls, *args):
use_liger_kernel = kwargs.get('use_liger_kernel', None)
if use_liger_kernel:
kwargs.pop('use_liger_kernel')
use_cce = kwargs.get('use_cce', None)
if use_cce:
kwargs.pop('use_cce')
if other_kwargs.get('use_muon'):
kwargs['use_muon'] = other_kwargs.pop('use_muon')

Expand Down Expand Up @@ -428,6 +442,9 @@ def train(cls, *args):
if use_liger_kernel:
params += f'--use_liger_kernel {cls.quote}{use_liger_kernel}{cls.quote} '
command.extend(['--use_liger_kernel', f'{use_liger_kernel}'])
if use_cce:
params += f'--use_cce {cls.quote}{use_cce}{cls.quote} '
command.extend(['--use_cce', f'{use_cce}'])
if use_muon:
params += f'--optimizer {cls.quote}muon{cls.quote} '
command.extend(['--optimizer', 'muon'])
Expand Down
2 changes: 1 addition & 1 deletion swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .env import (get_dist_setting, get_hf_endpoint, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled,
is_dist, is_last_rank, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
from .import_utils import (is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available,
from .import_utils import (is_cce_available, is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available,
is_lmdeploy_available, is_megatron_available, is_swanlab_available, is_trl_available,
is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available)
from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl
Expand Down
4 changes: 4 additions & 0 deletions swift/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def is_liger_available():
return importlib.util.find_spec('liger_kernel') is not None


def is_cce_available():
return importlib.util.find_spec('cut_cross_entropy') is not None


def is_swanlab_available():
return importlib.util.find_spec('swanlab') is not None

Expand Down
26 changes: 26 additions & 0 deletions tests/train/test_cce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
kwargs = {
'per_device_train_batch_size': 2,
'save_steps': 30,
'gradient_accumulation_steps': 2,
'num_train_epochs': 1,
}


def test_sft():
from swift.llm import sft_main, TrainArguments, infer_main, InferArguments
result = sft_main(
TrainArguments(
model='Qwen/Qwen2.5-0.5B-Instruct',
dataset=['swift/self-cognition#200'],
split_dataset_ratio=0.01,
use_cce=True,
**kwargs))
last_model_checkpoint = result['last_model_checkpoint']
infer_main(InferArguments(adapters=last_model_checkpoint, load_data_args=True))


if __name__ == '__main__':
test_sft()