Skip to content

Commit 629ede3

Browse files
address reviews
1 parent 4538375 commit 629ede3

File tree

3 files changed

+58
-55
lines changed

3 files changed

+58
-55
lines changed

guides/ipynb/writing_quantization_compatible_layers.ipynb

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
" raise NotImplementedError(f\"Unsupported quantization mode: {mode}\")\n",
137137
"\n",
138138
" quantized_kernel, scale = quantizers.abs_max_quantize(\n",
139-
" self._kernel, axis=0, to_numpy=True\n",
139+
" self._kernel, axis=0, dtype=\"int8\", to_numpy=True\n",
140140
" )\n",
141141
" scale = ops.squeeze(scale, axis=0)\n",
142142
"\n",
@@ -191,7 +191,8 @@
191191
"\n",
192192
"- `self._kernel` as an INT8 vector of shape `(input_dim,)` (the same shape as\n",
193193
" the original full-precision kernel).\n",
194-
"- `self.scale` as the scalar quantization scale in FP32."
194+
"- `self.scale` as the scalar quantization scale in the layer's compute dtype,\n",
195+
" which is FP32 in this case."
195196
]
196197
},
197198
{
@@ -228,12 +229,16 @@
228229
"source": [
229230
"#### Note\n",
230231
"\n",
231-
"1. The INT8 variables should be `trainable=False` since PTQ does not involve\n",
232-
" further training.\n",
232+
"1. INT8 variables should be created with `trainable=False`, as quantized parameters\n",
233+
" are not meant to be updated during training. Subsequent fine-tuning should not\n",
234+
" alter these quantized variables.\n",
233235
"2. If you support INT4 quantization, implement a similar `_int4_build(...)`\n",
234236
" method that allocates packed INT4 storage (often packed into INT8) plus\n",
235237
" per-feature scales. The original unpacked dimensions and packing axis should\n",
236-
" be recorded as instance variables for use in the call path."
238+
" be recorded as instance variables for use in the call path. A reference\n",
239+
" implementation is available in the Keras\n",
240+
" [Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512)\n",
241+
" layer."
237242
]
238243
},
239244
{
@@ -268,7 +273,6 @@
268273
"\n",
269274
"def _int8_call(self, inputs, training=None):\n",
270275
" x = ops.multiply(inputs, self._kernel)\n",
271-
" x = ops.cast(x, self.compute_dtype)\n",
272276
" x = ops.divide(x, self.scale)\n",
273277
" return x\n",
274278
""
@@ -317,7 +321,7 @@
317321
" raise NotImplementedError(f\"Unsupported quantization mode: {mode}\")\n",
318322
"\n",
319323
" quantized_kernel, scale = quantizers.abs_max_quantize(\n",
320-
" self._kernel, axis=0, to_numpy=True\n",
324+
" self._kernel, axis=0, dtype=\"int8\", to_numpy=True\n",
321325
" )\n",
322326
" scale = ops.squeeze(scale, axis=0)\n",
323327
"\n",
@@ -353,7 +357,6 @@
353357
"\n",
354358
" def _int8_call(self, inputs, training=None):\n",
355359
" x = ops.multiply(inputs, self._kernel)\n",
356-
" x = ops.cast(x, self.compute_dtype)\n",
357360
" x = ops.divide(x, self.scale)\n",
358361
" return x\n",
359362
""
@@ -411,7 +414,7 @@
411414
"- `quantize(\"int4\")`: quantize weights with `value_range=(-8, 7)`, then pack to int4 storage\n",
412415
"\n",
413416
"See the\n",
414-
"[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py)\n",
417+
"[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653)\n",
415418
"reference for a complete packed int4 example, including how to track and use the\n",
416419
"original (unpacked) dimension in the call path."
417420
]
@@ -584,7 +587,7 @@
584587
" raise NotImplementedError(f\"Unsupported quantization mode: {mode}\")\n",
585588
"\n",
586589
" quantized_kernel, scale = quantizers.abs_max_quantize(\n",
587-
" self._kernel, axis=0, to_numpy=True\n",
590+
" self._kernel, axis=0, dtype=\"int8\", to_numpy=True\n",
588591
" )\n",
589592
" scale = ops.squeeze(scale, axis=0)\n",
590593
"\n",
@@ -620,7 +623,6 @@
620623
"\n",
621624
" def _int8_call(self, inputs, training=None):\n",
622625
" x = ops.multiply(inputs, self._kernel)\n",
623-
" x = ops.cast(x, self.compute_dtype)\n",
624626
" x = ops.divide(x, self.scale)\n",
625627
" return x\n",
626628
"\n",
@@ -669,7 +671,7 @@
669671
"#### Note\n",
670672
"\n",
671673
"The `@keras.saving.register_keras_serializable()` decorator is needed to\n",
672-
" register the class for serialization."
674+
"register the class for serialization."
673675
]
674676
},
675677
{
@@ -747,11 +749,10 @@
747749
" avoid an infinite loop.\n",
748750
"\n",
749751
"- Serialization contract\n",
750-
" - Provide a fixed-order logic for variable serialization so save/load is\n",
751-
" deterministic.\n",
752-
" - Write variables in a fixed order per mode (e.g., None: [kernel, bias],\n",
753-
" `\"int8\"`: [kernel, bias, kernel_scale], `\"int4\"`:\n",
754-
" [kernel, bias, kernel_scale]).\n",
752+
" - Provide a fixed-order logic for variable serialization so save/load is\n",
753+
" deterministic.\n",
754+
" - Write variables in a fixed order per mode (e.g., None: [kernel, bias],\n",
755+
" `\"int8\"`: [kernel, bias, kernel_scale], `\"int4\"`: [kernel, bias, kernel_scale]).\n",
755756
"\n",
756757
"- Validation and error handling\n",
757758
" - Validate `mode` early and raise a `NotImplementedError` for unsupported\n",

guides/md/writing_quantization_compatible_layers.md

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def quantize(self, mode, **kwargs):
9595
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
9696

9797
quantized_kernel, scale = quantizers.abs_max_quantize(
98-
self._kernel, axis=0, to_numpy=True
98+
self._kernel, axis=0, dtype="int8", to_numpy=True
9999
)
100100
scale = ops.squeeze(scale, axis=0)
101101

@@ -137,7 +137,8 @@ INT8 variables. It should allocate:
137137

138138
- `self._kernel` as an INT8 vector of shape `(input_dim,)` (the same shape as
139139
the original full-precision kernel).
140-
- `self.scale` as the scalar quantization scale in FP32.
140+
- `self.scale` as the scalar quantization scale in the layer's compute dtype,
141+
which is FP32 in this case.
141142

142143

143144
```python
@@ -161,12 +162,16 @@ def _int8_build(self, kernel_shape):
161162

162163
#### Note
163164

164-
1. The INT8 variables should be `trainable=False` since PTQ does not involve
165-
further training.
165+
1. INT8 variables should be created with `trainable=False`, as quantized parameters
166+
are not meant to be updated during training. Subsequent fine-tuning should not
167+
alter these quantized variables.
166168
2. If you support INT4 quantization, implement a similar `_int4_build(...)`
167169
method that allocates packed INT4 storage (often packed into INT8) plus
168170
per-feature scales. The original unpacked dimensions and packing axis should
169-
be recorded as instance variables for use in the call path.
171+
be recorded as instance variables for use in the call path. A reference
172+
implementation is available in the Keras
173+
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512)
174+
layer.
170175

171176
### The `_int8_call(...)` method
172177

@@ -187,7 +192,6 @@ The INT8 path mirrors the float computation `y = x * w` but performs:
187192

188193
def _int8_call(self, inputs, training=None):
189194
x = ops.multiply(inputs, self._kernel)
190-
x = ops.cast(x, self.compute_dtype)
191195
x = ops.divide(x, self.scale)
192196
return x
193197

@@ -224,7 +228,7 @@ class SimpleScale(Layer):
224228
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
225229

226230
quantized_kernel, scale = quantizers.abs_max_quantize(
227-
self._kernel, axis=0, to_numpy=True
231+
self._kernel, axis=0, dtype="int8", to_numpy=True
228232
)
229233
scale = ops.squeeze(scale, axis=0)
230234

@@ -260,7 +264,6 @@ class SimpleScale(Layer):
260264

261265
def _int8_call(self, inputs, training=None):
262266
x = ops.multiply(inputs, self._kernel)
263-
x = ops.cast(x, self.compute_dtype)
264267
x = ops.divide(x, self.scale)
265268
return x
266269

@@ -293,8 +296,8 @@ print("SimpleScale INT8 sample:", y_int8[0].numpy())
293296

294297
<div class="k-default-codeblock">
295298
```
296-
SimpleScale FP32 sample: [-0.00359688 0.00296069 -0.00846314 0.00070467]
297-
SimpleScale INT8 sample: [-0.00359092 0.00290875 -0.00846319 0.00070462]
299+
SimpleScale FP32 sample: [-0.01259411 0.00385596 0.0053392 -0.00877095]
300+
SimpleScale INT8 sample: [-0.01256325 0.0038252 0.00535317 -0.00877098]
298301
```
299302
</div>
300303

@@ -308,7 +311,7 @@ If you want to support INT4 quantization, add:
308311
- `quantize("int4")`: quantize weights with `value_range=(-8, 7)`, then pack to int4 storage
309312

310313
See the
311-
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py)
314+
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653)
312315
reference for a complete packed int4 example, including how to track and use the
313316
original (unpacked) dimension in the call path.
314317

@@ -443,7 +446,7 @@ class SimpleScale(Layer):
443446
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
444447

445448
quantized_kernel, scale = quantizers.abs_max_quantize(
446-
self._kernel, axis=0, to_numpy=True
449+
self._kernel, axis=0, dtype="int8", to_numpy=True
447450
)
448451
scale = ops.squeeze(scale, axis=0)
449452

@@ -479,7 +482,6 @@ class SimpleScale(Layer):
479482

480483
def _int8_call(self, inputs, training=None):
481484
x = ops.multiply(inputs, self._kernel)
482-
x = ops.cast(x, self.compute_dtype)
483485
x = ops.divide(x, self.scale)
484486
return x
485487

@@ -522,7 +524,7 @@ class SimpleScale(Layer):
522524
#### Note
523525

524526
The `@keras.saving.register_keras_serializable()` decorator is needed to
525-
register the class for serialization.
527+
register the class for serialization.
526528

527529
---
528530
## Try it: quantize, save, and load
@@ -547,8 +549,8 @@ print("Loaded INT8 sample:", y_loaded[0].numpy())
547549

548550
<div class="k-default-codeblock">
549551
```
550-
SimpleScale INT8 sample: [0.00825868 0.01510935 0.02154977 0.00205997]
551-
Loaded INT8 sample: [0.00825868 0.01510935 0.02154977 0.00205997]
552+
SimpleScale INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613]
553+
Loaded INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613]
552554
```
553555
</div>
554556

@@ -589,11 +591,10 @@ Here are concrete patterns you can reuse when making your own layers PTQ-friendl
589591
avoid an infinite loop.
590592

591593
- Serialization contract
592-
- Provide a fixed-order logic for variable serialization so save/load is
593-
deterministic.
594-
- Write variables in a fixed order per mode (e.g., None: [kernel, bias],
595-
`"int8"`: [kernel, bias, kernel_scale], `"int4"`:
596-
[kernel, bias, kernel_scale]).
594+
- Provide a fixed-order logic for variable serialization so save/load is
595+
deterministic.
596+
- Write variables in a fixed order per mode (e.g., None: [kernel, bias],
597+
`"int8"`: [kernel, bias, kernel_scale], `"int4"`: [kernel, bias, kernel_scale]).
597598

598599
- Validation and error handling
599600
- Validate `mode` early and raise a `NotImplementedError` for unsupported

guides/writing_quantization_compatible_layers.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def quantize(self, mode, **kwargs):
9292
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
9393

9494
quantized_kernel, scale = quantizers.abs_max_quantize(
95-
self._kernel, axis=0, to_numpy=True
95+
self._kernel, axis=0, dtype="int8", to_numpy=True
9696
)
9797
scale = ops.squeeze(scale, axis=0)
9898

@@ -135,7 +135,8 @@ def quantize(self, mode, **kwargs):
135135
136136
- `self._kernel` as an INT8 vector of shape `(input_dim,)` (the same shape as
137137
the original full-precision kernel).
138-
- `self.scale` as the scalar quantization scale in FP32.
138+
- `self.scale` as the scalar quantization scale in the layer's compute dtype,
139+
which is FP32 in this case.
139140
"""
140141

141142

@@ -158,12 +159,16 @@ def _int8_build(self, kernel_shape):
158159
"""
159160
#### Note
160161
161-
1. The INT8 variables should be `trainable=False` since PTQ does not involve
162-
further training.
162+
1. INT8 variables should be created with `trainable=False`, as quantized parameters
163+
are not meant to be updated during training. Subsequent fine-tuning should not
164+
alter these quantized variables.
163165
2. If you support INT4 quantization, implement a similar `_int4_build(...)`
164166
method that allocates packed INT4 storage (often packed into INT8) plus
165167
per-feature scales. The original unpacked dimensions and packing axis should
166-
be recorded as instance variables for use in the call path.
168+
be recorded as instance variables for use in the call path. A reference
169+
implementation is available in the Keras
170+
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512)
171+
layer.
167172
"""
168173

169174
"""
@@ -185,7 +190,6 @@ def _int8_build(self, kernel_shape):
185190

186191
def _int8_call(self, inputs, training=None):
187192
x = ops.multiply(inputs, self._kernel)
188-
x = ops.cast(x, self.compute_dtype)
189193
x = ops.divide(x, self.scale)
190194
return x
191195

@@ -220,7 +224,7 @@ def quantize(self, mode, **kwargs):
220224
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
221225

222226
quantized_kernel, scale = quantizers.abs_max_quantize(
223-
self._kernel, axis=0, to_numpy=True
227+
self._kernel, axis=0, dtype="int8", to_numpy=True
224228
)
225229
scale = ops.squeeze(scale, axis=0)
226230

@@ -256,7 +260,6 @@ def _int8_build(self, kernel_shape):
256260

257261
def _int8_call(self, inputs, training=None):
258262
x = ops.multiply(inputs, self._kernel)
259-
x = ops.cast(x, self.compute_dtype)
260263
x = ops.divide(x, self.scale)
261264
return x
262265

@@ -294,7 +297,7 @@ def _int8_call(self, inputs, training=None):
294297
- `quantize("int4")`: quantize weights with `value_range=(-8, 7)`, then pack to int4 storage
295298
296299
See the
297-
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py)
300+
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653)
298301
reference for a complete packed int4 example, including how to track and use the
299302
original (unpacked) dimension in the call path.
300303
"""
@@ -426,7 +429,7 @@ def quantize(self, mode, **kwargs):
426429
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
427430

428431
quantized_kernel, scale = quantizers.abs_max_quantize(
429-
self._kernel, axis=0, to_numpy=True
432+
self._kernel, axis=0, dtype="int8", to_numpy=True
430433
)
431434
scale = ops.squeeze(scale, axis=0)
432435

@@ -462,7 +465,6 @@ def _int8_build(self, kernel_shape):
462465

463466
def _int8_call(self, inputs, training=None):
464467
x = ops.multiply(inputs, self._kernel)
465-
x = ops.cast(x, self.compute_dtype)
466468
x = ops.divide(x, self.scale)
467469
return x
468470

@@ -505,7 +507,7 @@ def load_own_variables(self, store):
505507
#### Note
506508
507509
The `@keras.saving.register_keras_serializable()` decorator is needed to
508-
register the class for serialization.
510+
register the class for serialization.
509511
"""
510512
"""
511513
## Try it: quantize, save, and load
@@ -562,11 +564,10 @@ def load_own_variables(self, store):
562564
avoid an infinite loop.
563565
564566
- Serialization contract
565-
- Provide a fixed-order logic for variable serialization so save/load is
566-
deterministic.
567-
- Write variables in a fixed order per mode (e.g., None: [kernel, bias],
568-
`"int8"`: [kernel, bias, kernel_scale], `"int4"`:
569-
[kernel, bias, kernel_scale]).
567+
- Provide a fixed-order logic for variable serialization so save/load is
568+
deterministic.
569+
- Write variables in a fixed order per mode (e.g., None: [kernel, bias],
570+
`"int8"`: [kernel, bias, kernel_scale], `"int4"`: [kernel, bias, kernel_scale]).
570571
571572
- Validation and error handling
572573
- Validate `mode` early and raise a `NotImplementedError` for unsupported

0 commit comments

Comments
 (0)