diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 40212c35c27..3d1e1eab0f2 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -124,11 +124,22 @@ class VulkanQuantizer(Quantizer): def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None + # If specified, only quantize nodes that return true for the filter + # function. + self.filter_fn: Optional[Callable[[Node], bool]] = None def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer: self.global_config = quantization_config return self + def set_filter_function(self, filter_fn: Callable[[Node], bool]): + """ + Set the filter function. We only quantize nodes that return True for + the filter function. + """ + self.filter_fn = filter_fn + return self + def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: @@ -149,8 +160,14 @@ def _annotate_all_patterns( if quantization_config is None: return model + # Create a combined filter function, which returns True only when + # both filter_fn and self.filter_fn return True. + def combined_filter_fn(n: Node) -> bool: + combined_filter = [self.filter_fn, filter_fn] + return all(f(n) for f in combined_filter if f is not None) + for op in _SUPPORTED_OPS: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + OP_TO_ANNOTATOR[op](model, quantization_config, combined_filter_fn) return model def _annotate_for_quantization_config(