Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/gpt_oss/mcore.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mcore>=0.15
# 2 * 40GiB
# dataset format: https://github.com/modelscope/ms-swift/pull/5277
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmark/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
import torch

from swift.llm import ExportArguments
from swift.pipelines import ExportArguments
from swift.utils import find_free_port, get_device_count, get_logger

logger = get_logger()
Expand Down
2 changes: 1 addition & 1 deletion scripts/utils/run_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from swift.llm import TemplateType
from swift.template import TemplateType

if __name__ == '__main__':
template_name_list = TemplateType.get_template_name_list()
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from dataclasses import dataclass
from typing import Literal, Optional

from swift.model import get_matched_model_meta
from swift.template import get_template_meta
from swift.utils import find_free_port, get_logger
from ..model import get_matched_model_meta
from ..template import get_template_meta
from .deploy_args import DeployArguments
from .webui_args import WebUIArguments

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import json

from swift.hub import get_hub
from swift.llm import Processor, Template, get_model_tokenizer, get_template, load_by_unsloth, safe_snapshot_download
from swift.llm.utils import get_ckpt_dir
from swift.model import get_ckpt_dir, get_model_tokenizer, load_by_unsloth, safe_snapshot_download
from swift.plugin import extra_tuners
from swift.template import Processor, Template, get_template
from swift.utils import (check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master,
json_parse_to_dict, set_device, use_hf_hub)
from .data_args import DataArguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union

from swift.llm import DATASET_MAPPING, register_dataset_info
from swift.dataset import DATASET_MAPPING, register_dataset_info
from swift.utils import get_logger, json_parse_to_dict

logger = get_logger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _init_stream(self):
def get_request_config(self):
if getattr(self, 'task_type') != 'causal_lm':
return
from swift.llm import RequestConfig
from swift.pipelines import RequestConfig

return RequestConfig(
max_tokens=self.max_new_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from transformers.utils import is_torch_mps_available

from swift.llm import MODEL_MAPPING, HfConfigFactory, get_model_info_meta, get_model_name
from swift.model import MODEL_MAPPING, HfConfigFactory, get_model_info_meta, get_model_name
from swift.utils import get_dist_setting, get_logger, json_parse_to_dict

logger = get_logger()
Expand Down Expand Up @@ -116,7 +116,7 @@ def _init_max_memory(self):

def _init_torch_dtype(self) -> None:
""""If torch_dtype is None, find a proper dtype by the train_type/GPU"""
from swift.llm import TrainArguments
from swift.pipelines import TrainArguments

self.torch_dtype: Optional[torch.dtype] = HfConfigFactory.to_torch_dtype(self.torch_dtype)
self.torch_dtype: torch.dtype = self._init_model_info()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from swift.llm import HfConfigFactory
from swift.model import HfConfigFactory
from swift.utils import get_modules_to_not_convert


Expand Down Expand Up @@ -72,7 +72,7 @@ def get_quantization_config(self):
if not hasattr(self, 'model_info'):
return
from transformers import FineGrainedFP8Config
from swift.llm import get_model_tokenizer
from swift.model import get_model_tokenizer
with torch.device('meta'):
hf_model, _ = get_model_tokenizer(self.model_dir, model_type=self.model_type, return_dummy_model=True)
modules_to_not_convert = get_modules_to_not_convert(hf_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Literal, Optional

from swift.llm import TEMPLATE_MAPPING
from swift.template import TEMPLATE_MAPPING
from swift.utils import get_logger

logger = get_logger()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Literal, Optional

from swift.llm import safe_snapshot_download
from swift.model import safe_snapshot_download
from swift.utils import find_free_port, get_device_count, get_logger
from .base_args import BaseArguments
from .infer_args import InferArguments
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist

from swift.llm import HfConfigFactory
from swift.model import HfConfigFactory
from swift.utils import get_logger, init_process_group, set_default_ddp_config
from .base_args import BaseArguments, to_abspath
from .merge_args import MergeArguments
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional

from swift.llm import MODEL_MAPPING
from swift.model import MODEL_MAPPING
from swift.trainers import GRPOArgumentsMixin, RLHFArgumentsMixin
from swift.utils import get_current_device, get_logger, is_master, is_mp, json_parse_to_dict, set_default_ddp_config
from .train_args import TrainArguments
Expand Down Expand Up @@ -413,7 +413,7 @@ def _init_external_vllm(self):
if self.rlhf_type not in rlhf_support_vllm_types or (self.vllm_server_host is None
and self.vllm_server_base_url is None):
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
from swift.trainers.rlhf_trainers.vllm_client import VLLMClient
if is_master():
logger.info('Start connecting to vLLM server')
self.vllm_client = VLLMClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import json

from swift.llm import BaseArguments
from swift.utils import get_logger
from .base_args import BaseArguments

logger = get_logger()

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion swift/cli/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from swift.llm import app_main
from swift.pipelines import app_main

if __name__ == '__main__':
app_main()
2 changes: 1 addition & 1 deletion swift/cli/deploy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import deploy_main
from swift.pipelines import deploy_main

if __name__ == '__main__':
deploy_main()
2 changes: 1 addition & 1 deletion swift/cli/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import eval_main
from swift.pipelines import eval_main

if __name__ == '__main__':
eval_main()
2 changes: 1 addition & 1 deletion swift/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import export_main
from swift.pipelines import export_main

if __name__ == '__main__':
export_main()
2 changes: 1 addition & 1 deletion swift/cli/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import infer_main
from swift.pipelines import infer_main

if __name__ == '__main__':
infer_main()
2 changes: 1 addition & 1 deletion swift/cli/merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import ExportArguments, SwiftPipeline, merge_lora
from swift.pipelines import ExportArguments, SwiftPipeline, merge_lora


class SwiftMergeLoRA(SwiftPipeline):
Expand Down
2 changes: 1 addition & 1 deletion swift/cli/pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
if __name__ == '__main__':
from swift.cli.utils import try_use_single_device_mode
try_use_single_device_mode()
from swift.llm import pt_main
from swift.pipelines import pt_main
pt_main()
2 changes: 1 addition & 1 deletion swift/cli/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
if __name__ == '__main__':
from swift.cli.utils import try_use_single_device_mode
try_use_single_device_mode()
from swift.llm import rlhf_main
from swift.pipelines import rlhf_main
rlhf_main()
2 changes: 1 addition & 1 deletion swift/cli/rollout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import rollout_main
from swift.pipelines import rollout_main

if __name__ == '__main__':
rollout_main()
2 changes: 1 addition & 1 deletion swift/cli/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
if __name__ == '__main__':
from swift.ray import try_init_ray
try_init_ray()
from swift.llm.sampling import sampling_main
from swift.pipelines import sampling_main
sampling_main()
2 changes: 1 addition & 1 deletion swift/cli/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ def try_init_unsloth():
try_init_unsloth()
from swift.ray import try_init_ray
try_init_ray()
from swift.llm import sft_main
from swift.pipelines import sft_main
sft_main()
5 changes: 2 additions & 3 deletions swift/llm/dataset/__init__.py → swift/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import datasets.fingerprint
from datasets import Dataset as HfDataset

from ..utils import get_temporary_cache_files_directory
from . import dataset
from . import datasets
from .loader import DATASET_TYPE, load_dataset
from .media import MediaResource
from .preprocessor import (AlpacaPreprocessor, AutoPreprocessor, MessagesPreprocessor, ResponsePreprocessor,
RowPreprocessor)
from .register import DATASET_MAPPING, DatasetMeta, SubsetDataset, register_dataset, register_dataset_info
from .utils import (AddLengthPreprocessor, EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset,
sample_dataset)
get_temporary_cache_files_directory, sample_dataset)

datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory
datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datasets import Sequence, Value
from modelscope.hub.utils.utils import get_cache_dir

from swift.llm import history_to_messages
from swift.template import history_to_messages
from swift.utils import get_logger, is_dist, is_master, safe_ddp_context

DATASET_TYPE = Union[HfDataset, HfIterableDataset]
Expand Down
File renamed without changes.
30 changes: 29 additions & 1 deletion swift/llm/dataset/utils.py → swift/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import math
import multiprocessing as mp
import os
import tempfile
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import numpy as np
import torch.distributed as dist
from datasets import Dataset as HfDataset
from modelscope.hub.utils.utils import get_cache_dir
from torch.utils.data import Dataset, IterableDataset
from tqdm import tqdm

Expand All @@ -17,7 +21,7 @@
logger = get_logger()

if TYPE_CHECKING:
from swift.llm import Template
from swift.template import Template


def sample_dataset(
Expand Down Expand Up @@ -323,3 +327,27 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
encoded = super().preprocess(row)
row['length'] = encoded['length']
return row


TEMP_DIR_POOL = {}


def get_temporary_cache_files_directory(prefix=None):
if prefix is None:
import datasets.config
prefix = datasets.config.TEMP_CACHE_DIR_PREFIX
global TEMP_DIR_POOL
if prefix in TEMP_DIR_POOL:
TEMP_DIR = TEMP_DIR_POOL[prefix]
else:
tmp_dir = os.path.join(get_cache_dir(), 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
kwargs = {}
parameters = inspect.signature(tempfile.TemporaryDirectory.__init__).parameters
if 'ignore_cleanup_errors' in parameters:
kwargs['ignore_cleanup_errors'] = True
TEMP_DIR = tempfile.TemporaryDirectory(prefix=prefix, dir=tmp_dir, **kwargs)
logger.info(f'create tmp_dir: {TEMP_DIR.name}')
TEMP_DIR_POOL[prefix] = TEMP_DIR

return TEMP_DIR.name
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Iterator, List, Optional, Union

from swift.llm import InferRequest
from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig
from .protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig


class BaseInferEngine(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from PIL import Image
from tqdm.asyncio import tqdm_asyncio

from swift.llm import InferRequest, Template, VllmEngine
from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput
from swift.template import Template
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, InferRequest, RequestConfig,
RolloutOutput)
from .utils import AdapterRequest
from .vllm_engine import VllmEngine

try:
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from requests.exceptions import HTTPError

from swift.plugin import Metric
from ..protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest,
ModelList, RequestConfig)
from .infer_engine import InferEngine
from .protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest,
ModelList, RequestConfig)


class InferClient(InferEngine):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from tqdm import tqdm

from swift.llm import InferRequest, ProcessorMixin, get_template
from swift.llm.template import Template
from swift.llm.utils import get_ckpt_dir
from swift.model import get_ckpt_dir
from swift.plugin import Metric
from swift.utils import get_logger
from ..protocol import (ChatCompletionMessageToolCall, ChatCompletionResponse, ChatCompletionStreamResponse,
RequestConfig, UsageInfo)
from swift.template import Template, get_template
from swift.utils import ProcessorMixin, get_logger
from .base import BaseInferEngine
from .protocol import (ChatCompletionMessageToolCall, ChatCompletionResponse, ChatCompletionStreamResponse,
InferRequest, RequestConfig, UsageInfo)

logger = get_logger()

Expand All @@ -37,7 +36,7 @@ def _post_init(self, template=None):
ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None))
logger.info('Create the default_template for the infer_engine')
if ckpt_dir:
from swift.llm import BaseArguments
from swift.pipelines import BaseArguments
args = BaseArguments.from_pretrained(ckpt_dir)
self.default_template = args.get_template(self.processor)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from transformers import GenerationConfig
from transformers.utils.versions import require_version

from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer
from swift.model import get_model_tokenizer
from swift.plugin import Metric
from swift.template import Template, TemplateMeta
from swift.utils import get_logger, get_seed
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig)
from .infer_engine import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig)
from .utils import InferStreamer

try:
Expand Down
File renamed without changes.
Loading