From fb05bf1d047461502244370ff68f976dcca2462f Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 9 Feb 2025 09:44:15 -0800 Subject: [PATCH 1/6] implement aten.cat.default for nf4 --- torchao/dtypes/nf4tensor.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 5ae06a1fe1..2c517bd65d 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -422,6 +422,31 @@ def nf4_pin_memory(aten_op, args, kwargs=None): updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) +@implements( + [ + aten.cat.default, + ] +) +def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): + + tensors_to_cat = args[0] + assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat) + remaining_args = args[1:] + + ts = [] + for t in tensors_to_cat: + assert isinstance(t, torch.Tensor) + + if isinstance(t, NF4Tensor): + ts.append(t.get_original_weight()) + else: + ts.append(t) + + if kwargs is None: + kwargs = {} + + tensors = aten_op(ts, *remaining_args, **kwargs) + return tensors @dataclass(frozen=True) class SubclassTensorArgs: @@ -1058,3 +1083,4 @@ def nf4_constructor( if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([NF4Tensor]) + torch.serialization.add_safe_globals([NF4Tensor]) From 6e77adf897fef134c1810d5f993fafe688c784d1 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 9 Feb 2025 09:44:27 -0800 Subject: [PATCH 2/6] add nf4 ddp tests --- test/dtypes/ddp/check_ddp_nf4.py | 40 +++++++ test/dtypes/ddp/ddp_nf4.py | 155 ++++++++++++++++++++++++++++ test/dtypes/ddp/run_ddp_nf4_test.sh | 48 +++++++++ 3 files changed, 243 insertions(+) create mode 100644 test/dtypes/ddp/check_ddp_nf4.py create mode 100644 test/dtypes/ddp/ddp_nf4.py create mode 100755 test/dtypes/ddp/run_ddp_nf4_test.sh diff --git a/test/dtypes/ddp/check_ddp_nf4.py b/test/dtypes/ddp/check_ddp_nf4.py new file mode 100644 index 0000000000..5b12aef1bd --- /dev/null +++ b/test/dtypes/ddp/check_ddp_nf4.py @@ -0,0 +1,40 @@ +import argparse +from pathlib import Path + +import torch + +from torchao.dtypes.nf4tensor import NF4Tensor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ref_checkpoint_dir", type=str, required=True) + parser.add_argument("--test_checkpoints_dir", type=str, required=True) + + args = parser.parse_args() + + ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt")) + assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint" + ref_checkpoint = ref_checkpoints[0] + ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu") + print(f"Ref checkpoint: {ref_checkpoint}") + + for path in Path(args.test_checkpoints_dir).glob("*.pt"): + print(f"Checking {path}") + state_dict = torch.load(path, weights_only=True, map_location="cpu") + assert ref_state_dict.keys() == state_dict.keys() + for name in ref_state_dict.keys(): + ref_param = ref_state_dict[name] + test_param = state_dict[name] + print(f"Checking {name} {type(ref_param)} {type(test_param)}") + + if isinstance(ref_param, NF4Tensor): + ref_param = ref_param.get_original_weight() + assert isinstance(test_param, NF4Tensor) + test_param = test_param.get_original_weight() + + if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4): + diff = (ref_param - test_param).abs().max() + print(f" \u2718 Param {name} differs by {diff}") + else: + print(f" \u2713 Param {name} is consistent") + print(f"Passed!") \ No newline at end of file diff --git a/test/dtypes/ddp/ddp_nf4.py b/test/dtypes/ddp/ddp_nf4.py new file mode 100644 index 0000000000..1f83114580 --- /dev/null +++ b/test/dtypes/ddp/ddp_nf4.py @@ -0,0 +1,155 @@ +import argparse +import math +import os +import time +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._dynamo import config as dynamo_config +from torch.nn.parallel import DistributedDataParallel as DDP + +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class LoRALinear(nn.Module): + def __init__( + self, + hidden_dim: int, + lora_rank: int = None, + lora_alpha: float = 16, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.hidden_dim = hidden_dim + if lora_rank is None: + lora_rank = hidden_dim // 2 + + weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.register_parameter( + "weight", nn.Parameter(to_nf4(weight), requires_grad=False) + ) + self.lora_a = nn.Linear( + in_features=hidden_dim, out_features=self.lora_rank, bias=False + ) + self.lora_b = nn.Linear( + in_features=self.lora_rank, out_features=hidden_dim, bias=False + ) + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = linear_nf4(input=x, weight=self.weight) + lora_out = self.lora_a(x) + lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out) + return out + lora_out + + +def _init_model(dim, num_linears, device, dtype) -> nn.Module: + with torch.device(device): + + modules = [] + for i in range(num_linears): + modules += [LoRALinear(hidden_dim=dim, dtype=dtype)] + seq = nn.Sequential(*modules) + + return seq + + +def dist_print(*args, delay=0.5): + rank = dist.get_rank() + time.sleep(delay * rank) + print(f"[rank{rank}]: ", *args, flush=True) + + +def make_batch(global_bs, dim, dtype, device): + batch = torch.randn((global_bs, dim), dtype=dtype, device=device) + if dist.get_world_size() > 1: + batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()] + return batch + + +def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile): + os.makedirs(save_dir, exist_ok=True) + model = _init_model(dim, num_linears, device, dtype) + model = DDP(model, device_ids=[device]) + + if compile: + model = torch.compile(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + + losses = [] + + for i in range(num_steps): + inp = make_batch(global_bs, dim, dtype, device) + loss = model(inp).sum() + losses.append(loss) + loss.backward() + optim.step() + optim.zero_grad() + + dist.barrier() + + save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt" + torch.save(model.state_dict(), save_path) + dist_print("Saved model to", save_path) + +def init_dist(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + dist_print("Dist initialized with world size", dist.get_world_size()) + + +def cleanup_dist(): + dist.barrier() + if dist.get_rank() == 0: + print("Cleaning up dist") + dist.destroy_process_group() + + +@contextmanager +def distributed_context(): + init_dist() + yield + cleanup_dist() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--global_bs", type=int, default=8) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--num_linears", type=int, default=1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--num_steps", type=int, default=3) + parser.add_argument("--save_dir", type=str, default="checkpoints") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer") + args = parser.parse_args() + + args.dtype = getattr(torch, args.dtype) + dynamo_config.optimize_ddp = args.optimize_ddp + + if args.optimize_ddp == "python_reducer": + dynamo_config.compiled_autograd = True + + with distributed_context(): + torch.manual_seed(args.seed) + run_ddp( + global_bs=args.global_bs, + dim=args.dim, + num_linears=args.num_linears, + device=args.device, + dtype=args.dtype, + num_steps=args.num_steps, + save_dir=args.save_dir, + compile=args.compile, + ) diff --git a/test/dtypes/ddp/run_ddp_nf4_test.sh b/test/dtypes/ddp/run_ddp_nf4_test.sh new file mode 100755 index 0000000000..b9a3c2929f --- /dev/null +++ b/test/dtypes/ddp/run_ddp_nf4_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -euo pipefail +WORLD_SIZE=${1:-2} + + +# Test params +GLOBAL_BS=8 +DIM=128 +NUM_LINEARS=1 +NUM_STEPS=3 + +PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS" +SAVE_DIR="checkpoints" +REF_DIR="${SAVE_DIR}/ref" +TEST_DIR="${SAVE_DIR}/test" +DDP_PROGRAM="ddp_nf4.py" +CHECK_PROGRAM="check_ddp_nf4.py" +REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR" +TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR" +CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR" +CLEANUP_CMD="rm -rf $SAVE_DIR" + +echo "Step 1: Generating reference checkpoint..." +echo $REF_CMD +$REF_CMD +echo -e "\n --- \n" +sleep 2 + +echo "Step 2: Generating test checkpoints..." +echo $TEST_CMD +$TEST_CMD +echo -e "\n --- \n" +sleep 2 + +# Check params +echo "Step 3: Checking params..." +echo $CHECK_CMD +$CHECK_CMD +echo -e "\n --- \n" +sleep 2 + +# Cleanup +echo "Step 4: Cleaning up..." +echo $CLEANUP_CMD +$CLEANUP_CMD +echo -e "\n --- \n" +echo "Done!" From 49a1ea21cd257f0750ca31af7bec811119b85e49 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 9 Feb 2025 09:46:10 -0800 Subject: [PATCH 3/6] run ruff --- test/dtypes/ddp/check_ddp_nf4.py | 2 +- test/dtypes/ddp/ddp_nf4.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dtypes/ddp/check_ddp_nf4.py b/test/dtypes/ddp/check_ddp_nf4.py index 5b12aef1bd..608bcb9c02 100644 --- a/test/dtypes/ddp/check_ddp_nf4.py +++ b/test/dtypes/ddp/check_ddp_nf4.py @@ -37,4 +37,4 @@ print(f" \u2718 Param {name} differs by {diff}") else: print(f" \u2713 Param {name} is consistent") - print(f"Passed!") \ No newline at end of file + print("Passed!") diff --git a/test/dtypes/ddp/ddp_nf4.py b/test/dtypes/ddp/ddp_nf4.py index 1f83114580..2dc7b4081e 100644 --- a/test/dtypes/ddp/ddp_nf4.py +++ b/test/dtypes/ddp/ddp_nf4.py @@ -100,6 +100,7 @@ def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, com torch.save(model.state_dict(), save_path) dist_print("Saved model to", save_path) + def init_dist(): dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank()) From 8814571d233b9325e02e2797240bdbb7b3439b84 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 9 Feb 2025 09:55:12 -0800 Subject: [PATCH 4/6] add dtype check --- torchao/dtypes/nf4tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 2c517bd65d..879782a734 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -442,6 +442,9 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): else: ts.append(t) + dtype = ts[0].dtype + assert all(t.dtype == dtype for t in ts) + if kwargs is None: kwargs = {} From 22bc211f0f2c800c4096b67a17b4c932bcc3d117 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 9 Feb 2025 10:25:00 -0800 Subject: [PATCH 5/6] formatting --- test/dtypes/ddp/ddp_nf4.py | 1 - torchao/dtypes/nf4tensor.py | 65 +++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/test/dtypes/ddp/ddp_nf4.py b/test/dtypes/ddp/ddp_nf4.py index 2dc7b4081e..e38d0015b1 100644 --- a/test/dtypes/ddp/ddp_nf4.py +++ b/test/dtypes/ddp/ddp_nf4.py @@ -53,7 +53,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _init_model(dim, num_linears, device, dtype) -> nn.Module: with torch.device(device): - modules = [] for i in range(num_linears): modules += [LoRALinear(hidden_dim=dim, dtype=dtype)] diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 879782a734..dfe64501cf 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -193,9 +193,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.numel() % num_chunks == 0 - ), f"{attr}.numel() not divisible by {num_chunks}" + assert inner_tensor.numel() % num_chunks == 0, ( + f"{attr}.numel() not divisible by {num_chunks}" + ) chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -236,9 +236,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert ( - inner_tensor.size(0) % ratio == 0 - ), f"{attr}.numel() must be divisible by {ratio}" + assert inner_tensor.size(0) % ratio == 0, ( + f"{attr}.numel() must be divisible by {ratio}" + ) inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -422,17 +422,17 @@ def nf4_pin_memory(aten_op, args, kwargs=None): updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + @implements( [ aten.cat.default, ] ) def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): - tensors_to_cat = args[0] assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat) remaining_args = args[1:] - + ts = [] for t in tensors_to_cat: assert isinstance(t, torch.Tensor) @@ -441,16 +441,17 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): ts.append(t.get_original_weight()) else: ts.append(t) - + dtype = ts[0].dtype assert all(t.dtype == dtype for t in ts) - + if kwargs is None: kwargs = {} - + tensors = aten_op(ts, *remaining_args, **kwargs) return tensors + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size @@ -472,9 +473,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (input_tensor.numel() % block_size) == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert (input_tensor.numel() % block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -557,12 +558,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert ( - input_tensor.dim() <= 2 - ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - assert ( - input_tensor.numel() % block_size == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + assert input_tensor.dim() <= 2, ( + f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + ) + assert input_tensor.numel() % block_size == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" + ) assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -643,9 +644,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (input_tensor.numel() % scaler_block_size) == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -653,9 +654,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert ( - scalers_1.numel() % scaler_block_size == 0 - ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + assert scalers_1.numel() % scaler_block_size == 0, ( + f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " + ) n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -697,9 +698,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert ( - (input_tensor.numel() % scaler_block_size) == 0 - ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + assert (input_tensor.numel() % scaler_block_size) == 0, ( + f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" + ) n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -715,9 +716,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert ( - numel % 2 == 0 - ), "Number of elements must be even just to not have to think about the end" + assert numel % 2 == 0, ( + "Number of elements must be even just to not have to think about the end" + ) # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size) From 71caddbe9b4a90e0e15a46f4b98ac6391d9a2e1e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 17 Feb 2025 16:35:52 -0800 Subject: [PATCH 6/6] run ruff format on nf4tensor --- torchao/dtypes/nf4tensor.py | 54 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index dfe64501cf..457cf352fa 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -193,9 +193,9 @@ def nf4_split(aten_op, args, kwargs=None): attr_to_chunks = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.numel() % num_chunks == 0, ( - f"{attr}.numel() not divisible by {num_chunks}" - ) + assert ( + inner_tensor.numel() % num_chunks == 0 + ), f"{attr}.numel() not divisible by {num_chunks}" chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) attr_to_chunks[attr] = chunks @@ -236,9 +236,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None): updated_attrs = {} for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4tensor, attr) - assert inner_tensor.size(0) % ratio == 0, ( - f"{attr}.numel() must be divisible by {ratio}" - ) + assert ( + inner_tensor.size(0) % ratio == 0 + ), f"{attr}.numel() must be divisible by {ratio}" inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) updated_attrs[attr] = inner_tensor updated_attrs["size"] = new_size @@ -473,9 +473,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso torch.Tensor: Tensor of scalers for each block """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + (input_tensor.numel() % block_size) == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" n_blocks = input_tensor.numel() // block_size blocks = input_tensor.view(n_blocks, block_size) @@ -558,12 +558,12 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert input_tensor.dim() <= 2, ( - f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" - ) - assert input_tensor.numel() % block_size == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" - ) + assert ( + input_tensor.dim() <= 2 + ), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}" + assert ( + input_tensor.numel() % block_size == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}" assert input_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not input_tensor.requires_grad, "Input tensor must not require grad" @@ -644,9 +644,9 @@ def double_quantize_scalers( size: (n_scaler_blocks) """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + (input_tensor.numel() % scaler_block_size) == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" # First round of quantization # Produces: A tensor of size (n_blocks) of input_tensor.dtype @@ -654,9 +654,9 @@ def double_quantize_scalers( scalers_1_mean = scalers_1.mean() scalers_1 = scalers_1 - scalers_1_mean # Second round of quantization - assert scalers_1.numel() % scaler_block_size == 0, ( - f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " - ) + assert ( + scalers_1.numel() % scaler_block_size == 0 + ), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} " n_scaler_blocks = scalers_1.numel() // scaler_block_size scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size) @@ -698,9 +698,9 @@ def dequantize_scalers( """ assert input_tensor.dim() == 1, "Input tensor must be flattened" - assert (input_tensor.numel() % scaler_block_size) == 0, ( - f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" - ) + assert ( + (input_tensor.numel() % scaler_block_size) == 0 + ), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}" n_scaler_blocks = input_tensor.numel() // scaler_block_size input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to( @@ -716,9 +716,9 @@ def convert_to_norm_float_weight( flattened_tensor = input_tensor.flatten() # Since we are using uint8 we will encode 2 entries per byte numel = input_tensor.numel() - assert numel % 2 == 0, ( - "Number of elements must be even just to not have to think about the end" - ) + assert ( + numel % 2 == 0 + ), "Number of elements must be even just to not have to think about the end" # Reshape the flattened tensor into blocks of size self.block_size blocks = flattened_tensor.view(n_blocks, block_size)