diff --git a/examples/models/gpt_oss/mcore.sh b/examples/models/gpt_oss/mcore.sh index e60541b0a7..d5ff76df36 100644 --- a/examples/models/gpt_oss/mcore.sh +++ b/examples/models/gpt_oss/mcore.sh @@ -1,5 +1,6 @@ # mcore>=0.15 # 2 * 40GiB +# dataset format: https://github.com/modelscope/ms-swift/pull/5277 PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=2 \ CUDA_VISIBLE_DEVICES=0,1 \ diff --git a/scripts/benchmark/exp_utils.py b/scripts/benchmark/exp_utils.py index b7209691c8..3823ffab7a 100644 --- a/scripts/benchmark/exp_utils.py +++ b/scripts/benchmark/exp_utils.py @@ -10,7 +10,7 @@ import json import torch -from swift.llm import ExportArguments +from swift.pipelines import ExportArguments from swift.utils import find_free_port, get_device_count, get_logger logger = get_logger() diff --git a/scripts/utils/run_template.py b/scripts/utils/run_template.py index 6c6b0445e4..989befe9f7 100644 --- a/scripts/utils/run_template.py +++ b/scripts/utils/run_template.py @@ -1,4 +1,4 @@ -from swift.llm import TemplateType +from swift.template import TemplateType if __name__ == '__main__': template_name_list = TemplateType.get_template_name_list() diff --git a/swift/llm/argument/__init__.py b/swift/arguments/__init__.py similarity index 100% rename from swift/llm/argument/__init__.py rename to swift/arguments/__init__.py diff --git a/swift/llm/argument/app_args.py b/swift/arguments/app_args.py similarity index 96% rename from swift/llm/argument/app_args.py rename to swift/arguments/app_args.py index 08ae5e8045..fa8a1a9085 100644 --- a/swift/llm/argument/app_args.py +++ b/swift/arguments/app_args.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from typing import Literal, Optional +from swift.model import get_matched_model_meta +from swift.template import get_template_meta from swift.utils import find_free_port, get_logger -from ..model import get_matched_model_meta -from ..template import get_template_meta from .deploy_args import DeployArguments from .webui_args import WebUIArguments diff --git a/swift/llm/argument/base_args/__init__.py b/swift/arguments/base_args/__init__.py similarity index 100% rename from swift/llm/argument/base_args/__init__.py rename to swift/arguments/base_args/__init__.py diff --git a/swift/llm/argument/base_args/base_args.py b/swift/arguments/base_args/base_args.py similarity index 98% rename from swift/llm/argument/base_args/base_args.py rename to swift/arguments/base_args/base_args.py index 3a86b1e59a..79b09ffb62 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/arguments/base_args/base_args.py @@ -6,9 +6,9 @@ import json from swift.hub import get_hub -from swift.llm import Processor, Template, get_model_tokenizer, get_template, load_by_unsloth, safe_snapshot_download -from swift.llm.utils import get_ckpt_dir +from swift.model import get_ckpt_dir, get_model_tokenizer, load_by_unsloth, safe_snapshot_download from swift.plugin import extra_tuners +from swift.template import Processor, Template, get_template from swift.utils import (check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master, json_parse_to_dict, set_device, use_hf_hub) from .data_args import DataArguments diff --git a/swift/llm/argument/base_args/data_args.py b/swift/arguments/base_args/data_args.py similarity index 99% rename from swift/llm/argument/base_args/data_args.py rename to swift/arguments/base_args/data_args.py index 7214af050e..6d7fba521a 100644 --- a/swift/llm/argument/base_args/data_args.py +++ b/swift/arguments/base_args/data_args.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import List, Literal, Optional, Union -from swift.llm import DATASET_MAPPING, register_dataset_info +from swift.dataset import DATASET_MAPPING, register_dataset_info from swift.utils import get_logger, json_parse_to_dict logger = get_logger() diff --git a/swift/llm/argument/base_args/generation_args.py b/swift/arguments/base_args/generation_args.py similarity index 98% rename from swift/llm/argument/base_args/generation_args.py rename to swift/arguments/base_args/generation_args.py index 2e030938d4..5225f7d613 100644 --- a/swift/llm/argument/base_args/generation_args.py +++ b/swift/arguments/base_args/generation_args.py @@ -54,7 +54,7 @@ def _init_stream(self): def get_request_config(self): if getattr(self, 'task_type') != 'causal_lm': return - from swift.llm import RequestConfig + from swift.pipelines import RequestConfig return RequestConfig( max_tokens=self.max_new_tokens, diff --git a/swift/llm/argument/base_args/model_args.py b/swift/arguments/base_args/model_args.py similarity index 98% rename from swift/llm/argument/base_args/model_args.py rename to swift/arguments/base_args/model_args.py index 238216a81f..f1fed725cb 100644 --- a/swift/llm/argument/base_args/model_args.py +++ b/swift/arguments/base_args/model_args.py @@ -9,7 +9,7 @@ import torch from transformers.utils import is_torch_mps_available -from swift.llm import MODEL_MAPPING, HfConfigFactory, get_model_info_meta, get_model_name +from swift.model import MODEL_MAPPING, HfConfigFactory, get_model_info_meta, get_model_name from swift.utils import get_dist_setting, get_logger, json_parse_to_dict logger = get_logger() @@ -116,7 +116,7 @@ def _init_max_memory(self): def _init_torch_dtype(self) -> None: """"If torch_dtype is None, find a proper dtype by the train_type/GPU""" - from swift.llm import TrainArguments + from swift.pipelines import TrainArguments self.torch_dtype: Optional[torch.dtype] = HfConfigFactory.to_torch_dtype(self.torch_dtype) self.torch_dtype: torch.dtype = self._init_model_info() diff --git a/swift/llm/argument/base_args/quant_args.py b/swift/arguments/base_args/quant_args.py similarity index 98% rename from swift/llm/argument/base_args/quant_args.py rename to swift/arguments/base_args/quant_args.py index 868981b983..5fc36c3328 100644 --- a/swift/llm/argument/base_args/quant_args.py +++ b/swift/arguments/base_args/quant_args.py @@ -5,7 +5,7 @@ import torch -from swift.llm import HfConfigFactory +from swift.model import HfConfigFactory from swift.utils import get_modules_to_not_convert @@ -72,7 +72,7 @@ def get_quantization_config(self): if not hasattr(self, 'model_info'): return from transformers import FineGrainedFP8Config - from swift.llm import get_model_tokenizer + from swift.model import get_model_tokenizer with torch.device('meta'): hf_model, _ = get_model_tokenizer(self.model_dir, model_type=self.model_type, return_dummy_model=True) modules_to_not_convert = get_modules_to_not_convert(hf_model) diff --git a/swift/llm/argument/base_args/ray_args.py b/swift/arguments/base_args/ray_args.py similarity index 100% rename from swift/llm/argument/base_args/ray_args.py rename to swift/arguments/base_args/ray_args.py diff --git a/swift/llm/argument/base_args/template_args.py b/swift/arguments/base_args/template_args.py similarity index 99% rename from swift/llm/argument/base_args/template_args.py rename to swift/arguments/base_args/template_args.py index 724ce7b311..4d2256d6fc 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/arguments/base_args/template_args.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import Literal, Optional -from swift.llm import TEMPLATE_MAPPING +from swift.template import TEMPLATE_MAPPING from swift.utils import get_logger logger = get_logger() diff --git a/swift/llm/argument/base_args/utils.py b/swift/arguments/base_args/utils.py similarity index 100% rename from swift/llm/argument/base_args/utils.py rename to swift/arguments/base_args/utils.py diff --git a/swift/llm/argument/deploy_args.py b/swift/arguments/deploy_args.py similarity index 99% rename from swift/llm/argument/deploy_args.py rename to swift/arguments/deploy_args.py index da4995a463..627f1a2676 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/arguments/deploy_args.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Literal, Optional -from swift.llm import safe_snapshot_download +from swift.model import safe_snapshot_download from swift.utils import find_free_port, get_device_count, get_logger from .base_args import BaseArguments from .infer_args import InferArguments diff --git a/swift/llm/argument/eval_args.py b/swift/arguments/eval_args.py similarity index 100% rename from swift/llm/argument/eval_args.py rename to swift/arguments/eval_args.py diff --git a/swift/llm/argument/export_args.py b/swift/arguments/export_args.py similarity index 99% rename from swift/llm/argument/export_args.py rename to swift/arguments/export_args.py index 4e1f55b890..ffe3056b66 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/arguments/export_args.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist -from swift.llm import HfConfigFactory +from swift.model import HfConfigFactory from swift.utils import get_logger, init_process_group, set_default_ddp_config from .base_args import BaseArguments, to_abspath from .merge_args import MergeArguments diff --git a/swift/llm/argument/infer_args.py b/swift/arguments/infer_args.py similarity index 100% rename from swift/llm/argument/infer_args.py rename to swift/arguments/infer_args.py diff --git a/swift/llm/argument/merge_args.py b/swift/arguments/merge_args.py similarity index 100% rename from swift/llm/argument/merge_args.py rename to swift/arguments/merge_args.py diff --git a/swift/llm/argument/rlhf_args.py b/swift/arguments/rlhf_args.py similarity index 99% rename from swift/llm/argument/rlhf_args.py rename to swift/arguments/rlhf_args.py index 44c28228ad..cd3373dd4f 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/arguments/rlhf_args.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional -from swift.llm import MODEL_MAPPING +from swift.model import MODEL_MAPPING from swift.trainers import GRPOArgumentsMixin, RLHFArgumentsMixin from swift.utils import get_current_device, get_logger, is_master, is_mp, json_parse_to_dict, set_default_ddp_config from .train_args import TrainArguments @@ -413,7 +413,7 @@ def _init_external_vllm(self): if self.rlhf_type not in rlhf_support_vllm_types or (self.vllm_server_host is None and self.vllm_server_base_url is None): return - from swift.trainers.rlhf_trainer.vllm_client import VLLMClient + from swift.trainers.rlhf_trainers.vllm_client import VLLMClient if is_master(): logger.info('Start connecting to vLLM server') self.vllm_client = VLLMClient( diff --git a/swift/llm/argument/sampling_args.py b/swift/arguments/sampling_args.py similarity index 99% rename from swift/llm/argument/sampling_args.py rename to swift/arguments/sampling_args.py index 81459de3e1..a086508d14 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/arguments/sampling_args.py @@ -7,8 +7,8 @@ import json -from swift.llm import BaseArguments from swift.utils import get_logger +from .base_args import BaseArguments logger = get_logger() diff --git a/swift/llm/argument/train_args.py b/swift/arguments/train_args.py similarity index 100% rename from swift/llm/argument/train_args.py rename to swift/arguments/train_args.py diff --git a/swift/llm/argument/tuner_args.py b/swift/arguments/tuner_args.py similarity index 100% rename from swift/llm/argument/tuner_args.py rename to swift/arguments/tuner_args.py diff --git a/swift/llm/argument/webui_args.py b/swift/arguments/webui_args.py similarity index 100% rename from swift/llm/argument/webui_args.py rename to swift/arguments/webui_args.py diff --git a/swift/cli/app.py b/swift/cli/app.py index ec4e79741a..34af9c9de1 100644 --- a/swift/cli/app.py +++ b/swift/cli/app.py @@ -1,4 +1,4 @@ -from swift.llm import app_main +from swift.pipelines import app_main if __name__ == '__main__': app_main() diff --git a/swift/cli/deploy.py b/swift/cli/deploy.py index 7cf4784fb0..86583704e4 100644 --- a/swift/cli/deploy.py +++ b/swift/cli/deploy.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import deploy_main +from swift.pipelines import deploy_main if __name__ == '__main__': deploy_main() diff --git a/swift/cli/eval.py b/swift/cli/eval.py index 402305ea4e..8f39f5a63e 100644 --- a/swift/cli/eval.py +++ b/swift/cli/eval.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import eval_main +from swift.pipelines import eval_main if __name__ == '__main__': eval_main() diff --git a/swift/cli/export.py b/swift/cli/export.py index 508f1be4bf..8b6a210648 100644 --- a/swift/cli/export.py +++ b/swift/cli/export.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import export_main +from swift.pipelines import export_main if __name__ == '__main__': export_main() diff --git a/swift/cli/infer.py b/swift/cli/infer.py index 2dce4f3acf..ae763a5ffd 100644 --- a/swift/cli/infer.py +++ b/swift/cli/infer.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import infer_main +from swift.pipelines import infer_main if __name__ == '__main__': infer_main() diff --git a/swift/cli/merge_lora.py b/swift/cli/merge_lora.py index 6b01f88249..2369e80499 100644 --- a/swift/cli/merge_lora.py +++ b/swift/cli/merge_lora.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import ExportArguments, SwiftPipeline, merge_lora +from swift.pipelines import ExportArguments, SwiftPipeline, merge_lora class SwiftMergeLoRA(SwiftPipeline): diff --git a/swift/cli/pt.py b/swift/cli/pt.py index 60477214b1..2fa10ed738 100644 --- a/swift/cli/pt.py +++ b/swift/cli/pt.py @@ -3,5 +3,5 @@ if __name__ == '__main__': from swift.cli.utils import try_use_single_device_mode try_use_single_device_mode() - from swift.llm import pt_main + from swift.pipelines import pt_main pt_main() diff --git a/swift/cli/rlhf.py b/swift/cli/rlhf.py index 5d8400fc5a..09bf3879fd 100644 --- a/swift/cli/rlhf.py +++ b/swift/cli/rlhf.py @@ -3,5 +3,5 @@ if __name__ == '__main__': from swift.cli.utils import try_use_single_device_mode try_use_single_device_mode() - from swift.llm import rlhf_main + from swift.pipelines import rlhf_main rlhf_main() diff --git a/swift/cli/rollout.py b/swift/cli/rollout.py index 631ac6f8d9..89f99f9b09 100644 --- a/swift/cli/rollout.py +++ b/swift/cli/rollout.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import rollout_main +from swift.pipelines import rollout_main if __name__ == '__main__': rollout_main() diff --git a/swift/cli/sample.py b/swift/cli/sample.py index c2ae0b326c..a0a9e6a927 100644 --- a/swift/cli/sample.py +++ b/swift/cli/sample.py @@ -3,5 +3,5 @@ if __name__ == '__main__': from swift.ray import try_init_ray try_init_ray() - from swift.llm.sampling import sampling_main + from swift.pipelines import sampling_main sampling_main() diff --git a/swift/cli/sft.py b/swift/cli/sft.py index 27076381da..cd1bb51dc3 100644 --- a/swift/cli/sft.py +++ b/swift/cli/sft.py @@ -16,5 +16,5 @@ def try_init_unsloth(): try_init_unsloth() from swift.ray import try_init_ray try_init_ray() - from swift.llm import sft_main + from swift.pipelines import sft_main sft_main() diff --git a/swift/llm/dataset/__init__.py b/swift/dataset/__init__.py similarity index 87% rename from swift/llm/dataset/__init__.py rename to swift/dataset/__init__.py index 1c989f482f..b5bfa47df1 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/dataset/__init__.py @@ -2,15 +2,14 @@ import datasets.fingerprint from datasets import Dataset as HfDataset -from ..utils import get_temporary_cache_files_directory -from . import dataset +from . import datasets from .loader import DATASET_TYPE, load_dataset from .media import MediaResource from .preprocessor import (AlpacaPreprocessor, AutoPreprocessor, MessagesPreprocessor, ResponsePreprocessor, RowPreprocessor) from .register import DATASET_MAPPING, DatasetMeta, SubsetDataset, register_dataset, register_dataset_info from .utils import (AddLengthPreprocessor, EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, - sample_dataset) + get_temporary_cache_files_directory, sample_dataset) datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory diff --git a/swift/llm/dataset/data/dataset_info.json b/swift/dataset/data/dataset_info.json similarity index 100% rename from swift/llm/dataset/data/dataset_info.json rename to swift/dataset/data/dataset_info.json diff --git a/swift/llm/dataset/dataset/__init__.py b/swift/dataset/dataset/__init__.py similarity index 100% rename from swift/llm/dataset/dataset/__init__.py rename to swift/dataset/dataset/__init__.py diff --git a/swift/llm/dataset/dataset/llm.py b/swift/dataset/dataset/llm.py similarity index 100% rename from swift/llm/dataset/dataset/llm.py rename to swift/dataset/dataset/llm.py diff --git a/swift/llm/dataset/dataset/mllm.py b/swift/dataset/dataset/mllm.py similarity index 100% rename from swift/llm/dataset/dataset/mllm.py rename to swift/dataset/dataset/mllm.py diff --git a/swift/llm/dataset/indexed_dataset.py b/swift/dataset/indexed_dataset.py similarity index 100% rename from swift/llm/dataset/indexed_dataset.py rename to swift/dataset/indexed_dataset.py diff --git a/swift/llm/dataset/loader.py b/swift/dataset/loader.py similarity index 100% rename from swift/llm/dataset/loader.py rename to swift/dataset/loader.py diff --git a/swift/llm/dataset/media.py b/swift/dataset/media.py similarity index 100% rename from swift/llm/dataset/media.py rename to swift/dataset/media.py diff --git a/swift/llm/dataset/preprocessor/__init__.py b/swift/dataset/preprocessor/__init__.py similarity index 100% rename from swift/llm/dataset/preprocessor/__init__.py rename to swift/dataset/preprocessor/__init__.py diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/dataset/preprocessor/core.py similarity index 99% rename from swift/llm/dataset/preprocessor/core.py rename to swift/dataset/preprocessor/core.py index a75ab5b6bc..bcbdda9f93 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/dataset/preprocessor/core.py @@ -13,7 +13,7 @@ from datasets import Sequence, Value from modelscope.hub.utils.utils import get_cache_dir -from swift.llm import history_to_messages +from swift.template import history_to_messages from swift.utils import get_logger, is_dist, is_master, safe_ddp_context DATASET_TYPE = Union[HfDataset, HfIterableDataset] diff --git a/swift/llm/dataset/preprocessor/extra.py b/swift/dataset/preprocessor/extra.py similarity index 100% rename from swift/llm/dataset/preprocessor/extra.py rename to swift/dataset/preprocessor/extra.py diff --git a/swift/llm/dataset/register.py b/swift/dataset/register.py similarity index 100% rename from swift/llm/dataset/register.py rename to swift/dataset/register.py diff --git a/swift/llm/dataset/utils.py b/swift/dataset/utils.py similarity index 92% rename from swift/llm/dataset/utils.py rename to swift/dataset/utils.py index 159ccf8953..97987cb993 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/dataset/utils.py @@ -1,12 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import inspect import math import multiprocessing as mp +import os +import tempfile from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import numpy as np import torch.distributed as dist from datasets import Dataset as HfDataset +from modelscope.hub.utils.utils import get_cache_dir from torch.utils.data import Dataset, IterableDataset from tqdm import tqdm @@ -17,7 +21,7 @@ logger = get_logger() if TYPE_CHECKING: - from swift.llm import Template + from swift.template import Template def sample_dataset( @@ -323,3 +327,27 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: encoded = super().preprocess(row) row['length'] = encoded['length'] return row + + +TEMP_DIR_POOL = {} + + +def get_temporary_cache_files_directory(prefix=None): + if prefix is None: + import datasets.config + prefix = datasets.config.TEMP_CACHE_DIR_PREFIX + global TEMP_DIR_POOL + if prefix in TEMP_DIR_POOL: + TEMP_DIR = TEMP_DIR_POOL[prefix] + else: + tmp_dir = os.path.join(get_cache_dir(), 'tmp') + os.makedirs(tmp_dir, exist_ok=True) + kwargs = {} + parameters = inspect.signature(tempfile.TemporaryDirectory.__init__).parameters + if 'ignore_cleanup_errors' in parameters: + kwargs['ignore_cleanup_errors'] = True + TEMP_DIR = tempfile.TemporaryDirectory(prefix=prefix, dir=tmp_dir, **kwargs) + logger.info(f'create tmp_dir: {TEMP_DIR.name}') + TEMP_DIR_POOL[prefix] = TEMP_DIR + + return TEMP_DIR.name diff --git a/swift/llm/infer/infer_engine/__init__.py b/swift/infer_engine/__init__.py similarity index 100% rename from swift/llm/infer/infer_engine/__init__.py rename to swift/infer_engine/__init__.py diff --git a/swift/llm/infer/infer_engine/base.py b/swift/infer_engine/base.py similarity index 95% rename from swift/llm/infer/infer_engine/base.py rename to swift/infer_engine/base.py index 866f1dc8c0..2a0f6f4750 100644 --- a/swift/llm/infer/infer_engine/base.py +++ b/swift/infer_engine/base.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from typing import AsyncIterator, Iterator, List, Optional, Union -from swift.llm import InferRequest from swift.plugin import Metric -from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig +from .protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig class BaseInferEngine(ABC): diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/infer_engine/grpo_vllm_engine.py similarity index 97% rename from swift/llm/infer/infer_engine/grpo_vllm_engine.py rename to swift/infer_engine/grpo_vllm_engine.py index 7ba34f71f0..715e39c0d5 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/infer_engine/grpo_vllm_engine.py @@ -6,10 +6,12 @@ from PIL import Image from tqdm.asyncio import tqdm_asyncio -from swift.llm import InferRequest, Template, VllmEngine from swift.plugin import Metric -from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput +from swift.template import Template +from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, InferRequest, RequestConfig, + RolloutOutput) from .utils import AdapterRequest +from .vllm_engine import VllmEngine try: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' diff --git a/swift/llm/infer/infer_engine/infer_client.py b/swift/infer_engine/infer_client.py similarity index 97% rename from swift/llm/infer/infer_engine/infer_client.py rename to swift/infer_engine/infer_client.py index c86132b2d3..c6deade6fc 100644 --- a/swift/llm/infer/infer_engine/infer_client.py +++ b/swift/infer_engine/infer_client.py @@ -10,9 +10,9 @@ from requests.exceptions import HTTPError from swift.plugin import Metric -from ..protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, - ModelList, RequestConfig) from .infer_engine import InferEngine +from .protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, + ModelList, RequestConfig) class InferClient(InferEngine): diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/infer_engine/infer_engine.py similarity index 96% rename from swift/llm/infer/infer_engine/infer_engine.py rename to swift/infer_engine/infer_engine.py index 86c9583e40..4de8eb0d92 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/infer_engine/infer_engine.py @@ -9,14 +9,13 @@ from tqdm import tqdm -from swift.llm import InferRequest, ProcessorMixin, get_template -from swift.llm.template import Template -from swift.llm.utils import get_ckpt_dir +from swift.model import get_ckpt_dir from swift.plugin import Metric -from swift.utils import get_logger -from ..protocol import (ChatCompletionMessageToolCall, ChatCompletionResponse, ChatCompletionStreamResponse, - RequestConfig, UsageInfo) +from swift.template import Template, get_template +from swift.utils import ProcessorMixin, get_logger from .base import BaseInferEngine +from .protocol import (ChatCompletionMessageToolCall, ChatCompletionResponse, ChatCompletionStreamResponse, + InferRequest, RequestConfig, UsageInfo) logger = get_logger() @@ -37,7 +36,7 @@ def _post_init(self, template=None): ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) logger.info('Create the default_template for the infer_engine') if ckpt_dir: - from swift.llm import BaseArguments + from swift.pipelines import BaseArguments args = BaseArguments.from_pretrained(ckpt_dir) self.default_template = args.get_template(self.processor) else: diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/infer_engine/lmdeploy_engine.py similarity index 98% rename from swift/llm/infer/infer_engine/lmdeploy_engine.py rename to swift/infer_engine/lmdeploy_engine.py index edbca4f5e6..ecda872556 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/infer_engine/lmdeploy_engine.py @@ -17,13 +17,14 @@ from transformers import GenerationConfig from transformers.utils.versions import require_version -from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.model import get_model_tokenizer from swift.plugin import Metric +from swift.template import Template, TemplateMeta from swift.utils import get_logger, get_seed -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig) from .infer_engine import InferEngine from .patch import patch_auto_config, patch_auto_tokenizer +from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig) from .utils import InferStreamer try: diff --git a/swift/llm/infer/infer_engine/patch.py b/swift/infer_engine/patch.py similarity index 100% rename from swift/llm/infer/infer_engine/patch.py rename to swift/infer_engine/patch.py diff --git a/swift/llm/infer/protocol.py b/swift/infer_engine/protocol.py similarity index 77% rename from swift/llm/infer/protocol.py rename to swift/infer_engine/protocol.py index 52ca01e005..2f6a93a241 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/infer_engine/protocol.py @@ -12,8 +12,7 @@ from PIL import Image from pydantic import BaseModel, Field, field_validator -from ..template import InferRequest -from ..utils import Messages, Tool +from swift.template import Messages, Tool def random_uuid() -> str: @@ -35,6 +34,124 @@ class ModelList: object: str = 'list' +@dataclass +class InferRequest: + """ + Data structure for inference requests. + + Attributes: + messages (Messages): + The input conversation in messages format. Each message is a dict containing at least + a "role" field (e.g., "user", "assistant", "system") and a "content" field. + Example: + [{ + "role": "user", + "content": [ + { + "type": "image", # can also be audio/video + "image": "", + }, + {"type": "text", "text": "Please describe the picture."}, + ], + }] + The above is equivalent to: + [{"role": "user", "content": "Please describe the picture."}] + with an additional argument: + images = [""] + + images (List[Union[str, Image.Image]]): + Optional, a list of images associated with the request. + Each image can be a URL, local path, base64 string, or PIL.Image object. + + audios (List[str]): + Optional, a list of audio resources associated with the request. + + videos (List[str]): + Optional, a list of video resources associated with the request. + + tools (Optional[List[Tool]]): + An optional list of tools. These should be organized in the agent_template format for + tools requested by the system, for example 'react_en'. + + objects (Dict[str, List[Any]]): + Container for additional multimodal objects, grouped by type (key). + """ + messages: Messages + + images: List[Union[str, Image.Image]] = field(default_factory=list) + audios: List[str] = field(default_factory=list) + videos: List[str] = field(default_factory=list) + + tools: Optional[List[Tool]] = None + objects: Dict[str, List[Any]] = field(default_factory=dict) + + def __post_init__(self): + for key in ['images', 'audios', 'videos']: + val = getattr(self, key) + if isinstance(val, str): + setattr(self, key, [val]) + assert isinstance(self.messages, list), f'messages: {self.messages}' + + @staticmethod + def remove_response(messages) -> Optional[str]: + last_role = messages[-1]['role'] if messages else None + if last_role == 'assistant': + return messages.pop()['content'] + + @staticmethod + def _to_printable(obj, key: Optional[str] = None): + if isinstance(obj, str) and key not in {'content', 'text'} and len(obj) >= 1000: + return f'<<>>' + elif isinstance(obj, list): + res = [] + for item in obj: + res.append(InferRequest._to_printable(item)) + return res + elif isinstance(obj, dict): + res = {} + for k, v in obj.items(): + res[k] = InferRequest._to_printable(v, key=k) + return res + return obj + + def to_printable(self): + return InferRequest._to_printable(asdict(self)) + + +@dataclass +class RolloutInferRequest(InferRequest): + """ + An inference request class for rollout scenarios. + + This class extends `InferRequest` and specifically overrides the `images` attribute + to be a list of strings for compatibility with POST requests. Each string may + represent an image URL or a Base64-encoded image. + + Inherits all fields from `InferRequest`: + messages (Messages): + Input conversation messages, supporting multimodal content. + audios (List[str]): + List of audio resources associated with the request. + videos (List[str]): + List of video resources associated with the request. + tools (Optional[List[Tool]]): + List of tools, organized by the agent template (e.g. 'react_en'). + objects (Dict[str, List[Any]]): + Optional container for additional multimodal objects. + + Additional / Overridden fields: + images (List[str]): + List of image resources, each as a string (URL or base64). + data_dict (Dict): + Optional dictionary for extra request data. + uuid (Optional[str]): + Optional unique identifier for this request instance. + """ + images: List[str] = field(default_factory=list) + data_dict: Dict = field(default_factory=dict) + uuid: Optional[str] = None + + @dataclass class RequestConfig: """NOTE: The following behavior is inconsistent with the OpenAI API. diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/infer_engine/pt_engine.py similarity index 98% rename from swift/llm/infer/infer_engine/pt_engine.py rename to swift/infer_engine/pt_engine.py index 09fbe5825f..db25d1f3eb 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/infer_engine/pt_engine.py @@ -18,13 +18,15 @@ from transformers import GenerationConfig, LogitsProcessorList from transformers.utils import is_torch_npu_available -from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device +from swift.model import get_model_tokenizer, safe_snapshot_download from swift.plugin import Metric +from swift.template import Template, TemplateMeta from swift.tuners import Swift -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, - EmbeddingResponseData, RequestConfig, random_uuid) +from swift.utils import to_device from .infer_engine import InferEngine +from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, + EmbeddingResponseData, InferRequest, RequestConfig, random_uuid) from .utils import AdapterRequest, InferStreamer, LogitsStreamer, TokensIteratorStreamer, prepare_generation_config diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/infer_engine/sglang_engine.py similarity index 96% rename from swift/llm/infer/infer_engine/sglang_engine.py rename to swift/infer_engine/sglang_engine.py index fe02ac8a3f..eb979af94e 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/infer_engine/sglang_engine.py @@ -11,13 +11,14 @@ from sglang.srt.server_args import ServerArgs from transformers import GenerationConfig -from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.model import get_model_tokenizer from swift.plugin import Metric +from swift.template import Template, TemplateMeta from swift.utils import get_logger -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, - EmbeddingResponseData, RequestConfig, random_uuid) from .infer_engine import InferEngine +from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, + EmbeddingResponseData, InferRequest, RequestConfig, random_uuid) from .utils import InferStreamer logger = get_logger() diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/infer_engine/utils.py similarity index 99% rename from swift/llm/infer/infer_engine/utils.py rename to swift/infer_engine/utils.py index 7c14b9bb6a..3c48bb4846 100644 --- a/swift/llm/infer/infer_engine/utils.py +++ b/swift/infer_engine/utils.py @@ -17,9 +17,9 @@ from transformers import GenerationConfig, LogitsProcessor from transformers.generation.streamers import BaseStreamer -from swift.llm.model.register import fix_do_sample_warning +from swift.model.register import fix_do_sample_warning from swift.utils import get_current_device, get_device, get_device_count, get_node_setting, set_device -from ..protocol import RequestConfig +from .protocol import RequestConfig @dataclass diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/infer_engine/vllm_engine.py similarity index 98% rename from swift/llm/infer/infer_engine/vllm_engine.py rename to swift/infer_engine/vllm_engine.py index 96d2818cff..db6644b561 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/infer_engine/vllm_engine.py @@ -13,14 +13,15 @@ from transformers import GenerationConfig from transformers.utils import is_torch_npu_available -from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.model import get_model_tokenizer from swift.plugin import Metric +from swift.template import Template, TemplateMeta from swift.utils import get_device, get_dist_setting, get_logger, is_dist -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, - EmbeddingResponseData, RequestConfig, random_uuid) from .infer_engine import InferEngine from .patch import patch_auto_config, patch_auto_tokenizer +from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, + EmbeddingResponseData, InferRequest, RequestConfig, random_uuid) from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm_memory_leak logger = get_logger() diff --git a/swift/llm/utils.py b/swift/llm/utils.py deleted file mode 100644 index b80c54752e..0000000000 --- a/swift/llm/utils.py +++ /dev/null @@ -1,336 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import inspect -import os -import shutil -import tempfile -from types import MethodType -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union - -import torch -import torch.nn as nn -from modelscope.hub.utils.utils import get_cache_dir -from peft import PeftModel -from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase -from transformers import ProcessorMixin as HfProcessorMixin - -from swift.utils import deep_getattr, get_logger - -try: - from transformers import BaseImageProcessor - Processor = Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, HfProcessorMixin] -except ImportError: - Processor = Union[PreTrainedTokenizerBase, FeatureExtractionMixin, HfProcessorMixin] - -if 'TOKENIZERS_PARALLELISM' not in os.environ: - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - -logger = get_logger() - -Tool = Dict[str, Union[str, Dict]] -History = List[Union[Tuple[str, str], List[str]]] -Message = Dict[str, Union[str, List[Dict[str, Any]], List[int], None]] -Messages = List[Message] - - -class ProcessorMixin: - - @property - def tokenizer(self): - tokenizer = self.processor - if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'): - tokenizer = tokenizer.tokenizer - return tokenizer - - @tokenizer.setter - def tokenizer(self, value): - if self.processor is self.tokenizer: - self.processor = value - elif self.tokenizer is not value: - raise AttributeError('Please use `self.processor` for assignment.') - - -def to_float_dtype(data: Any, dtype: torch.dtype) -> Any: - """Change the float inputs to a dtype""" - if isinstance(data, Mapping): - return type(data)({k: to_float_dtype(v, dtype) for k, v in data.items()}) - elif isinstance(data, (tuple, list)): - return type(data)(to_float_dtype(v, dtype) for v in data) - elif isinstance(data, torch.Tensor) and torch.is_floating_point(data): - return data.to(dtype=dtype) - else: - return data - - -def to_device(data: Any, device: Union[str, torch.device, int], non_blocking: bool = False) -> Any: - """Move inputs to a device""" - if isinstance(data, Mapping): - return type(data)({k: to_device(v, device, non_blocking) for k, v in data.items()}) - elif isinstance(data, (tuple, list)): - return type(data)(to_device(v, device, non_blocking) for v in data) - elif isinstance(data, torch.Tensor): - return data.to(device=device, non_blocking=non_blocking) - else: - return data - - -def set_generation_config(model: nn.Module, generation_config: GenerationConfig) -> None: - old_generation_config = getattr(model, 'generation_config', None) - old_generation_priority_config = ['no_repeat_ngram_size', 'num_beams'] - if old_generation_config is not None: - for k, old_v in dir(old_generation_config).items(): - if k.startswith('_'): - continue - v = getattr(generation_config, k, None) - if k in old_generation_priority_config or old_v is not None and v is None: - setattr(generation_config, k, old_v) - model.generation_config = generation_config - - -def find_module_list(model) -> Optional[nn.ModuleList]: - module_lists = [] - for m in model.modules(): - if hasattr(m, 'gradient_checkpointing') or m.__class__.__name__ == 'CheckpointWrapper': - return - if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10 - and 'mlp' not in m[0].__class__.__name__.lower()): # fix moe - module_lists.append(m) - if module_lists: - return max(module_lists, key=lambda x: len(x)) - - -def _kwargs_to_args(func, args, kwargs) -> Optional[List[Any]]: - parameters = inspect.signature(func).parameters - args = list(args) - parameters = list(parameters.items())[len(args):] - for key, param in parameters: - if key in kwargs: - args.append(kwargs[key]) - elif param.default != param.empty: - args.append(param.default) - else: - return - return args - - -def _add_gradient_checkpointing(module_list): - - requires_grad = None - - def _new_forward(self, *args, **kwargs): - nonlocal requires_grad - if requires_grad is None: - requires_grad = any(p.requires_grad for p in self.parameters()) - - new_args = _kwargs_to_args(self.__old_forward, args, kwargs) - if new_args is not None and self.gradient_checkpointing and self.training: - if new_args and isinstance(new_args[0], torch.Tensor) and requires_grad and not new_args[0].requires_grad: - new_args[0].requires_grad_(True) - layer_ret = self._gradient_checkpointing_func(self.__old_forward, *new_args) - logger.info_once('Successfully using dynamic gradient checkpointing.') - else: - layer_ret = self.__old_forward(*args, **kwargs) - return layer_ret - - for module in module_list: - module.gradient_checkpointing = False - if hasattr(module, '_old_forward'): # device_map - __old_forward = module._old_forward - module._old_forward = MethodType(_new_forward, module) - else: - __old_forward = module.forward - module.forward = MethodType(_new_forward, module) - module.__old_forward = __old_forward - - -def dynamic_gradient_checkpointing(model, including_vit: bool = False) -> None: - from .model import ModelMeta - if isinstance(model, PeftModel): - model = model.model - model_meta: ModelMeta = getattr(model, 'model_meta', None) - if model_meta is not None and model_meta.is_multimodal and model_meta.model_arch: - tower_names = model_meta.model_arch.language_model.copy() - if including_vit: - tower_names += model_meta.model_arch.vision_tower - else: - tower_names = [None] - - model.supports_gradient_checkpointing = True - for tower_name in tower_names: - if tower_name is None: - model_tower = model - else: - model_tower = deep_getattr(model, tower_name) - model_tower.supports_gradient_checkpointing = True - module_list = find_module_list(model_tower) - if module_list is None: - continue - _add_gradient_checkpointing(module_list) - logger.info(f'Automatically add gradient_checkpointing to {model_tower.__class__}.') - - -def history_to_messages(history: History, - system: Optional[str] = None, - roles: Optional[List[List[str]]] = None) -> 'Messages': - """ - history: [['query1', 'response1'], ['query2', 'response2']] - or [['query1', 'response1'], ['query2', None]] - """ - messages = [] - if not roles: - roles = [['user', 'assistant']] * len(history) - else: - assert len(roles) == len(history), f'len(roles): {len(roles)}, len(history): {len(history)}' - if system is not None: - messages.append({'role': 'system', 'content': system}) - - for role, h in zip(roles, history): - assert isinstance(h, (list, tuple)) - if h[0] is not None: - messages.append({'role': role[0], 'content': h[0]}) - if h[1] is not None: - messages.append({'role': role[1], 'content': h[1]}) - return messages - - -def messages_to_history(messages: 'Messages') -> Dict[str, Any]: - system = None - messages = messages.copy() - if messages[0]['role'] == 'system': - system = messages[0]['content'] - messages = messages[1::] - if len(messages) % 2 == 1: - messages.append({'role': 'assistant', 'content': None}) - history = [] - history_roles = [] - for user_message, assistant_message in zip(messages[::2], messages[1::2]): - assert user_message['role'] in {'tool', 'user'}, f'user_message {user_message}' - assert assistant_message['role'] == 'assistant', f'assistant_message: {assistant_message}' - history.append([user_message['content'], assistant_message['content']]) - history_roles.append([user_message['role'], assistant_message['role']]) - query, response = history.pop() if history else (None, None) - query_role = history_roles.pop()[0] if history_roles else None - return { - 'history': history, - 'history_roles': history_roles, - 'query': query, - 'query_role': query_role, - 'response': response, - 'system': system, - } - - -def save_checkpoint(model: Optional[PreTrainedModel], - processor: 'Processor', - output_dir: str, - *, - safe_serialization: bool = True, - max_shard_size: Union[int, str] = '5GB', - model_dirs: List[str] = None, - additional_saved_files: Optional[List[str]] = None) -> None: - if model is not None: - if model.__class__.__name__ != 'SentenceTransformer': - model.save_pretrained(output_dir, safe_serialization=safe_serialization, max_shard_size=max_shard_size) - else: - model.save_pretrained(output_dir, safe_serialization=safe_serialization) - # copy sentencetransformers files - from swift.utils import copy_files_by_pattern - copy_files_by_pattern(model.model_dir, output_dir, '*.py') - copy_files_by_pattern(model.model_dir, output_dir, '*.json') - processor.save_pretrained(output_dir) - - if model_dirs is None: - model_dirs = [] - else: - model_dirs = model_dirs.copy() - if model and model.model_dir and model.model_dir not in model_dirs: - model_dirs.append(model.model_dir) - for src_file in (additional_saved_files or []) + ['preprocessor_config.json', 'args.json']: - tgt_path = os.path.join(output_dir, src_file) - if os.path.exists(tgt_path) and src_file == 'args.json': - continue - for model_dir in model_dirs: - src_path: str = os.path.join(model_dir, src_file) - if os.path.isfile(src_path): - shutil.copy(src_path, tgt_path) - break - elif os.path.isdir(src_path): - shutil.copytree(src_path, tgt_path) - break - - -TEMP_DIR_POOL = {} - - -def get_temporary_cache_files_directory(prefix=None): - if prefix is None: - import datasets.config - prefix = datasets.config.TEMP_CACHE_DIR_PREFIX - global TEMP_DIR_POOL - if prefix in TEMP_DIR_POOL: - TEMP_DIR = TEMP_DIR_POOL[prefix] - else: - tmp_dir = os.path.join(get_cache_dir(), 'tmp') - os.makedirs(tmp_dir, exist_ok=True) - kwargs = {} - parameters = inspect.signature(tempfile.TemporaryDirectory.__init__).parameters - if 'ignore_cleanup_errors' in parameters: - kwargs['ignore_cleanup_errors'] = True - TEMP_DIR = tempfile.TemporaryDirectory(prefix=prefix, dir=tmp_dir, **kwargs) - logger.info(f'create tmp_dir: {TEMP_DIR.name}') - TEMP_DIR_POOL[prefix] = TEMP_DIR - - return TEMP_DIR.name - - -def get_ckpt_dir(model_dir: str, adapters_dir: Optional[List[str]]) -> str: - model_dirs = (adapters_dir or []).copy() - if model_dir: - model_dirs.append(model_dir) - # The adapter takes higher priority. - ckpt_dir = None - for model_dir in model_dirs: - if os.path.exists(os.path.join(model_dir, 'args.json')): - ckpt_dir = model_dir - break - return ckpt_dir - - -def update_generation_config_eos_token(generation_config, template): - if generation_config is None: - return - stop_words = template.template_meta.stop_words - eos_token_id = generation_config.eos_token_id - if eos_token_id is None: - eos_token_id = [] - elif isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - modified = False - for stop_word in stop_words: - if stop_word is None: - continue - if isinstance(stop_word, str): - stop_word = template._tokenize(stop_word) - if isinstance(stop_word, (list, tuple)) and len(stop_word) == 1 and stop_word[0] not in eos_token_id: - eos_token_id.append(stop_word[0]) - modified = True - if modified: - generation_config.eos_token_id = eos_token_id - - -def get_packed_seq_params(position_ids: torch.Tensor): - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - position_ids_f = position_ids.flatten() - indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) - - cu_seqlens = torch.cat([ - indices_q[position_ids_f == 0], - torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), - ]) - - max_length = cu_seqlens.diff().max() # position_ids_f.max() + 1 - return { - 'cu_seq_lens_q': cu_seqlens, - 'cu_seq_lens_k': cu_seqlens, - 'max_length_q': max_length, - 'max_length_k': max_length, - } diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 8b89165b73..e78e82d97b 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -18,8 +18,7 @@ from packaging import version from swift.llm import get_packed_seq_params as _get_packed_seq_params -from swift.llm import to_device -from swift.utils import get_logger +from swift.utils import get_logger, to_device from swift.utils.torch_utils import empty_cache, get_current_device mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') diff --git a/swift/llm/model/__init__.py b/swift/model/__init__.py similarity index 80% rename from swift/llm/model/__init__.py rename to swift/model/__init__.py index 34f45685a1..accdd8e5bb 100644 --- a/swift/llm/model/__init__.py +++ b/swift/model/__init__.py @@ -1,14 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from transformers.utils import is_torch_npu_available -from . import model +from . import models from .constant import LLMModelType, MLLMModelType, ModelType from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map, get_default_torch_dtype, get_matched_model_meta, get_model_info_meta, get_model_name, get_model_tokenizer, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, load_by_unsloth, register_model) -from .utils import HfConfigFactory, ModelInfo, get_llm_model, git_clone_github, safe_snapshot_download +from .utils import (HfConfigFactory, ModelInfo, get_ckpt_dir, get_llm_model, git_clone_github, safe_snapshot_download, + save_checkpoint) if is_torch_npu_available(): from . import npu_patcher diff --git a/swift/llm/model/constant.py b/swift/model/constant.py similarity index 97% rename from swift/llm/model/constant.py rename to swift/model/constant.py index e460a5f5a7..8c025c1681 100644 --- a/swift/llm/model/constant.py +++ b/swift/model/constant.py @@ -12,17 +12,12 @@ class LLMModelType: qwen2_moe = 'qwen2_moe' qwq_preview = 'qwq_preview' qwq = 'qwq' + # TODO: others qwen3 = 'qwen3' - qwen3_guard = 'qwen3_guard' - qwen3_thinking = 'qwen3_thinking' - qwen3_nothinking = 'qwen3_nothinking' - qwen3_coder = 'qwen3_coder' qwen3_moe = 'qwen3_moe' - qwen3_moe_thinking = 'qwen3_moe_thinking' qwen3_next = 'qwen3_next' - qwen3_next_thinking = 'qwen3_next_thinking' - qwen3_emb = 'qwen3_emb' + qwen3_emb = 'qwen3_emb' qwen2_gte = 'qwen2_gte' codefuse_qwen = 'codefuse_qwen' @@ -168,6 +163,7 @@ class MLLMModelType: qwen2_audio = 'qwen2_audio' qwen3_vl = 'qwen3_vl' qwen3_moe_vl = 'qwen3_moe_vl' + qvq = 'qvq' qwen2_gme = 'qwen2_gme' ovis1_6 = 'ovis1_6' diff --git a/swift/llm/model/model_arch.py b/swift/model/model_arch.py similarity index 100% rename from swift/llm/model/model_arch.py rename to swift/model/model_arch.py diff --git a/swift/llm/model/model/__init__.py b/swift/model/models/__init__.py similarity index 100% rename from swift/llm/model/model/__init__.py rename to swift/model/models/__init__.py diff --git a/swift/llm/model/model/baai.py b/swift/model/models/baai.py similarity index 99% rename from swift/llm/model/model/baai.py rename to swift/model/models/baai.py index 3ce71f3704..c1ea2d6612 100644 --- a/swift/llm/model/model/baai.py +++ b/swift/model/models/baai.py @@ -5,7 +5,7 @@ from transformers import AutoModel, AutoModelForSequenceClassification -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_device from ..constant import MLLMModelType, RerankerModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/baichuan.py b/swift/model/models/baichuan.py similarity index 99% rename from swift/llm/model/model/baichuan.py rename to swift/model/models/baichuan.py index 599faee7b2..38b8447af2 100644 --- a/swift/llm/model/model/baichuan.py +++ b/swift/model/models/baichuan.py @@ -6,7 +6,7 @@ from torch import Tensor from transformers import AutoConfig -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/baidu.py b/swift/model/models/baidu.py similarity index 98% rename from swift/llm/model/model/baidu.py rename to swift/model/models/baidu.py index 9294334ed3..0816d4f378 100644 --- a/swift/llm/model/model/baidu.py +++ b/swift/model/models/baidu.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/bert.py b/swift/model/models/bert.py similarity index 98% rename from swift/llm/model/model/bert.py rename to swift/model/models/bert.py index 7d35259481..10f24ff7f7 100644 --- a/swift/llm/model/model/bert.py +++ b/swift/model/models/bert.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import BertModelType, RerankerModelType from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model diff --git a/swift/llm/model/model/codefuse.py b/swift/model/models/codefuse.py similarity index 98% rename from swift/llm/model/model/codefuse.py rename to swift/model/models/codefuse.py index 02294a3b4b..58925d5c25 100644 --- a/swift/llm/model/model/codefuse.py +++ b/swift/model/models/codefuse.py @@ -3,7 +3,7 @@ from transformers import AutoTokenizer -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType from ..model_arch import ModelArch from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model diff --git a/swift/llm/model/model/deepseek.py b/swift/model/models/deepseek.py similarity index 99% rename from swift/llm/model/model/deepseek.py rename to swift/model/models/deepseek.py index 166472fff6..b83e55f956 100644 --- a/swift/llm/model/model/deepseek.py +++ b/swift/model/models/deepseek.py @@ -2,7 +2,7 @@ import sys from typing import Any, Dict -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch from ..patcher import patch_output_clone, patch_output_to_input_device diff --git a/swift/llm/model/model/gemma.py b/swift/model/models/gemma.py similarity index 99% rename from swift/llm/model/model/gemma.py rename to swift/model/models/gemma.py index cf91b0aff0..fe18467f6c 100644 --- a/swift/llm/model/model/gemma.py +++ b/swift/model/models/gemma.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch from ..patcher import patch_output_to_input_device diff --git a/swift/llm/model/model/glm.py b/swift/model/models/glm.py similarity index 99% rename from swift/llm/model/model/glm.py rename to swift/model/models/glm.py index de8e1bf69b..1d15382430 100644 --- a/swift/llm/model/model/glm.py +++ b/swift/model/models/glm.py @@ -9,7 +9,7 @@ from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.models.auto.tokenization_auto import get_tokenizer_config -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_device_count, get_dist_setting, get_logger from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/internlm.py b/swift/model/models/internlm.py similarity index 99% rename from swift/llm/model/model/internlm.py rename to swift/model/models/internlm.py index cfeb14f8fc..7d381ece1e 100644 --- a/swift/llm/model/model/internlm.py +++ b/swift/model/models/internlm.py @@ -6,7 +6,7 @@ import torch from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, MLLMModelType, RMModelType from ..model_arch import ModelArch from ..patcher import patch_output_clone, patch_output_to_input_device diff --git a/swift/llm/model/model/llama.py b/swift/model/models/llama.py similarity index 99% rename from swift/llm/model/model/llama.py rename to swift/model/models/llama.py index bda64c2269..a10b399d35 100644 --- a/swift/llm/model/model/llama.py +++ b/swift/model/models/llama.py @@ -5,7 +5,7 @@ from transformers import AutoConfig -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_device from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/llava.py b/swift/model/models/llava.py similarity index 99% rename from swift/llm/model/model/llava.py rename to swift/model/models/llava.py index daf7c68dae..f6583bbb33 100644 --- a/swift/llm/model/model/llava.py +++ b/swift/model/models/llava.py @@ -7,7 +7,7 @@ from transformers import AutoConfig from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import MLLMModelType from ..model_arch import ModelArch from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, diff --git a/swift/llm/model/model/llm.py b/swift/model/models/llm.py similarity index 99% rename from swift/llm/model/model/llm.py rename to swift/model/models/llm.py index 0b7a308fc9..5ec2eb7691 100644 --- a/swift/llm/model/model/llm.py +++ b/swift/model/models/llm.py @@ -3,7 +3,7 @@ from transformers import AutoTokenizer -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/mamba.py b/swift/model/models/mamba.py similarity index 97% rename from swift/llm/model/model/mamba.py rename to swift/model/models/mamba.py index 8517e3ab63..eabd8c5eaf 100644 --- a/swift/llm/model/model/mamba.py +++ b/swift/model/models/mamba.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model diff --git a/swift/llm/model/model/microsoft.py b/swift/model/models/microsoft.py similarity index 99% rename from swift/llm/model/model/microsoft.py rename to swift/model/models/microsoft.py index 98693fcc99..7d5024b7b4 100644 --- a/swift/llm/model/model/microsoft.py +++ b/swift/model/models/microsoft.py @@ -5,7 +5,7 @@ from transformers import AutoConfig -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_device, get_env_args from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/minicpm.py b/swift/model/models/minicpm.py similarity index 99% rename from swift/llm/model/model/minicpm.py rename to swift/model/models/minicpm.py index f1e85e6dde..801f74e35b 100644 --- a/swift/llm/model/model/minicpm.py +++ b/swift/model/models/minicpm.py @@ -6,7 +6,7 @@ from transformers import AutoConfig from transformers.utils import strtobool -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_env_args from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/minimax.py b/swift/model/models/minimax.py similarity index 99% rename from swift/llm/model/model/minimax.py rename to swift/model/models/minimax.py index 2162c05b62..4b29d5bf0f 100644 --- a/swift/llm/model/model/minimax.py +++ b/swift/model/models/minimax.py @@ -6,7 +6,7 @@ from transformers import AutoConfig, AutoProcessor from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_device, get_device_count, get_dist_setting, get_logger from ..constant import LLMModelType, MLLMModelType from ..patcher import patch_ignore_check_imports diff --git a/swift/llm/model/model/mistral.py b/swift/model/models/mistral.py similarity index 99% rename from swift/llm/model/model/mistral.py rename to swift/model/models/mistral.py index e028176591..a836af4823 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/model/models/mistral.py @@ -3,7 +3,7 @@ from transformers import AutoProcessor, AutoTokenizer -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, diff --git a/swift/llm/model/model/mllm.py b/swift/model/models/mllm.py similarity index 98% rename from swift/llm/model/model/mllm.py rename to swift/model/models/mllm.py index 80c096c05a..27f09ff8bb 100644 --- a/swift/llm/model/model/mllm.py +++ b/swift/model/models/mllm.py @@ -5,8 +5,7 @@ import torch from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType -from swift.llm.model.model.qwen import get_model_tokenizer_qwen2_vl +from swift.template import TemplateType from swift.utils import get_logger from ..constant import MLLMModelType, RerankerModelType from ..model_arch import ModelArch @@ -14,7 +13,7 @@ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, register_model) from ..utils import ModelInfo, use_submodel_func -from .qwen import patch_qwen_vl_utils +from .qwen import get_model_tokenizer_qwen2_vl, patch_qwen_vl_utils logger = get_logger() diff --git a/swift/llm/model/model/moonshot.py b/swift/model/models/moonshot.py similarity index 98% rename from swift/llm/model/model/moonshot.py rename to swift/model/models/moonshot.py index 9a9ef32791..5eeb90ab6d 100644 --- a/swift/llm/model/model/moonshot.py +++ b/swift/model/models/moonshot.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch from ..patcher import patch_get_input_embeddings diff --git a/swift/llm/model/model/mplug.py b/swift/model/models/mplug.py similarity index 99% rename from swift/llm/model/model/mplug.py rename to swift/model/models/mplug.py index ff282127f6..e398eeaf5f 100644 --- a/swift/llm/model/model/mplug.py +++ b/swift/model/models/mplug.py @@ -7,7 +7,7 @@ from transformers import AutoConfig from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/openbuddy.py b/swift/model/models/openbuddy.py similarity index 98% rename from swift/llm/model/model/openbuddy.py rename to swift/model/models/openbuddy.py index 5ed08a42f0..d10751c447 100644 --- a/swift/llm/model/model/openbuddy.py +++ b/swift/model/models/openbuddy.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/model/qwen.py b/swift/model/models/qwen.py similarity index 93% rename from swift/llm/model/model/qwen.py rename to swift/model/models/qwen.py index 2a8d3f072c..342ab68fb4 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/model/models/qwen.py @@ -13,8 +13,8 @@ from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils.versions import require_version -from swift.llm import TemplateType, to_device -from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger, is_deepspeed_enabled +from swift.template import TemplateType +from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger, is_deepspeed_enabled, to_device from ..constant import LLMModelType, MLLMModelType, RerankerModelType, RMModelType from ..model_arch import ModelArch from ..patcher import patch_fixed_device, patch_get_input_embeddings, patch_output_clone @@ -507,125 +507,98 @@ def _get_cast_dtype(self) -> torch.dtype: ModelMeta( LLMModelType.qwen3, [ - ModelGroup([ - Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'), - Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'), - Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'), - Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'), - Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'), - # instruct - Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'), - Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'), - Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'), - Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'), - Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'), - Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'), - # fp8 - Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'), - Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'), - Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'), - Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'), - Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'), - Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'), - # awq - Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'), - Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'), - Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'), - Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'), - # swift - Model('swift/Qwen3-32B-AWQ'), - ]), - ], - TemplateType.qwen3, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3ForCausalLM'], - requires=['transformers>=4.51'], - model_arch=ModelArch.llama)) - -register_model( - ModelMeta( - LLMModelType.qwen3_moe, - [ - ModelGroup([ - Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'), - # instruct - Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'), - Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'), - # fp8 - Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'), - Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'), - # awq - Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'), - Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'), - ]), - ModelGroup([ - Model('iic/Tongyi-DeepResearch-30B-A3B', 'Alibaba-NLP/Tongyi-DeepResearch-30B-A3B'), - ]) - ], - TemplateType.qwen3, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3MoeForCausalLM'], - requires=['transformers>=4.51'], - )) - -register_model( - ModelMeta( - LLMModelType.qwen3_guard, - [ + ModelGroup( + [ + Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'), + Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'), + Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'), + Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'), + Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'), + # instruct + Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'), + Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'), + Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'), + Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'), + Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'), + Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'), + ], + template=TemplateType.qwen3_nothinking), + ModelGroup( + [ + # fp8 + Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'), + Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'), + Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'), + Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'), + Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'), + Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'), + # awq + Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'), + Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'), + Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'), + Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'), + # swift + Model('swift/Qwen3-32B-AWQ'), + ], + template=TemplateType.qwen3_mixed), ModelGroup([ Model('Qwen/Qwen3Guard-Gen-0.6B', 'Qwen/Qwen3Guard-Gen-0.6B'), Model('Qwen/Qwen3Guard-Gen-4B', 'Qwen/Qwen3Guard-Gen-4B'), Model('Qwen/Qwen3Guard-Gen-8B', 'Qwen/Qwen3Guard-Gen-8B'), - ]) - ], - TemplateType.qwen3_guard, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3ForCausalLM'], - requires=['transformers>=4.51'], - )) - -register_model( - ModelMeta( - LLMModelType.qwen3_thinking, - [ + ], + template=TemplateType.qwen3_guard), + ModelGroup([ + Model('Qwen/Qwen3-4B-Instruct-2507', 'Qwen/Qwen3-4B-Instruct-2507'), + Model('Qwen/Qwen3-4B-Instruct-2507-FP8', 'Qwen/Qwen3-4B-Instruct-2507-FP8'), + ], + template=TemplateType.qwen3_nothinking), ModelGroup([ Model('Qwen/Qwen3-4B-Thinking-2507', 'Qwen/Qwen3-4B-Thinking-2507'), Model('Qwen/Qwen3-4B-Thinking-2507-FP8', 'Qwen/Qwen3-4B-Thinking-2507-FP8'), - ]), + ], + template=TemplateType.qwen3_thinking), ], - TemplateType.qwen3_thinking, + None, get_model_tokenizer_with_flash_attn, architectures=['Qwen3ForCausalLM'], requires=['transformers>=4.51'], - )) + model_arch=ModelArch.llama)) register_model( ModelMeta( - LLMModelType.qwen3_nothinking, + LLMModelType.qwen3_moe, [ ModelGroup([ - Model('Qwen/Qwen3-30B-A3B-Instruct-2507', 'Qwen/Qwen3-30B-A3B-Instruct-2507'), - Model('Qwen/Qwen3-30B-A3B-Instruct-2507-FP8', 'Qwen/Qwen3-30B-A3B-Instruct-2507-FP8'), - Model('Qwen/Qwen3-235B-A22B-Instruct-2507', 'Qwen/Qwen3-235B-A22B-Instruct-2507'), - Model('Qwen/Qwen3-235B-A22B-Instruct-2507-FP8', 'Qwen/Qwen3-235B-A22B-Instruct-2507-FP8'), - # awq - Model('swift/Qwen3-235B-A22B-Instruct-2507-AWQ'), - ]), + Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'), + ], + template=TemplateType.qwen3_nothinking), + ModelGroup( + [ + # instruct + Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'), + Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'), + # fp8 + Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'), + Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'), + # awq + Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'), + Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'), + ], + template=TemplateType.qwen3_mixed), ModelGroup([ - Model('Qwen/Qwen3-4B-Instruct-2507', 'Qwen/Qwen3-4B-Instruct-2507'), - Model('Qwen/Qwen3-4B-Instruct-2507-FP8', 'Qwen/Qwen3-4B-Instruct-2507-FP8'), - ]) - ], - TemplateType.qwen3_nothinking, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3MoeForCausalLM', 'Qwen3ForCausalLM'], - requires=['transformers>=4.51'], - )) - -register_model( - ModelMeta( - LLMModelType.qwen3_coder, - [ + Model('iic/Tongyi-DeepResearch-30B-A3B', 'Alibaba-NLP/Tongyi-DeepResearch-30B-A3B'), + ], + template=TemplateType.qwen3_mixed), + ModelGroup( + [ + Model('Qwen/Qwen3-30B-A3B-Instruct-2507', 'Qwen/Qwen3-30B-A3B-Instruct-2507'), + Model('Qwen/Qwen3-30B-A3B-Instruct-2507-FP8', 'Qwen/Qwen3-30B-A3B-Instruct-2507-FP8'), + Model('Qwen/Qwen3-235B-A22B-Instruct-2507', 'Qwen/Qwen3-235B-A22B-Instruct-2507'), + Model('Qwen/Qwen3-235B-A22B-Instruct-2507-FP8', 'Qwen/Qwen3-235B-A22B-Instruct-2507-FP8'), + # awq + Model('swift/Qwen3-235B-A22B-Instruct-2507-AWQ'), + ], + template=TemplateType.qwen3_nothinking), ModelGroup([ Model('Qwen/Qwen3-Coder-30B-A3B-Instruct', 'Qwen/Qwen3-Coder-30B-A3B-Instruct'), Model('Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8', 'Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8'), @@ -633,9 +606,20 @@ def _get_cast_dtype(self) -> torch.dtype: Model('Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8', 'Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8'), Model('swift/Qwen3-Coder-480B-A35B-Instruct-AWQ'), ], + template=TemplateType.qwen3_coder, tags=['coding']), + ModelGroup( + [ + Model('Qwen/Qwen3-30B-A3B-Thinking-2507', 'Qwen/Qwen3-30B-A3B-Thinking-2507'), + Model('Qwen/Qwen3-30B-A3B-Thinking-2507-FP8', 'Qwen/Qwen3-30B-A3B-Thinking-2507-FP8'), + Model('Qwen/Qwen3-235B-A22B-Thinking-2507', 'Qwen/Qwen3-235B-A22B-Thinking-2507'), + Model('Qwen/Qwen3-235B-A22B-Thinking-2507-FP8', 'Qwen/Qwen3-235B-A22B-Thinking-2507-FP8'), + # awq + Model('swift/Qwen3-235B-A22B-Thinking-2507-AWQ'), + ], + template=TemplateType.qwen3_thinking), ], - TemplateType.qwen3_coder, + None, get_model_tokenizer_with_flash_attn, architectures=['Qwen3MoeForCausalLM'], requires=['transformers>=4.51'], @@ -643,44 +627,20 @@ def _get_cast_dtype(self) -> torch.dtype: register_model( ModelMeta( - LLMModelType.qwen3_moe_thinking, + LLMModelType.qwen3_next, [ ModelGroup([ - Model('Qwen/Qwen3-30B-A3B-Thinking-2507', 'Qwen/Qwen3-30B-A3B-Thinking-2507'), - Model('Qwen/Qwen3-30B-A3B-Thinking-2507-FP8', 'Qwen/Qwen3-30B-A3B-Thinking-2507-FP8'), - Model('Qwen/Qwen3-235B-A22B-Thinking-2507', 'Qwen/Qwen3-235B-A22B-Thinking-2507'), - Model('Qwen/Qwen3-235B-A22B-Thinking-2507-FP8', 'Qwen/Qwen3-235B-A22B-Thinking-2507-FP8'), - # awq - Model('swift/Qwen3-235B-A22B-Thinking-2507-AWQ'), - ]), + Model('Qwen/Qwen3-Next-80B-A3B-Instruct'), + Model('Qwen/Qwen3-Next-80B-A3B-Instruct-FP8'), + ], + template=TemplateType.qwen3_nothinking), + ModelGroup([ + Model('Qwen/Qwen3-Next-80B-A3B-Thinking'), + Model('Qwen/Qwen3-Next-80B-A3B-Thinking-FP8'), + ], + template=TemplateType.qwen3_thinking) ], - TemplateType.qwen3_thinking, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3MoeForCausalLM'], - requires=['transformers>=4.51'], - )) - -register_model( - ModelMeta( - LLMModelType.qwen3_next, - [ModelGroup([ - Model('Qwen/Qwen3-Next-80B-A3B-Instruct'), - Model('Qwen/Qwen3-Next-80B-A3B-Instruct-FP8'), - ])], - TemplateType.qwen3_nothinking, - get_model_tokenizer_with_flash_attn, - architectures=['Qwen3NextForCausalLM'], - requires=['transformers>=4.57'], - )) - -register_model( - ModelMeta( - LLMModelType.qwen3_next_thinking, - [ModelGroup([ - Model('Qwen/Qwen3-Next-80B-A3B-Thinking'), - Model('Qwen/Qwen3-Next-80B-A3B-Thinking-FP8'), - ])], - TemplateType.qwen3_thinking, + None, get_model_tokenizer_with_flash_attn, architectures=['Qwen3NextForCausalLM'], requires=['transformers>=4.57'], @@ -727,7 +687,7 @@ def patch_qwen_vl_utils(vision_process): if _read_video_decord is not None: def _new_read_video_decord(ele: dict): - from swift.llm import load_file + from swift.template import load_file ele['video'] = load_file(ele['video']) return _read_video_decord(ele) diff --git a/swift/llm/model/model/seed.py b/swift/model/models/seed.py similarity index 95% rename from swift/llm/model/model/seed.py rename to swift/model/models/seed.py index f139d7240f..4015281489 100644 --- a/swift/llm/model/model/seed.py +++ b/swift/model/models/seed.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model diff --git a/swift/llm/model/model/skywork.py b/swift/model/models/skywork.py similarity index 98% rename from swift/llm/model/model/skywork.py rename to swift/model/models/skywork.py index a53b4ffa0f..0739cdce14 100644 --- a/swift/llm/model/model/skywork.py +++ b/swift/model/models/skywork.py @@ -2,7 +2,7 @@ from typing import Any, Dict -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType, RMModelType from ..model_arch import ModelArch from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model diff --git a/swift/llm/model/model/stepfun.py b/swift/model/models/stepfun.py similarity index 99% rename from swift/llm/model/model/stepfun.py rename to swift/model/models/stepfun.py index f7ccd65f3e..369607fc23 100644 --- a/swift/llm/model/model/stepfun.py +++ b/swift/model/models/stepfun.py @@ -6,7 +6,7 @@ from transformers import AutoModel -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import MLLMModelType from ..model_arch import ModelArch from ..patcher import patch_output_clone diff --git a/swift/llm/model/model/telechat.py b/swift/model/models/telechat.py similarity index 98% rename from swift/llm/model/model/telechat.py rename to swift/model/models/telechat.py index 2a7bfd5b1e..c7fd43bf47 100644 --- a/swift/llm/model/model/telechat.py +++ b/swift/model/models/telechat.py @@ -2,7 +2,7 @@ from transformers import GenerationConfig -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import LLMModelType from ..model_arch import ModelArch from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model diff --git a/swift/llm/model/model/valley.py b/swift/model/models/valley.py similarity index 98% rename from swift/llm/model/model/valley.py rename to swift/model/models/valley.py index 5a15156e59..d0e890e687 100644 --- a/swift/llm/model/model/valley.py +++ b/swift/model/models/valley.py @@ -4,7 +4,7 @@ from functools import partial, wraps from typing import Any, Dict -from swift.llm import TemplateType +from swift.template import TemplateType from ..constant import MLLMModelType from ..model_arch import ModelArch from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model diff --git a/swift/llm/model/model/yi.py b/swift/model/models/yi.py similarity index 99% rename from swift/llm/model/model/yi.py rename to swift/model/models/yi.py index f921c07609..278cceece7 100644 --- a/swift/llm/model/model/yi.py +++ b/swift/model/models/yi.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer -from swift.llm import TemplateType +from swift.template import TemplateType from swift.utils import get_logger from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch diff --git a/swift/llm/model/npu_patcher.py b/swift/model/npu_patcher.py similarity index 100% rename from swift/llm/model/npu_patcher.py rename to swift/model/npu_patcher.py diff --git a/swift/llm/model/patcher.py b/swift/model/patcher.py similarity index 99% rename from swift/llm/model/patcher.py rename to swift/model/patcher.py index 3c32b5a657..fa190d18af 100644 --- a/swift/llm/model/patcher.py +++ b/swift/model/patcher.py @@ -18,8 +18,8 @@ from transformers import PreTrainedModel, dynamic_module_utils, trainer from transformers.modeling_outputs import SequenceClassifierOutputWithPast -from swift.llm import deep_getattr, to_device, to_float_dtype -from swift.utils import get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context +from swift.utils import (deep_getattr, get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context, to_device, + to_float_dtype) from swift.utils.torch_utils import (_get_max_memory, _sync_max_memory, get_cu_seqlens_from_position_ids, get_device_count, get_position_ids_from_cu_seqlens) from .utils import HfConfigFactory diff --git a/swift/llm/model/register.py b/swift/model/register.py similarity index 99% rename from swift/llm/model/register.py rename to swift/model/register.py index 85b4abcb3f..073bd68b5e 100644 --- a/swift/llm/model/register.py +++ b/swift/model/register.py @@ -47,6 +47,7 @@ class ModelGroup: models: List[Model] # Higher priority. If set to None, the attributes of the ModelMeta will be used. + template: Optional[str] = None ignore_patterns: Optional[List[str]] = None requires: Optional[List[str]] = None tags: List[str] = field(default_factory=list) @@ -374,7 +375,7 @@ def get_model_tokenizer_from_local(model_dir: str, model._auto_class = automodel_class.__name__ if model_info.task_type == 'embedding' and automodel_class.__name__ != 'AutoModel': - from swift.llm.model.patcher import patch_output_normalizer + from swift.model.patcher import patch_output_normalizer patch_output_normalizer(model, model_meta=model_meta) init_strategy = kwargs.get('init_strategy') @@ -588,7 +589,7 @@ def get_matched_model_types(architectures: Optional[List[str]]) -> List[str]: def _read_args_json_model_type(model_dir): if not os.path.exists(os.path.join(model_dir, 'args.json')): return - from swift.llm import BaseArguments + from swift.pipelines import BaseArguments args = BaseArguments.from_pretrained(model_dir) return args.model_type diff --git a/swift/llm/model/utils.py b/swift/model/utils.py similarity index 89% rename from swift/llm/model/utils.py rename to swift/model/utils.py index 0fc7a6ee57..5e8380a886 100644 --- a/swift/llm/model/utils.py +++ b/swift/model/utils.py @@ -3,17 +3,19 @@ from dataclasses import dataclass from functools import wraps from types import MethodType -from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union import torch from accelerate.utils import find_device from modelscope.hub.utils.utils import get_cache_dir from torch import nn -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedModel from swift.hub import get_hub -from swift.llm import to_device -from swift.utils import deep_getattr, get_logger, is_local_master, safe_ddp_context, subprocess_run +from swift.utils import deep_getattr, get_logger, is_local_master, safe_ddp_context, subprocess_run, to_device + +if TYPE_CHECKING: + from swift.template import Processor logger = get_logger() @@ -539,3 +541,55 @@ def init_parameters(model: nn.Module, init_strategy: str) -> None: if InitModelStrategy.is_uninitialized(param): logger.info(f'Initializing parameters: {name}.') init_func(param) + + +def save_checkpoint(model: Optional[PreTrainedModel], + processor: 'Processor', + output_dir: str, + *, + safe_serialization: bool = True, + max_shard_size: Union[int, str] = '5GB', + model_dirs: List[str] = None, + additional_saved_files: Optional[List[str]] = None) -> None: + if model is not None: + if model.__class__.__name__ != 'SentenceTransformer': + model.save_pretrained(output_dir, safe_serialization=safe_serialization, max_shard_size=max_shard_size) + else: + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + # copy sentencetransformers files + from swift.utils import copy_files_by_pattern + copy_files_by_pattern(model.model_dir, output_dir, '*.py') + copy_files_by_pattern(model.model_dir, output_dir, '*.json') + processor.save_pretrained(output_dir) + + if model_dirs is None: + model_dirs = [] + else: + model_dirs = model_dirs.copy() + if model and model.model_dir and model.model_dir not in model_dirs: + model_dirs.append(model.model_dir) + for src_file in (additional_saved_files or []) + ['preprocessor_config.json', 'args.json']: + tgt_path = os.path.join(output_dir, src_file) + if os.path.exists(tgt_path) and src_file == 'args.json': + continue + for model_dir in model_dirs: + src_path: str = os.path.join(model_dir, src_file) + if os.path.isfile(src_path): + shutil.copy(src_path, tgt_path) + break + elif os.path.isdir(src_path): + shutil.copytree(src_path, tgt_path) + break + + +def get_ckpt_dir(model_dir: str, adapters_dir: Optional[List[str]]) -> str: + model_dirs = (adapters_dir or []).copy() + if model_dir: + model_dirs.append(model_dir) + # The adapter takes higher priority. + ckpt_dir = None + for model_dir in model_dirs: + if os.path.exists(os.path.join(model_dir, 'args.json')): + ckpt_dir = model_dir + break + return ckpt_dir diff --git a/swift/llm/__init__.py b/swift/pipelines/__init__.py similarity index 79% rename from swift/llm/__init__.py rename to swift/pipelines/__init__.py index d4a41b9db8..0788f9c631 100644 --- a/swift/llm/__init__.py +++ b/swift/pipelines/__init__.py @@ -17,8 +17,8 @@ RolloutArguments, GymRolloutArguments, RLHFArguments, WebUIArguments, BaseArguments, AppArguments, SamplingArguments) from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template, - TemplateInputs, TemplateMeta, get_template_meta, InferRequest, load_image, MaxLengthError, - load_file, draw_bbox, RolloutInferRequest) + TemplateInputs, TemplateMeta, get_template_meta, load_image, MaxLengthError, load_file, + draw_bbox) from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys, ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup, @@ -28,10 +28,6 @@ DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor, LazyLLMDataset, load_dataset, DATASET_TYPE, sample_dataset, RowPreprocessor, DatasetMeta, HfDataset, SubsetDataset) - from .utils import (deep_getattr, to_float_dtype, to_device, History, Messages, history_to_messages, - messages_to_history, Processor, save_checkpoint, ProcessorMixin, - get_temporary_cache_files_directory, get_cache_dir, dynamic_gradient_checkpointing, - get_packed_seq_params) from .base import SwiftPipeline from .data_loader import DataLoaderDispatcher, DataLoaderShard, BatchSamplerShard else: @@ -52,9 +48,19 @@ 'RolloutArguments', 'RLHFArguments', 'BaseArguments', 'AppArguments', 'SamplingArguments' ], 'template': [ - 'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template', - 'TemplateInputs', 'TemplateMeta', 'get_template_meta', 'InferRequest', 'load_image', 'MaxLengthError', - 'load_file', 'draw_bbox', 'RolloutInferRequest' + 'TEMPLATE_MAPPING', + 'Template', + 'Word', + 'get_template', + 'TemplateType', + 'register_template', + 'TemplateInputs', + 'TemplateMeta', + 'get_template_meta', + 'load_image', + 'MaxLengthError', + 'load_file', + 'draw_bbox', ], 'model': [ 'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory', @@ -69,12 +75,6 @@ 'DATASET_TYPE', 'sample_dataset', 'RowPreprocessor', 'ResponsePreprocessor', 'DatasetMeta', 'HfDataset', 'SubsetDataset' ], - 'utils': [ - 'deep_getattr', 'to_device', 'to_float_dtype', 'History', 'Messages', 'history_to_messages', - 'messages_to_history', 'Processor', 'save_checkpoint', 'ProcessorMixin', - 'get_temporary_cache_files_directory', 'get_cache_dir', 'dynamic_gradient_checkpointing', - 'get_packed_seq_params' - ], 'base': ['SwiftPipeline'], 'data_loader': ['DataLoaderDispatcher', 'DataLoaderShard', 'BatchSamplerShard'], } diff --git a/swift/llm/app/__init__.py b/swift/pipelines/app/__init__.py similarity index 100% rename from swift/llm/app/__init__.py rename to swift/pipelines/app/__init__.py diff --git a/swift/llm/app/app.py b/swift/pipelines/app/app.py similarity index 100% rename from swift/llm/app/app.py rename to swift/pipelines/app/app.py diff --git a/swift/llm/app/build_ui.py b/swift/pipelines/app/build_ui.py similarity index 98% rename from swift/llm/app/build_ui.py rename to swift/pipelines/app/build_ui.py index ab1a5a9003..172a2bf636 100644 --- a/swift/llm/app/build_ui.py +++ b/swift/pipelines/app/build_ui.py @@ -4,6 +4,7 @@ import gradio as gr +from swift.infer_request import InferRequest from swift.utils import get_file_mm_type from ..utils import History from .locale import locale_mapping @@ -53,7 +54,6 @@ def _parse_text(text: str) -> str: async def model_chat(history: History, real_history: History, system: Optional[str], *, client, model: str, request_config: Optional['RequestConfig']): if history: - from swift.llm import InferRequest messages = _history_to_messages(real_history, system) resp_or_gen = await client.infer_async( @@ -102,7 +102,7 @@ def build_ui(base_url: str, studio_title: Optional[str] = None, lang: Literal['en', 'zh'] = 'en', default_system: Optional[str] = None): - from swift.llm import InferClient + from swift.pipelines import InferClient client = InferClient(base_url=base_url) model = model or client.models[0] studio_title = studio_title or model diff --git a/swift/llm/app/locale.py b/swift/pipelines/app/locale.py similarity index 100% rename from swift/llm/app/locale.py rename to swift/pipelines/app/locale.py diff --git a/swift/llm/base.py b/swift/pipelines/base.py similarity index 92% rename from swift/llm/base.py rename to swift/pipelines/base.py index addd19de26..f721aaddf9 100644 --- a/swift/llm/base.py +++ b/swift/pipelines/base.py @@ -5,9 +5,8 @@ from typing import List, Optional, Union import swift -from swift.utils import get_logger, parse_args, seed_everything +from swift.utils import ProcessorMixin, get_logger, parse_args, seed_everything from .argument import BaseArguments -from .utils import ProcessorMixin logger = get_logger() @@ -38,7 +37,7 @@ def _parse_args(self, args: Optional[Union[List[str], args_class]] = None) -> ar @staticmethod def _compat_dsw_gradio(args) -> None: - from swift.llm import WebUIArguments, AppArguments + from swift.pipelines import WebUIArguments, AppArguments if (isinstance(args, (WebUIArguments, AppArguments)) and 'JUPYTER_NAME' in os.environ and 'dsw-' in os.environ['JUPYTER_NAME'] and 'GRADIO_ROOT_PATH' not in os.environ): os.environ['GRADIO_ROOT_PATH'] = f"/{os.environ['JUPYTER_NAME']}/proxy/{args.server_port}" diff --git a/swift/llm/ds_config/zero0.json b/swift/pipelines/ds_config/zero0.json similarity index 100% rename from swift/llm/ds_config/zero0.json rename to swift/pipelines/ds_config/zero0.json diff --git a/swift/llm/ds_config/zero1.json b/swift/pipelines/ds_config/zero1.json similarity index 100% rename from swift/llm/ds_config/zero1.json rename to swift/pipelines/ds_config/zero1.json diff --git a/swift/llm/ds_config/zero2.json b/swift/pipelines/ds_config/zero2.json similarity index 100% rename from swift/llm/ds_config/zero2.json rename to swift/pipelines/ds_config/zero2.json diff --git a/swift/llm/ds_config/zero2_offload.json b/swift/pipelines/ds_config/zero2_offload.json similarity index 100% rename from swift/llm/ds_config/zero2_offload.json rename to swift/pipelines/ds_config/zero2_offload.json diff --git a/swift/llm/ds_config/zero3.json b/swift/pipelines/ds_config/zero3.json similarity index 100% rename from swift/llm/ds_config/zero3.json rename to swift/pipelines/ds_config/zero3.json diff --git a/swift/llm/ds_config/zero3_offload.json b/swift/pipelines/ds_config/zero3_offload.json similarity index 100% rename from swift/llm/ds_config/zero3_offload.json rename to swift/pipelines/ds_config/zero3_offload.json diff --git a/swift/llm/eval/__init__.py b/swift/pipelines/eval/__init__.py similarity index 100% rename from swift/llm/eval/__init__.py rename to swift/pipelines/eval/__init__.py diff --git a/swift/llm/eval/eval.py b/swift/pipelines/eval/eval.py similarity index 100% rename from swift/llm/eval/eval.py rename to swift/pipelines/eval/eval.py diff --git a/swift/llm/eval/utils.py b/swift/pipelines/eval/utils.py similarity index 99% rename from swift/llm/eval/utils.py rename to swift/pipelines/eval/utils.py index 7e5000b397..02fddf7d71 100644 --- a/swift/llm/eval/utils.py +++ b/swift/pipelines/eval/utils.py @@ -18,8 +18,7 @@ from evalscope.api.tool import ToolChoice, ToolInfo from evalscope.models.utils.openai import chat_choices_from_openai -from ..infer import PtEngine, RequestConfig -from ..template import InferRequest +from swift.infer_engine import InferRequest, PtEngine, RequestConfig @dataclass diff --git a/swift/llm/export/__init__.py b/swift/pipelines/export/__init__.py similarity index 100% rename from swift/llm/export/__init__.py rename to swift/pipelines/export/__init__.py diff --git a/swift/llm/export/cached_dataset.py b/swift/pipelines/export/cached_dataset.py similarity index 94% rename from swift/llm/export/cached_dataset.py rename to swift/pipelines/export/cached_dataset.py index cb29f1a80c..cd43dbff82 100644 --- a/swift/llm/export/cached_dataset.py +++ b/swift/pipelines/export/cached_dataset.py @@ -4,8 +4,8 @@ import torch -from swift.llm import TEMPLATE_MAPPING, ExportArguments -from swift.llm.train import SwiftSft +from swift.pipelines import ExportArguments, SwiftSft +from swift.template import TEMPLATE_MAPPING from swift.utils import get_logger logger = get_logger() diff --git a/swift/llm/export/export.py b/swift/pipelines/export/export.py similarity index 96% rename from swift/llm/export/export.py rename to swift/pipelines/export/export.py index 23f9202735..d9e89ac0c5 100644 --- a/swift/llm/export/export.py +++ b/swift/pipelines/export/export.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union -from swift.llm import ExportArguments, SwiftPipeline +from swift.pipelines import ExportArguments, SwiftPipeline from swift.tuners import swift_to_peft_format from swift.utils import get_logger from .cached_dataset import export_cached_dataset diff --git a/swift/llm/export/merge_lora.py b/swift/pipelines/export/merge_lora.py similarity index 94% rename from swift/llm/export/merge_lora.py rename to swift/pipelines/export/merge_lora.py index 2295741cfa..4720c7b58a 100644 --- a/swift/llm/export/merge_lora.py +++ b/swift/pipelines/export/merge_lora.py @@ -1,7 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from swift.llm import ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint +from swift.model import prepare_model_template, save_checkpoint +from swift.pipelines import ExportArguments, HfConfigFactory from swift.tuners import Swift from swift.utils import get_logger diff --git a/swift/llm/export/ollama.py b/swift/pipelines/export/ollama.py similarity index 96% rename from swift/llm/export/ollama.py rename to swift/pipelines/export/ollama.py index c706de25b1..cd68811a9f 100644 --- a/swift/llm/export/ollama.py +++ b/swift/pipelines/export/ollama.py @@ -2,7 +2,7 @@ import os from typing import List -from swift.llm import ExportArguments, PtEngine, RequestConfig, Template, prepare_model_template +from swift.pipelines import ExportArguments, PtEngine, RequestConfig, Template, prepare_model_template from swift.utils import get_logger logger = get_logger() diff --git a/swift/llm/export/quant.py b/swift/pipelines/export/quant.py similarity index 97% rename from swift/llm/export/quant.py rename to swift/pipelines/export/quant.py index ec4e9c51ad..759b4bd33f 100644 --- a/swift/llm/export/quant.py +++ b/swift/pipelines/export/quant.py @@ -9,9 +9,10 @@ from packaging import version from tqdm import tqdm -from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, load_dataset, - prepare_model_template, save_checkpoint, to_device) -from swift.utils import get_logger, get_model_parameter_info +from swift.dataset import load_dataset +from swift.model import HfConfigFactory, prepare_model_template, save_checkpoint +from swift.template import MaxLengthError +from swift.utils import ProcessorMixin, deep_getattr, get_logger, get_model_parameter_info, to_device logger = get_logger() diff --git a/swift/llm/infer/__init__.py b/swift/pipelines/infer/__init__.py similarity index 100% rename from swift/llm/infer/__init__.py rename to swift/pipelines/infer/__init__.py diff --git a/swift/llm/infer/deploy.py b/swift/pipelines/infer/deploy.py similarity index 97% rename from swift/llm/infer/deploy.py rename to swift/pipelines/infer/deploy.py index dba7f7efda..119adb066b 100644 --- a/swift/llm/infer/deploy.py +++ b/swift/pipelines/infer/deploy.py @@ -15,13 +15,12 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse -from swift.llm import AdapterRequest, DeployArguments, InferArguments -from swift.llm.infer.protocol import EmbeddingRequest, MultiModalRequestMixin +from swift.infer_engine import (ChatCompletionRequest, CompletionRequest, EmbeddingRequest, InferClient, Model, + ModelList, MultiModalRequestMixin) +from swift.pipelines import AdapterRequest, DeployArguments, InferArguments from swift.plugin import InferStats from swift.utils import JsonlWriter, get_logger from .infer import SwiftInfer -from .infer_engine import InferClient -from .protocol import ChatCompletionRequest, CompletionRequest, Model, ModelList logger = get_logger() diff --git a/swift/llm/infer/infer.py b/swift/pipelines/infer/infer.py similarity index 97% rename from swift/llm/infer/infer.py rename to swift/pipelines/infer/infer.py index 6747038384..507572426d 100644 --- a/swift/llm/infer/infer.py +++ b/swift/pipelines/infer/infer.py @@ -5,12 +5,12 @@ from datasets import Dataset as HfDataset from tqdm import tqdm -from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset +from swift.dataset import DatasetLoader, load_dataset, sample_dataset +from swift.infer_engine import AdapterRequest, InferRequest, PtEngine, RequestConfig +from swift.model import prepare_model_template +from swift.pipelines import InferArguments, SwiftPipeline from swift.plugin import InferStats, MeanMetric, compute_rouge_bleu from swift.utils import JsonlWriter, get_dist_setting, get_logger, is_dist, is_master, read_from_jsonl -from ..dataset.loader import DatasetLoader -from .infer_engine import AdapterRequest, PtEngine -from .protocol import RequestConfig from .utils import InferCliState, get_cached_dataset logger = get_logger() @@ -21,7 +21,7 @@ class SwiftInfer(SwiftPipeline): args: args_class def __init__(self, args: Optional[Union[List[str], InferArguments]] = None) -> None: - from swift.llm import merge_lora + from swift.pipelines import merge_lora super().__init__(args) args = self.args if args.merge_lora: diff --git a/swift/llm/infer/rollout.py b/swift/pipelines/infer/rollout.py similarity index 98% rename from swift/llm/infer/rollout.py rename to swift/pipelines/infer/rollout.py index 9084d2609f..e6310d8f3d 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/pipelines/infer/rollout.py @@ -22,13 +22,13 @@ from fastapi import FastAPI from trl.scripts.vllm_serve import WeightSyncWorkerExtension as HFWeightSyncWorkerExtension -from swift.llm import RolloutArguments, SwiftPipeline -from swift.llm.template.template_inputs import RolloutInferRequest +from swift.infer_engine import RolloutInferRequest +from swift.pipelines import RolloutArguments, SwiftPipeline from swift.plugin.multi_turn import RolloutScheduler, multi_turns -from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest, - UpdateAdapterRequest, UpdateFlattenedAdapterRequest, - UpdateFlattenedParamsRequest, check_vllm_version_ge, - patch_vllm_load_adapter) +from swift.trainers.rlhf_trainers.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest, + UpdateAdapterRequest, UpdateFlattenedAdapterRequest, + UpdateFlattenedParamsRequest, check_vllm_version_ge, + patch_vllm_load_adapter) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest diff --git a/swift/llm/infer/utils.py b/swift/pipelines/infer/utils.py similarity index 97% rename from swift/llm/infer/utils.py rename to swift/pipelines/infer/utils.py index 384a0e2b3b..3051880436 100644 --- a/swift/llm/infer/utils.py +++ b/swift/pipelines/infer/utils.py @@ -7,11 +7,10 @@ from datasets import load_from_disk -from swift.llm.utils import update_generation_config_eos_token -from swift.plugin import extra_tuners +from swift.plugins import extra_tuners +from swift.template import Messages, update_generation_config_eos_token from swift.tuners import Swift from swift.utils import get_logger -from ..utils import Messages logger = get_logger() diff --git a/swift/llm/sampling/__init__.py b/swift/pipelines/sampling/__init__.py similarity index 100% rename from swift/llm/sampling/__init__.py rename to swift/pipelines/sampling/__init__.py diff --git a/swift/pipelines/sampling/samplers/__init__.py b/swift/pipelines/sampling/samplers/__init__.py new file mode 100644 index 0000000000..c6b545a668 --- /dev/null +++ b/swift/pipelines/sampling/samplers/__init__.py @@ -0,0 +1,3 @@ +from .base import Sampler +from .distill_sampler import DistillSampler +from .vanilla_sampler import VanillaSampler diff --git a/swift/llm/sampling/base.py b/swift/pipelines/sampling/samplers/base.py similarity index 91% rename from swift/llm/sampling/base.py rename to swift/pipelines/sampling/samplers/base.py index 980176a8ed..ff12e462f1 100644 --- a/swift/llm/sampling/base.py +++ b/swift/pipelines/sampling/samplers/base.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List -from swift.llm import SamplingArguments +from swift.arguments import SamplingArguments +from swift.infer_engine import PtEngine from swift.plugin import orms, prms -from swift.ray.base import RayHelper +from swift.ray import RayHelper from swift.utils import get_logger logger = get_logger() @@ -33,7 +34,6 @@ def _prepare_prm(self): elif self.args.prm_model in prms: self.prm_model = prms[self.args.prm_model]() else: - from swift.llm import PtEngine self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64) @RayHelper.function(group='orm') @@ -44,7 +44,6 @@ def _prepare_orm(self): elif self.args.orm_model in orms: self.orm_model = orms[self.args.orm_model]() else: - from swift.llm import PtEngine self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64) def _prepare_template(self) -> None: diff --git a/swift/llm/sampling/distill_sampler.py b/swift/pipelines/sampling/samplers/distill_sampler.py similarity index 97% rename from swift/llm/sampling/distill_sampler.py rename to swift/pipelines/sampling/samplers/distill_sampler.py index 51f7edd83b..5910db889e 100644 --- a/swift/llm/sampling/distill_sampler.py +++ b/swift/pipelines/sampling/samplers/distill_sampler.py @@ -4,10 +4,10 @@ from openai import OpenAI -from swift.llm.infer.protocol import InferRequest, RequestConfig -from swift.llm.sampling.vanilla_sampler import VanillaSampler +from swift.infer_engine import InferRequest, RequestConfig from swift.ray import RayHelper from .utils import get_messages_md5 +from .vanilla_sampler import VanillaSampler class OpenAIEngine: diff --git a/swift/llm/sampling/vanilla_sampler.py b/swift/pipelines/sampling/samplers/vanilla_sampler.py similarity index 97% rename from swift/llm/sampling/vanilla_sampler.py rename to swift/pipelines/sampling/samplers/vanilla_sampler.py index 216a524710..87e60411f2 100644 --- a/swift/llm/sampling/vanilla_sampler.py +++ b/swift/pipelines/sampling/samplers/vanilla_sampler.py @@ -5,10 +5,10 @@ import json import numpy as np -from swift.llm import RequestConfig -from swift.llm.sampling.base import Sampler +from swift.infer_engine import RequestConfig from swift.ray.base import RayHelper from swift.utils import get_logger +from .base import Sampler from .utils import get_messages_md5, get_reward logger = get_logger() @@ -25,13 +25,13 @@ def __init__(self, *args, **kwargs): @RayHelper.function(group='sampler') def _prepare_sampler(self): if self.args.sampler_engine == 'pt': - from swift.llm import PtEngine + from swift.infer_engine import PtEngine _Engine = PtEngine elif self.args.sampler_engine == 'vllm': - from swift.llm import VllmEngine + from swift.infer_engine import VllmEngine _Engine = VllmEngine elif self.args.sampler_engine == 'lmdeploy': - from swift.llm import LmdeployEngine + from swift.infer_engine import LmdeployEngine _Engine = LmdeployEngine elif self.args.sampler_engine == 'no': _Engine = None diff --git a/swift/llm/sampling/sampling.py b/swift/pipelines/sampling/sampling.py similarity index 94% rename from swift/llm/sampling/sampling.py rename to swift/pipelines/sampling/sampling.py index 61044f5422..5a58550e7e 100644 --- a/swift/llm/sampling/sampling.py +++ b/swift/pipelines/sampling/sampling.py @@ -6,8 +6,10 @@ import json -from swift.llm import SamplingArguments, SwiftPipeline, load_dataset +from swift.dataset import load_dataset +from swift.pipelines import SamplingArguments, SwiftPipeline from swift.utils import get_logger +from .samplers import DistillSampler, VanillaSampler logger = get_logger() @@ -27,10 +29,8 @@ def __init__(self, args: Optional[Union[List[str], SamplingArguments]] = None) - self.cur_piece, self.total_piece = self.args.data_range if self.args.sampler_type == 'sample': - from swift.llm.sampling.vanilla_sampler import VanillaSampler self.sampler = VanillaSampler(self.args) elif self.args.sampler_type == 'distill': - from swift.llm.sampling.distill_sampler import DistillSampler self.sampler = DistillSampler(self.args) else: raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}') diff --git a/swift/llm/sampling/utils.py b/swift/pipelines/sampling/utils.py similarity index 95% rename from swift/llm/sampling/utils.py rename to swift/pipelines/sampling/utils.py index 0aa422004a..20bd8312ec 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/pipelines/sampling/utils.py @@ -6,7 +6,7 @@ import json import numpy as np -from swift.llm import InferRequest, RequestConfig +from swift.infer_engine import ChatCompletionResponse, InferEngine, InferRequest, RequestConfig from swift.utils import get_logger logger = get_logger() @@ -38,7 +38,6 @@ def get_reward(model: Any, Index 0: The min-max normalized scores matched the infer_requests Index 1: The mask filtered by the threshold """ - from swift.llm import InferEngine infer_func = model.infer if isinstance(model, InferEngine) else model.__call__ parameters = inspect.signature(infer_func).parameters gt_param = {} @@ -47,7 +46,6 @@ def get_reward(model: Any, if isinstance(infer_requests[0], dict): infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests] rewards = infer_func(infer_requests, request_config=request_config, **gt_param) - from swift.llm.infer.protocol import ChatCompletionResponse if isinstance(rewards[0], ChatCompletionResponse): print('reward:', rewards[0].choices[0].message.content) if isinstance(rewards[0].choices[0].message.content, str): diff --git a/swift/llm/train/__init__.py b/swift/pipelines/train/__init__.py similarity index 100% rename from swift/llm/train/__init__.py rename to swift/pipelines/train/__init__.py diff --git a/swift/llm/train/callback.py b/swift/pipelines/train/callback.py similarity index 100% rename from swift/llm/train/callback.py rename to swift/pipelines/train/callback.py diff --git a/swift/llm/train/kto.py b/swift/pipelines/train/kto.py similarity index 100% rename from swift/llm/train/kto.py rename to swift/pipelines/train/kto.py diff --git a/swift/llm/train/pt.py b/swift/pipelines/train/pt.py similarity index 100% rename from swift/llm/train/pt.py rename to swift/pipelines/train/pt.py diff --git a/swift/llm/train/rlhf.py b/swift/pipelines/train/rlhf.py similarity index 95% rename from swift/llm/train/rlhf.py rename to swift/pipelines/train/rlhf.py index 5d16ebf4c7..bc0cdb45c2 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/pipelines/train/rlhf.py @@ -3,13 +3,13 @@ from contextlib import nullcontext from typing import List, Optional, Union -from swift.llm import safe_snapshot_download -from swift.plugin import Tuner, extra_tuners +from swift.argument import BaseArguments, RLHFArguments +from swift.dataset import DatasetLoader +from swift.models import HfConfigFactory, get_model_info_meta, safe_snapshot_download +from swift.plugins import Tuner, extra_tuners from swift.tuners import Swift -from swift.utils import get_logger, get_model_parameter_info -from swift.utils.utils import disable_deepspeed_zero3 -from ..argument import BaseArguments, RLHFArguments -from ..model import HfConfigFactory +from swift.utils import disable_deepspeed_zero3, get_logger, get_model_parameter_info +from ..infer.utils import prepare_adapter from .kto import prepare_kto_dataset from .sft import SwiftSft @@ -49,7 +49,6 @@ def _get_model_task_type(model_dir): return task_type, num_labels def _prepare_single_model(self, key, origin_key, model_type, model_revision): - from swift.llm.infer.utils import prepare_adapter args = self.args origin_key = origin_key or key model_id_or_path = getattr(args, f'{key}_model') @@ -57,7 +56,6 @@ def _prepare_single_model(self, key, origin_key, model_type, model_revision): return if model_type is None: - from swift.llm.model.register import get_model_info_meta model_info, _ = get_model_info_meta(model_id_or_path) model_type = model_info.model_type @@ -192,7 +190,6 @@ def _get_dataset(self): def _prepare_chord_sft_dataset(self): from ..dataset import load_dataset - from swift.llm.dataset.loader import DatasetLoader # prepare expert sft dataset for chord args = self.args diff --git a/swift/llm/train/sft.py b/swift/pipelines/train/sft.py similarity index 100% rename from swift/llm/train/sft.py rename to swift/pipelines/train/sft.py diff --git a/swift/llm/train/tuner.py b/swift/pipelines/train/tuner.py similarity index 98% rename from swift/llm/train/tuner.py rename to swift/pipelines/train/tuner.py index 286a9f4b04..b52d39514e 100644 --- a/swift/llm/train/tuner.py +++ b/swift/pipelines/train/tuner.py @@ -9,10 +9,12 @@ from packaging import version from transformers import TrainingArguments -from swift.llm import TrainArguments, deep_getattr +from swift.arguments import TrainArguments +from swift.model import ModelType from swift.plugin import Tuner, extra_tuners from swift.tuners import Swift -from swift.utils import activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters, get_logger +from swift.utils import (activate_parameters, deep_getattr, find_all_linears, find_embedding, find_norm, + freeze_parameters, get_logger) logger = get_logger() @@ -24,7 +26,6 @@ def apply_liger(model_type: str): apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen3, apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl, apply_liger_kernel_to_phi3, apply_liger_kernel_to_mllama) - from swift.llm import ModelType if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2): apply_liger_kernel_to_llama() elif model_type in (ModelType.mistral): diff --git a/swift/pipelines/utils.py b/swift/pipelines/utils.py new file mode 100644 index 0000000000..e8d314981c --- /dev/null +++ b/swift/pipelines/utils.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +import os +import shutil +import tempfile +from types import MethodType +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union + +import torch +import torch.nn as nn +from modelscope.hub.utils.utils import get_cache_dir +from peft import PeftModel +from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase + +from swift.utils import deep_getattr, get_logger +from .utils import Processor + +logger = get_logger() + + +def set_generation_config(model: nn.Module, generation_config: GenerationConfig) -> None: + old_generation_config = getattr(model, 'generation_config', None) + old_generation_priority_config = ['no_repeat_ngram_size', 'num_beams'] + if old_generation_config is not None: + for k, old_v in dir(old_generation_config).items(): + if k.startswith('_'): + continue + v = getattr(generation_config, k, None) + if k in old_generation_priority_config or old_v is not None and v is None: + setattr(generation_config, k, old_v) + model.generation_config = generation_config + + +def find_module_list(model) -> Optional[nn.ModuleList]: + module_lists = [] + for m in model.modules(): + if hasattr(m, 'gradient_checkpointing') or m.__class__.__name__ == 'CheckpointWrapper': + return + if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10 + and 'mlp' not in m[0].__class__.__name__.lower()): # fix moe + module_lists.append(m) + if module_lists: + return max(module_lists, key=lambda x: len(x)) + + +def _kwargs_to_args(func, args, kwargs) -> Optional[List[Any]]: + parameters = inspect.signature(func).parameters + args = list(args) + parameters = list(parameters.items())[len(args):] + for key, param in parameters: + if key in kwargs: + args.append(kwargs[key]) + elif param.default != param.empty: + args.append(param.default) + else: + return + return args + + +def _add_gradient_checkpointing(module_list): + + requires_grad = None + + def _new_forward(self, *args, **kwargs): + nonlocal requires_grad + if requires_grad is None: + requires_grad = any(p.requires_grad for p in self.parameters()) + + new_args = _kwargs_to_args(self.__old_forward, args, kwargs) + if new_args is not None and self.gradient_checkpointing and self.training: + if new_args and isinstance(new_args[0], torch.Tensor) and requires_grad and not new_args[0].requires_grad: + new_args[0].requires_grad_(True) + layer_ret = self._gradient_checkpointing_func(self.__old_forward, *new_args) + logger.info_once('Successfully using dynamic gradient checkpointing.') + else: + layer_ret = self.__old_forward(*args, **kwargs) + return layer_ret + + for module in module_list: + module.gradient_checkpointing = False + if hasattr(module, '_old_forward'): # device_map + __old_forward = module._old_forward + module._old_forward = MethodType(_new_forward, module) + else: + __old_forward = module.forward + module.forward = MethodType(_new_forward, module) + module.__old_forward = __old_forward diff --git a/swift/plugin/__init__.py b/swift/plugins/__init__.py similarity index 100% rename from swift/plugin/__init__.py rename to swift/plugins/__init__.py diff --git a/swift/plugin/agent_template/__init__.py b/swift/plugins/agent_template/__init__.py similarity index 100% rename from swift/plugin/agent_template/__init__.py rename to swift/plugins/agent_template/__init__.py diff --git a/swift/plugin/agent_template/base.py b/swift/plugins/agent_template/base.py similarity index 96% rename from swift/plugin/agent_template/base.py rename to swift/plugins/agent_template/base.py index cf5f87bc36..039a0a4581 100644 --- a/swift/plugin/agent_template/base.py +++ b/swift/plugins/agent_template/base.py @@ -6,9 +6,8 @@ import json -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt +from swift.infer_engine import Function +from swift.template import Prompt, split_str_parts_by @dataclass @@ -32,8 +31,6 @@ class ReactCompatMixin: @staticmethod def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']: - from swift.llm.template import split_str_parts_by - from swift.llm.infer import Function agent_parts = split_str_parts_by(response, list(asdict(keyword).values())) functions = [] action_content = None diff --git a/swift/plugin/agent_template/deepseek_v3_1.py b/swift/plugins/agent_template/deepseek_v3_1.py similarity index 97% rename from swift/plugin/agent_template/deepseek_v3_1.py rename to swift/plugins/agent_template/deepseek_v3_1.py index 7d0fb694ae..edc78b0e18 100644 --- a/swift/plugin/agent_template/deepseek_v3_1.py +++ b/swift/plugins/agent_template/deepseek_v3_1.py @@ -4,12 +4,10 @@ import json +from swift.infer_engine import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class DeepSeekV31AgentTemplate(BaseAgentTemplate): diff --git a/swift/plugin/agent_template/extra.py b/swift/plugins/agent_template/extra.py similarity index 100% rename from swift/plugin/agent_template/extra.py rename to swift/plugins/agent_template/extra.py diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugins/agent_template/glm4.py similarity index 97% rename from swift/plugin/agent_template/glm4.py rename to swift/plugins/agent_template/glm4.py index fc3461fcb7..52dc97e17f 100644 --- a/swift/plugin/agent_template/glm4.py +++ b/swift/plugins/agent_template/glm4.py @@ -4,19 +4,16 @@ import json +from swift.infer_engine import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class GLM4AgentTemplate(BaseAgentTemplate): is_glm4_0414 = False @staticmethod def _find_function_call(single_content: str) -> Optional['Function']: - from swift.llm.infer import Function single_content = single_content.replace('<|observation|>', '') pattern = re.compile(r'([^\n`]*?)\n({.*?})(?=\w*\n|$)', re.DOTALL) matches = pattern.findall(single_content) @@ -82,7 +79,6 @@ class GLM4_5AgentTemplate(BaseAgentTemplate): @staticmethod def _find_function_call(single_content: str) -> Optional['Function']: - from swift.llm.infer import Function single_content = single_content.strip() func_name_match = re.match(r'^([^\n<]+)', single_content) if not func_name_match: diff --git a/swift/plugin/agent_template/hermes.py b/swift/plugins/agent_template/hermes.py similarity index 96% rename from swift/plugin/agent_template/hermes.py rename to swift/plugins/agent_template/hermes.py index 4527c92eca..d1ecf53d87 100644 --- a/swift/plugin/agent_template/hermes.py +++ b/swift/plugins/agent_template/hermes.py @@ -4,17 +4,14 @@ import json +from swift.infer import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class HermesAgentTemplate(BaseAgentTemplate): def get_toolcall(self, response: str) -> List['Function']: - from swift.llm.infer import Function res_list = re.findall(r'(.+?)', response, re.DOTALL) functions = [] for res in res_list: @@ -88,7 +85,6 @@ def _format_tool_calls(self, tool_call_messages): class HunyuanHermesAgentTemplate(HermesAgentTemplate): def get_toolcall(self, response: str) -> List['Function']: - from swift.llm.infer import Function res_list = re.findall(r'(.+?)\n```json(.+?)```', response, re.DOTALL) functions = [] for name, arguments in res_list: diff --git a/swift/plugin/agent_template/llama.py b/swift/plugins/agent_template/llama.py similarity index 95% rename from swift/plugin/agent_template/llama.py rename to swift/plugins/agent_template/llama.py index 3b7fff6d31..373f3aa7b9 100644 --- a/swift/plugin/agent_template/llama.py +++ b/swift/plugins/agent_template/llama.py @@ -4,12 +4,10 @@ import json +from swift.infer_engine import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class Llama3AgentTemplate(BaseAgentTemplate): eom_token = '<|eom_id|>' @@ -18,7 +16,6 @@ class Llama3AgentTemplate(BaseAgentTemplate): eot_token = '<|eot_id|>' def get_toolcall(self, response: str) -> List['Function']: - from swift.llm.infer import Function if response.endswith(self.eom_token): response = response[:-len(self.eom_token)] functions = [] diff --git a/swift/plugin/agent_template/mistral.py b/swift/plugins/agent_template/mistral.py similarity index 95% rename from swift/plugin/agent_template/mistral.py rename to swift/plugins/agent_template/mistral.py index 120a3e21a9..b4525afc88 100644 --- a/swift/plugin/agent_template/mistral.py +++ b/swift/plugins/agent_template/mistral.py @@ -3,17 +3,14 @@ import json +from swift.infer_engine import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class MistralAgentTemplate(BaseAgentTemplate): def get_toolcall(self, response: str) -> List['Function']: - from swift.llm.infer import Function res_list = re.findall(r'\[TOOL_CALLS\]\[(.*?)\]', response, re.DOTALL) if not res_list: return [] diff --git a/swift/plugin/agent_template/qwen.py b/swift/plugins/agent_template/qwen.py similarity index 100% rename from swift/plugin/agent_template/qwen.py rename to swift/plugins/agent_template/qwen.py diff --git a/swift/plugin/agent_template/qwen3_coder.py b/swift/plugins/agent_template/qwen3_coder.py similarity index 98% rename from swift/plugin/agent_template/qwen3_coder.py rename to swift/plugins/agent_template/qwen3_coder.py index 585d551648..75194294cc 100644 --- a/swift/plugin/agent_template/qwen3_coder.py +++ b/swift/plugins/agent_template/qwen3_coder.py @@ -4,11 +4,9 @@ import json +from swift.infer_engine import Function from .hermes import HermesAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - def render_extra_keys(obj, handled_keys): """Helper function to render extra keys not explicitly handled""" @@ -24,7 +22,6 @@ class Qwen3CoderAgentTemplate(HermesAgentTemplate): @staticmethod def _find_function_call(single_content: str) -> Optional['Function']: - from swift.llm.infer import Function single_content = single_content.strip() # Check whether the complete function tag is included if not single_content.startswith(''): diff --git a/swift/plugin/agent_template/react.py b/swift/plugins/agent_template/react.py similarity index 100% rename from swift/plugin/agent_template/react.py rename to swift/plugins/agent_template/react.py diff --git a/swift/plugin/agent_template/seed_oss.py b/swift/plugins/agent_template/seed_oss.py similarity index 97% rename from swift/plugin/agent_template/seed_oss.py rename to swift/plugins/agent_template/seed_oss.py index 97d6891012..26e235007f 100644 --- a/swift/plugin/agent_template/seed_oss.py +++ b/swift/plugins/agent_template/seed_oss.py @@ -3,12 +3,10 @@ import json +from swift.infer_engine import Function +from swift.template import Prompt from .base import BaseAgentTemplate -if TYPE_CHECKING: - from swift.llm.infer import Function - from swift.llm.template import Prompt - class SeedAgentTemplate(BaseAgentTemplate): TOOL_CALL_START = '' @@ -29,7 +27,6 @@ def _py_type(t: str) -> str: return SeedAgentTemplate._PY_TYPE_MAPPING.get(t, 'Any') def get_toolcall(self, response: str) -> List['Function']: - from swift.llm.infer import Function res_list = re.findall(rf'{self.TOOL_CALL_START}(.+?){self.TOOL_CALL_END}', response, re.DOTALL) if not res_list: diff --git a/swift/plugin/agent_template/toolbench.py b/swift/plugins/agent_template/toolbench.py similarity index 100% rename from swift/plugin/agent_template/toolbench.py rename to swift/plugins/agent_template/toolbench.py diff --git a/swift/plugin/callback.py b/swift/plugins/callback.py similarity index 100% rename from swift/plugin/callback.py rename to swift/plugins/callback.py diff --git a/swift/plugin/context_manager.py b/swift/plugins/context_manager.py similarity index 94% rename from swift/plugin/context_manager.py rename to swift/plugins/context_manager.py index 234a2f761c..fcf2c37fa4 100644 --- a/swift/plugin/context_manager.py +++ b/swift/plugins/context_manager.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING -if TYPE_CHECKING: - from swift.llm.utils import Messages +from swift.template import Messages class ContextManager(ABC): diff --git a/swift/plugin/env.py b/swift/plugins/env.py similarity index 97% rename from swift/plugin/env.py rename to swift/plugins/env.py index c1706dfa55..b9579c0eac 100644 --- a/swift/plugin/env.py +++ b/swift/plugins/env.py @@ -2,11 +2,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from swift.infer_engine import RolloutInferRequest from swift.plugin.orm import MathAccuracy - -if TYPE_CHECKING: - from swift.llm.template import RolloutInferRequest - from swift.llm.utils import Messages +from swift.template import Messages class Env(ABC): diff --git a/swift/plugin/loss.py b/swift/plugins/loss.py similarity index 100% rename from swift/plugin/loss.py rename to swift/plugins/loss.py diff --git a/swift/plugin/loss_scale/__init__.py b/swift/plugins/loss_scale/__init__.py similarity index 100% rename from swift/plugin/loss_scale/__init__.py rename to swift/plugins/loss_scale/__init__.py diff --git a/swift/plugin/loss_scale/config/agentflan.json b/swift/plugins/loss_scale/config/agentflan.json similarity index 100% rename from swift/plugin/loss_scale/config/agentflan.json rename to swift/plugins/loss_scale/config/agentflan.json diff --git a/swift/plugin/loss_scale/config/alpha_umi.json b/swift/plugins/loss_scale/config/alpha_umi.json similarity index 100% rename from swift/plugin/loss_scale/config/alpha_umi.json rename to swift/plugins/loss_scale/config/alpha_umi.json diff --git a/swift/plugin/loss_scale/config/hermes.json b/swift/plugins/loss_scale/config/hermes.json similarity index 100% rename from swift/plugin/loss_scale/config/hermes.json rename to swift/plugins/loss_scale/config/hermes.json diff --git a/swift/plugin/loss_scale/config/ignore_empty_think.json b/swift/plugins/loss_scale/config/ignore_empty_think.json similarity index 100% rename from swift/plugin/loss_scale/config/ignore_empty_think.json rename to swift/plugins/loss_scale/config/ignore_empty_think.json diff --git a/swift/plugin/loss_scale/config/qwen.json b/swift/plugins/loss_scale/config/qwen.json similarity index 100% rename from swift/plugin/loss_scale/config/qwen.json rename to swift/plugins/loss_scale/config/qwen.json diff --git a/swift/plugin/loss_scale/config/react.json b/swift/plugins/loss_scale/config/react.json similarity index 100% rename from swift/plugin/loss_scale/config/react.json rename to swift/plugins/loss_scale/config/react.json diff --git a/swift/plugin/loss_scale/loss_scale.py b/swift/plugins/loss_scale/loss_scale.py similarity index 98% rename from swift/plugin/loss_scale/loss_scale.py rename to swift/plugins/loss_scale/loss_scale.py index 216393a4db..96e2aa868e 100644 --- a/swift/plugin/loss_scale/loss_scale.py +++ b/swift/plugins/loss_scale/loss_scale.py @@ -4,8 +4,7 @@ import json -from swift.llm import Messages -from swift.llm.template.utils import ContextType +from swift.template import ContextType, Messages from .utils import calculate_loss_scale diff --git a/swift/plugin/loss_scale/utils.py b/swift/plugins/loss_scale/utils.py similarity index 98% rename from swift/plugin/loss_scale/utils.py rename to swift/plugins/loss_scale/utils.py index d60c592a5d..76d0d19edc 100644 --- a/swift/plugin/loss_scale/utils.py +++ b/swift/plugins/loss_scale/utils.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Tuple -from swift.llm.template import split_str_parts_by +from swift.template import split_str_parts_by def calculate_loss_scale(query: str, diff --git a/swift/plugin/metric.py b/swift/plugins/metric.py similarity index 100% rename from swift/plugin/metric.py rename to swift/plugins/metric.py diff --git a/swift/plugin/multi_turn.py b/swift/plugins/multi_turn.py similarity index 98% rename from swift/plugin/multi_turn.py rename to swift/plugins/multi_turn.py index 61bca23bf0..7b2e9b5165 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugins/multi_turn.py @@ -3,16 +3,12 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from swift.infer_engine import (ChatCompletionResponse, ChatCompletionResponseChoice, GRPOVllmEngine, RequestConfig, + RolloutInferRequest, RolloutOutput) from swift.plugin import ContextManager, Env, context_managers, envs +from swift.template import Messages, ThinkingTemplate from swift.utils import remove_response -if TYPE_CHECKING: - from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig, - RolloutOutput) - from swift.llm.template import RolloutInferRequest - from swift.llm.infer.infer_engine import GRPOVllmEngine - from swift.llm.utils import Messages - class RolloutScheduler(ABC): # Single Turn Rollout Scheduler @@ -79,7 +75,6 @@ async def async_infer(self, async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[str, Any]], request_config: 'RequestConfig', **kwargs): - from swift.llm.template import RolloutInferRequest if isinstance(infer_request, Dict): infer_request = RolloutInferRequest(**infer_request) @@ -101,7 +96,6 @@ async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[s async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'RolloutOutput': - from swift.llm.infer.protocol import RolloutOutput response: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, **kwargs) response_token_ids = response.choices[0].token_ids @@ -224,7 +218,6 @@ async def run(self, infer_request, request_config, **kwargs): # Must return RolloutOutput or List[RolloutOutput] ... """ - from swift.llm.infer.protocol import RolloutOutput current_request = infer_request current_turn = 1 rollout_infos = {} @@ -461,8 +454,6 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque Returns: List[RolloutOutput]: A list of RolloutOutput objects, one for each reasoning round. """ - from swift.llm.infer.protocol import RolloutOutput - current_request = infer_request current_turn = 1 rollout_outputs = [] @@ -530,7 +521,6 @@ def _is_thinking_template(self) -> bool: return False template = self.infer_engine.default_template - from swift.llm.template.template.utils import ThinkingTemplate return isinstance(template, ThinkingTemplate) @@ -726,7 +716,6 @@ async def _close_env_async(self, env: Env): async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'RolloutOutput': - from swift.llm.infer.protocol import RolloutOutput """ Execute the gym environment-based rollout: 1. Initialize environment and context manager diff --git a/swift/plugin/optimizer.py b/swift/plugins/optimizer.py similarity index 98% rename from swift/plugin/optimizer.py rename to swift/plugins/optimizer.py index f2429a2dc3..7cb53c306a 100644 --- a/swift/plugin/optimizer.py +++ b/swift/plugins/optimizer.py @@ -9,7 +9,7 @@ from transformers import Trainer from swift.trainers.optimizers.galore import create_optimizer_and_scheduler -from swift.utils import get_dist_setting, get_logger +from swift.utils import get_dist_setting, get_logger, git_clone_github if TYPE_CHECKING: from swift.trainers import TrainingArguments @@ -64,7 +64,6 @@ def create_lorap_optimizer(args: 'TrainingArguments', model, dataset): def create_muon_optimizer(args: 'TrainingArguments', model, dataset): - from swift.llm import git_clone_github if not args.local_repo_path: args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git') sys.path.append(os.path.join(args.local_repo_path, 'examples')) diff --git a/swift/plugin/orm.py b/swift/plugins/orm.py similarity index 99% rename from swift/plugin/orm.py rename to swift/plugins/orm.py index 778e698cba..92e9d4ab12 100644 --- a/swift/plugin/orm.py +++ b/swift/plugins/orm.py @@ -4,8 +4,7 @@ import json -if TYPE_CHECKING: - from swift.llm import InferRequest +from swift.infer_engine import InferRequest class ORM: diff --git a/swift/plugin/prm.py b/swift/plugins/prm.py similarity index 98% rename from swift/plugin/prm.py rename to swift/plugins/prm.py index 0976dfb692..c97fa143b9 100644 --- a/swift/plugin/prm.py +++ b/swift/plugins/prm.py @@ -3,8 +3,7 @@ import json -if TYPE_CHECKING: - from swift.llm import InferRequest +from swift.infer_engine import InferClient, InferRequest class PRM: @@ -94,7 +93,6 @@ def __call__(self, infer_requests: List[Union['InferRequest', Dict]], ground_tru class ClientPRM(PRM): def __init__(self, api_key=None, base_url=None, model=None): - from swift.llm import InferClient import os if api_key is None: api_key = os.getenv('DASHSCOPE_API_KEY') diff --git a/swift/plugin/rm_plugin.py b/swift/plugins/rm_plugin.py similarity index 97% rename from swift/plugin/rm_plugin.py rename to swift/plugins/rm_plugin.py index 32a76557d6..4fa9f7123a 100644 --- a/swift/plugin/rm_plugin.py +++ b/swift/plugins/rm_plugin.py @@ -5,11 +5,9 @@ import torch -from swift.llm import PtEngine, RequestConfig, Template, to_device -from swift.utils import get_logger - -if TYPE_CHECKING: - from swift.llm.infer.protocol import ChatCompletionResponse +from swift.infer_engine import ChatCompletionResponse, PtEngine, RequestConfig +from swift.template import Template +from swift.utils import get_logger, to_device logger = get_logger() diff --git a/swift/plugin/tuner.py b/swift/plugins/tuner.py similarity index 96% rename from swift/plugin/tuner.py rename to swift/plugins/tuner.py index f162929e46..847d73953c 100644 --- a/swift/plugin/tuner.py +++ b/swift/plugins/tuner.py @@ -4,12 +4,10 @@ import torch from peft import IA3Config, PeftModel, get_peft_model -from swift.llm import ModelKeys +from swift.arguments import TrainArguments +from swift.model import ModelKeys from swift.utils import find_all_linears -if TYPE_CHECKING: - from swift.llm import TrainArguments - class Tuner: diff --git a/swift/ray/base.py b/swift/ray/base.py index 44bfad191a..3a7fd443c9 100644 --- a/swift/ray/base.py +++ b/swift/ray/base.py @@ -8,10 +8,9 @@ import json import numpy as np -from swift.llm.argument.base_args.ray_args import RayArguments -from swift.ray.resource_manager import ResourceManager from swift.utils import find_free_port from swift.utils.utils import find_node_ip +from .resource_manager import ResourceManager T = TypeVar('T') @@ -27,8 +26,6 @@ class RayHelper: worker_cls: Dict = {} - args: RayArguments = None - worker_instance: Dict = {} initialized = False diff --git a/swift/llm/template/__init__.py b/swift/template/__init__.py similarity index 57% rename from swift/llm/template/__init__.py rename to swift/template/__init__.py index 8e5ac5a972..4555fa5a54 100644 --- a/swift/llm/template/__init__.py +++ b/swift/template/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import template +from . import templates from .base import MaxLengthError, Template from .constant import TemplateType from .grounding import draw_bbox from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template -from .template_inputs import InferRequest, RolloutInferRequest, TemplateInputs +from .template_inputs import TemplateInputs from .template_meta import TemplateMeta -from .utils import Prompt, Word, split_str_parts_by +from .utils import (History, Message, Messages, Processor, Prompt, Tool, Word, get_packed_seq_params, + history_to_messages, messages_to_history, split_str_parts_by) from .vision_utils import load_file, load_image diff --git a/swift/llm/template/base.py b/swift/template/base.py similarity index 99% rename from swift/llm/template/base.py rename to swift/template/base.py index 97c3fa7641..c1ce8f7c2e 100644 --- a/swift/llm/template/base.py +++ b/swift/template/base.py @@ -23,16 +23,15 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import strtobool -from swift.llm import to_device -from swift.utils import get_env_args, get_logger -from ..utils import Processor, ProcessorMixin -from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs -from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by +from swift.utils import ProcessorMixin, get_env_args, get_logger, to_device +from .template_inputs import StdTemplateInputs, TemplateInputs +from .utils import Context, ContextType, Processor, StopWordsCriteria, fetch_one, findall, split_str_parts_by from .vision_utils import load_audio, load_batch, load_image, rescale_image logger = get_logger() if TYPE_CHECKING: from .template_meta import TemplateMeta + from swift.infer_engine import InferRequest class MaxLengthError(ValueError): @@ -173,7 +172,7 @@ def _get_model(self): if self.model is not None: return self.model if self.dummy_model is None: - from swift.llm import get_model_tokenizer + from swift.model import get_model_tokenizer with torch.device('meta'): self.dummy_model = get_model_tokenizer(self.model_info.model_dir, return_dummy_model=True)[0] return self.dummy_model @@ -482,7 +481,7 @@ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @torch.inference_mode() def encode(self, - inputs: Union[TemplateInputs, Dict[str, Any], InferRequest], + inputs: Union[TemplateInputs, Dict[str, Any], 'InferRequest'], return_template_inputs: bool = False, return_length: bool = False) -> Dict[str, Any]: """The entrance method of Template! @@ -490,6 +489,7 @@ def encode(self, Returns: return {'input_ids': List[int], 'labels': Optional[List[int]], ...} """ + from swift.infer_engine import InferRequest assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' 'template.init_processor(processor).') if isinstance(inputs, InferRequest): @@ -1436,7 +1436,7 @@ def remove_post_encode_hook(self): return models def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: - from swift.llm import RowPreprocessor + from swift.dataset import RowPreprocessor if self.packing and isinstance(batch[0], list): batch = sum(batch, start=[]) num_samples = len(batch) @@ -1483,7 +1483,7 @@ def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[D @staticmethod def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]: - from swift.llm import RowPreprocessor + from swift.dataset import RowPreprocessor keys = keys or [] rows = RowPreprocessor.rows_to_batched(batch) return {k: rows[k] for k in keys if rows.get(k) is not None} diff --git a/swift/llm/template/constant.py b/swift/template/constant.py similarity index 99% rename from swift/llm/template/constant.py rename to swift/template/constant.py index 752e7f7c35..86758d6c5b 100644 --- a/swift/llm/template/constant.py +++ b/swift/template/constant.py @@ -12,8 +12,8 @@ class LLMTemplateType: qwen2_5 = 'qwen2_5' qwen2_5_math = 'qwen2_5_math' qwen2_5_math_prm = 'qwen2_5_math_prm' - qwen3 = 'qwen3' qwen3_guard = 'qwen3_guard' + qwen3_mixed = 'qwen3_mixed' qwen3_thinking = 'qwen3_thinking' qwen3_nothinking = 'qwen3_nothinking' qwen3_coder = 'qwen3_coder' diff --git a/swift/llm/template/grounding.py b/swift/template/grounding.py similarity index 100% rename from swift/llm/template/grounding.py rename to swift/template/grounding.py diff --git a/swift/llm/template/register.py b/swift/template/register.py similarity index 95% rename from swift/llm/template/register.py rename to swift/template/register.py index 91f9778169..93d0820954 100644 --- a/swift/llm/template/register.py +++ b/swift/template/register.py @@ -2,9 +2,9 @@ from typing import Dict, Literal, Optional -from ..utils import Processor from .base import Template from .template_meta import TemplateMeta +from .utils import Processor TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {} @@ -17,11 +17,12 @@ def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> def get_template( - template_type: str, + model_id_or_path: str, processor: Processor, default_system: Optional[str] = None, max_length: Optional[int] = None, *, + template_type: Optional[str] = None, truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', max_pixels: Optional[int] = None, # h * w agent_template: Optional[str] = None, diff --git a/swift/llm/template/template_inputs.py b/swift/template/template_inputs.py similarity index 65% rename from swift/llm/template/template_inputs.py rename to swift/template/template_inputs.py index 4c4ea31f2e..5dd4a68464 100644 --- a/swift/llm/template/template_inputs.py +++ b/swift/template/template_inputs.py @@ -1,135 +1,17 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from copy import deepcopy -from dataclasses import asdict, dataclass, field, fields +from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Optional, Union import json from PIL import Image from swift.utils import get_logger -from ..utils import Messages, Tool, messages_to_history +from .utils import Messages, Tool, messages_to_history logger = get_logger() -@dataclass -class InferRequest: - """ - Data structure for inference requests. - - Attributes: - messages (Messages): - The input conversation in messages format. Each message is a dict containing at least - a "role" field (e.g., "user", "assistant", "system") and a "content" field. - Example: - [{ - "role": "user", - "content": [ - { - "type": "image", # can also be audio/video - "image": "", - }, - {"type": "text", "text": "Please describe the picture."}, - ], - }] - The above is equivalent to: - [{"role": "user", "content": "Please describe the picture."}] - with an additional argument: - images = [""] - - images (List[Union[str, Image.Image]]): - Optional, a list of images associated with the request. - Each image can be a URL, local path, base64 string, or PIL.Image object. - - audios (List[str]): - Optional, a list of audio resources associated with the request. - - videos (List[str]): - Optional, a list of video resources associated with the request. - - tools (Optional[List[Tool]]): - An optional list of tools. These should be organized in the agent_template format for - tools requested by the system, for example 'react_en'. - - objects (Dict[str, List[Any]]): - Container for additional multimodal objects, grouped by type (key). - """ - messages: Messages - - images: List[Union[str, Image.Image]] = field(default_factory=list) - audios: List[str] = field(default_factory=list) - videos: List[str] = field(default_factory=list) - - tools: Optional[List[Tool]] = None - objects: Dict[str, List[Any]] = field(default_factory=dict) - - def __post_init__(self): - for key in ['images', 'audios', 'videos']: - val = getattr(self, key) - if isinstance(val, str): - setattr(self, key, [val]) - assert isinstance(self.messages, list), f'messages: {self.messages}' - - @staticmethod - def remove_response(messages) -> Optional[str]: - last_role = messages[-1]['role'] if messages else None - if last_role == 'assistant': - return messages.pop()['content'] - - @staticmethod - def _to_printable(obj, key: Optional[str] = None): - if isinstance(obj, str) and key not in {'content', 'text'} and len(obj) >= 1000: - return f'<<>>' - elif isinstance(obj, list): - res = [] - for item in obj: - res.append(InferRequest._to_printable(item)) - return res - elif isinstance(obj, dict): - res = {} - for k, v in obj.items(): - res[k] = InferRequest._to_printable(v, key=k) - return res - return obj - - def to_printable(self): - return InferRequest._to_printable(asdict(self)) - - -@dataclass -class RolloutInferRequest(InferRequest): - """ - An inference request class for rollout scenarios. - - This class extends `InferRequest` and specifically overrides the `images` attribute - to be a list of strings for compatibility with POST requests. Each string may - represent an image URL or a Base64-encoded image. - - Inherits all fields from `InferRequest`: - messages (Messages): - Input conversation messages, supporting multimodal content. - audios (List[str]): - List of audio resources associated with the request. - videos (List[str]): - List of video resources associated with the request. - tools (Optional[List[Tool]]): - List of tools, organized by the agent template (e.g. 'react_en'). - objects (Dict[str, List[Any]]): - Optional container for additional multimodal objects. - - Additional / Overridden fields: - images (List[str]): - List of image resources, each as a string (URL or base64). - data_dict (Dict): - Optional dictionary for extra request data. - uuid (Optional[str]): - Optional unique identifier for this request instance. - """ - images: List[str] = field(default_factory=list) - data_dict: Dict = field(default_factory=dict) - uuid: Optional[str] = None - - @dataclass class StdTemplateInputs: # only user/tool/assistant diff --git a/swift/llm/template/template_meta.py b/swift/template/template_meta.py similarity index 100% rename from swift/llm/template/template_meta.py rename to swift/template/template_meta.py diff --git a/swift/llm/template/template/__init__.py b/swift/template/templates/__init__.py similarity index 100% rename from swift/llm/template/template/__init__.py rename to swift/template/templates/__init__.py diff --git a/swift/llm/template/template/baai.py b/swift/template/templates/baai.py similarity index 100% rename from swift/llm/template/template/baai.py rename to swift/template/templates/baai.py diff --git a/swift/llm/template/template/baidu.py b/swift/template/templates/baidu.py similarity index 100% rename from swift/llm/template/template/baidu.py rename to swift/template/templates/baidu.py diff --git a/swift/llm/template/template/bert.py b/swift/template/templates/bert.py similarity index 100% rename from swift/llm/template/template/bert.py rename to swift/template/templates/bert.py diff --git a/swift/llm/template/template/deepseek.py b/swift/template/templates/deepseek.py similarity index 100% rename from swift/llm/template/template/deepseek.py rename to swift/template/templates/deepseek.py diff --git a/swift/llm/template/template/dots.py b/swift/template/templates/dots.py similarity index 100% rename from swift/llm/template/template/dots.py rename to swift/template/templates/dots.py diff --git a/swift/llm/template/template/gemma.py b/swift/template/templates/gemma.py similarity index 100% rename from swift/llm/template/template/gemma.py rename to swift/template/templates/gemma.py diff --git a/swift/llm/template/template/glm.py b/swift/template/templates/glm.py similarity index 99% rename from swift/llm/template/template/glm.py rename to swift/template/templates/glm.py index 7711d10ddb..06ecdda9ae 100644 --- a/swift/llm/template/template/glm.py +++ b/swift/template/templates/glm.py @@ -4,12 +4,11 @@ import torch -from swift.llm import get_packed_seq_params from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template from ..template_inputs import StdTemplateInputs -from ..utils import Context, Prompt, Word, findall +from ..utils import Context, Prompt, Word, findall, get_packed_seq_params from ..vision_utils import load_batch, load_video_cogvlm2, load_video_hf from .utils import ThinkingTemplate diff --git a/swift/llm/template/template/idefics3.py b/swift/template/templates/idefics3.py similarity index 100% rename from swift/llm/template/template/idefics3.py rename to swift/template/templates/idefics3.py diff --git a/swift/llm/template/template/internlm.py b/swift/template/templates/internlm.py similarity index 100% rename from swift/llm/template/template/internlm.py rename to swift/template/templates/internlm.py diff --git a/swift/llm/template/template/internvl.py b/swift/template/templates/internvl.py similarity index 99% rename from swift/llm/template/template/internvl.py rename to swift/template/templates/internvl.py index 41d5c809f1..cbb766bbaa 100644 --- a/swift/llm/template/template/internvl.py +++ b/swift/template/templates/internvl.py @@ -213,7 +213,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: from transformers.image_utils import make_flat_list_of_images, concatenate_list from transformers.video_utils import make_batched_videos - from swift.llm.template.vision_utils import load_video_hf + from ..vision_utils import load_video_hf import numpy as np encoded = super(InternvlTemplate, self)._encode(inputs) input_ids = encoded['input_ids'] diff --git a/swift/llm/template/template/kwai.py b/swift/template/templates/kwai.py similarity index 99% rename from swift/llm/template/template/kwai.py rename to swift/template/templates/kwai.py index 0119afdb0e..9b58686af1 100644 --- a/swift/llm/template/template/kwai.py +++ b/swift/template/templates/kwai.py @@ -6,8 +6,7 @@ import numpy as np import torch -from swift.llm import to_device -from swift.utils import is_deepspeed_enabled +from swift.utils import is_deepspeed_enabled, to_device from ..base import Template from ..constant import MLLMTemplateType from ..register import register_template diff --git a/swift/llm/template/template/llama.py b/swift/template/templates/llama.py similarity index 100% rename from swift/llm/template/template/llama.py rename to swift/template/templates/llama.py diff --git a/swift/llm/template/template/llava.py b/swift/template/templates/llava.py similarity index 100% rename from swift/llm/template/template/llava.py rename to swift/template/templates/llava.py diff --git a/swift/llm/template/template/llm.py b/swift/template/templates/llm.py similarity index 100% rename from swift/llm/template/template/llm.py rename to swift/template/templates/llm.py diff --git a/swift/llm/template/template/megrez.py b/swift/template/templates/megrez.py similarity index 100% rename from swift/llm/template/template/megrez.py rename to swift/template/templates/megrez.py diff --git a/swift/llm/template/template/microsoft.py b/swift/template/templates/microsoft.py similarity index 100% rename from swift/llm/template/template/microsoft.py rename to swift/template/templates/microsoft.py diff --git a/swift/llm/template/template/midashenglm.py b/swift/template/templates/midashenglm.py similarity index 100% rename from swift/llm/template/template/midashenglm.py rename to swift/template/templates/midashenglm.py diff --git a/swift/llm/template/template/minicpm.py b/swift/template/templates/minicpm.py similarity index 100% rename from swift/llm/template/template/minicpm.py rename to swift/template/templates/minicpm.py diff --git a/swift/llm/template/template/minimax.py b/swift/template/templates/minimax.py similarity index 100% rename from swift/llm/template/template/minimax.py rename to swift/template/templates/minimax.py diff --git a/swift/llm/template/template/mistral.py b/swift/template/templates/mistral.py similarity index 99% rename from swift/llm/template/template/mistral.py rename to swift/template/templates/mistral.py index bfa040be7e..9ceafde1b2 100644 --- a/swift/llm/template/template/mistral.py +++ b/swift/template/templates/mistral.py @@ -147,7 +147,7 @@ def _get_new_tokens(i): class Mistral2506Template(Mistral2503Template): def _get_mistral_system(self): - from swift.llm import get_model_name + from swift.model import get_model_name model_dir = self.model_info.model_dir model_name = get_model_name(model_dir) file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt') diff --git a/swift/llm/template/template/molmo.py b/swift/template/templates/molmo.py similarity index 100% rename from swift/llm/template/template/molmo.py rename to swift/template/templates/molmo.py diff --git a/swift/llm/template/template/moonshot.py b/swift/template/templates/moonshot.py similarity index 100% rename from swift/llm/template/template/moonshot.py rename to swift/template/templates/moonshot.py diff --git a/swift/llm/template/template/mplug.py b/swift/template/templates/mplug.py similarity index 100% rename from swift/llm/template/template/mplug.py rename to swift/template/templates/mplug.py diff --git a/swift/llm/template/template/openbuddy.py b/swift/template/templates/openbuddy.py similarity index 100% rename from swift/llm/template/template/openbuddy.py rename to swift/template/templates/openbuddy.py diff --git a/swift/llm/template/template/pixtral.py b/swift/template/templates/pixtral.py similarity index 100% rename from swift/llm/template/template/pixtral.py rename to swift/template/templates/pixtral.py diff --git a/swift/llm/template/template/qwen.py b/swift/template/templates/qwen.py similarity index 99% rename from swift/llm/template/template/qwen.py rename to swift/template/templates/qwen.py index d1a000a87f..94eb088cb1 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/template/templates/qwen.py @@ -13,14 +13,13 @@ from torch import nn from transformers.integrations import is_deepspeed_zero3_enabled -from swift.llm import get_packed_seq_params, to_float_dtype -from swift.utils import get_env_args, is_deepspeed_enabled +from swift.utils import get_env_args, is_deepspeed_enabled, to_float_dtype from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import register_template from ..template_inputs import StdTemplateInputs from ..template_meta import TemplateMeta -from ..utils import Context, Word, findall +from ..utils import Context, Word, findall, get_packed_seq_params from ..vision_utils import load_audio, load_batch, load_video_ovis2, load_video_ovis2_5 from .llama import Llama3TemplateMeta from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta, ThinkingTemplate @@ -60,7 +59,7 @@ class Qwen3Template(ThinkingTemplate): no_think_prefix = '\n\n\n\n' -register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=Qwen3Template)) +register_template(QwenTemplateMeta(LLMTemplateType.qwen3_mixed, default_system=None, template_cls=Qwen3Template)) QWEN3_GUARD_TEMPLATE = ( '<|im_start|>user\n' diff --git a/swift/llm/template/template/seed.py b/swift/template/templates/seed.py similarity index 98% rename from swift/llm/template/template/seed.py rename to swift/template/templates/seed.py index aafaef277f..bd8990c8f7 100644 --- a/swift/llm/template/template/seed.py +++ b/swift/template/templates/seed.py @@ -7,10 +7,10 @@ from torch import nn from transformers.utils import strtobool -from swift.llm.template.constant import LLMTemplateType, MLLMTemplateType -from swift.llm.template.template_inputs import StdTemplateInputs from swift.utils import is_deepspeed_enabled +from ..constant import LLMTemplateType, MLLMTemplateType from ..register import Template, TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs from ..utils import Context, Prompt, Word, findall from .utils import ChatmlTemplateMeta diff --git a/swift/llm/template/template/stepfun.py b/swift/template/templates/stepfun.py similarity index 100% rename from swift/llm/template/template/stepfun.py rename to swift/template/templates/stepfun.py diff --git a/swift/llm/template/template/utils.py b/swift/template/templates/utils.py similarity index 100% rename from swift/llm/template/template/utils.py rename to swift/template/templates/utils.py diff --git a/swift/llm/template/template/valley.py b/swift/template/templates/valley.py similarity index 100% rename from swift/llm/template/template/valley.py rename to swift/template/templates/valley.py diff --git a/swift/llm/template/template/yi.py b/swift/template/templates/yi.py similarity index 100% rename from swift/llm/template/template/yi.py rename to swift/template/templates/yi.py diff --git a/swift/llm/template/utils.py b/swift/template/utils.py similarity index 58% rename from swift/llm/template/utils.py rename to swift/template/utils.py index 172dcb3205..166db9529d 100644 --- a/swift/llm/template/utils.py +++ b/swift/template/utils.py @@ -1,9 +1,26 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import re from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch -from transformers import PreTrainedTokenizerBase, StoppingCriteria +from transformers import FeatureExtractionMixin, PreTrainedTokenizerBase +from transformers import ProcessorMixin as HfProcessorMixin +from transformers import StoppingCriteria + +try: + from transformers import BaseImageProcessor + Processor = Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, HfProcessorMixin] +except ImportError: + Processor = Union[PreTrainedTokenizerBase, FeatureExtractionMixin, HfProcessorMixin] + +Tool = Dict[str, Union[str, Dict]] +History = List[Union[Tuple[str, str], List[str]]] +Message = Dict[str, Union[str, List[Dict[str, Any]], List[int], None]] +Messages = List[Message] + +if 'TOKENIZERS_PARALLELISM' not in os.environ: + os.environ['TOKENIZERS_PARALLELISM'] = 'false' Prompt = List[Union[str, List[int], List[str]]] Word = Union[str, List[int]] @@ -156,3 +173,95 @@ def split_str_parts_by(text: str, delimiters: List[str], regex_mode: bool = Fals for key, content in zip(parts[::2], parts[1::2]): res.append({'key': key, 'content': content}) return res + + +def history_to_messages(history: History, + system: Optional[str] = None, + roles: Optional[List[List[str]]] = None) -> 'Messages': + """ + history: [['query1', 'response1'], ['query2', 'response2']] + or [['query1', 'response1'], ['query2', None]] + """ + messages = [] + if not roles: + roles = [['user', 'assistant']] * len(history) + else: + assert len(roles) == len(history), f'len(roles): {len(roles)}, len(history): {len(history)}' + if system is not None: + messages.append({'role': 'system', 'content': system}) + + for role, h in zip(roles, history): + assert isinstance(h, (list, tuple)) + if h[0] is not None: + messages.append({'role': role[0], 'content': h[0]}) + if h[1] is not None: + messages.append({'role': role[1], 'content': h[1]}) + return messages + + +def messages_to_history(messages: 'Messages') -> Dict[str, Any]: + system = None + messages = messages.copy() + if messages[0]['role'] == 'system': + system = messages[0]['content'] + messages = messages[1::] + if len(messages) % 2 == 1: + messages.append({'role': 'assistant', 'content': None}) + history = [] + history_roles = [] + for user_message, assistant_message in zip(messages[::2], messages[1::2]): + assert user_message['role'] in {'tool', 'user'}, f'user_message {user_message}' + assert assistant_message['role'] == 'assistant', f'assistant_message: {assistant_message}' + history.append([user_message['content'], assistant_message['content']]) + history_roles.append([user_message['role'], assistant_message['role']]) + query, response = history.pop() if history else (None, None) + query_role = history_roles.pop()[0] if history_roles else None + return { + 'history': history, + 'history_roles': history_roles, + 'query': query, + 'query_role': query_role, + 'response': response, + 'system': system, + } + + +def get_packed_seq_params(position_ids: torch.Tensor): + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + position_ids_f = position_ids.flatten() + indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32) + + cu_seqlens = torch.cat([ + indices_q[position_ids_f == 0], + torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32), + ]) + + max_length = cu_seqlens.diff().max() # position_ids_f.max() + 1 + return { + 'cu_seq_lens_q': cu_seqlens, + 'cu_seq_lens_k': cu_seqlens, + 'max_length_q': max_length, + 'max_length_k': max_length, + } + + +def update_generation_config_eos_token(generation_config, template): + if generation_config is None: + return + stop_words = template.template_meta.stop_words + eos_token_id = generation_config.eos_token_id + if eos_token_id is None: + eos_token_id = [] + elif isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + modified = False + for stop_word in stop_words: + if stop_word is None: + continue + if isinstance(stop_word, str): + stop_word = template._tokenize(stop_word) + if isinstance(stop_word, (list, tuple)) and len(stop_word) == 1 and stop_word[0] not in eos_token_id: + eos_token_id.append(stop_word[0]) + modified = True + if modified: + generation_config.eos_token_id = eos_token_id diff --git a/swift/llm/template/vision_utils.py b/swift/template/vision_utils.py similarity index 100% rename from swift/llm/template/vision_utils.py rename to swift/template/vision_utils.py diff --git a/swift/trainers/__init__.py b/swift/trainers/__init__.py index 11ad8d8bfd..2944e489c8 100644 --- a/swift/trainers/__init__.py +++ b/swift/trainers/__init__.py @@ -16,13 +16,14 @@ if TYPE_CHECKING: from .arguments import (Seq2SeqTrainingArguments, TrainingArguments, RLHFArgumentsMixin, VllmArguments, GRPOArgumentsMixin, RolloutTrainerArgumentsMixin) - from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer, - RewardTrainer, GRPOTrainer, GKDTrainer) + from .rlhf_trainers import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer, + RewardTrainer, GRPOTrainer, GKDTrainer) from .rlhf_arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig, GKDConfig from .trainer_factory import TrainerFactory from .trainers import Seq2SeqTrainer, Trainer, EmbeddingTrainer, RerankerTrainer from .mixin import SwiftMixin from .utils import per_token_loss_func + from .data_loader import BatchSamplerShard, DataLoaderShard, DataLoaderDispatcher else: _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} @@ -33,7 +34,7 @@ ], 'rlhf_arguments': ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig', 'GKDConfig'], - 'rlhf_trainer': [ + 'rlhf_trainers': [ 'CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer', 'GRPOTrainer', 'GKDTrainer' ], @@ -41,6 +42,7 @@ 'trainers': ['Seq2SeqTrainer', 'Trainer', 'EmbeddingTrainer', 'RerankerTrainer'], 'mixin': ['SwiftMixin'], 'utils': ['per_token_loss_func'], + 'data_loader': ['BatchSamplerShard', 'DataLoaderShard', 'DataLoaderDispatcher'], } import sys diff --git a/swift/llm/data_loader.py b/swift/trainers/data_loader.py similarity index 99% rename from swift/llm/data_loader.py rename to swift/trainers/data_loader.py index ca3fced1e5..0c5d46fce7 100644 --- a/swift/llm/data_loader.py +++ b/swift/trainers/data_loader.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from swift.llm import to_device +from swift.utils import to_device class BatchSamplerShard: diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index a4a3fca0f1..6a6a4e4bed 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -37,16 +37,16 @@ ParallelMode, Trainer, TrainerCallback, reissue_pt_warnings) from transformers.trainer_utils import IntervalStrategy -from swift.hub import get_hub -from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template, get_llm_model -from swift.llm.utils import update_generation_config_eos_token +from swift.model import HfConfigFactory, get_llm_model, save_checkpoint +from swift.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward from swift.plugin import MeanMetric, compute_acc, extra_tuners, get_loss_func, get_metric +from swift.template import Template, get_packed_seq_params, update_generation_config_eos_token from swift.tuners import SwiftModel -from swift.utils import (get_current_device, get_last_valid_indices, get_logger, is_dist, is_mp, is_mp_ddp, - ms_logger_context, seed_worker) -from ..llm.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward +from swift.utils import (deep_getattr, get_current_device, get_last_valid_indices, get_logger, is_dist, is_mp, + is_mp_ddp, ms_logger_context, seed_worker) from .arguments import TrainingArguments -from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model +from .data_loader import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard +from .utils import can_return_loss, dynamic_gradient_checkpointing, find_labels, get_function, is_instance_of_ms_model try: from trl import AutoModelForCausalLMWithValueHead @@ -393,7 +393,6 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): is_adapter = isinstance(self.model, (SwiftModel, PeftModel)) # tokenizer if not is_adapter: - from swift.llm import save_checkpoint additional_saved_files = self.model_meta.additional_saved_files save_checkpoint( None, @@ -757,7 +756,6 @@ def _new_checkpoint(*args, use_reentrant=None, **kwargs): pass def _prepare_gradient_checkpointing(self, model) -> None: - from swift.llm import HfConfigFactory, deep_getattr, dynamic_gradient_checkpointing args = self.args HfConfigFactory.set_model_config_attr(model, 'use_cache', False) if args.gradient_checkpointing or args.vit_gradient_checkpointing: @@ -1063,7 +1061,6 @@ def prepare_logits_to_keep(self, inputs): inputs['logits_to_keep'] = logits_to_keep def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor: - from swift.llm import get_packed_seq_params cu_seqlens = get_packed_seq_params(position_ids)['cu_seq_lens_q'] res_cu_seqlens = cu_seqlens.clone() if isinstance(logits_to_keep, torch.Tensor): diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 428799da62..bf4e30d6a5 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -12,6 +12,8 @@ from trl.models.utils import prepare_deepspeed from trl.trainer.utils import selective_log_softmax +from swift.model import HfConfigFactory + class RLHFTrainerMixin: @@ -21,7 +23,6 @@ def __init__(self, *_args, **kwargs): from trl.trainer import disable_dropout_in_model - from swift.llm import HfConfigFactory self.ref_model = ref_model self._stored_metrics = defaultdict(lambda: defaultdict(list)) args = kwargs['args'] diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index a69fc0935b..400447fae5 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -26,9 +26,10 @@ from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback -from swift.llm import MultiModelKeys, RequestConfig, RolloutInferRequest, Template -from swift.llm.infer.protocol import ChatCompletionResponse, RolloutOutput +from swift.infer_engine import ChatCompletionResponse, GRPOVllmEngine, RequestConfig, RolloutInferRequest, RolloutOutput +from swift.model import MultiModelKeys from swift.plugin import MultiTurnScheduler, multi_turns +from swift.template import Template from swift.trainers import RolloutTrainerArgumentsMixin from swift.utils import get_logger, is_vllm_available, remove_response from swift.utils.torch_utils import get_current_device @@ -185,7 +186,6 @@ def _prepare_vllm(self): def _prepare_vllm_engine(self): """Create and configure vLLM engine for colocate mode""" from swift.tuners import Swift - from swift.llm.infer.infer_engine import GRPOVllmEngine args = self.args model = self.model steps_per_generation = args.steps_per_generation if hasattr(args, 'steps_per_generation') else 1 diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 05754d59a3..d26ba3e114 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -22,6 +22,7 @@ from torch import nn from torch.utils.data import DataLoader, RandomSampler +from swift.template import Messages from swift.tuners.lora import LoraConfig from swift.utils import gc_collect, get_logger, is_swanlab_available, is_vllm_available, is_wandb_available from swift.utils.torch_utils import get_torch_device @@ -31,8 +32,6 @@ if is_swanlab_available(): import swanlab -if TYPE_CHECKING: - from swift.llm.utils import Messages T = TypeVar('T') TensorLoRARequest = None @@ -1001,7 +1000,7 @@ def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor: Combined CHORD loss tensor """ from swift.trainers import per_token_loss_func - from swift.llm import to_device + from swift.utils import to_device current_step = trainer.state.global_step mu = mu_schedule_function(current_step, trainer.args.chord_mu_warmup_steps, trainer.args.chord_mu_decay_steps, diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 32cd32db0e..762686c534 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -17,9 +17,9 @@ from torch import nn from transformers.utils import is_torch_cuda_available -from swift.llm import AdapterRequest, RolloutInferRequest, Template -from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig, RolloutOutput +from swift.infer_engine import AdapterRequest, ChatCompletionResponse, RequestConfig, RolloutInferRequest, RolloutOutput from swift.plugin import Metric +from swift.template import Template from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available from .utils import peft_config_to_dict diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainers/__init__.py similarity index 99% rename from swift/trainers/rlhf_trainer/__init__.py rename to swift/trainers/rlhf_trainers/__init__.py index 829dba091b..b55c68cfcc 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainers/__init__.py @@ -15,6 +15,7 @@ from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection from .vllm_client import VLLMClient + else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], diff --git a/swift/trainers/rlhf_trainer/cpo_trainer.py b/swift/trainers/rlhf_trainers/cpo_trainer.py similarity index 100% rename from swift/trainers/rlhf_trainer/cpo_trainer.py rename to swift/trainers/rlhf_trainers/cpo_trainer.py diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainers/dpo_trainer.py similarity index 99% rename from swift/trainers/rlhf_trainer/dpo_trainer.py rename to swift/trainers/rlhf_trainers/dpo_trainer.py index b84a7051eb..ab54248847 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainers/dpo_trainer.py @@ -12,8 +12,7 @@ from trl.trainer.dpo_config import DPOConfig from trl.trainer.utils import RunningMoments -from swift.llm import to_device -from swift.utils import get_logger +from swift.utils import get_logger, to_device from ..mixin import DataLoaderMixin, SwiftMixin from .rlhf_mixin import RLHFTrainerMixin diff --git a/swift/trainers/rlhf_trainer/gkd_trainer.py b/swift/trainers/rlhf_trainers/gkd_trainer.py similarity index 99% rename from swift/trainers/rlhf_trainer/gkd_trainer.py rename to swift/trainers/rlhf_trainers/gkd_trainer.py index ae65eaf0f4..9b05067e25 100644 --- a/swift/trainers/rlhf_trainer/gkd_trainer.py +++ b/swift/trainers/rlhf_trainers/gkd_trainer.py @@ -17,7 +17,7 @@ from trl import GKDTrainer as HFGKDTrainer from trl import SFTTrainer as HFSFTTrainer -from swift.llm.template.template_inputs import TemplateInputs +from swift.template import TemplateInputs from swift.utils import (JsonlWriter, get_logger, is_swanlab_available, is_wandb_available, remove_response, unwrap_model_for_generation) from ..mixin import SwiftMixin @@ -262,7 +262,7 @@ def _prepare_batch_inputs(self, inputs: list) -> Dict[str, torch.Tensor]: encoded = template.encode(data, return_length=True) batch_encoded_inputs.append(encoded) - from swift.llm import to_device + from swift.utils import to_device batch_encoded = to_device(template.data_collator(batch_encoded_inputs), self.model.device) return batch_encoded diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainers/grpo_trainer.py similarity index 99% rename from swift/trainers/rlhf_trainer/grpo_trainer.py rename to swift/trainers/rlhf_trainers/grpo_trainer.py index 7818251f9d..37e7110b44 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainers/grpo_trainer.py @@ -39,11 +39,12 @@ from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd from trl.trainer.utils import selective_log_softmax -from swift.llm import RowPreprocessor, Template, to_device -from swift.llm.template.template_inputs import TemplateInputs -from swift.plugin import orms, rm_plugins +from swift.dataset import RowPreprocessor +from swift.infer_engine import PtEngine +from swift.plugins import orms, rm_plugins +from swift.template import Template, TemplateInputs from swift.utils import (JsonlWriter, get_logger, is_swanlab_available, is_wandb_available, remove_response, - seed_worker, unwrap_model_for_generation) + seed_worker, to_device, unwrap_model_for_generation) from ..mixin import SwiftMixin from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, @@ -120,7 +121,6 @@ def __init__(self, set_seed(args.seed, device_specific=True) if not self.args.use_vllm: - from swift.llm import PtEngine infer_template = copy(self.template) infer_template.padding_free = False infer_template.sequence_parallel_size = 1 diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainers/kto_trainer.py similarity index 100% rename from swift/trainers/rlhf_trainer/kto_trainer.py rename to swift/trainers/rlhf_trainers/kto_trainer.py diff --git a/swift/trainers/rlhf_trainer/orpo_trainer.py b/swift/trainers/rlhf_trainers/orpo_trainer.py similarity index 100% rename from swift/trainers/rlhf_trainer/orpo_trainer.py rename to swift/trainers/rlhf_trainers/orpo_trainer.py diff --git a/swift/trainers/rlhf_trainer/ppo_trainer.py b/swift/trainers/rlhf_trainers/ppo_trainer.py similarity index 100% rename from swift/trainers/rlhf_trainer/ppo_trainer.py rename to swift/trainers/rlhf_trainers/ppo_trainer.py diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainers/reward_trainer.py similarity index 100% rename from swift/trainers/rlhf_trainer/reward_trainer.py rename to swift/trainers/rlhf_trainers/reward_trainer.py diff --git a/swift/trainers/rlhf_trainers/rlhf_mixin.py b/swift/trainers/rlhf_trainers/rlhf_mixin.py new file mode 100644 index 0000000000..bf4e30d6a5 --- /dev/null +++ b/swift/trainers/rlhf_trainers/rlhf_mixin.py @@ -0,0 +1,181 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import inspect +from collections import defaultdict +from contextlib import contextmanager +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from transformers import PreTrainedModel +from trl.models.utils import prepare_deepspeed +from trl.trainer.utils import selective_log_softmax + +from swift.model import HfConfigFactory + + +class RLHFTrainerMixin: + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + *_args, + **kwargs): + from trl.trainer import disable_dropout_in_model + self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + args = kwargs['args'] + self.beta = getattr(args, 'beta', 0.0) + if getattr(args, 'disable_dropout', False): + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.is_encoder_decoder = kwargs['template'].is_encoder_decoder + self._peft_has_been_casted_to_bf16 = False + self.generate_during_eval = getattr(args, 'generate_during_eval', False) + if self.is_encoder_decoder: + self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id') + self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id') + # not use + self.is_vision_model = False + self.label_pad_token_id = -100 + self.use_dpo_data_collator = True + super().__init__(model, *_args, **kwargs) + self.aux_loss_enabled = model.model_info.is_moe_model and args.router_aux_loss_coef > 0 + self.aux_loss_coef = args.router_aux_loss_coef + if ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + self.padding_value = self.tokenizer.pad_token_id + + def create_loss_and_metric(self, args): + return {} + + def _prepare_inputs(self, inputs): + inputs = super()._prepare_inputs(inputs) + if self.template.sequence_parallel_size > 1: + from swift.trainers.sequence_parallel import sequence_parallel + sequence_parallel.prepare_inputs(inputs) + return inputs + + def get_train_dataloader(self, *args, **kwargs): + train_dataloader = super().get_train_dataloader(*args, **kwargs) + base_dataloader = train_dataloader.base_dataloader if hasattr( + train_dataloader, 'base_dataloader') and isinstance(train_dataloader.base_dataloader, + DataLoader) else train_dataloader + if base_dataloader.worker_init_fn is not None and not isinstance( + base_dataloader.worker_init_fn, partial) and 'num_workers' in inspect.signature( + base_dataloader.worker_init_fn).parameters: + base_dataloader.worker_init_fn = partial( + base_dataloader.worker_init_fn, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index) + return train_dataloader + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + model_kwargs = batch.copy() + labels = model_kwargs.pop('labels', None) + if self.is_encoder_decoder: + model_kwargs['labels'] = labels + + if self.aux_loss_enabled: + model_kwargs['output_router_logits'] = True + outputs = model(**model_kwargs, use_cache=False) + model_kwargs['labels'] = labels + model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) # just get shape + if outputs.logits.shape[1] != labels.shape[1]: + # for llava, the model returns logits for the entire sequence, including the image tokens + # (placed before the text tokens) + outputs.logits = outputs.logits[:, -labels.shape[1]:] + for key in ['input_ids', 'attention_mask', 'labels']: + model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None) + if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels + model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels'] + + @contextmanager + def _patch_concatenated_forward(): + _old_concatenated_inputs = self.concatenated_inputs + _old_model_call = model.__class__.__call__ + self.concatenated_inputs = lambda *args, **kwargs: model_kwargs + model.__class__.__call__ = lambda *args, **kwargs: outputs + try: + yield + finally: + self.concatenated_inputs = _old_concatenated_inputs + model.__class__.__call__ = _old_model_call + + with _patch_concatenated_forward(): + return super().concatenated_forward(model, model_kwargs) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + res = super().compute_loss(model, inputs, return_outputs=return_outputs) + # compat transformers>=4.46.* + if num_items_in_batch is not None and self.model_accepts_loss_kwargs: + loss = res[0] if return_outputs else res + loss = loss / self.args.gradient_accumulation_steps + return (loss, res[1:]) if return_outputs else loss + return res + + def _get_train_sampler(self, train_dataset=None): + get_train_sampler = super()._get_train_sampler + parameters = inspect.signature(get_train_sampler).parameters + kwargs = {'train_dataset': train_dataset} if 'train_dataset' in parameters else {} + return get_train_sampler(**kwargs) + + def get_per_token_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + label_pad_token_id=-100, + reduction='mean', + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if logits.shape[:-1] != labels.shape: + raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' + 'and labels must have the same shape {labels.shape}') + loss_mask = labels != label_pad_token_id + labels = labels.clone() + labels[~loss_mask] = 0 + if reduction == 'mean': + reduce_logits = logits.mean(-1) + elif reduction == 'sum': + reduce_logits = logits.sum(-1) + else: + raise ValueError(f'Invalid reduction: {reduction}') + if self.template.sequence_parallel_size == 1: + # https://github.com/huggingface/trl/pull/2799 + # Reduce peak vram consumption with efficient selective log_softmax + per_token_logps = selective_log_softmax(logits, labels) + per_token_logps[~loss_mask] = 0 + reduce_logits[~loss_mask] = 0 + return per_token_logps, reduce_logits, loss_mask + else: + labels = labels.to(logits.device) + loss_mask = loss_mask.to(logits.device) + mean_logits = reduce_logits + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + from swift.trainers.sequence_parallel.utils import GatherLoss + from swift.trainers.sequence_parallel import sequence_parallel + position_ids = sequence_parallel.real_position_ids + total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, 1, position_ids) + total_mean_logits = sequence_parallel.gather(mean_logits, dim=1, position_ids=position_ids) + if position_ids is not None and position_ids.min() == -1: + _pos_mask = position_ids >= 0 + total_per_token_logps = total_per_token_logps[_pos_mask].contiguous() + total_mean_logits = total_mean_logits[_pos_mask].contiguous() + total_loss_mask = total_loss_mask[_pos_mask].contiguous() + + total_loss_mask = total_loss_mask.bool() + total_per_token_logps = total_per_token_logps * (total_loss_mask) + + if total_per_token_logps.dim() == 1: + total_per_token_logps = total_per_token_logps.unsqueeze(0) + total_mean_logits = total_mean_logits.unsqueeze(0) + total_loss_mask = total_loss_mask.unsqueeze(0) + return total_per_token_logps, total_mean_logits, total_loss_mask diff --git a/swift/trainers/sequence_parallel/ulysses.py b/swift/trainers/sequence_parallel/ulysses.py index 392b6c3689..bb8c8a53bd 100644 --- a/swift/trainers/sequence_parallel/ulysses.py +++ b/swift/trainers/sequence_parallel/ulysses.py @@ -9,7 +9,7 @@ from torch.distributed import init_device_mesh from transformers import PreTrainedTokenizer -from swift.llm import HfConfigFactory, get_llm_model +from swift.model import HfConfigFactory, get_llm_model from swift.utils import get_cu_seqlens_from_position_ids, get_device, get_dist_setting from .utils import GatherLoss diff --git a/swift/trainers/sequence_parallel/utils.py b/swift/trainers/sequence_parallel/utils.py index 30953aa987..bbff0fc9f6 100644 --- a/swift/trainers/sequence_parallel/utils.py +++ b/swift/trainers/sequence_parallel/utils.py @@ -8,12 +8,7 @@ from torch.nn import CrossEntropyLoss from torch.utils.data import Sampler -from swift.llm import DataLoaderDispatcher - -try: - from trl.trainer.utils import entropy_from_logits -except ImportError: - from ..rlhf_trainer.utils import entropy_from_logits +from .data_loader import DataLoaderDispatcher class GatherLoss(torch.autograd.Function): diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 74cd9ad481..22870568ce 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -16,6 +16,8 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import is_peft_available +from swift.infer_engine import InferRequest, PtEngine, RequestConfig +from swift.model import HfConfigFactory from swift.utils import JsonlWriter, Serializer, gc_collect, get_logger, unwrap_model_for_generation from .arguments import Seq2SeqTrainingArguments, TrainingArguments from .mixin import DataLoaderMixin, SwiftMixin @@ -213,7 +215,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model_accepts_loss_kwargs = True # fix transformers>=4.46.2 if self.args.predict_with_generate: - from swift.llm import PtEngine self.infer_engine = PtEngine.from_model_template( self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) @@ -256,7 +257,6 @@ def prediction_step( with self.template.forward_context(self.model, inputs): return super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys) - from swift.llm import RequestConfig, InferRequest data_list = inputs['_data'] labels_list = [InferRequest.remove_response(data['messages']) for data in data_list] with unwrap_model_for_generation( @@ -282,7 +282,6 @@ def prediction_step( return None, response_list, labels_list def _prepare_inputs(self, inputs): - from swift.llm import HfConfigFactory args = self.args inputs = super()._prepare_inputs(inputs) if self.template.sequence_parallel_size > 1: diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index bae8b79928..fb3cfc849a 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -107,3 +107,29 @@ def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs target_probs = torch.exp(-loss) loss *= target_probs return loss + + +def dynamic_gradient_checkpointing(model, including_vit: bool = False) -> None: + from .model import ModelMeta + if isinstance(model, PeftModel): + model = model.model + model_meta: ModelMeta = getattr(model, 'model_meta', None) + if model_meta is not None and model_meta.is_multimodal and model_meta.model_arch: + tower_names = model_meta.model_arch.language_model.copy() + if including_vit: + tower_names += model_meta.model_arch.vision_tower + else: + tower_names = [None] + + model.supports_gradient_checkpointing = True + for tower_name in tower_names: + if tower_name is None: + model_tower = model + else: + model_tower = deep_getattr(model, tower_name) + model_tower.supports_gradient_checkpointing = True + module_list = find_module_list(model_tower) + if module_list is None: + continue + _add_gradient_checkpointing(module_list) + logger.info(f'Automatically add gradient_checkpointing to {model_tower.__class__}.') diff --git a/swift/tuners/llamapro.py b/swift/tuners/llamapro.py index 9678038116..ab4481e1a9 100644 --- a/swift/tuners/llamapro.py +++ b/swift/tuners/llamapro.py @@ -6,7 +6,7 @@ import torch from torch import nn -from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys +from swift.model import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys from swift.utils.logger import get_logger from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 7669805bb6..8061af5585 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -20,10 +20,9 @@ from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper from peft.utils import _get_submodules -from swift.llm import MODEL_ARCH_MAPPING, ModelKeys -from swift.utils import gc_collect +from swift.model import MODEL_ARCH_MAPPING, ModelKeys +from swift.utils import gc_collect, get_logger from swift.utils.constants import BIN_EXTENSIONS -from swift.utils.logger import get_logger logger = get_logger() diff --git a/swift/ui/app.py b/swift/ui/app.py index 460f572ad7..c480a575c2 100644 --- a/swift/ui/app.py +++ b/swift/ui/app.py @@ -8,15 +8,15 @@ from transformers.utils import strtobool import swift -from swift.llm import (DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SamplingArguments, SwiftPipeline, - WebUIArguments) -from swift.ui.llm_eval.llm_eval import LLMEval -from swift.ui.llm_export.llm_export import LLMExport -from swift.ui.llm_grpo.llm_grpo import LLMGRPO -from swift.ui.llm_infer.llm_infer import LLMInfer -from swift.ui.llm_rlhf.llm_rlhf import LLMRLHF -from swift.ui.llm_sample.llm_sample import LLMSample -from swift.ui.llm_train.llm_train import LLMTrain +from swift.arguments import (DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SamplingArguments, + SwiftPipeline, WebUIArguments) +from .llm_eval.llm_eval import LLMEval +from .llm_export.llm_export import LLMExport +from .llm_grpo.llm_grpo import LLMGRPO +from .llm_infer.llm_infer import LLMInfer +from .llm_rlhf.llm_rlhf import LLMRLHF +from .llm_sample.llm_sample import LLMSample +from .llm_train.llm_train import LLMTrain locale_dict = { 'title': { diff --git a/swift/ui/base.py b/swift/ui/base.py index 957bebe162..7cbe1a07f7 100644 --- a/swift/ui/base.py +++ b/swift/ui/base.py @@ -15,7 +15,9 @@ from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video from modelscope.hub.utils.utils import get_cache_dir -from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta +from swift.arguments import BaseArguments +from swift.model import get_matched_model_meta +from swift.template import TEMPLATE_MAPPING all_langs = ['zh', 'en'] builder: Type['BaseUI'] = None diff --git a/swift/ui/llm_eval/eval.py b/swift/ui/llm_eval/eval.py index ded9038bba..963741b756 100644 --- a/swift/ui/llm_eval/eval.py +++ b/swift/ui/llm_eval/eval.py @@ -3,6 +3,7 @@ import gradio as gr +from swift.arguments import EvalArguments from swift.ui.base import BaseUI from swift.utils import get_logger @@ -98,7 +99,6 @@ class Eval(BaseUI): @classmethod def do_build_ui(cls, base_tab: Type['BaseUI']): try: - from swift.llm.argument.eval_args import EvalArguments eval_dataset_dict = EvalArguments.list_eval_dataset() default_backend = EvalArguments.eval_backend except Exception as e: diff --git a/swift/ui/llm_eval/llm_eval.py b/swift/ui/llm_eval/llm_eval.py index 5819c4aa74..bc4cf0b8ab 100644 --- a/swift/ui/llm_eval/llm_eval.py +++ b/swift/ui/llm_eval/llm_eval.py @@ -13,7 +13,7 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import EvalArguments +from swift.arguments import EvalArguments from swift.ui.base import BaseUI from swift.ui.llm_eval.eval import Eval from swift.ui.llm_eval.model import Model diff --git a/swift/ui/llm_eval/model.py b/swift/ui/llm_eval/model.py index 570afabf8c..a805841adc 100644 --- a/swift/ui/llm_eval/model.py +++ b/swift/ui/llm_eval/model.py @@ -4,8 +4,8 @@ import gradio as gr -from swift.llm import TEMPLATE_MAPPING, EvalArguments, ModelType -from swift.llm.model.register import get_all_models +from swift.model import get_all_models +from swift.template import TEMPLATE_MAPPING, EvalArguments, ModelType from swift.ui.base import BaseUI diff --git a/swift/ui/llm_export/export.py b/swift/ui/llm_export/export.py index 6cd4a9d4da..376b57d16a 100644 --- a/swift/ui/llm_export/export.py +++ b/swift/ui/llm_export/export.py @@ -3,7 +3,7 @@ import gradio as gr -from swift.llm.dataset.register import get_dataset_list +from swift.dataset import get_dataset_list from swift.ui.base import BaseUI diff --git a/swift/ui/llm_export/llm_export.py b/swift/ui/llm_export/llm_export.py index 49fe8b799a..c96e727ae6 100644 --- a/swift/ui/llm_export/llm_export.py +++ b/swift/ui/llm_export/llm_export.py @@ -14,7 +14,7 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import ExportArguments +from swift.arguments import ExportArguments from swift.ui.base import BaseUI from swift.ui.llm_export.export import Export from swift.ui.llm_export.model import Model diff --git a/swift/ui/llm_export/model.py b/swift/ui/llm_export/model.py index d42862f71d..2930f3ca93 100644 --- a/swift/ui/llm_export/model.py +++ b/swift/ui/llm_export/model.py @@ -4,8 +4,9 @@ import gradio as gr -from swift.llm import TEMPLATE_MAPPING, ExportArguments, ModelType -from swift.llm.model.register import get_all_models +from swift.arguments import ExportArguments +from swift.model import ModelType, get_all_models +from swift.template import TEMPLATE_MAPPING from swift.ui.base import BaseUI diff --git a/swift/ui/llm_grpo/external_rollout.py b/swift/ui/llm_grpo/external_rollout.py index 4ca94aeaac..1d57c0e1e4 100644 --- a/swift/ui/llm_grpo/external_rollout.py +++ b/swift/ui/llm_grpo/external_rollout.py @@ -15,11 +15,11 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import DeployArguments, RLHFArguments, RolloutArguments +from swift.pipelines import DeployArguments, RLHFArguments, RolloutArguments from swift.ui.base import BaseUI -from swift.ui.llm_grpo.external_runtime import RolloutRuntime from swift.ui.llm_train.llm_train import run_command_in_background_with_popen from swift.utils import get_device_count, get_logger +from .external_runtime import RolloutRuntime logger = get_logger() diff --git a/swift/ui/llm_grpo/grpo_advanced.py b/swift/ui/llm_grpo/grpo_advanced.py index 0077919a32..e54ef29a26 100644 --- a/swift/ui/llm_grpo/grpo_advanced.py +++ b/swift/ui/llm_grpo/grpo_advanced.py @@ -4,8 +4,8 @@ import gradio as gr -from swift.llm import BaseArguments, ModelType -from swift.llm.model.register import get_all_models +from swift.arguments import BaseArguments, ModelType +from swift.model.register import get_all_models from swift.ui.base import BaseUI diff --git a/swift/ui/llm_grpo/llm_grpo.py b/swift/ui/llm_grpo/llm_grpo.py index 0cd13462fe..e0009fca2c 100644 --- a/swift/ui/llm_grpo/llm_grpo.py +++ b/swift/ui/llm_grpo/llm_grpo.py @@ -5,25 +5,25 @@ import gradio as gr from packaging import version -from swift.llm.argument.base_args.base_args import get_supported_tuners -from swift.ui.base import BaseUI -from swift.ui.llm_grpo.advanced import GRPOAdvanced -from swift.ui.llm_grpo.dataset import GRPODataset -from swift.ui.llm_grpo.external_rollout import LLMRollout -from swift.ui.llm_grpo.grpo_advanced import GrpoAdvanced -from swift.ui.llm_grpo.hyper import GRPOHyper -from swift.ui.llm_grpo.model import GRPOModel -from swift.ui.llm_grpo.optimizer import GRPOOptimizer -from swift.ui.llm_grpo.quantization import GRPOQuantization -from swift.ui.llm_grpo.report_to import GRPOReportTo -from swift.ui.llm_grpo.reward import Reward -from swift.ui.llm_grpo.rollout import Rollout -from swift.ui.llm_grpo.runtime import GRPORuntime -from swift.ui.llm_grpo.save import GRPOSave -from swift.ui.llm_grpo.tuner import GRPOTuner -from swift.ui.llm_train.llm_train import LLMTrain -from swift.ui.llm_train.runtime import Runtime +from swift.arguments import get_supported_tuners from swift.utils import get_device_count, get_logger +from .base import BaseUI +from .llm_grpo.advanced import GRPOAdvanced +from .llm_grpo.dataset import GRPODataset +from .llm_grpo.external_rollout import LLMRollout +from .llm_grpo.grpo_advanced import GrpoAdvanced +from .llm_grpo.hyper import GRPOHyper +from .llm_grpo.model import GRPOModel +from .llm_grpo.optimizer import GRPOOptimizer +from .llm_grpo.quantization import GRPOQuantization +from .llm_grpo.report_to import GRPOReportTo +from .llm_grpo.reward import Reward +from .llm_grpo.rollout import Rollout +from .llm_grpo.runtime import GRPORuntime +from .llm_grpo.save import GRPOSave +from .llm_grpo.tuner import GRPOTuner +from .llm_train.llm_train import LLMTrain +from .llm_train.runtime import Runtime logger = get_logger() diff --git a/swift/ui/llm_grpo/reward.py b/swift/ui/llm_grpo/reward.py index 7071efd154..bdfcbb8406 100644 --- a/swift/ui/llm_grpo/reward.py +++ b/swift/ui/llm_grpo/reward.py @@ -4,8 +4,8 @@ import gradio as gr -from swift.llm import BaseArguments, ModelType -from swift.llm.model.register import get_all_models +from swift.model import ModelType +from swift.model.register import get_all_models from swift.ui.base import BaseUI diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index 356b0b1ba8..12db9a15d7 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -15,7 +15,8 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import DeployArguments, InferArguments, InferClient, InferRequest, RequestConfig +from swift.infer_engine import InferClient, InferRequest, RequestConfig +from swift.pipelines import DeployArguments, InferArguments from swift.ui.base import BaseUI from swift.ui.llm_infer.model import Model from swift.ui.llm_infer.runtime import Runtime diff --git a/swift/ui/llm_infer/model.py b/swift/ui/llm_infer/model.py index ba1bbdb385..e2df1cec77 100644 --- a/swift/ui/llm_infer/model.py +++ b/swift/ui/llm_infer/model.py @@ -4,8 +4,10 @@ import gradio as gr -from swift.llm import TEMPLATE_MAPPING, DeployArguments, ModelType -from swift.llm.model.register import get_all_models +from swift.model import ModelType +from swift.model.register import get_all_models +from swift.pipelines import DeployArguments +from swift.template import TEMPLATE_MAPPING from swift.ui.base import BaseUI from swift.ui.llm_infer.generate import Generate diff --git a/swift/ui/llm_rlhf/llm_rlhf.py b/swift/ui/llm_rlhf/llm_rlhf.py index d7f1f740c7..0e3ef1fee2 100644 --- a/swift/ui/llm_rlhf/llm_rlhf.py +++ b/swift/ui/llm_rlhf/llm_rlhf.py @@ -4,21 +4,21 @@ import gradio as gr -from swift.llm.argument.base_args.base_args import get_supported_tuners -from swift.ui.base import BaseUI -from swift.ui.llm_rlhf.advanced import RLHFAdvanced -from swift.ui.llm_rlhf.dataset import RLHFDataset -from swift.ui.llm_rlhf.hyper import RLHFHyper -from swift.ui.llm_rlhf.model import RLHFModel -from swift.ui.llm_rlhf.optimizer import RLHFOptimizer -from swift.ui.llm_rlhf.quantization import RLHFQuantization -from swift.ui.llm_rlhf.report_to import RLHFReportTo -from swift.ui.llm_rlhf.rlhf import RLHF -from swift.ui.llm_rlhf.runtime import RLHFRuntime -from swift.ui.llm_rlhf.save import RLHFSave -from swift.ui.llm_rlhf.tuner import RLHFTuner -from swift.ui.llm_train.llm_train import LLMTrain +from swift.arguments import get_supported_tuners from swift.utils import get_device_count, get_logger +from .base import BaseUI +from .llm_rlhf.advanced import RLHFAdvanced +from .llm_rlhf.dataset import RLHFDataset +from .llm_rlhf.hyper import RLHFHyper +from .llm_rlhf.model import RLHFModel +from .llm_rlhf.optimizer import RLHFOptimizer +from .llm_rlhf.quantization import RLHFQuantization +from .llm_rlhf.report_to import RLHFReportTo +from .llm_rlhf.rlhf import RLHF +from .llm_rlhf.runtime import RLHFRuntime +from .llm_rlhf.save import RLHFSave +from .llm_rlhf.tuner import RLHFTuner +from .llm_train.llm_train import LLMTrain logger = get_logger() diff --git a/swift/ui/llm_rlhf/rlhf.py b/swift/ui/llm_rlhf/rlhf.py index 7aac23aeb6..c75a6f1cea 100644 --- a/swift/ui/llm_rlhf/rlhf.py +++ b/swift/ui/llm_rlhf/rlhf.py @@ -4,8 +4,8 @@ import gradio as gr -from swift.llm import ModelType -from swift.llm.model.register import get_all_models +from swift.model import ModelType +from swift.model.register import get_all_models from swift.ui.base import BaseUI diff --git a/swift/ui/llm_sample/llm_sample.py b/swift/ui/llm_sample/llm_sample.py index 7561f133c2..f17ff017b1 100644 --- a/swift/ui/llm_sample/llm_sample.py +++ b/swift/ui/llm_sample/llm_sample.py @@ -14,14 +14,14 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import SamplingArguments -from swift.llm.dataset.register import get_dataset_list -from swift.ui.base import BaseUI -from swift.ui.llm_sample.model import Model -from swift.ui.llm_sample.runtime import SampleRuntime -from swift.ui.llm_sample.sample import Sample -from swift.ui.llm_train.utils import run_command_in_background_with_popen +from swift.dataset import get_dataset_list +from swift.pipelines import SamplingArguments from swift.utils import get_device_count, get_logger +from .base import BaseUI +from .llm_sample.model import Model +from .llm_sample.runtime import SampleRuntime +from .llm_sample.sample import Sample +from .llm_train.utils import run_command_in_background_with_popen logger = get_logger() diff --git a/swift/ui/llm_sample/model.py b/swift/ui/llm_sample/model.py index 5988bcd312..408e52d857 100644 --- a/swift/ui/llm_sample/model.py +++ b/swift/ui/llm_sample/model.py @@ -4,8 +4,10 @@ import gradio as gr -from swift.llm import TEMPLATE_MAPPING, ModelType, SamplingArguments -from swift.llm.model.register import get_all_models +from swift.model import ModelType +from swift.model.register import get_all_models +from swift.pipelines import SamplingArguments +from swift.template import TEMPLATE_MAPPING from swift.ui.base import BaseUI diff --git a/swift/ui/llm_train/dataset.py b/swift/ui/llm_train/dataset.py index 833552e83e..e42556e92e 100644 --- a/swift/ui/llm_train/dataset.py +++ b/swift/ui/llm_train/dataset.py @@ -3,8 +3,8 @@ import gradio as gr -from swift.llm.dataset.register import get_dataset_list -from swift.ui.base import BaseUI +from swift.dataset import get_dataset_list +from .base import BaseUI class Dataset(BaseUI): diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index b1b2b94f6d..e9f106bdbc 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -14,8 +14,8 @@ from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available -from swift.llm import ExportArguments, RLHFArguments -from swift.llm.argument.base_args.base_args import get_supported_tuners +from swift.pipelines import ExportArguments, RLHFArguments +from swift.pipelines.argument.base_args.base_args import get_supported_tuners from swift.ui.base import BaseUI from swift.ui.llm_train.advanced import Advanced from swift.ui.llm_train.dataset import Dataset diff --git a/swift/ui/llm_train/model.py b/swift/ui/llm_train/model.py index 4221f36c1b..e908010a36 100644 --- a/swift/ui/llm_train/model.py +++ b/swift/ui/llm_train/model.py @@ -4,9 +4,11 @@ import gradio as gr -from swift.llm import TEMPLATE_MAPPING, ModelType, RLHFArguments -from swift.llm.model.register import get_all_models -from swift.ui.base import BaseUI +from swift.model import ModelType +from swift.model.register import get_all_models +from swift.pipelines import RLHFArguments +from swift.template import TEMPLATE_MAPPING +from ..base import BaseUI class Model(BaseUI): diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 4d9216bb70..093526c46c 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -8,14 +8,15 @@ from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl from .logger import get_logger, ms_logger_context from .np_utils import get_seed, stat_array, transform_jsonl_to_df +from .processor_mixin import ProcessorMixin from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing from .torch_utils import (Serializer, activate_parameters, check_shared_disk, disable_safe_ddp_context_use_barrier, empty_cache, find_all_linears, find_embedding, find_layers, find_norm, freeze_parameters, gc_collect, get_cu_seqlens_from_position_ids, get_current_device, get_device, get_device_count, get_last_valid_indices, get_model_parameter_info, get_n_params_grads, init_process_group, safe_ddp_context, seed_worker, set_default_ddp_config, set_device, - show_layers, time_synchronize, unwrap_model_for_generation) -from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, - format_time, get_env_args, get_modules_to_not_convert, import_external_file, json_parse_to_dict, - lower_bound, parse_args, patch_getattr, read_multi_line, remove_response, seed_everything, - split_list, subprocess_run, test_time, upper_bound) + show_layers, time_synchronize, to_device, to_float_dtype, unwrap_model_for_generation) +from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, + disable_deepspeed_zero3, find_free_port, format_time, get_env_args, get_modules_to_not_convert, + import_external_file, json_parse_to_dict, lower_bound, parse_args, patch_getattr, read_multi_line, + remove_response, seed_everything, split_list, subprocess_run, test_time, upper_bound) diff --git a/swift/utils/processor_mixin.py b/swift/utils/processor_mixin.py new file mode 100644 index 0000000000..7e65edf3b8 --- /dev/null +++ b/swift/utils/processor_mixin.py @@ -0,0 +1,18 @@ +from transformers import PreTrainedTokenizerBase + + +class ProcessorMixin: + + @property + def tokenizer(self): + tokenizer = self.processor + if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'): + tokenizer = tokenizer.tokenizer + return tokenizer + + @tokenizer.setter + def tokenizer(self, value): + if self.processor is self.tokenizer: + self.processor = value + elif self.tokenizer is not value: + raise AttributeError('Please use `self.processor` for assignment.') diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index 4b412e2dd9..c31366b06d 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -9,7 +9,7 @@ from bisect import bisect_right from contextlib import contextmanager, nullcontext from datetime import timedelta -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -534,3 +534,27 @@ def unwrap_model_for_generation( add_hooks(model) else: yield unwrapped_model + + +def to_float_dtype(data: Any, dtype: torch.dtype) -> Any: + """Change the float inputs to a dtype""" + if isinstance(data, Mapping): + return type(data)({k: to_float_dtype(v, dtype) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(to_float_dtype(v, dtype) for v in data) + elif isinstance(data, torch.Tensor) and torch.is_floating_point(data): + return data.to(dtype=dtype) + else: + return data + + +def to_device(data: Any, device: Union[str, torch.device, int], non_blocking: bool = False) -> Any: + """Move inputs to a device""" + if isinstance(data, Mapping): + return type(data)({k: to_device(v, device, non_blocking) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(to_device(v, device, non_blocking) for v in data) + elif isinstance(data, torch.Tensor): + return data.to(device=device, non_blocking=non_blocking) + else: + return data