@@ -95,7 +95,7 @@ def quantize(self, mode, **kwargs):
9595 raise NotImplementedError (f " Unsupported quantization mode: { mode} " )
9696
9797 quantized_kernel, scale = quantizers.abs_max_quantize(
98- self ._kernel, axis = 0 , to_numpy = True
98+ self ._kernel, axis = 0 , dtype = " int8 " , to_numpy = True
9999 )
100100 scale = ops.squeeze(scale, axis = 0 )
101101
@@ -137,7 +137,8 @@ INT8 variables. It should allocate:
137137
138138- ` self._kernel ` as an INT8 vector of shape ` (input_dim,) ` (the same shape as
139139 the original full-precision kernel).
140- - ` self.scale ` as the scalar quantization scale in FP32.
140+ - ` self.scale ` as the scalar quantization scale in the layer's compute dtype,
141+ which is FP32 in this case.
141142
142143
143144``` python
@@ -161,12 +162,16 @@ def _int8_build(self, kernel_shape):
161162
162163#### Note
163164
164- 1 . The INT8 variables should be ` trainable=False ` since PTQ does not involve
165- further training.
165+ 1 . INT8 variables should be created with ` trainable=False ` , as quantized parameters
166+ are not meant to be updated during training. Subsequent fine-tuning should not
167+ alter these quantized variables.
1661682 . If you support INT4 quantization, implement a similar ` _int4_build(...) `
167169 method that allocates packed INT4 storage (often packed into INT8) plus
168170 per-feature scales. The original unpacked dimensions and packing axis should
169- be recorded as instance variables for use in the call path.
171+ be recorded as instance variables for use in the call path. A reference
172+ implementation is available in the Keras
173+ [ Dense] ( https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512 )
174+ layer.
170175
171176### The ` _int8_call(...) ` method
172177
@@ -187,7 +192,6 @@ The INT8 path mirrors the float computation `y = x * w` but performs:
187192
188193def _int8_call (self , inputs , training = None ):
189194 x = ops.multiply(inputs, self ._kernel)
190- x = ops.cast(x, self .compute_dtype)
191195 x = ops.divide(x, self .scale)
192196 return x
193197
@@ -224,7 +228,7 @@ class SimpleScale(Layer):
224228 raise NotImplementedError (f " Unsupported quantization mode: { mode} " )
225229
226230 quantized_kernel, scale = quantizers.abs_max_quantize(
227- self ._kernel, axis = 0 , to_numpy = True
231+ self ._kernel, axis = 0 , dtype = " int8 " , to_numpy = True
228232 )
229233 scale = ops.squeeze(scale, axis = 0 )
230234
@@ -260,7 +264,6 @@ class SimpleScale(Layer):
260264
261265 def _int8_call (self , inputs , training = None ):
262266 x = ops.multiply(inputs, self ._kernel)
263- x = ops.cast(x, self .compute_dtype)
264267 x = ops.divide(x, self .scale)
265268 return x
266269
@@ -293,8 +296,8 @@ print("SimpleScale INT8 sample:", y_int8[0].numpy())
293296
294297<div class =" k-default-codeblock " >
295298```
296- SimpleScale FP32 sample: [-0.00359688 0.00296069 -0.00846314 0.00070467 ]
297- SimpleScale INT8 sample: [-0.00359092 0.00290875 -0.00846319 0.00070462 ]
299+ SimpleScale FP32 sample: [-0.01259411 0.00385596 0.0053392 -0.00877095 ]
300+ SimpleScale INT8 sample: [-0.01256325 0.0038252 0.00535317 -0.00877098 ]
298301```
299302</div >
300303
@@ -308,7 +311,7 @@ If you want to support INT4 quantization, add:
308311- ` quantize("int4") ` : quantize weights with ` value_range=(-8, 7) ` , then pack to int4 storage
309312
310313See the
311- [ Dense] ( https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py )
314+ [ Dense] ( https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653 )
312315reference for a complete packed int4 example, including how to track and use the
313316original (unpacked) dimension in the call path.
314317
@@ -443,7 +446,7 @@ class SimpleScale(Layer):
443446 raise NotImplementedError (f " Unsupported quantization mode: { mode} " )
444447
445448 quantized_kernel, scale = quantizers.abs_max_quantize(
446- self ._kernel, axis = 0 , to_numpy = True
449+ self ._kernel, axis = 0 , dtype = " int8 " , to_numpy = True
447450 )
448451 scale = ops.squeeze(scale, axis = 0 )
449452
@@ -479,7 +482,6 @@ class SimpleScale(Layer):
479482
480483 def _int8_call (self , inputs , training = None ):
481484 x = ops.multiply(inputs, self ._kernel)
482- x = ops.cast(x, self .compute_dtype)
483485 x = ops.divide(x, self .scale)
484486 return x
485487
@@ -522,7 +524,7 @@ class SimpleScale(Layer):
522524#### Note
523525
524526The ` @keras.saving.register_keras_serializable() ` decorator is needed to
525- register the class for serialization.
527+ register the class for serialization.
526528
527529---
528530## Try it: quantize, save, and load
@@ -547,8 +549,8 @@ print("Loaded INT8 sample:", y_loaded[0].numpy())
547549
548550<div class =" k-default-codeblock " >
549551```
550- SimpleScale INT8 sample: [0.00825868 0.01510935 0.02154977 0.00205997 ]
551- Loaded INT8 sample: [0.00825868 0.01510935 0.02154977 0.00205997 ]
552+ SimpleScale INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613 ]
553+ Loaded INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613 ]
552554```
553555</div >
554556
@@ -589,11 +591,10 @@ Here are concrete patterns you can reuse when making your own layers PTQ-friendl
589591 avoid an infinite loop.
590592
591593- Serialization contract
592- - Provide a fixed-order logic for variable serialization so save/load is
593- deterministic.
594- - Write variables in a fixed order per mode (e.g., None: [ kernel, bias] ,
595- ` "int8" ` : [ kernel, bias, kernel_scale] , ` "int4" ` :
596- [ kernel, bias, kernel_scale] ).
594+ - Provide a fixed-order logic for variable serialization so save/load is
595+ deterministic.
596+ - Write variables in a fixed order per mode (e.g., None: [ kernel, bias] ,
597+ ` "int8" ` : [ kernel, bias, kernel_scale] , ` "int4" ` : [ kernel, bias, kernel_scale] ).
597598
598599- Validation and error handling
599600 - Validate ` mode ` early and raise a ` NotImplementedError ` for unsupported
0 commit comments