You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Initialize the codebook using k-means clustering on blocks of the input tensor.
130
+
131
+
Args:
132
+
input_tensor (torch.Tensor): The input tensor to be quantized.
133
+
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
134
+
block_size (List[int]): block sizes for how many elements in each dimension share
135
+
the same lookup table (len(block_size) == input_tensor.dim())
136
+
Each dimension of input_tensor must be divisible by the corresponding element of block_size
137
+
Look up tables are indexed by {(di // bi) for i in input_tensor.dim()}
138
+
For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means
139
+
there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables
140
+
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
141
+
cluster_dim (int): this means the size of the vector for vector lookup table quantization
142
+
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
143
+
the tensor in a unit of 4 element vectors, a vector of original tensor will be mapped to
144
+
a vector in the codebook (lookup table) based on the indices.
145
+
vector_axis (Optional[int]): used in vector quantization, see more docs in https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/optimize/_utils.py#L371
146
+
147
+
Returns:
148
+
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
149
+
The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where:
150
+
* The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension)
151
+
* The N + 1 dimension indexes over the nbit indices (2 ** nbits)
152
+
* The N + 2 dimension indexes over the look up values (shape = 1 for scalar)
0 commit comments