diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py index ef10561a0f..768ef47544 100644 --- a/lmdeploy/cli/lite.py +++ b/lmdeploy/cli/lite.py @@ -92,6 +92,7 @@ def add_parser_smooth_quant(): type=str, default='./work_dir', help='The working directory for outputs. defaults to "./work_dir"') + parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)') ArgumentHelper.calib_dataset(parser) ArgumentHelper.calib_samples(parser) ArgumentHelper.calib_seqlen(parser) diff --git a/lmdeploy/lite/apis/auto_awq.py b/lmdeploy/lite/apis/auto_awq.py index dd1dc8f0f4..756e4e3a5f 100644 --- a/lmdeploy/lite/apis/auto_awq.py +++ b/lmdeploy/lite/apis/auto_awq.py @@ -10,7 +10,7 @@ from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers from lmdeploy.lite.utils import collect_target_modules -from lmdeploy.pytorch.check_env import try_import_deeplink +from lmdeploy.utils import try_import_deeplink from .calibrate import LAYER_TYPE_MAP, calibrate diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index 37d924327e..9aa889c648 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -11,6 +11,7 @@ from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers from lmdeploy.lite.utils import collect_target_modules from lmdeploy.pytorch.models import QLinear, QRMSNorm +from lmdeploy.utils import try_import_deeplink def smooth_quant(model: str, @@ -26,6 +27,7 @@ def smooth_quant(model: str, quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8', revision: str = None, download_dir: str = None): + try_import_deeplink(device) if quant_dtype == 'fp8': quant_dtype = 'float8_e4m3fn' diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index bc95a32be6..ef101fec61 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -1,14 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseChecker # noqa: F401 - - -def check_env_deeplink(device_type: str): - """check Deeplink environment.""" - from .deeplink import DeeplinkChecker - checker = DeeplinkChecker(device_type) - checker.handle() - - -def try_import_deeplink(device_type: str): - """check Deeplink environment.""" - check_env_deeplink(device_type) diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py index 00bcfdf77c..9bffb4c52b 100644 --- a/lmdeploy/pytorch/check_env/deeplink.py +++ b/lmdeploy/pytorch/check_env/deeplink.py @@ -1,12 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseChecker +from lmdeploy.utils import try_import_deeplink -deeplink_device_type_list = [ - 'ascend', - 'npu', - 'maca', - 'camb', -] +from .base import BaseChecker class DeeplinkChecker(BaseChecker): @@ -18,9 +13,4 @@ def __init__(self, device_type: str, logger=None) -> None: def check(self): """check.""" - device_type = self.device_type - if device_type in deeplink_device_type_list: - try: - import dlinfer.framework.lmdeploy_ext # noqa: F401 - except Exception as e: - self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.') + try_import_deeplink(self.device_type) diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py index 2373a613f1..f741a8053d 100644 --- a/lmdeploy/pytorch/kernels/cuda/__init__.py +++ b/lmdeploy/pytorch/kernels/cuda/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - +from ..default.w8a8_kernels import per_channel_quant from .alibi_pagedattention import alibi_paged_attention_fwd from .apply_rotary_pos_emb import apply_rotary_pos_emb from .fill_kv_cache import fill_kv_cache @@ -12,8 +12,7 @@ from .pagedattention import paged_attention_fwd from .rms_norm import rms_norm from .w8a8_fused_moe import fused_moe_w8a8 -from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8, - rms_norm_dynamic_quant) +from .w8a8_triton_kernels import matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant __all__ = [ 'apply_rotary_pos_emb', diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py index 652e1a5be5..e252ed8188 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py @@ -5,6 +5,7 @@ import triton.language as tl from packaging import version +from ..default.w8a8_kernels import per_channel_quant from .triton_utils import get_kernel_meta TRITON_VERSION = version.parse(triton.__version__) @@ -14,34 +15,6 @@ tl_round = tl.math.round -def per_channel_quant(x: torch.Tensor, dtype: torch.dtype): - """Quantize the input tensor 'x' channel-wise using the given number of - bits. - - Args: - x (torch.Tensor): The input tensor to be quantized. Must be a - 2-dimensional tensor. - dtype (torch.dtype): The data type to which the quantized tensor should - be converted. - - Returns: - tuple: A tuple containing two items -- the quantized tensor and - the scale used for quantization. - """ - assert x.ndim == 2 - x = x.to(torch.float32) - x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0] - qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype) - q_max = qtype_info.max - q_min = qtype_info.min - scale = x_absmax / q_max - x_q = x / scale - if not dtype.is_floating_point: - x_q = torch.round(x_q) - x_q = x_q.clamp(q_min, q_max).to(dtype) - return x_q, scale - - @triton.autotune( configs=[ triton.Config({ diff --git a/lmdeploy/pytorch/kernels/default/__init__.py b/lmdeploy/pytorch/kernels/default/__init__.py index d80f466ae9..ff62d5ba4b 100644 --- a/lmdeploy/pytorch/kernels/default/__init__.py +++ b/lmdeploy/pytorch/kernels/default/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .multinomial_sampling import multinomial_sampling +from .w8a8_kernels import per_channel_quant __all__ = [ 'multinomial_sampling', + 'per_channel_quant', ] diff --git a/lmdeploy/pytorch/kernels/default/w8a8_kernels.py b/lmdeploy/pytorch/kernels/default/w8a8_kernels.py new file mode 100644 index 0000000000..5d215fe532 --- /dev/null +++ b/lmdeploy/pytorch/kernels/default/w8a8_kernels.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def per_channel_quant(x: torch.Tensor, dtype: torch.dtype): + """Quantize the input tensor 'x' channel-wise using the given number of + bits. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be a + 2-dimensional tensor. + dtype (torch.dtype): The data type to which the quantized tensor should + be converted. + + Returns: + tuple: A tuple containing two items -- the quantized tensor and + the scale used for quantization. + """ + assert x.ndim == 2 + x = x.to(torch.float32) + x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0] + qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype) + q_max = qtype_info.max + q_min = qtype_info.min + scale = x_absmax / q_max + x_q = x / scale + if not dtype.is_floating_point: + x_q = torch.round(x_q) + x_q = x_q.clamp(q_min, q_max).to(dtype) + return x_q, scale diff --git a/lmdeploy/pytorch/kernels/dispatcher.py b/lmdeploy/pytorch/kernels/dispatcher.py index 1ae7cde9e1..7ca0648880 100644 --- a/lmdeploy/pytorch/kernels/dispatcher.py +++ b/lmdeploy/pytorch/kernels/dispatcher.py @@ -64,6 +64,7 @@ def __init__(self, func_name: str): self.func_name = func_name self.dispatched_func = self.load_and_call self.device_manager.register_context_callback(self.device_callback) + self.device_map = {'cuda': 'cuda', 'ascend': 'dlinfer', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'} def device_callback(self, context: DeviceContext): """device context callback.""" diff --git a/lmdeploy/pytorch/kernels/dlinfer/__init__.py b/lmdeploy/pytorch/kernels/dlinfer/__init__.py index fe82010761..7b226d7ff4 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/__init__.py +++ b/lmdeploy/pytorch/kernels/dlinfer/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from ..default import multinomial_sampling +from ..default import multinomial_sampling, per_channel_quant from .apply_rotary_pos_emb import apply_rotary_pos_emb from .awq_kernels import awq_linear from .fill_kv_cache import fill_kv_cache @@ -21,4 +21,5 @@ 'linear', 'moe_gating_topk_softmax', 'multinomial_sampling', + 'per_channel_quant', ] diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py index af6496cfaf..f304d00108 100644 --- a/lmdeploy/pytorch/models/q_modules.py +++ b/lmdeploy/pytorch/models/q_modules.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from dataclasses import dataclass +from dataclasses import dataclass, fields import torch import torch.nn as nn @@ -19,13 +19,15 @@ class QTensor: scale: torch.Tensor zero_point: torch.Tensor = None + def __post_init__(self): + self.fields = [field.name for field in fields(self)] + def __getattr__(self, name: str): """Allows attribute access to be forwarded to the wrapped tensor when the attribute doesn't exist in QTensor.""" - try: + if name in self.fields: return super().__getattr__(name) - except AttributeError: - return getattr(self.tensor, name) + return getattr(self.tensor, name) class QRMSNorm(nn.Module): diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index 5bf227b661..e1413cdf8a 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -6,9 +6,8 @@ import PIL from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig -from lmdeploy.pytorch.check_env import try_import_deeplink from lmdeploy.serve.async_engine import AsyncEngine -from lmdeploy.utils import get_logger +from lmdeploy.utils import get_logger, try_import_deeplink from lmdeploy.vl.engine import ImageEncoder from lmdeploy.vl.utils import load_image diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index eca7a2d058..376576e220 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -387,3 +387,19 @@ def is_bf16_supported(device_type: str = 'cuda'): return True else: return False + + +def try_import_deeplink(device_type: str): + deeplink_device_type_list = [ + 'ascend', + 'npu', + 'maca', + 'camb', + ] + if device_type in deeplink_device_type_list: + try: + import dlinfer.framework.lmdeploy_ext # noqa: F401 + except Exception as e: + logger = get_logger('lmdeploy') + logger.error(f'{type(e).__name__}: {e}') + exit(1)