Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ascend w8a8 graph_mode #3267

Merged
merged 5 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/cli/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def add_parser_smooth_quant():
type=str,
default='./work_dir',
help='The working directory for outputs. defaults to "./work_dir"')
parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')
ArgumentHelper.calib_dataset(parser)
ArgumentHelper.calib_samples(parser)
ArgumentHelper.calib_seqlen(parser)
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.utils import try_import_deeplink

from .calibrate import LAYER_TYPE_MAP, calibrate

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.models import QLinear, QRMSNorm
from lmdeploy.utils import try_import_deeplink


def smooth_quant(model: str,
Expand All @@ -26,6 +27,7 @@ def smooth_quant(model: str,
quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',
revision: str = None,
download_dir: str = None):
try_import_deeplink(device)
if quant_dtype == 'fp8':
quant_dtype = 'float8_e4m3fn'

Expand Down
13 changes: 0 additions & 13 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker # noqa: F401


def check_env_deeplink(device_type: str):
"""check Deeplink environment."""
from .deeplink import DeeplinkChecker
checker = DeeplinkChecker(device_type)
checker.handle()


def try_import_deeplink(device_type: str):
"""check Deeplink environment."""
check_env_deeplink(device_type)
16 changes: 3 additions & 13 deletions lmdeploy/pytorch/check_env/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker
from lmdeploy.utils import try_import_deeplink

deeplink_device_type_list = [
'ascend',
'npu',
'maca',
'camb',
]
from .base import BaseChecker


class DeeplinkChecker(BaseChecker):
Expand All @@ -18,9 +13,4 @@ def __init__(self, device_type: str, logger=None) -> None:

def check(self):
"""check."""
device_type = self.device_type
if device_type in deeplink_device_type_list:
try:
import dlinfer.framework.lmdeploy_ext # noqa: F401
except Exception as e:
self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.')
try_import_deeplink(self.device_type)
5 changes: 2 additions & 3 deletions lmdeploy/pytorch/kernels/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from ..default.w8a8_kernels import per_channel_quant
from .alibi_pagedattention import alibi_paged_attention_fwd
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
Expand All @@ -12,8 +12,7 @@
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
from .w8a8_fused_moe import fused_moe_w8a8
from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,
rms_norm_dynamic_quant)
from .w8a8_triton_kernels import matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant

__all__ = [
'apply_rotary_pos_emb',
Expand Down
29 changes: 1 addition & 28 deletions lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from packaging import version

from ..default.w8a8_kernels import per_channel_quant
from .triton_utils import get_kernel_meta

TRITON_VERSION = version.parse(triton.__version__)
Expand All @@ -14,34 +15,6 @@
tl_round = tl.math.round


def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.

Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.

Returns:
tuple: A tuple containing two items -- the quantized tensor and
the scale used for quantization.
"""
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
q_max = qtype_info.max
q_min = qtype_info.min
scale = x_absmax / q_max
x_q = x / scale
if not dtype.is_floating_point:
x_q = torch.round(x_q)
x_q = x_q.clamp(q_min, q_max).to(dtype)
return x_q, scale


@triton.autotune(
configs=[
triton.Config({
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/default/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multinomial_sampling import multinomial_sampling
from .w8a8_kernels import per_channel_quant

__all__ = [
'multinomial_sampling',
'per_channel_quant',
]
30 changes: 30 additions & 0 deletions lmdeploy/pytorch/kernels/default/w8a8_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.

Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.

Returns:
tuple: A tuple containing two items -- the quantized tensor and
the scale used for quantization.
"""
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)
q_max = qtype_info.max
q_min = qtype_info.min
scale = x_absmax / q_max
x_q = x / scale
if not dtype.is_floating_point:
x_q = torch.round(x_q)
x_q = x_q.clamp(q_min, q_max).to(dtype)
return x_q, scale
1 change: 1 addition & 0 deletions lmdeploy/pytorch/kernels/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, func_name: str):
self.func_name = func_name
self.dispatched_func = self.load_and_call
self.device_manager.register_context_callback(self.device_callback)
self.device_map = {'cuda': 'cuda', 'ascend': 'dlinfer', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}

def device_callback(self, context: DeviceContext):
"""device context callback."""
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..default import multinomial_sampling
from ..default import multinomial_sampling, per_channel_quant
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
Expand All @@ -21,4 +21,5 @@
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
'per_channel_quant',
]
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/q_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.

from dataclasses import dataclass
from dataclasses import dataclass, fields

import torch
import torch.nn as nn
Expand All @@ -19,13 +19,15 @@ class QTensor:
scale: torch.Tensor
zero_point: torch.Tensor = None

def __post_init__(self):
self.fields = [field.name for field in fields(self)]

def __getattr__(self, name: str):
"""Allows attribute access to be forwarded to the wrapped tensor when
the attribute doesn't exist in QTensor."""
try:
if name in self.fields:
return super().__getattr__(name)
except AttributeError:
return getattr(self.tensor, name)
return getattr(self.tensor, name)


class QRMSNorm(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import PIL

from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig
from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.utils import get_logger
from lmdeploy.utils import get_logger, try_import_deeplink
from lmdeploy.vl.engine import ImageEncoder
from lmdeploy.vl.utils import load_image

Expand Down
16 changes: 16 additions & 0 deletions lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,19 @@ def is_bf16_supported(device_type: str = 'cuda'):
return True
else:
return False


def try_import_deeplink(device_type: str):
deeplink_device_type_list = [
'ascend',
'npu',
'maca',
'camb',
]
if device_type in deeplink_device_type_list:
try:
import dlinfer.framework.lmdeploy_ext # noqa: F401
except Exception as e:
logger = get_logger('lmdeploy')
logger.error(f'{type(e).__name__}: {e}')
exit(1)
Loading