forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_quant_param.py
60 lines (45 loc) · 2.23 KB
/
test_quant_param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.linear.quantization import QuantizedParameter
from deepspeed.linear.config import QuantizationConfig
from deepspeed.ops.op_builder import FPQuantizerBuilder
from unit.common import DistributedTest
if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)
class TestQuantParam(DistributedTest):
world_size = 1
@pytest.mark.parametrize('dtype', [torch.half, torch.float])
def test_unsupported_dtypes(self, dtype):
device = get_accelerator().current_device_name()
data = torch.rand(5, 5, device='cpu', dtype=dtype)
qp = QuantizedParameter(data)
with pytest.raises(AssertionError):
qp.to(device)
def test_requires_grad(self):
data = torch.rand(5, 5, dtype=torch.bfloat16)
with pytest.raises(ValueError):
QuantizedParameter(data, requires_grad=True)
def test_move_to_accelerator(self):
device = get_accelerator().current_device()
data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16)
quantization_config = QuantizationConfig()
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
qp = QuantizedParameter(data, quantization_config=quantization_config)
assert qp.device == torch.device('cpu')
qp = qp.to(get_accelerator().current_device_name())
assert qp.device == torch.device(device)
assert qp.dtype == quantization_config.q_dtype
def test_hf_clone(self):
device = get_accelerator().current_device_name()
data = torch.rand(5, 5, device=device, dtype=torch.bfloat16)
quantization_config = QuantizationConfig(q_bits=6)
qp = QuantizedParameter(data, quantization_config=quantization_config)
# should be able to clone parameter via dict, HF expects this to work
qp_copy = QuantizedParameter(qp.data, **qp.__dict__)
assert all(qp.data == qp_copy.data)
assert qp.quantization_config == qp_copy.quantization_config