diff --git a/docs/source/api-doc/apis.rst b/docs/source/api-doc/apis.rst index cfc9be59537..98e89b31cc8 100644 --- a/docs/source/api-doc/apis.rst +++ b/docs/source/api-doc/apis.rst @@ -18,3 +18,11 @@ API tf_quantization_common.rst tf_quantization_config.rst tf_quantization_autotune.rst + +** JAX Extension API:** + +.. toctree:: + :maxdepth: 1 + + jax_quantization_common.rst + jax_quantization_config.rst diff --git a/docs/source/api-doc/jax_quantization_common.rst b/docs/source/api-doc/jax_quantization_common.rst new file mode 100644 index 00000000000..3c3cd7abb60 --- /dev/null +++ b/docs/source/api-doc/jax_quantization_common.rst @@ -0,0 +1,6 @@ +JAX Quantization Base API +################################# + +.. autoapisummary:: + + neural_compressor.jax.quantization.quantize diff --git a/docs/source/api-doc/jax_quantization_config.rst b/docs/source/api-doc/jax_quantization_config.rst new file mode 100644 index 00000000000..54e55b22e66 --- /dev/null +++ b/docs/source/api-doc/jax_quantization_config.rst @@ -0,0 +1,6 @@ +JAX Quantization Config +============ + +.. autoapisummary:: + + neural_compressor.jax.quantization.config diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 4531f0115ff..0f8aee804a4 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -97,6 +97,14 @@ class ExampleAlgorithmConfig: """ def decorator(config_cls): + """Register the configuration class for the given framework and algorithm. + + Args: + config_cls: Configuration class to register. + + Returns: + The same configuration class for decorator chaining. + """ cls.registered_configs.setdefault(framework_name, {}) cls.registered_configs[framework_name][algo_name] = {"priority": priority, "cls": config_cls} return config_cls @@ -207,6 +215,7 @@ def __init__(self, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_ self._white_list = white_list def _post_init(self): + """Populate global and local configs based on the whitelist settings.""" if self.white_list == DEFAULT_WHITE_LIST: global_config = self.get_params_dict() self._global_config = self.__class__(**global_config, white_list=None) @@ -558,6 +567,11 @@ def expand(self) -> List[BaseConfig]: return config_list def _get_op_name_op_type_config(self): + """Split local configs into op-type and op-name mappings. + + Returns: + tuple[dict, dict]: Mapping of op types to configs and op names to configs. + """ op_type_config_dict = dict() op_name_config_dict = dict() for name, config in self.local_config.items(): @@ -604,6 +618,14 @@ def to_config_mapping( @staticmethod def _op_type_to_str(op_type: Callable) -> str: + """Convert an operator type to a string key. + + Args: + op_type (Callable): Operator type or callable object. + + Returns: + str: String identifier for the operator type. + """ # * Ort and TF may override this method. op_type_name = getattr(op_type, "__name__", "") if op_type_name == "": @@ -612,6 +634,14 @@ def _op_type_to_str(op_type: Callable) -> str: @staticmethod def _is_op_type(name: str) -> bool: + """Check whether the identifier represents an operator type. + + Args: + name (str): Operator identifier. + + Returns: + bool: True if the identifier is an operator type, otherwise False. + """ # * Ort and TF may override this method. return not isinstance(name, str) diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 50cdbc2af68..742f4e63c36 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -70,7 +70,7 @@ class Evaluator: def eval_acc(model): ... - def eval_perf(molde): + def eval_perf(model): ... # Usage @@ -109,6 +109,16 @@ def evaluate(self, model) -> float: return result def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> float: + """Update the aggregated objective score with a weighted evaluation result. + + Args: + eval_pair (dict): Evaluation function metadata including weight. + eval_result (float): Result from the evaluation function. + overall_result (float): Current aggregated score. + + Returns: + float: Updated aggregated score. + """ return overall_result + eval_result * eval_pair[self.WEIGHT] def get_number_of_eval_functions(self) -> int: @@ -120,6 +130,11 @@ def get_number_of_eval_functions(self) -> int: return len(self.eval_fn_registry) def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: + """Normalize and store evaluation function metadata. + + Args: + user_eval_fns (List[Dict]): User-provided evaluation function configs. + """ self.eval_fn_registry = [ { self.EVAL_FN: user_eval_fn_pair[self.EVAL_FN], @@ -202,12 +217,28 @@ def __len__(self) -> int: @classmethod def _from_single_config(cls, config: BaseConfig) -> List[BaseConfig]: + """Expand a single config into a list of configs. + + Args: + config (BaseConfig): Configuration to expand. + + Returns: + List[BaseConfig]: Expanded configuration list. + """ config_list = [] config_list = config.expand() return config_list @classmethod def _from_list_of_configs(cls, fwk_configs: List[BaseConfig]) -> List[BaseConfig]: + """Expand a list of configs into a single list. + + Args: + fwk_configs (List[BaseConfig]): Configurations to expand. + + Returns: + List[BaseConfig]: Expanded configuration list. + """ config_list = [] for config in fwk_configs: config_list += cls._from_single_config(config) @@ -378,12 +409,26 @@ def __init__( class _TrialRecord: + """Record information for a single tuning trial.""" + @staticmethod def _generate_unique_id(): + """Generate a unique identifier for a trial record. + + Returns: + str: Unique identifier string. + """ unique_id = str(uuid.uuid4()) return unique_id def __init__(self, trial_index: int, trial_result: Union[int, float], quant_config: BaseConfig): + """Initialize a trial record. + + Args: + trial_index (int): Trial index in the tuning loop. + trial_result (Union[int, float]): Evaluation result for the trial. + quant_config (BaseConfig): Quantization configuration used for the trial. + """ # The unique id to refer to one trial self.trial_id = _TrialRecord._generate_unique_id() self.trial_index = trial_index @@ -525,6 +570,10 @@ def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger Args: tuning_config (TuningConfig): The configuration for the tuning process. + + Returns: + Tuple[ConfigLoader, TuningLogger, TuningMonitor]: A tuple containing the config loader, + tuning logger, and tuning monitor. """ config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) tuning_logger = TuningLogger() diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py index 924a6eed8a8..e5f7772e689 100644 --- a/neural_compressor/common/tuning_param.py +++ b/neural_compressor/common/tuning_param.py @@ -100,6 +100,8 @@ def create_input_args_model(expect_args_type: Any): """ class DynamicInputArgsModel(BaseModel): + """Pydantic model for validating dynamic input arguments.""" + input_args: expect_args_type return DynamicInputArgsModel diff --git a/neural_compressor/common/utils/logger.py b/neural_compressor/common/utils/logger.py index aa487fb38b1..938bdd11479 100644 --- a/neural_compressor/common/utils/logger.py +++ b/neural_compressor/common/utils/logger.py @@ -152,6 +152,14 @@ def warning_once(msg, *args, **kwargs): def _get_log_msg(mode): + """Map a Mode enum value to a human-readable log message. + + Args: + mode (Mode): Execution mode enum. + + Returns: + str | None: Log message string or None when mode is unsupported. + """ log_msg = None if mode == Mode.QUANTIZE: log_msg = "Quantization" diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index e585c435dc3..b0490da66b6 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -136,6 +136,11 @@ def brand_raw(self, brand_name): @staticmethod def _detect_cores(): + """Detect physical CPU core count using psutil. + + Returns: + int: Number of physical CPU cores. + """ physical_cores = psutil.cpu_count(logical=False) return physical_cores @@ -181,6 +186,11 @@ def sockets(self, num_of_sockets): self._sockets = num_of_sockets def _get_number_of_sockets(self) -> int: + """Detect the number of CPU sockets available. + + Returns: + int: Number of CPU sockets detected. + """ if "arch" in self._info and "ARM" in self._info["arch"]: # pragma: no cover return 1 @@ -224,7 +234,17 @@ def dump_elapsed_time(customized_msg=""): """ def f(func): + """Decorator factory that times the wrapped function. + + Args: + func (Callable): Function to wrap. + + Returns: + Callable: Wrapped function that logs elapsed time. + """ + def fi(*args, **kwargs): + """Execute the function and log elapsed time.""" start = time.time() res = func(*args, **kwargs) end = time.time() @@ -288,7 +308,10 @@ def log_process(mode=Mode.QUANTIZE): """ def log_process_wrapper(func): + """Wrap a function to log execution start and end.""" + def inner_wrapper(*args, **kwargs): + """Execute the wrapped function with start/end logging.""" start_log = default_tuning_logger.execution_start end_log = default_tuning_logger.execution_end @@ -321,6 +344,7 @@ def call_counter(func): """ def wrapper(*args, **kwargs): + """Increment call count and invoke the wrapped function.""" FUNC_CALL_COUNTS[func.__name__] += 1 return func(*args, **kwargs) diff --git a/neural_compressor/jax/algorithms/dynamic.py b/neural_compressor/jax/algorithms/dynamic.py index bc6ac6d9c9d..e5f21baf867 100644 --- a/neural_compressor/jax/algorithms/dynamic.py +++ b/neural_compressor/jax/algorithms/dynamic.py @@ -1,3 +1,5 @@ +"""Dynamic quantization algorithm entry point for JAX models.""" + # Copyright (c) 2025-2026 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -39,11 +41,15 @@ def dynamic_quantize( """Quantize model using Dynamic quantization algorithm. Args: - model: a JAX model to be quantized. - configs_mapping: mapping of configurations for the algorithm. + model (keras.Model): JAX model to be quantized. + configs_mapping (Optional[OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]]): Mapping of configurations + for the algorithm. + quant_config (Optional[BaseConfig]): Quantization configuration for wrapper selection. + *args (Any): Additional positional arguments (unused). + **kwargs (Any): Additional keyword arguments (unused). Returns: - q_model: the quantized model. + keras.Model: The quantized model wrapped for inference. """ for _, value in configs_mapping.items(): config = value diff --git a/neural_compressor/jax/algorithms/static.py b/neural_compressor/jax/algorithms/static.py index 6a740452fab..4eddba25645 100644 --- a/neural_compressor/jax/algorithms/static.py +++ b/neural_compressor/jax/algorithms/static.py @@ -1,3 +1,5 @@ +"""Static quantization algorithm entry point for JAX models.""" + # Copyright (c) 2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,11 +45,14 @@ def static_quantize( """Quantize model using Static quantization algorithm. Args: - model: a JAX model to be quantized. - configs_mapping: mapping of configurations for the algorithm. + model (keras.Model): JAX model to be quantized. + configs_mapping (Optional[OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]]): Mapping of configurations + for the algorithm. + quant_config (Optional[BaseConfig]): Quantization configuration for wrapper selection. + calib_function (Optional[Callable]): Calibration function used to collect activation statistics. Returns: - q_model: the quantized model + keras.Model: The quantized model wrapped for inference. """ for _, value in configs_mapping.items(): config = value diff --git a/neural_compressor/jax/quantization/__init__.py b/neural_compressor/jax/quantization/__init__.py index 44e9227a421..65a2e6561c4 100644 --- a/neural_compressor/jax/quantization/__init__.py +++ b/neural_compressor/jax/quantization/__init__.py @@ -1,3 +1,5 @@ +"""Public JAX quantization API exports.""" + # Copyright (c) 2026 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/neural_compressor/jax/quantization/config.py b/neural_compressor/jax/quantization/config.py index c49192a711b..78cf21c9214 100644 --- a/neural_compressor/jax/quantization/config.py +++ b/neural_compressor/jax/quantization/config.py @@ -39,7 +39,7 @@ class OperatorConfig(NamedTuple): - """The config for operator.""" + """Configuration pairing a quantization config with supported operators.""" config: BaseConfig operators: List[str] @@ -81,6 +81,9 @@ def __init__( weight_dtype (str): Data type for weights, default is "fp8_e4m3". activation_dtype (str): Data type for activations, default is "fp8_e4m3". white_list (list): A list of supported operators of this algorithm. + + Returns: + None: Initializes the configuration instance. """ super().__init__(white_list=white_list) if not isinstance(weight_dtype, list): @@ -95,8 +98,12 @@ def __init__( self._post_init() @classmethod - def register_supported_configs(cls) -> List[OperatorConfig]: - """Register supported configs.""" + def register_supported_configs(cls) -> None: + """Register supported configs for dynamic quantization. + + Returns: + None: Updates the class-level supported configuration list. + """ supported_configs = [] dynamic_config = DynamicQuantConfig( weight_dtype=["fp8", "fp8_e4m3", "fp8_e5m2", "int8"], @@ -108,8 +115,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]: cls.supported_configs = supported_configs @staticmethod - def get_model_info(model) -> List[Tuple[str, Callable]]: - """Get concrete node names for supported operators.""" + def get_model_info(model) -> List[Tuple[str, str]]: + """Get concrete node names for supported operators. + + Args: + model (keras.Model): Keras model to inspect. + + Returns: + List[Tuple[str, str]]: List of (layer name, layer class name) pairs. + """ white_list = ["Dense", "EinsumDense"] filter_result = [] @@ -123,7 +137,11 @@ def get_model_info(model) -> List[Tuple[str, Callable]]: @classmethod def get_config_set_for_tuning(cls) -> Union[None, "DynamicQuantConfig", List["DynamicQuantConfig"]]: - """Get a default config set for tuning.""" + """Get a default config set for tuning. + + Returns: + DynamicQuantConfig: Configuration to use for tuning. + """ return DynamicQuantConfig( weight_dtype=["fp8", "fp8_e4m3", "fp8_e5m2", "int8"], activation_dtype=["fp8", "fp8_e4m3", "fp8_e5m2", "int8"], @@ -131,11 +149,27 @@ def get_config_set_for_tuning(cls) -> Union[None, "DynamicQuantConfig", List["Dy @classmethod def from_json_string(cls, json_string: str) -> "DynamicQuantConfig": + """Create a DynamicQuantConfig from a JSON string. + + Args: + json_string (str): JSON string describing the config. + + Returns: + DynamicQuantConfig: Parsed configuration instance. + """ cfg = json.loads(json_string) return cls.from_dict(cfg) @classmethod def from_dict(cls, config_dict: Dict) -> "DynamicQuantConfig": + """Create a DynamicQuantConfig from a dictionary. + + Args: + config_dict (Dict): Configuration fields. + + Returns: + DynamicQuantConfig: Parsed configuration instance. + """ weight_dtype = config_dict.get("weight_dtype", "fp8_e4m3") activation_dtype = config_dict.get("activation_dtype", "fp8_e4m3") white_list = config_dict.get("white_list", DEFAULT_WHITE_LIST) @@ -183,6 +217,9 @@ def __init__( weight_dtype (str): Data type for weights, default is "fp8_e4m3". activation_dtype (str): Data type for activations, default is "fp8_e4m3". white_list (list): A list of supported operators of this algorithm. + + Returns: + None: Initializes the configuration instance. """ super().__init__(white_list=white_list) if not isinstance(weight_dtype, list): @@ -198,8 +235,12 @@ def __init__( self._post_init() @classmethod - def register_supported_configs(cls) -> List[OperatorConfig]: - """Register supported configs.""" + def register_supported_configs(cls) -> None: + """Register supported configs for static quantization. + + Returns: + None: Updates the class-level supported configuration list. + """ supported_configs = [] static_config = StaticQuantConfig( weight_dtype=["fp8", "fp8_e4m3", "fp8_e5m2", "int8"], @@ -211,8 +252,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]: cls.supported_configs = supported_configs @staticmethod - def get_model_info(model) -> List[Tuple[str, Callable]]: - """Get concrete node names for supported operators.""" + def get_model_info(model) -> List[Tuple[str, str]]: + """Get concrete node names for supported operators. + + Args: + model (keras.Model): Keras model to inspect. + + Returns: + List[Tuple[str, str]]: List of (layer name, layer class name) pairs. + """ white_list = ["Dense", "EinsumDense", "MultiHeadAttention"] filter_result = [] @@ -226,7 +274,11 @@ def get_model_info(model) -> List[Tuple[str, Callable]]: @classmethod def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: - """Get a default config set for tuning.""" + """Get a default config set for tuning. + + Returns: + StaticQuantConfig: Configuration to use for tuning. + """ return StaticQuantConfig( weight_dtype=["fp8_e4m3", "fp8_e5m2", "int8"], activation_dtype=["fp8_e4m3", "fp8_e5m2", "int8"], @@ -234,11 +286,27 @@ def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["Sta @classmethod def from_json_string(cls, json_string: str) -> "StaticQuantConfig": + """Create a StaticQuantConfig from a JSON string. + + Args: + json_string (str): JSON string describing the config. + + Returns: + StaticQuantConfig: Parsed configuration instance. + """ cfg = json.loads(json_string) return cls.from_dict(cfg) @classmethod def from_dict(cls, config_dict: Dict) -> "StaticQuantConfig": + """Create a StaticQuantConfig from a dictionary. + + Args: + config_dict (Dict): Configuration fields. + + Returns: + StaticQuantConfig: Parsed configuration instance. + """ weight_dtype = config_dict.get("weight_dtype", "fp8_e5m2") activation_dtype = config_dict.get("activation_dtype", "fp8_e5m2") white_list = config_dict.get("white_list", DEFAULT_WHITE_LIST) @@ -253,7 +321,11 @@ def from_dict(cls, config_dict: Dict) -> "StaticQuantConfig": def get_all_registered_configs() -> Dict[str, BaseConfig]: - """Get all registered configs for JAX framework.""" + """Get all registered configs for JAX framework. + + Returns: + Dict[str, BaseConfig]: Mapping of config names to config classes. + """ registered_configs = config_registry.get_cls_configs() return registered_configs.get(FRAMEWORK_NAME, {}) @@ -262,7 +334,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig: """Generate the default Dynamic quantization config. Returns: - the default JAX Dynamic quantization config. + DynamicQuantConfig: The default JAX Dynamic quantization config. """ return DynamicQuantConfig() @@ -271,6 +343,6 @@ def get_default_static_config() -> StaticQuantConfig: """Generate the default Static quantization config. Returns: - the default JAX Static quantization config. + StaticQuantConfig: The default JAX Static quantization config. """ return StaticQuantConfig() diff --git a/neural_compressor/jax/quantization/layers_dynamic.py b/neural_compressor/jax/quantization/layers_dynamic.py index 2b3467b7beb..7b8586802d7 100644 --- a/neural_compressor/jax/quantization/layers_dynamic.py +++ b/neural_compressor/jax/quantization/layers_dynamic.py @@ -1,3 +1,5 @@ +"""Dynamic quantized layer implementations for JAX-backed Keras models.""" + # Copyright (c) 2025-2026 Intel Corporation # # Portions of this code are derived from: @@ -45,9 +47,24 @@ def register_dynamic_quantized_layer(clso): - """Register quantized layer class for original layer class.""" + """Register quantized layer class for an original layer class. + + Args: + clso (type): Original layer class to map to a quantized implementation. + + Returns: + Callable: Decorator that registers the quantized class. + """ def decorator(cls): + """Attach the quantized class to the dynamic mapping. + + Args: + cls (type): Quantized layer class to register. + + Returns: + type: The same class, for decorator chaining. + """ dynamic_quant_mapping[clso] = cls return cls @@ -55,31 +72,77 @@ def decorator(cls): class DynamicQDQLayer(keras.layers.Layer, SaveableLayerMixin): + """Layer that applies dynamic quantize-dequantize to activations.""" + def __init__(self, name, activation_dtype, asymmetric=False): + """Initialize the dynamic QDQ helper layer. + + Args: + name (str): Layer name. + activation_dtype (jnp.dtype): Activation dtype used for quantization. + asymmetric (bool): Whether to use asymmetric quantization. + + Returns: + None: Initializes the layer instance. + """ super().__init__(name=name) self.activation_dtype = activation_dtype self._is_asymmetric = asymmetric self.supports_masking = True def add_variables(self): + """Create quantization helper functions for activations. + + Returns: + None: Initializes quantization functions. + """ self._tracker.unlock() self.aquantfun = get_quantize_fun(dtype=self.activation_dtype, asymmetric=self._is_asymmetric) self.adequantfun = get_dequantize_fun(dtype=self.compute_dtype, asymmetric=self._is_asymmetric) self._tracker.lock() def call_symmetric(self, inputs, batch_min_max, mask=None): + """Apply symmetric quantization to inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + batch_min_max (jnp.ndarray): Min/max tensor for the batch. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Quantized-dequantized tensor. + """ ascale, _ = get_q_params(batch_min_max, self.activation_dtype, asymmetric=False) x = self.aquantfun(inputs, ascale) x = self.adequantfun(x, ascale) return x def call_asymmetric(self, inputs, batch_min_max, mask=None): + """Apply asymmetric quantization to inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + batch_min_max (jnp.ndarray): Min/max tensor for the batch. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Quantized-dequantized tensor. + """ ascale, azero_point = get_q_params(batch_min_max, self.activation_dtype, asymmetric=True) x = self.aquantfun(inputs, ascale, azero_point) x = self.adequantfun(x, ascale, azero_point) return x def call(self, inputs, mask=None): + """Apply dynamic activation quantize-dequantize. + + Args: + inputs (jnp.ndarray): Input tensor. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Tensor with quantize-dequantize applied. + """ if any([dim == 0 for dim in inputs.shape]): # Skip quantization for zero-size inputs return inputs @@ -105,8 +168,20 @@ def call(self, inputs, mask=None): class QDynamicDenseMixin(SaveableLayerMixin): + """Mixin that adds dynamic quantization to dense-like layers.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a dense-like layer instance for dynamic quantization. + + Args: + orig (keras.layers.Layer): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + keras.layers.Layer: The updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.weight_dtype = weight_dtype @@ -116,6 +191,11 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_variables(self): + """Create quantization variables and cached weight tensor. + + Returns: + None: Initializes quantization variables. + """ self._tracker.unlock() self.input_qdq.add_variables() wscale, _ = get_q_params(self._kernel.value, self.weight_dtype, asymmetric=False) @@ -141,6 +221,11 @@ def add_variables(self): self._tracker.lock() def post_quantization_cleanup(self): + """Remove original weights after quantization is complete. + + Returns: + None: Cleans up original weights. + """ self._tracker.unlock() self._trainable_variables.remove(self._kernel) del self._kernel @@ -148,10 +233,24 @@ def post_quantization_cleanup(self): @property def kernel(self): + """Return the dequantized kernel tensor. + + Returns: + jnp.ndarray: Dequantized kernel tensor. + """ w = self.wdequantfun(self._kernel_quant.value, self.wscale.value) return w def call(self, inputs, training=None): + """Apply quantized input processing before the dense computation. + + Args: + inputs (jnp.ndarray): Input tensor. + training (Optional[bool]): Training mode flag. + + Returns: + jnp.ndarray: Layer output tensor. + """ x = self.input_qdq(inputs) x = super().call(x, training=training) return x @@ -159,6 +258,8 @@ def call(self, inputs, training=None): @register_dynamic_quantized_layer(Dense) class QDynamicDense(QDynamicDenseMixin, Dense): + """Dynamically quantized Dense layer.""" + pass @@ -167,6 +268,8 @@ class QDynamicDense(QDynamicDenseMixin, Dense): @register_dynamic_quantized_layer(EinsumDense) class QDynamicEinsumDense(QDynamicDenseMixin, EinsumDense): + """Dynamically quantized EinsumDense layer.""" + pass @@ -175,8 +278,20 @@ class QDynamicEinsumDense(QDynamicDenseMixin, EinsumDense): @register_dynamic_quantized_layer(MultiHeadAttention) class QDynamicMultiHeadAttention(MultiHeadAttention, SaveableLayerMixin): + """Dynamically quantized MultiHeadAttention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a MultiHeadAttention instance for dynamic quantization. + + Args: + orig (keras.layers.MultiHeadAttention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + keras.layers.MultiHeadAttention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer) @@ -188,12 +303,22 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_variables(self): + """Create quantization helper layers for activations. + + Returns: + None: Initializes quantization helper layers. + """ self.q_qdq.add_variables() self.k_qdq.add_variables() self.a_qdq.add_variables() self.v_qdq.add_variables() def post_quantization_cleanup(self): + """Finalize dynamic quantization with no extra cleanup. + + Returns: + None: Keeps the layer ready for inference. + """ pass # fmt: off @@ -224,8 +349,7 @@ def _compute_attention( nothing). Returns: - attention_output: Multi-headed outputs of attention computation. - attention_scores: Multi-headed attention weights. + Tuple[jnp.ndarray, Optional[jnp.ndarray]]: Attention outputs and attention scores. """ # Check for flash attention constraints if self._flash_attention and return_attention_scores: @@ -310,8 +434,20 @@ def _compute_attention( @register_dynamic_quantized_layer(CachedGemma3Attention) class QDynamicCachedGemma3Attention(CachedGemma3Attention, SaveableLayerMixin): + """Dynamically quantized CachedGemma3Attention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a CachedGemma3Attention instance for dynamic quantization. + + Args: + orig (CachedGemma3Attention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + CachedGemma3Attention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.qdq = DynamicQDQLayer("qdq", activation_dtype, False) @@ -319,9 +455,19 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_variables(self): + """Create activation QDQ helper layer. + + Returns: + None: Initializes activation helper layer. + """ self.qdq.add_variables() def post_quantization_cleanup(self): + """Finalize dynamic quantization with no extra cleanup. + + Returns: + None: Keeps the layer ready for inference. + """ pass def _compute_attention( @@ -333,6 +479,19 @@ def _compute_attention( training=False, cache_update_index=0, ): + """Compute attention with dynamic activation quantization. + + Args: + q (jnp.ndarray): Query tensor. + k (jnp.ndarray): Key tensor. + v (jnp.ndarray): Value tensor. + attention_mask (Optional[jnp.ndarray]): Optional attention mask. + training (bool): Training mode flag. + cache_update_index (int): Cache update index for generation. + + Returns: + jnp.ndarray: Attention output tensor. + """ if self.query_head_dim_normalize: query_normalization = 1 / np.sqrt(self.head_dim) else: @@ -390,8 +549,20 @@ def _compute_attention( @register_dynamic_quantized_layer(Gemma3VisionAttention) class QDynamicGemma3VisionAttention(Gemma3VisionAttention, SaveableLayerMixin): + """Dynamically quantized Gemma3VisionAttention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a Gemma3VisionAttention instance for dynamic quantization. + + Args: + orig (Gemma3VisionAttention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + Gemma3VisionAttention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.qdq = DynamicQDQLayer("qdq", activation_dtype, False) @@ -399,9 +570,19 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_variables(self): + """Create activation QDQ helper layer. + + Returns: + None: Initializes activation helper layer. + """ self.qdq.add_variables() def post_quantization_cleanup(self): + """Finalize dynamic quantization with no extra cleanup. + + Returns: + None: Keeps the layer ready for inference. + """ pass def call( @@ -411,6 +592,17 @@ def call( return_attention_scores=None, training=False, ): + """Compute vision attention with quantized activations. + + Args: + x (jnp.ndarray): Input tensor. + attention_mask (Optional[jnp.ndarray]): Optional attention mask. + return_attention_scores (Optional[bool]): Whether to return attention scores. + training (bool): Training mode flag. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: Attention output and attention probabilities. + """ batch_size = ops.shape(x)[0] mixed_query_layer = self.query_proj(inputs=x) mixed_key_layer = self.key_proj(inputs=x) @@ -457,8 +649,20 @@ def call( @register_dynamic_quantized_layer(ReversibleEmbedding) class QDynamicReversibleEmbedding(ReversibleEmbedding, SaveableLayerMixin): + """Dynamically quantized ReversibleEmbedding layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a ReversibleEmbedding instance for dynamic quantization. + + Args: + orig (ReversibleEmbedding): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + ReversibleEmbedding: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer) @@ -468,15 +672,34 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_variables(self): + """Create activation QDQ helper layers. + + Returns: + None: Initializes activation helper layers. + """ self.inputs_qdq.add_variables() self.kernel_qdq.add_variables() def post_quantization_cleanup(self): + """Finalize dynamic quantization with no extra cleanup. + + Returns: + None: Keeps the layer ready for inference. + """ pass # TODO maybe make kernel (offline) quantization for reversible embedding (self.embeddings in our path) ? def call(self, inputs, reverse=False): + """Compute forward or reverse embedding with activation quantization. + + Args: + inputs (jnp.ndarray): Input tensor. + reverse (bool): Whether to compute the reverse embedding. + + Returns: + jnp.ndarray: Embedded outputs or logits. + """ if reverse: if self.tie_weights: kernel = ops.transpose(ops.convert_to_tensor(self.embeddings)) diff --git a/neural_compressor/jax/quantization/layers_static.py b/neural_compressor/jax/quantization/layers_static.py index f00fcffbf10..e9d56e705e5 100644 --- a/neural_compressor/jax/quantization/layers_static.py +++ b/neural_compressor/jax/quantization/layers_static.py @@ -1,3 +1,5 @@ +"""Static quantized layer implementations for JAX-backed Keras models.""" + # Copyright (c) 2025-2026 Intel Corporation # # Portions of this code are derived from: @@ -45,9 +47,24 @@ def register_static_quantized_layer(clso): - """Register quantized layer class for original layer class.""" + """Register quantized layer class for an original layer class. + + Args: + clso (type): Original layer class to map to a quantized implementation. + + Returns: + Callable: Decorator that registers the quantized class. + """ def decorator(cls): + """Attach the quantized class to the static mapping. + + Args: + cls (type): Quantized layer class to register. + + Returns: + type: The same class, for decorator chaining. + """ static_quant_mapping[clso] = cls return cls @@ -55,7 +72,18 @@ def decorator(cls): class MinMaxObserver(keras.layers.Layer): + """Observer that tracks running min/max values for calibration.""" + def __init__(self, *args, **kwargs): + """Initialize the min/max observer layer. + + Args: + *args: Positional arguments for the base layer. + **kwargs: Keyword arguments for the base layer. + + Returns: + None: Initializes the observer layer. + """ super().__init__(*args, **kwargs, name="min_max") # Track running min/max as non-trainable weights self.min_val = self.add_weight( @@ -67,6 +95,15 @@ def __init__(self, *args, **kwargs): self.supports_masking = True def call(self, inputs, mask=None): + """Update min/max statistics during calibration. + + Args: + inputs (jnp.ndarray): Input tensor to observe. + mask (Optional[jnp.ndarray]): Optional mask to ignore padded elements. + + Returns: + jnp.ndarray: The original inputs for passthrough. + """ if 0 not in inputs.shape: if mask is not None: # Expand mask to match input dimensions if needed @@ -87,14 +124,39 @@ def call(self, inputs, mask=None): return inputs def build(self, input_shape): + """Override build with no additional variables. + + Args: + input_shape (Tuple[int, ...]): Input shape for the layer. + + Returns: + None: No additional variables are created. + """ pass def get_calibrated_range(self): + """Return the calibrated min/max range as a tensor. + + Returns: + jnp.ndarray: Tensor containing min and max values. + """ return ops.array((self.min_val, self.max_val)) class StaticQDQLayer(keras.layers.Layer, SaveableLayerMixin): + """Layer that applies static quantize-dequantize to activations.""" + def __init__(self, name, activation_dtype, asymmetric=False): + """Initialize the static QDQ helper layer. + + Args: + name (str): Layer name. + activation_dtype (jnp.dtype): Activation dtype used for quantization. + asymmetric (bool): Whether to use asymmetric quantization. + + Returns: + None: Initializes the layer instance. + """ super().__init__(name=name) self.activation_dtype = activation_dtype self._is_asymmetric = asymmetric @@ -102,11 +164,21 @@ def __init__(self, name, activation_dtype, asymmetric=False): self._is_quantized = False def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self._tracker.unlock() self.input_observer = MinMaxObserver() self._tracker.lock() def add_variables(self): + """Create quantization variables for activations. + + Returns: + None: Initializes quantization variables. + """ self._tracker.unlock() if self._is_asymmetric: self.azero_point = self.add_weight( @@ -130,6 +202,11 @@ def add_variables(self): self._tracker.lock() def convert(self): + """Compute activation scale and finalize static quantization. + + Returns: + None: Updates activation scale variables. + """ self._tracker.unlock() arange = self.input_observer.get_calibrated_range() ascale, azero_point = get_q_params(arange, self.activation_dtype, asymmetric=self._is_asymmetric) @@ -147,6 +224,11 @@ def convert(self): self._tracker.lock() def post_quantization_cleanup(self): + """Remove observers and finalize quantized call path. + + Returns: + None: Cleans up observers and sets quantized call. + """ self._tracker.unlock() if hasattr(self, "_layers") and hasattr(self, "input_observer"): if self.input_observer in self._layers: @@ -157,16 +239,43 @@ def post_quantization_cleanup(self): self._tracker.lock() def call(self, inputs, mask=None): + """Run calibration observer on inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Observed inputs. + """ x = self.input_observer(inputs, mask=mask) return x def call_symmetric(self, inputs, mask=None): + """Apply symmetric quantize-dequantize to inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Quantized-dequantized tensor. + """ ascale = self.ascale.value x = self.aquantfun(inputs, ascale) x = self.adequantfun(x, ascale) return x def call_asymmetric(self, inputs, mask=None): + """Apply asymmetric quantize-dequantize to inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + mask (Optional[jnp.ndarray]): Optional mask tensor. + + Returns: + jnp.ndarray: Quantized-dequantized tensor. + """ ascale = self.ascale.value zero_point = self.azero_point.value x = self.aquantfun(inputs, ascale, zero_point) @@ -175,9 +284,20 @@ def call_asymmetric(self, inputs, mask=None): class QStaticDenseMixin(SaveableLayerMixin): + """Mixin that adds static quantization to dense-like layers.""" @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a dense-like layer instance for static quantization. + + Args: + orig (keras.layers.Layer): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + keras.layers.Layer: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.weight_dtype = weight_dtype @@ -189,11 +309,21 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self._tracker.unlock() self.input_observer = MinMaxObserver() self._tracker.lock() def add_variables(self): + """Create quantization variables for activations and weights. + + Returns: + None: Initializes quantization variables. + """ self._tracker.unlock() if self._is_int8: self.azero_point = self.add_weight( @@ -237,6 +367,11 @@ def add_variables(self): self._tracker.lock() def convert(self): + """Compute activation/weight scales and quantize weights. + + Returns: + None: Updates quantization variables with calibrated values. + """ self._tracker.unlock() arange = self.input_observer.get_calibrated_range() @@ -262,6 +397,11 @@ def convert(self): self._tracker.lock() def post_quantization_cleanup(self): + """Finalize static quantization and drop unused weights. + + Returns: + None: Cleans up observers and original weights. + """ self._tracker.unlock() if hasattr(self, "_kernel") and self._kernel in self._trainable_variables: self._trainable_variables.remove(self._kernel) @@ -278,6 +418,11 @@ def post_quantization_cleanup(self): @property def kernel(self): + """Return the dequantized kernel tensor. + + Returns: + jnp.ndarray: Dequantized kernel tensor. + """ if self._is_quantized: w = self.wdequantfun(self.w.value, self.wscale.value) return w @@ -285,11 +430,29 @@ def kernel(self): return ret.value def call(self, inputs, training=None): + """Run calibration observer before the dense computation. + + Args: + inputs (jnp.ndarray): Input tensor. + training (Optional[bool]): Training mode flag. + + Returns: + jnp.ndarray: Layer output tensor. + """ x = self.input_observer(inputs) x = super().call(x, training=training) return x def call_fp8(self, inputs, training=None): + """Apply FP8 quantize-dequantize before dense computation. + + Args: + inputs (jnp.ndarray): Input tensor. + training (Optional[bool]): Training mode flag. + + Returns: + jnp.ndarray: Layer output tensor. + """ ascale = self.ascale.value x = self.aquantfun(inputs, ascale) x = self.adequantfun(x, ascale) @@ -297,6 +460,15 @@ def call_fp8(self, inputs, training=None): return x def call_int8(self, inputs, training=None): + """Apply int8 quantize-dequantize before dense computation. + + Args: + inputs (jnp.ndarray): Input tensor. + training (Optional[bool]): Training mode flag. + + Returns: + jnp.ndarray: Layer output tensor. + """ ascale = self.ascale.value zero_point = self.azero_point.value x = self.aquantfun(inputs, ascale, zero_point) @@ -307,6 +479,8 @@ def call_int8(self, inputs, training=None): @register_static_quantized_layer(Dense) class QStaticDense(QStaticDenseMixin, Dense): + """Statically quantized Dense layer.""" + pass @@ -315,6 +489,8 @@ class QStaticDense(QStaticDenseMixin, Dense): @register_static_quantized_layer(EinsumDense) class QStaticEinsumDense(QStaticDenseMixin, EinsumDense): + """Statically quantized EinsumDense layer.""" + pass @@ -323,8 +499,20 @@ class QStaticEinsumDense(QStaticDenseMixin, EinsumDense): @register_static_quantized_layer(MultiHeadAttention) class QStaticMultiHeadAttention(MultiHeadAttention, SaveableLayerMixin): + """Statically quantized MultiHeadAttention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a MultiHeadAttention instance for static quantization. + + Args: + orig (keras.layers.MultiHeadAttention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + keras.layers.MultiHeadAttention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer) @@ -341,24 +529,44 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self.q_qdq.add_observers() self.k_qdq.add_observers() self.a_qdq.add_observers() self.v_qdq.add_observers() def add_variables(self): + """Create quantization variables for activation QDQ. + + Returns: + None: Initializes QDQ helper variables. + """ self.q_qdq.add_variables() self.k_qdq.add_variables() self.a_qdq.add_variables() self.v_qdq.add_variables() def convert(self): + """Compute activation calibration values for QDQ helpers. + + Returns: + None: Updates QDQ helpers with calibrated values. + """ self.q_qdq.convert() self.k_qdq.convert() self.a_qdq.convert() self.v_qdq.convert() def post_quantization_cleanup(self): + """Finalize static quantization and mark the layer as quantized. + + Returns: + None: Cleans up observers and marks quantized state. + """ self._tracker.unlock() self.q_qdq.post_quantization_cleanup() self.k_qdq.post_quantization_cleanup() @@ -395,8 +603,7 @@ def _compute_attention( nothing). Returns: - attention_output: Multi-headed outputs of attention computation. - attention_scores: Multi-headed attention weights. + Tuple[jnp.ndarray, Optional[jnp.ndarray]]: Attention outputs and attention scores. """ # Check for flash attention constraints if self._flash_attention and return_attention_scores: @@ -481,8 +688,20 @@ def _compute_attention( @register_static_quantized_layer(CachedGemma3Attention) class QStaticCachedGemma3Attention(CachedGemma3Attention, SaveableLayerMixin): + """Statically quantized CachedGemma3Attention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a CachedGemma3Attention instance for static quantization. + + Args: + orig (CachedGemma3Attention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + CachedGemma3Attention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.q_qdq = StaticQDQLayer("q_qdq", activation_dtype, False) @@ -494,24 +713,44 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self.q_qdq.add_observers() self.k_qdq.add_observers() self.attention_softmax_qdq.add_observers() self.v_qdq.add_observers() def add_variables(self): + """Create quantization variables for activation QDQ. + + Returns: + None: Initializes QDQ helper variables. + """ self.q_qdq.add_variables() self.k_qdq.add_variables() self.attention_softmax_qdq.add_variables() self.v_qdq.add_variables() def convert(self): + """Compute activation calibration values for QDQ helpers. + + Returns: + None: Updates QDQ helpers with calibrated values. + """ self.q_qdq.convert() self.k_qdq.convert() self.attention_softmax_qdq.convert() self.v_qdq.convert() def post_quantization_cleanup(self): + """Finalize static quantization and mark the layer as quantized. + + Returns: + None: Cleans up observers and marks quantized state. + """ self._tracker.unlock() self.q_qdq.post_quantization_cleanup() self.k_qdq.post_quantization_cleanup() @@ -529,6 +768,19 @@ def _compute_attention( training=False, cache_update_index=0, ): + """Compute attention with static activation quantization. + + Args: + q (jnp.ndarray): Query tensor. + k (jnp.ndarray): Key tensor. + v (jnp.ndarray): Value tensor. + attention_mask (Optional[jnp.ndarray]): Optional attention mask. + training (bool): Training mode flag. + cache_update_index (int): Cache update index for generation. + + Returns: + jnp.ndarray: Attention output tensor. + """ if self.query_head_dim_normalize: query_normalization = 1 / np.sqrt(self.head_dim) else: @@ -586,8 +838,20 @@ def _compute_attention( @register_static_quantized_layer(Gemma3VisionAttention) class QStaticGemma3VisionAttention(Gemma3VisionAttention, SaveableLayerMixin): + """Statically quantized Gemma3VisionAttention layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a Gemma3VisionAttention instance for static quantization. + + Args: + orig (Gemma3VisionAttention): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + Gemma3VisionAttention: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.query_qdq = StaticQDQLayer("query_qdq", activation_dtype, False) @@ -599,24 +863,44 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self.query_qdq.add_observers() self.key_qdq.add_observers() self.dropout_attention_probs_qdq.add_observers() self.value_qdq.add_observers() def add_variables(self): + """Create quantization variables for activation QDQ. + + Returns: + None: Initializes QDQ helper variables. + """ self.query_qdq.add_variables() self.key_qdq.add_variables() self.dropout_attention_probs_qdq.add_variables() self.value_qdq.add_variables() def convert(self): + """Compute activation calibration values for QDQ helpers. + + Returns: + None: Updates QDQ helpers with calibrated values. + """ self.query_qdq.convert() self.key_qdq.convert() self.dropout_attention_probs_qdq.convert() self.value_qdq.convert() def post_quantization_cleanup(self): + """Finalize static quantization and mark the layer as quantized. + + Returns: + None: Cleans up observers and marks quantized state. + """ self._tracker.unlock() self.query_qdq.post_quantization_cleanup() self.key_qdq.post_quantization_cleanup() @@ -632,6 +916,17 @@ def call( return_attention_scores=None, training=False, ): + """Compute vision attention with static activation quantization. + + Args: + x (jnp.ndarray): Input tensor. + attention_mask (Optional[jnp.ndarray]): Optional attention mask. + return_attention_scores (Optional[bool]): Whether to return attention scores. + training (bool): Training mode flag. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: Attention output and attention probabilities. + """ batch_size = ops.shape(x)[0] mixed_query_layer = self.query_proj(inputs=x) mixed_key_layer = self.key_proj(inputs=x) @@ -678,8 +973,20 @@ def call( # @register_static_quantized_layer(RotaryEmbedding) class QStaticRotaryEmbedding(RotaryEmbedding, SaveableLayerMixin): + """Statically quantized RotaryEmbedding layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a RotaryEmbedding instance for static quantization. + + Args: + orig (RotaryEmbedding): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + RotaryEmbedding: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig.positions_qdq = StaticQDQLayer("positions_qdq", activation_dtype, False) @@ -689,18 +996,38 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self.positions_qdq.add_observers() self.inverse_freq_qdq.add_observers() def add_variables(self): + """Create quantization variables for activation QDQ. + + Returns: + None: Initializes QDQ helper variables. + """ self.positions_qdq.add_variables() self.inverse_freq_qdq.add_variables() def convert(self): + """Compute activation calibration values for QDQ helpers. + + Returns: + None: Updates QDQ helpers with calibrated values. + """ self.positions_qdq.convert() self.inverse_freq_qdq.convert() def post_quantization_cleanup(self): + """Finalize static quantization and mark the layer as quantized. + + Returns: + None: Cleans up observers and marks quantized state. + """ self._tracker.unlock() self.positions_qdq.post_quantization_cleanup() self.inverse_freq_qdq.post_quantization_cleanup() @@ -708,6 +1035,16 @@ def post_quantization_cleanup(self): self._tracker.lock() def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): + """Compute cosine/sine embeddings with quantized inputs. + + Args: + inputs (jnp.ndarray): Input tensor. + start_index (int): Starting index for positions. + positions (Optional[jnp.ndarray]): Optional explicit positions tensor. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: Cosine and sine embeddings. + """ feature_axis = len(inputs.shape) - 1 sequence_axis = 1 @@ -743,8 +1080,20 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): @register_static_quantized_layer(ReversibleEmbedding) class QStaticReversibleEmbedding(ReversibleEmbedding, SaveableLayerMixin): + """Statically quantized ReversibleEmbedding layer.""" + @classmethod def prepare(cls, orig, weight_dtype, activation_dtype): + """Convert a ReversibleEmbedding instance for static quantization. + + Args: + orig (ReversibleEmbedding): Original layer instance. + weight_dtype (jnp.dtype): Dtype for quantized weights. + activation_dtype (jnp.dtype): Dtype for quantized activations. + + Returns: + ReversibleEmbedding: Updated layer instance. + """ orig._tracker.unlock() orig.__class__ = cls orig._is_int8 = jnp.issubdtype(activation_dtype, jnp.integer) @@ -755,19 +1104,39 @@ def prepare(cls, orig, weight_dtype, activation_dtype): return orig def add_observers(self): + """Attach observer layers for calibration. + + Returns: + None: Adds observer layers. + """ self.inputs_qdq.add_observers() self.kernel_qdq.add_observers() def add_variables(self): + """Create quantization variables for activation QDQ. + + Returns: + None: Initializes QDQ helper variables. + """ self.inputs_qdq.add_variables() self.kernel_qdq.add_variables() def convert(self): + """Compute activation calibration values for QDQ helpers. + + Returns: + None: Updates QDQ helpers with calibrated values. + """ # TODO maybe make kernel (offline) quantization for reversible embedding (self.embeddings in our path) ? self.inputs_qdq.convert() self.kernel_qdq.convert() def post_quantization_cleanup(self): + """Finalize static quantization and mark the layer as quantized. + + Returns: + None: Cleans up observers and marks quantized state. + """ self._tracker.unlock() self.inputs_qdq.post_quantization_cleanup() self.kernel_qdq.post_quantization_cleanup() @@ -775,6 +1144,15 @@ def post_quantization_cleanup(self): self._tracker.lock() def call(self, inputs, reverse=False): + """Compute forward or reverse embedding with static quantization. + + Args: + inputs (jnp.ndarray): Input tensor. + reverse (bool): Whether to compute the reverse embedding. + + Returns: + jnp.ndarray: Embedded outputs or logits. + """ if reverse: if self.tie_weights: kernel = ops.transpose(ops.convert_to_tensor(self.embeddings)) diff --git a/neural_compressor/jax/quantization/quantize.py b/neural_compressor/jax/quantization/quantize.py index 2a55d12f9e3..765268362e3 100644 --- a/neural_compressor/jax/quantization/quantize.py +++ b/neural_compressor/jax/quantization/quantize.py @@ -25,7 +25,15 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name): - """Whether to apply the algorithm.""" + """Determine whether a quantization algorithm should be applied. + + Args: + configs_mapping (Dict[Tuple[str, callable], BaseConfig]): Mapping of layer identifiers to configs. + algo_name (str): Algorithm name to check. + + Returns: + bool: True if any config matches the algorithm name. + """ return any(config.name == algo_name for config in configs_mapping.values()) @@ -40,14 +48,14 @@ def quantize_model( """Return a quantized Keras model according to the given configuration. Args: - model: FP32 Keras model to be quantized. - quant_config: Quantization configuration. - calib_function: Function used for model calibration, required for static quantization. - inplace: When True, the original model is modified in-place and should not be used - afterward. A value of False is not yet supported. + model (keras.Model): FP32 Keras model to be quantized. + quant_config (BaseConfig): Quantization configuration. + calib_function (Callable, optional): Function used for model calibration, required for static quantization. + inplace (bool): When True, the original model is modified in-place and should not be used afterward. A value of + False is not yet supported. Returns: - The quantized model. + keras.Model: The quantized model. """ # fmt: on if not inplace: diff --git a/neural_compressor/jax/quantization/saving.py b/neural_compressor/jax/quantization/saving.py index aa78e97676b..14334eeb916 100644 --- a/neural_compressor/jax/quantization/saving.py +++ b/neural_compressor/jax/quantization/saving.py @@ -1,3 +1,5 @@ +"""Serialization helpers for JAX quantized Keras models.""" + # Copyright (c) 2026 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,10 +35,10 @@ def quant_config_to_json_object(quant_config: BaseConfig) -> dict: """Serialize a quant config to a JSON-compatible dict with class name. Args: - quant_config: The quantization config object to serialize. + quant_config (BaseConfig): The quantization config object to serialize. Returns: - A dict with 'quantization_type' and 'config' keys. + dict: A dict with 'quantization_type' and 'config' keys. """ return { "quantization_type": quant_config.name, @@ -48,10 +50,10 @@ def quant_config_from_json_object(json_obj: dict) -> BaseConfig: """Deserialize a quant config from a JSON-compatible dict with class name. Args: - json_obj: A dict with 'quantization_type' and 'config' keys. + json_obj (dict): A dict with 'quantization_type' and 'config' keys. Returns: - The instantiated quantization config object. + BaseConfig: The instantiated quantization config object. Raises: ValueError: If the class name is unknown. @@ -68,16 +70,34 @@ def quant_config_from_json_object(json_obj: dict) -> BaseConfig: class VersionManager: + """Handle version metadata for serialized quantized models.""" + _MODULES = ["neural_compressor_jax", "keras", "keras_hub"] @classmethod def add_versions(cls, config): + """Insert package versions into the serialized config. + + Args: + config (dict): Configuration dictionary to update in-place. + + Returns: + None: Updates the config dictionary in-place. + """ config["_versions"] = {} for package in cls._MODULES: config["_versions"][package] = importlib_metadata.version(package) @classmethod def check_versions_mismatch(cls, config): + """Check for version mismatches between saved and current packages. + + Args: + config (dict): Configuration dictionary that may include version metadata. + + Returns: + None: Logs warnings if mismatches are found. + """ versions = config.get("_versions") if versions is None: logger.error( @@ -94,7 +114,17 @@ def check_versions_mismatch(cls, config): class SaveableLayerMixin: + """Mixin for saving and loading quantized layer variables.""" + def save_own_variables(self, store): + """Save layer variables into the provided store. + + Args: + store (dict): Mutable mapping to receive serialized variables. + + Returns: + None: Updates the store mapping with serialized variables. + """ weight_dtype = getattr(self, "weight_dtype", None) for var in self._trainable_variables + self._non_trainable_variables: is_one_byte_format = dtype_utils.dtype_size(var.dtype) == 8 @@ -106,6 +136,14 @@ def save_own_variables(self, store): store[var.name] = value_to_save def load_own_variables(self, store): + """Load layer variables from the provided store. + + Args: + store (dict): Mapping containing serialized variables. + + Returns: + None: Loads variables into the layer. + """ weight_dtype = getattr(self, "weight_dtype", None) for var in self._trainable_variables + self._non_trainable_variables: value_to_load = store[var.name] @@ -117,7 +155,18 @@ def load_own_variables(self, store): @keras.saving.register_keras_serializable(package="INC", name=None) class KerasQuantizedModelBackboneWrapper(Backbone): + """Wrapper that preserves quantization config when saving Keras backbones.""" + def __init__(self, model, quant_config: Optional[BaseConfig] = None): + """Initialize the wrapper around a backbone model. + + Args: + model (keras.Model): Backbone model to wrap. + quant_config (Optional[BaseConfig]): Quantization configuration. + + Returns: + None: Initializes the wrapper. + """ object.__setattr__(self, "_wrapped_model", model) object.__setattr__( self, @@ -138,26 +187,65 @@ def __init__(self, model, quant_config: Optional[BaseConfig] = None): object.__setattr__(self, "_quant_config", quant_config) def __getattribute__(self, name): + """Delegate attribute access to the wrapped model. + + Args: + name (str): Attribute name to access. + + Returns: + Any: Attribute value from the wrapper or wrapped model. + """ if name in object.__getattribute__(self, "fields"): return object.__getattribute__(self, name) return object.__getattribute__(self, "_wrapped_model").__getattribute__(name) def __setattr__(self, name, value): + """Delegate attribute updates to the wrapped model. + + Args: + name (str): Attribute name to update. + value (Any): Value to assign. + + Returns: + None: Updates the attribute on the wrapper or wrapped model. + """ if name in object.__getattribute__(self, "fields"): return object.__setattr__(self, name, value) return object.__getattribute__(self, "_wrapped_model").__setattr__(name, value) def get_config(self): + """Serialize the wrapper configuration for Keras saving. + + Returns: + dict: Serialized configuration for the wrapper. + """ config = super().get_config() config["_quant_config"] = quant_config_to_json_object(self._quant_config) config["_wrapped_model"] = keras.saving.serialize_keras_object(self._wrapped_model) return config def __new__(cls, *args, **kwargs): + """Bypass BaseModel __new__ to allow manual initialization. + + Args: + *args: Positional arguments for object creation. + **kwargs: Keyword arguments for object creation. + + Returns: + KerasQuantizedModelBackboneWrapper: New wrapper instance. + """ return object.__new__(cls) @classmethod def from_config(cls, config): + """Recreate a wrapper from a serialized config dictionary. + + Args: + config (dict): Serialized configuration dictionary. + + Returns: + KerasQuantizedModelWrapper: Reconstructed quantized model wrapper. + """ model = keras.saving.deserialize_keras_object(config["_wrapped_model"]) quant_config_json = config.get("_quant_config") quant_config = quant_config_from_json_object(quant_config_json) @@ -172,6 +260,9 @@ def save_to_preset(self, preset_dir, max_shard_size=10): max_shard_size: `int` or `float`. Maximum size in GB for each sharded file. If `None`, no sharding will be done. Defaults to `10`. + + Returns: + None: Writes the preset files to disk. """ saver = get_preset_saver(preset_dir) saver.save_backbone(self, max_shard_size=max_shard_size) @@ -179,10 +270,20 @@ def save_to_preset(self, preset_dir, max_shard_size=10): @keras.saving.register_keras_serializable(package="INC", name=None) class KerasQuantizedModelWrapper(Task): + """Wrapper that preserves quantization config for Keras tasks.""" backbone_cls = KerasQuantizedModelBackboneWrapper def __init__(self, model, quant_config: Optional[BaseConfig] = None): + """Initialize the wrapper around a task model. + + Args: + model (keras.Model): Task model to wrap. + quant_config (Optional[BaseConfig]): Quantization configuration. + + Returns: + None: Initializes the wrapper. + """ object.__setattr__(self, "_wrapped_model", model) object.__setattr__( self, @@ -202,16 +303,38 @@ def __init__(self, model, quant_config: Optional[BaseConfig] = None): object.__setattr__(self, "_quant_config", quant_config) def __getattribute__(self, name): + """Delegate attribute access to the wrapped model. + + Args: + name (str): Attribute name to access. + + Returns: + Any: Attribute value from the wrapper or wrapped model. + """ if name in object.__getattribute__(self, "fields"): return object.__getattribute__(self, name) return object.__getattribute__(self, "_wrapped_model").__getattribute__(name) def __setattr__(self, name, value): + """Delegate attribute updates to the wrapped model. + + Args: + name (str): Attribute name to update. + value (Any): Value to assign. + + Returns: + None: Updates the attribute on the wrapper or wrapped model. + """ if name in object.__getattribute__(self, "fields"): return object.__setattr__(self, name, value) return object.__getattribute__(self, "_wrapped_model").__setattr__(name, value) def get_config(self): + """Serialize the wrapper configuration for Keras saving. + + Returns: + dict: Serialized configuration for the wrapper. + """ config = super().get_config() VersionManager.add_versions(config) config["_quant_config"] = quant_config_to_json_object(self._quant_config) @@ -227,10 +350,27 @@ def get_config(self): return config def __new__(cls, *args, **kwargs): + """Bypass BaseModel __new__ to allow manual initialization. + + Args: + *args: Positional arguments for object creation. + **kwargs: Keyword arguments for object creation. + + Returns: + KerasQuantizedModelWrapper: New wrapper instance. + """ return object.__new__(cls) @classmethod def from_config(cls, config): + """Recreate a wrapper from a serialized config dictionary. + + Args: + config (dict): Serialized configuration dictionary. + + Returns: + KerasQuantizedModelWrapper: Reconstructed quantized model wrapper. + """ VersionManager.check_versions_mismatch(config) model = keras.saving.deserialize_keras_object(config["_wrapped_model"]) quant_config_json = config.get("_quant_config") @@ -247,6 +387,9 @@ def save_to_preset(self, preset_dir, max_shard_size=10): max_shard_size: `int` or `float`. Maximum size in GB for each sharded file. If `None`, no sharding will be done. Defaults to `10`. + + Returns: + None: Writes the preset files to disk. """ saver = get_preset_saver(preset_dir) saver.save_task(self, max_shard_size=max_shard_size) @@ -254,16 +397,22 @@ def save_to_preset(self, preset_dir, max_shard_size=10): @keras.saving.register_keras_serializable(package="INC", name=None) class KerasQuantizedGemmaWrapper(KerasQuantizedModelWrapper, Gemma3CausalLM): + """Quantized wrapper for Gemma3CausalLM models.""" + backbone_cls = KerasQuantizedModelBackboneWrapper @keras.saving.register_keras_serializable(package="INC", name=None) class KerasQuantizedViTWrapper(KerasQuantizedModelWrapper, ViTImageClassifier): + """Quantized wrapper for ViTImageClassifier models.""" + backbone_cls = KerasQuantizedModelBackboneWrapper @keras.saving.register_keras_serializable(package="INC", name=None) class KerasQuantizedTokenizerWrapper(KerasQuantizedModelWrapper, Gemma3Tokenizer): + """Quantized wrapper for Gemma3Tokenizer models.""" + backbone_cls = KerasQuantizedModelBackboneWrapper @@ -282,10 +431,10 @@ def prepare_deserialized_quantized_model( It prepares the model for inference by preparing the quantized layers. Args: - model: loaded base keras model - quant_config: quantization configuration + model (keras.Model): Loaded base keras model. + quant_config (BaseConfig): Quantization configuration. Returns: - KerasQuantizedModelWrapper: the transformed quantized model + KerasQuantizedModelWrapper: The transformed quantized model wrapper. """ model_info = quant_config.get_model_info(model) configs_mapping = quant_config.to_config_mapping(model_info=model_info) diff --git a/neural_compressor/jax/utils/__init__.py b/neural_compressor/jax/utils/__init__.py index 07d5bef5cf2..b53e2c09606 100644 --- a/neural_compressor/jax/utils/__init__.py +++ b/neural_compressor/jax/utils/__init__.py @@ -1,3 +1,5 @@ +"""JAX utility exports for algorithm registration.""" + # Copyright (c) 2026 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/neural_compressor/jax/utils/utility.py b/neural_compressor/jax/utils/utility.py index 26f0f06ab69..e61fe112b32 100644 --- a/neural_compressor/jax/utils/utility.py +++ b/neural_compressor/jax/utils/utility.py @@ -31,9 +31,24 @@ def add_fp8_support(function): - """Extend the given size function to support FP8 dtypes.""" + """Extend a dtype size function to support FP8 dtypes. + + Args: + function (Callable): Function that returns the size of a dtype in bits. + + Returns: + Callable: Wrapped function that handles FP8 dtypes. + """ def wrapper(dtype): + """Return dtype size in bits with added FP8 support. + + Args: + dtype (str): Dtype name to query. + + Returns: + int: Size of the dtype in bits. + """ q_dtypes = ["float8_e4m3fn", "float8_e4m3", "float8_e5m2"] if dtype in q_dtypes: return 8 @@ -62,6 +77,14 @@ def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.k """ def decorator(algo_func): + """Register an algorithm implementation in the global mapping. + + Args: + algo_func (Callable): Algorithm implementation to register. + + Returns: + Callable: The original algorithm function. + """ algos_mapping[name] = algo_func return algo_func @@ -69,20 +92,58 @@ def decorator(algo_func): def get_quantize_fun(dtype=ml_dtypes.float8_e4m3, asymmetric=False): + """Create a quantization function for the specified dtype. + + Args: + dtype (jnp.dtype): Target quantization dtype. + asymmetric (bool): Whether to use asymmetric quantization for integer dtypes. + + Returns: + Callable: Quantization function that maps tensors to the target dtype. + """ + @partial(jax.lax.composite, name="inc.quantize_fp8") def quantize_tensor_float(x, scale): + """Quantize floating-point tensors using clamping. + + Args: + x (jnp.ndarray): Input tensor. + scale (jnp.ndarray): Scale factor for quantization. + + Returns: + jnp.ndarray: Quantized tensor. + """ return jax.lax.clamp( jnp.finfo(dtype).min.astype(x.dtype), x / scale, jnp.finfo(dtype).max.astype(x.dtype) ).astype(dtype) @partial(jax.lax.composite, name="inc.quantize_int8") def quantize_tensor_int(x, scale): + """Quantize integer tensors using symmetric scaling. + + Args: + x (jnp.ndarray): Input tensor. + scale (jnp.ndarray): Scale factor for quantization. + + Returns: + jnp.ndarray: Quantized tensor. + """ val = jnp.round(x / scale) val = jnp.clip(val, jnp.iinfo(dtype).min, jnp.iinfo(dtype).max) return val.astype(dtype) @partial(jax.lax.composite, name="inc.quantize_int8_asymmetric") def quantize_tensor_int_asymmetric(x, scale, zero_point): + """Quantize integer tensors using asymmetric scaling. + + Args: + x (jnp.ndarray): Input tensor. + scale (jnp.ndarray): Scale factor for quantization. + zero_point (jnp.ndarray): Zero point offset. + + Returns: + jnp.ndarray: Quantized tensor. + """ val = jnp.round(x / scale) + zero_point val = jnp.clip(val, jnp.iinfo(dtype).min, jnp.iinfo(dtype).max) return val.astype(dtype) @@ -96,21 +157,69 @@ def quantize_tensor_int_asymmetric(x, scale, zero_point): def get_dequantize_fun(dtype=jnp.float32, asymmetric=False): + """Create a dequantization function for the specified dtype. + + Args: + dtype (jnp.dtype): Output dtype after dequantization. + asymmetric (bool): Whether to use asymmetric dequantization. + + Returns: + Callable: Function that dequantizes tensors. + """ + @partial(jax.lax.composite, name="inc.dequantize") def dequantize(x, scale): + """Dequantize a tensor by applying the scale. + + Args: + x (jnp.ndarray): Quantized tensor. + scale (jnp.ndarray): Scale factor used for quantization. + + Returns: + jnp.ndarray: Dequantized tensor. + """ return x.astype(dtype) * scale @partial(jax.lax.composite, name="inc.dequantize_asymmetric") def dequantize_asymmetric(x, scale, zero_point=jnp.array(0, dtype=dtype)): + """Dequantize a tensor with asymmetric scaling. + + Args: + x (jnp.ndarray): Quantized tensor. + scale (jnp.ndarray): Scale factor used for quantization. + zero_point (jnp.ndarray): Zero point offset. + + Returns: + jnp.ndarray: Dequantized tensor. + """ return (x.astype(dtype) - zero_point) * scale return dequantize_asymmetric if asymmetric else dequantize def get_scale(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32): + """Compute the quantization scale for a weight tensor. + + Args: + orig_weight (jnp.ndarray): Weight tensor to analyze. + dtype (jnp.dtype): Target quantized dtype. + compute_dtype (jnp.dtype): dtype for scale computation. + + Returns: + jnp.ndarray: Computed scale tensor. + """ + # fp8 quantization @partial(jax.lax.composite, name="inc.get_scale_fp8") def float_get_scale(orig_weight): + """Compute scale for floating-point quantization. + + Args: + orig_weight (jnp.ndarray): Weight tensor to analyze. + + Returns: + jnp.ndarray: Computed scale tensor. + """ if 0 in orig_weight.shape: # For empty tensor, return scale as 1.0 return jnp.array(1.0, dtype=compute_dtype) @@ -122,6 +231,14 @@ def float_get_scale(orig_weight): @partial(jax.lax.composite, name="inc.get_scale_int") def integer_get_scale(orig_weight): + """Compute scale for integer quantization. + + Args: + orig_weight (jnp.ndarray): Weight tensor to analyze. + + Returns: + jnp.ndarray: Computed scale tensor. + """ if 0 in orig_weight.shape: # For empty tensor, return scale as 1.0 return jnp.array(1.0, dtype=compute_dtype) @@ -140,9 +257,29 @@ def integer_get_scale(orig_weight): def get_q_params(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False): + """Compute quantization scale and zero-point for a weight tensor. + + Args: + orig_weight (jnp.ndarray): Weight tensor to analyze. + dtype (jnp.dtype): Target quantized dtype. + compute_dtype (jnp.dtype): dtype for scale computation. + asymmetric (bool): Whether to compute asymmetric quantization parameters. + + Returns: + Tuple[jnp.ndarray, Optional[jnp.ndarray]]: Scale and zero-point. Zero-point is `None` for floating-point + dtypes or symmetric quantization. + """ @partial(jax.lax.composite, name="inc.get_q_params_int") def integer_get_q_params(orig_weight): + """Compute scale and zero-point for integer quantization. + + Args: + orig_weight (jnp.ndarray): Weight tensor to analyze. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: Scale and zero-point tensors. + """ if 0 in orig_weight.shape: # For empty tensor, return scale as 1.0 return jnp.array(1.0, dtype=compute_dtype), jnp.array(0.0, dtype=compute_dtype) @@ -169,13 +306,27 @@ def print_model(container, max_lines=999999, internal=True, str_length=(0, 0), p """Print the model structure. Args: - container: The model or layer to be printed. - max_lines: The maximum number of elements to print. - internal: Whether to print layers from internal _layers (True) or public layers API (False). - str_length: Tuple with max lengths for class name and path. + container (keras.Model): The model or layer to be printed. + max_lines (int): The maximum number of elements to print. + internal (bool): Whether to print layers from internal _layers (True) or public layers API (False). + str_length (Tuple[int, int]): Tuple with max lengths for class name and path. + path (str): Prefix path for the current layer. + + Returns: + None: Logs model structure via the logger. """ def get_str_length(container, max_long=(0, 0), path=""): + """Compute maximum string lengths for aligned model printing. + + Args: + container (keras.Layer): Layer or model to inspect. + max_long (Tuple[int, int]): Current maximum lengths. + path (str): Path prefix for this layer. + + Returns: + Tuple[int, int]: Updated maximum lengths for class name and path. + """ current = (len(container.__class__.__name__), len(path)) max_long = (max(current[0], max_long[0]), max(current[1], max_long[1])) if hasattr(container, "_layers"): @@ -228,14 +379,28 @@ def get_str_length(container, max_long=(0, 0), path=""): def causal_lm_make_replace_generate_function(self, revert=False): - """Replace generate function for the model to version suitable for calibration, - where non-trainable are also stored. + """Replace generate function for calibration and restore on demand. + + Args: + self (keras.Model): Causal language model instance to modify. + revert (bool): When True, restore the original generate function. - For revert=True, restore the original generate function. + Returns: + Callable: Updated generate function. """ @partial(jax.jit, static_argnames=["stop_token_ids"]) def compiled_generate_function(inputs, stop_token_ids, state): + """JIT-compiled generate function for calibration-friendly state handling. + + Args: + inputs (jnp.ndarray): Input tokens for generation. + stop_token_ids (Tuple[int, ...]): Token IDs used to stop generation. + state (Tuple[Any, Any, Any]): Tuple of sampler, trainable, and non-trainable variables. + + Returns: + Tuple[Any, List[Any], List[Any]]: Outputs, updated non-trainable variables, and sampler variables. + """ ( sampler_variables, trainable_variables, @@ -265,6 +430,15 @@ def wrapped_generate_function( inputs, stop_token_ids=None, ): + """Wrapper around generate_step to preserve variable state. + + Args: + inputs (jnp.ndarray): Input tokens for generation. + stop_token_ids (Optional[Tuple[int, ...]]): Token IDs used to stop generation. + + Returns: + Any: Model outputs from generate_step. + """ if isinstance(stop_token_ids, list): stop_token_ids = tuple(stop_token_ids) @@ -302,7 +476,16 @@ def wrapped_generate_function( def iterate_over_layers(model, operations, /, *, filter_function: Optional[Callable] = lambda _: True): + """Apply operations to model layers matching the filter function. + Args: + model (keras.Model): Keras model with a _flatten_layers iterator. + operations (Iterable[Callable]): Operations to apply to each layer. + filter_function (Callable, optional): Predicate to select layers. Defaults to always True. + + Returns: + keras.Model: The original model after operations have been applied. + """ for layer in model._flatten_layers(): if filter_function(layer.__class__): @@ -313,7 +496,16 @@ def iterate_over_layers(model, operations, /, *, filter_function: Optional[Calla def verify_api(orig_cls, quant_cls, method_name): - """Check if quantized layer method API matches original layer method API.""" + """Check if quantized layer method API matches original layer method API. + + Args: + orig_cls (type): Original layer class. + quant_cls (type): Quantized layer class. + method_name (str): Method name to compare. + + Returns: + None: Logs an error if the method signatures differ. + """ orig_method = getattr(orig_cls, method_name) quant_method = getattr(quant_cls, method_name) if inspect.signature(orig_method) != inspect.signature(quant_method): diff --git a/neural_compressor/tensorflow/algorithms/smoother/calibration.py b/neural_compressor/tensorflow/algorithms/smoother/calibration.py index 48bb19b5d7a..004bcdb27c0 100644 --- a/neural_compressor/tensorflow/algorithms/smoother/calibration.py +++ b/neural_compressor/tensorflow/algorithms/smoother/calibration.py @@ -125,6 +125,7 @@ def _inference_for_calibration(self, model): # sometimes the input_tensor is not the same order with inputs # we should check and pair them def check_shape(tensor, data): + """Validate that a tensor shape matches the sample data.""" # scalar or 1 dim default True if ( tensor.shape is None @@ -451,6 +452,7 @@ def _inference_for_calibration(self, model): del sampling_graph_def def _get_weight_tensors(self): + """Load and cache weight tensors needed for smooth quantization.""" model = load.load(self.model, [tag_constants.SERVING]) for weight_tensor in model.variables: parsed_name = self.weight_name_mapping(weight_tensor.name) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index a93b225c84e..f844ca038e4 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -133,6 +133,7 @@ def _check_quantize_format(self, model): break def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers): # pragma: no cover + """Fuse batch normalization into convolution layers for Keras 3 graphs.""" fuse_layers = [] fused_bn_name = "" for idx, layer in enumerate(fp32_layers): @@ -180,6 +181,7 @@ def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers): # pragma: no cover return fuse_layers def _fuse_bn_keras2(self, fuse_conv_bn, fp32_layers): # pragma: no cover + """Fuse batch normalization into convolution layers for Keras 2 graphs.""" fuse_layers = [] for idx, layer in enumerate(fp32_layers): if hasattr(layer, "_inbound_nodes"): @@ -250,6 +252,17 @@ def _fuse_bn(self, model): # pragma: no cover fp32_layers = fuse_bn_model.layers def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): + """Fuse convolution weights with batch normalization parameters. + + Args: + conv_weight (list[np.ndarray]): Convolution weights and optional bias. + bn_weight (list[np.ndarray]): BatchNorm parameters. + conv_type (str): Convolution layer type. + eps (float): Epsilon used in BatchNorm. + + Returns: + list[np.ndarray]: Fused convolution weights (and bias). + """ assert conv_type in [ "Conv2D", "DepthwiseConv2D", diff --git a/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py b/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py index 3bf9cff80af..1fb6b05fcc5 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py @@ -110,6 +110,7 @@ def __init__(self, framework_specific_info): self._last_dequantize_ops = None def _check_itex(self): # pragma: no cover + """Verify that Intel Extension for TensorFlow is installed.""" try: import intel_extension_for_tensorflow except: @@ -587,6 +588,7 @@ def _query_fw_capability(self, model): ) def check_match(patterns, input_pattern): + """Check whether an input pattern matches a configured pattern list.""" for i in patterns: if input_pattern == [i for i in i.replace("+", " ").strip().split(" ") if i]: # pragma: no cover return True @@ -946,6 +948,7 @@ def _get_specified_version_cfg(self, data): config = None def _compare(version1, version2): + """Compare two TensorFlow version strings for sorting.""" if parse_version(version1) == parse_version(version2): # pragma: no cover return 0 elif parse_version(version1) < parse_version(version2): @@ -1414,6 +1417,7 @@ def generate_internal_patterns(self): """Translate the patterns defined in the yaml to internal pattern expression.""" def _generate_pattern(data): + """Generate a normalized internal pattern from op sequences.""" length = [len(i) for i in data] res = [] for index in range(max(length)): diff --git a/neural_compressor/tensorflow/quantization/config.py b/neural_compressor/tensorflow/quantization/config.py index 10db0249f35..3d181ef0783 100644 --- a/neural_compressor/tensorflow/quantization/config.py +++ b/neural_compressor/tensorflow/quantization/config.py @@ -47,6 +47,8 @@ class OperatorConfig(NamedTuple): + """Configuration tuple describing operators and validation functions.""" + config: BaseConfig operators: List[Union[str, Callable]] valid_func_list: List[Callable] = [] diff --git a/neural_compressor/tensorflow/quantization/utils/graph_converter.py b/neural_compressor/tensorflow/quantization/utils/graph_converter.py index e3c1c640c86..01871f08e50 100644 --- a/neural_compressor/tensorflow/quantization/utils/graph_converter.py +++ b/neural_compressor/tensorflow/quantization/utils/graph_converter.py @@ -291,6 +291,7 @@ def _inference(self, model): # sometimes the input_tensor is not the same order with inputs # we should check and pair them def check_shape(tensor, data): + """Validate that a tensor shape matches the sample data.""" # scalar or 1 dim default True if ( tensor.shape is None @@ -330,6 +331,11 @@ def check_shape(tensor, data): os.environ["ITEX_REMAPPER"] = "1" def _inference_llm(self, model): + """Run inference for large language models during calibration. + + Args: + model: TensorFlow model wrapper with input_tensor_names and model signatures. + """ input_tensor_names = model.input_tensor_names auto_trackable = model.model infer = auto_trackable.signatures["serving_default"] diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py index 924536db696..19b97eed60c 100644 --- a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py @@ -61,6 +61,7 @@ def _fold_value(self, end_node_name): end_node = self.graph_info[end_node_name].node def can_broadcast(s1, s2): + """Check whether two tensor shapes can be broadcast together.""" if s1.shape and s2.shape: s1a = np.asarray(s1.shape) s2a = np.asarray(s2.shape) diff --git a/neural_compressor/tensorflow/quantization/utils/graph_util.py b/neural_compressor/tensorflow/quantization/utils/graph_util.py index 2e6a745366e..167ae6eb25c 100644 --- a/neural_compressor/tensorflow/quantization/utils/graph_util.py +++ b/neural_compressor/tensorflow/quantization/utils/graph_util.py @@ -244,6 +244,7 @@ def _search_patterns(self, input_pattern): """ def _validate_input(data, criteria): + """Check whether a node op matches the criteria.""" if isinstance(criteria, str) and data == criteria: return True @@ -275,6 +276,7 @@ def _compare_list(list_a, list_b): return is_subset def _dfs(op_names, op_types, graph_info, node, pattern): + """Depth-first search to match graph patterns.""" if pattern == []: return start_index = 0 @@ -1025,6 +1027,7 @@ def gen_valid_sampling_log(log_path): """ def gen_per_iter(data): + """Normalize per-iteration log entries for min/max values.""" res = [] requant_tmp = [] for i in data: diff --git a/neural_compressor/tensorflow/quantization/utils/utility.py b/neural_compressor/tensorflow/quantization/utils/utility.py index 5e3fa83ea90..e9a45e39a7b 100644 --- a/neural_compressor/tensorflow/quantization/utils/utility.py +++ b/neural_compressor/tensorflow/quantization/utils/utility.py @@ -315,6 +315,7 @@ def strip_equivalent_nodes(graph_def, output_node_names): stripped_graph_info = stripped_graph.parse_graph() def is_equivalent_input(input_tensor_list_1, input_tensor_list_2): + """Check whether two input tensor lists are equivalent.""" if len(input_tensor_list_1) != len(input_tensor_list_2): return False const_num = 0 @@ -435,6 +436,7 @@ def generate_feed_dict(input_tensor, inputs): # sometimes the input_tensor is not the same order with inputs # we should check and pair them def check_shape(tensor, data): + """Validate that a tensor shape matches the sample data.""" # scalar or 1 dim default True if tensor.shape is None or len(tensor.shape.dims) == 1 or not hasattr(data, "shape"): return True diff --git a/neural_compressor/tensorflow/utils/data.py b/neural_compressor/tensorflow/utils/data.py index 5854e45ad75..008b7c29bd9 100644 --- a/neural_compressor/tensorflow/utils/data.py +++ b/neural_compressor/tensorflow/utils/data.py @@ -337,6 +337,20 @@ def _generate_dataloader( shuffle, distributed, ): + """Yield batches from the dataset using the configured sampler. + + Args: + dataset: Dataset to iterate. + batch_size (int): Batch size. + last_batch: Last batch handling mode. + collate_fn (Callable): Function to collate batch samples. + sampler: Sampler instance (unused; generated internally). + batch_sampler: Batch sampler instance (unused; generated internally). + num_workers (int): Worker count (unused for TF). + pin_memory (bool): Pin memory flag. + shuffle (bool): Whether to shuffle (handled by sampler). + distributed (bool): Whether to use distributed sampling. + """ sampler = self._generate_sampler(dataset, distributed) self.batch_sampler = BatchSampler(sampler, batch_size, self.drop_last) @@ -353,6 +367,15 @@ def _generate_dataloader( return def _generate_sampler(self, dataset, distributed): + """Create a sampler based on dataset capabilities. + + Args: + dataset: Dataset object to inspect. + distributed (bool): Whether to use distributed sampling. + + Returns: + Sampler: IterableSampler or SequentialSampler depending on dataset type. + """ if hasattr(dataset, "__getitem__"): self.dataset_type = "index" return SequentialSampler(dataset, distributed) diff --git a/neural_compressor/tensorflow/utils/model_wrappers.py b/neural_compressor/tensorflow/utils/model_wrappers.py index 5740bd882fc..4ac85fac3c5 100644 --- a/neural_compressor/tensorflow/utils/model_wrappers.py +++ b/neural_compressor/tensorflow/utils/model_wrappers.py @@ -253,6 +253,14 @@ def frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs): def _contains_function_with_implements_attr(saved_model_proto): + """Check whether SavedModel functions declare implementation attributes. + + Args: + saved_model_proto: Loaded SavedModel protocol buffer. + + Returns: + bool: True if a function contains _implements or api_implements attributes. + """ meta_graph = saved_model_proto.meta_graphs[0] for function in meta_graph.graph_def.library.function: if function.attr.get("_implements", None) or function.attr.get("api_implements", None): # pragma: no cover @@ -954,6 +962,15 @@ def graph_def(self, graph_def): self.model_type = "graph_def" def _load_sess(self, model, **kwargs): + """Load a TensorFlow session for the wrapped model. + + Args: + model: Model path or graph object. + **kwargs: Additional session creation arguments. + + Returns: + tf.compat.v1.Session: Initialized session object. + """ if self.name: kwargs.update({"name": self.name}) # assert self.model_type, 'model type not set....' diff --git a/neural_compressor/tensorflow/utils/utility.py b/neural_compressor/tensorflow/utils/utility.py index 279f091a63a..3d2be1ecf16 100644 --- a/neural_compressor/tensorflow/utils/utility.py +++ b/neural_compressor/tensorflow/utils/utility.py @@ -74,6 +74,14 @@ def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.k """ def decorator(algo_func): + """Register the algorithm function in the global mapping. + + Args: + algo_func (Callable): Algorithm implementation to register. + + Returns: + Callable: The original algorithm function. + """ algos_mapping[name] = algo_func return algo_func @@ -115,7 +123,10 @@ def dump_elapsed_time(customized_msg=""): """ def f(func): + """Wrap the function to log elapsed execution time.""" + def fi(*args, **kwargs): + """Execute the function and log elapsed time.""" start = time.time() res = func(*args, **kwargs) end = time.time() @@ -185,7 +196,10 @@ def disable_random(seed=1): import tensorflow as tf def decorator(func): + """Decorate a function to disable TensorFlow randomness.""" + def wrapper(*args, **kw): + """Reset graph state and run the wrapped function.""" tf.compat.v1.disable_eager_execution() tf.compat.v1.reset_default_graph() tf.compat.v1.set_random_seed(seed) diff --git a/neural_compressor/torch/algorithms/autoround/autoround.py b/neural_compressor/torch/algorithms/autoround/autoround.py index cabb1e99468..528999b9b4c 100644 --- a/neural_compressor/torch/algorithms/autoround/autoround.py +++ b/neural_compressor/torch/algorithms/autoround/autoround.py @@ -23,6 +23,11 @@ @lru_cache(None) def _is_auto_round_available(): + """Check whether the AutoRound package is importable. + + Returns: + bool: True when the auto_round package can be imported. + """ try: import auto_round # pylint: disable=E0401 except ImportError: @@ -137,6 +142,11 @@ def __init__( self.device = self.accelerator.name() def _is_w4afp8(self) -> bool: + """Return whether the configuration requests W4AFP8 quantization. + + Returns: + bool: True when using fp8_to_int_sym data type. + """ return self.data_type == "fp8_to_int_sym" def prepare(self, model: torch.nn.Module, *args, **kwargs): @@ -236,13 +246,12 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42 Args: tokenizer (Tokenizer): The tokenizer to use for tokenization. - seqlen (int): The exact sequence length. samples < seqlen will be dropped, - samples longer than seqlen will be truncated + seqlen (int): The exact sequence length. Samples shorter than `seqlen` will be dropped, + and samples longer than `seqlen` will be truncated. dataset_name (str, optional): The name of the dataset or datasets separated by commas. - Defaults to "NeelNanda/pile-10k". - split (str, optional): The data split to use. Defaults to None. + Defaults to "NeelNanda/pile-10k". seed (int, optional): The random seed for reproducibility. Defaults to 42. - bs (int, optional): The batch size. Defaults to 4. + bs (int, optional): The batch size. Defaults to 8. nsamples (int, optional): The total number of samples to include. Defaults to 128. Returns: @@ -275,16 +284,23 @@ def get_mllm_dataloader( """Generate a DataLoader for calibration using specified parameters. Args: - template (Template): The template to specify process for different mllms. - model (Model): The model to quantized. + model (Model): The model to quantize. tokenizer (Tokenizer): The tokenizer to use for tokenization. - Dataset_name (str): The name or path of the dataset. - extra_data_dir (str): The path for extra data such as images, audio or videos. - seqlen (int): The exact sequence length. samples < seqlen will be dropped, - samples longer than seqlen will be truncated - bs (int, optional): The batch size. Defaults to 4. + template (Template, optional): The template to specify process for different MLLMs. + processor (transformers.AutoProcessor, optional): The processor for multi-modal inputs. + image_processor (object, optional): The image processor for multi-modal inputs. + dataset (str, optional): The name or path of the dataset. + extra_data_dir (str, optional): The path for extra data such as images, audio, or videos. + seqlen (int, optional): The exact sequence length. Samples shorter than `seqlen` will be dropped, + and samples longer than `seqlen` will be truncated. + batch_size (int, optional): The batch size. Defaults to 8. split (str, optional): The data split to use. Defaults to None. - apply_template: Whether to apply chat template in tokenization. + apply_template (bool, optional): Whether to apply chat template in tokenization. + truncation (bool, optional): Whether to truncate sequences during tokenization. + seed (int, optional): The random seed for reproducibility. Defaults to 42. + nsamples (int, optional): The total number of samples to include. Defaults to 128. + gradient_accumulate_steps (int, optional): The number of gradient accumulation steps. Defaults to 1. + quant_nontext_module (bool, optional): Whether to quantize non-text modules. Defaults to False. Returns: DataLoader: The DataLoader for the calibrated datasets. diff --git a/neural_compressor/torch/algorithms/fp8_quant/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/__init__.py index 5fc72d88908..255f1f106e3 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/__init__.py +++ b/neural_compressor/torch/algorithms/fp8_quant/__init__.py @@ -1,3 +1,5 @@ +"""Public entry points for the FP8 quantization algorithm.""" + # Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py index 28f108cb636..f93f597390d 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py @@ -1,3 +1,5 @@ +"""Internal helpers for FP8 quantization.""" + # Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py index 456bc3aa7a8..6ddeabdf9b8 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -1,3 +1,5 @@ +"""Shared utilities for FP8 quantization configuration and scale files.""" + # Copyright (c) 2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +31,15 @@ UNMEASURED_MODELS = "UnmeasuredModels" def dequant_original_fp8_weight_if_needed(mod: torch.nn.Module, param: torch.Tensor) -> torch.Tensor: + """Dequantize FP8 weights using the module hook when required. + + Args: + mod (torch.nn.Module): Module that may provide a dequantization callback. + param (torch.Tensor): Parameter tensor to inspect for FP8 data types. + + Returns: + torch.Tensor: The original parameter or a dequantized replacement. + """ if param.dtype in [torch.float8_e4m3fn]: if hasattr(mod, "get_dequant_weights_func"): dequant_weights_func = mod.get_dequant_weights_func() @@ -42,16 +53,39 @@ def dequant_original_fp8_weight_if_needed(mod: torch.nn.Module, param: torch.Ten return param class QuantTensorType(Enum): + """Enum describing the type of quantized tensor representation. + + Attributes: + MEASUREMENTS: Tensor values captured from calibration measurements. + CONST: Constant quantized tensors derived from stored parameters. + DYNAMIC: Dynamically computed quantized tensors at runtime. + """ + MEASUREMENTS = auto() CONST = auto() DYNAMIC = auto() class ShapeList: + """Container for transporting shape information through format conversions. + + Attributes: + data (list[int] | None): Shape dimensions carried with the wrapper. + """ + data = None def rec_fn(x, fn): + """Recursively apply a function to nested container values. + + Args: + x (Any): Input object that can be a dict, list, tuple, or leaf value. + fn (Callable[[Any], Any]): Function to apply to each leaf value. + + Returns: + Any: Structure with the same container layout and transformed leaves. + """ if isinstance(x, dict): return {k: rec_fn(x[k], fn) for k in x} elif isinstance(x, list): @@ -63,26 +97,67 @@ def rec_fn(x, fn): def save_json(d, fname): + """Save a Python dictionary to a JSON file. + + Args: + d (dict): Data to serialize. + fname (str): Destination file path. + """ with open(fname, "w") as f: json.dump(d, f, indent=4) def load_json(fname): + """Load a JSON file into a Python dictionary. + + Args: + fname (str): Source file path. + + Returns: + dict: Deserialized JSON content. + """ with open(fname, "r") as f: d = json.load(f) return d def save_npz(d, fname): + """Save a Python object to a NumPy NPZ archive. + + Args: + d (Any): Object to store in the archive. + fname (str): Destination file path. + """ np.savez(fname, d) def load_npz(fname): + """Load a Python object from a NumPy NPZ archive. + + Warning: + NumPy may deserialize pickled objects during load. Only use trusted files. + + Args: + fname (str): Source file path to the NPZ archive. + + Returns: + Any: The stored object from the archive. + """ d = np.load(fname, allow_pickle=True) return d["arr_0"].item() def save_file(model, d, source_format, fname, mode, num_samples=0): + """Persist a scale-related payload to disk in the chosen format. + + Args: + model (torch.nn.Module): Model associated with the data. + d (dict): Source data keyed by module name. + source_format (type): In-memory data format type. + fname (str): Target file path. + mode (str): Label describing the saved data (e.g., "Scale"). + num_samples (int, optional): Optional number of calibration samples. Defaults to 0. + """ from .._quant_common.quant_config import get_hqt_config config = get_hqt_config(model) logger.debug("Saving %s file: %s", mode, fname) @@ -104,6 +179,16 @@ def save_file(model, d, source_format, fname, mode, num_samples=0): def load_file(fname, target_format, fail_on_file_not_exist): + """Load a scale file and convert it to module configuration objects. + + Args: + fname (str): Source file path. + target_format (type): Desired data format type for loaded content. + fail_on_file_not_exist (bool): Whether to raise if the file is missing. + + Returns: + dict[str, ModuleConfig]: Mapping of module names to configuration objects. + """ logger.debug("Loading file: %s", fname) ext = os.path.splitext(fname)[1] source_format = file_functions[ext]['format'] @@ -122,6 +207,15 @@ def load_file(fname, target_format, fail_on_file_not_exist): # convert module config data to other format def module_convert(m, fcn): + """Convert a module configuration to the specified data format. + + Args: + m (ModuleConfig): Module configuration to convert. + fcn (Callable[[Any], Any]): Conversion function applied to all tensor-like values. + + Returns: + ModuleConfig: Converted configuration instance. + """ mt = ModuleConfig( tuple([fcn(x) for x in m.inputs]), ( @@ -137,6 +231,14 @@ def module_convert(m, fcn): def fix_fields(d): + """Normalize legacy field names in serialized configuration data. + + Args: + d (dict): Input dictionary to update in-place. + + Returns: + dict: The updated dictionary with normalized keys. + """ if "input" in d: d["inputs"] = d.pop("input") if "output" in d: @@ -170,6 +272,17 @@ def load_scales(fname, target_format): def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, device=cur_device): + """Convert stored scale objects to tensors on the requested device. + + Args: + scales_obj (dict[str, ModuleConfig]): Module scales loaded from disk. + scales_file_format (type): Format used in the stored file. + hp_dtype (torch.dtype): Target tensor dtype. + device (str, optional): Device identifier to move tensors to. Defaults to current device. + + Returns: + dict[str, ModuleConfig]: Scales represented as tensors in the target dtype/device. + """ scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj} scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp) scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device=device)) @@ -200,6 +313,14 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev def get_device_type_for_scales(mod): + """Return the device name used for storing scales for a module. + + Args: + mod (torch.nn.Module): Module whose quantization config is consulted. + + Returns: + str: Device type string for scale storage. + """ from .._quant_common.quant_config import get_hqt_config config = get_hqt_config(mod).cfg return config["device_for_scales"] @@ -207,14 +328,29 @@ def get_device_type_for_scales(mod): @lru_cache def is_runtime_scale_patching(): + """Check whether runtime scale patching is enabled via environment variable. + + Returns: + bool: True when runtime patching is enabled. + """ return os.getenv("RUNTIME_SCALE_PATCHING", "False").lower() in ["true", "1"] #TODO [SW-224612]: Use cguid to calc scales and remove the check @lru_cache def is_calc_scale_with_cguid(): + """Check whether scale calculation uses cguid logic. + + Returns: + bool: True when cguid-based scale calculation is enabled. + """ return os.getenv("CALC_SCALE_WITH_CGUID", "True").lower() in ["true", "1"] #TODO [SW-224612]: Use cguid to calc scales and remove the check @lru_cache def is_calc_scale_rounding_with_cguid(): + """Check whether scale rounding uses cguid-based configuration. + + Returns: + bool: True when rounding with cguid is enabled. + """ return is_calc_scale_with_cguid() and os.getenv("CALC_ROUNDING_WITH_CGUID", "False").lower() in ["true", "1"] diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/external_func_impl.py b/neural_compressor/torch/algorithms/fp8_quant/_core/external_func_impl.py index a32337b8cb7..8f6e072d13a 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/external_func_impl.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/external_func_impl.py @@ -1,3 +1,5 @@ +"""External distributed collective function import helpers.""" + # Copyright (c) 2025 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,9 +34,23 @@ tensor_model_parallel_all_reduce = None def get_external_column_parallel_collective_func(): - assert tensor_model_parallel_all_gather is not None, "Couldn't import function tensor_model_parallel_all_gather from external source" + """Return the column-parallel all-gather collective from external runtime. + + Returns: + Callable: The tensor-model-parallel all-gather function. + """ + assert ( + tensor_model_parallel_all_gather is not None + ), "Could not import function tensor_model_parallel_all_gather from external source" return tensor_model_parallel_all_gather def get_external_row_parallel_collective_func(): - assert tensor_model_parallel_all_reduce is not None, "Couldn't import function tensor_model_parallel_all_reduce from external source" + """Return the row-parallel all-reduce collective from external runtime. + + Returns: + Callable: The tensor-model-parallel all-reduce function. + """ + assert ( + tensor_model_parallel_all_reduce is not None + ), "Could not import function tensor_model_parallel_all_reduce from external source" return tensor_model_parallel_all_reduce diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py index 52bf10a4caf..48e116629ad 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py @@ -1,3 +1,5 @@ +"""Scale loading helpers for FP8 quantization preparation.""" + # Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,8 +25,41 @@ -def load_layer_scales(mod, mod_name, config, mod_type_str, measurement, scales, scale_file, scales_file_format, - scales_obj, scaling_method_config, scale_config, save_file, scale_method_config_by_mod_map): +def load_layer_scales( + mod, + mod_name, + config, + mod_type_str, + measurement, + scales, + scale_file, + scales_file_format, + scales_obj, + scaling_method_config, + scale_config, + save_file, + scale_method_config_by_mod_map, +): + """Load or calculate scales for a module and build quant-dequant configs. + + Args: + mod (torch.nn.Module): Target module. + mod_name (str): Qualified module name. + config: Quantization configuration object. + mod_type_str (str): Module type identifier string. + measurement (dict): Measurements collected during calibration. + scales (dict): Cached module scale configurations. + scale_file (str | None): Optional path to persist scale data. + scales_file_format (type): File format for serialized scales. + scales_obj (dict): Cache of serialized scale objects. + scaling_method_config: Scaling method configuration. + scale_config: Scale configuration object. + save_file (bool): Flag indicating whether to write updated scales to disk. + scale_method_config_by_mod_map (dict): Output map for per-module method configs. + + Returns: + tuple[ModuleExtraConfig | None, bool]: The module extra config and updated save flag. + """ module_type = mod_default_dict[mod_type_str].type logger.debug( "Preparing quantization functions for module %s module_type=%s", diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 26209b60beb..dc963e53920 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -55,6 +55,14 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: def _deepcopy_warp(model): + """Create a deep copy of the model while preserving specific attributes. + + Args: + model (torch.nn.Module): The model to deep copy. + + Returns: + torch.nn.Module: A deep copy of the model with preserved attributes. + """ additional_attr_lst = ["_exported", "dynamic_shapes"] original_attr = {key: getattr(model, key, None) for key in additional_attr_lst} new_model = deepcopy(model) @@ -64,7 +72,15 @@ def _deepcopy_warp(model): def _preprocess_model_quant_config(model, quant_config): - """Preprocess model and quant config before quantization.""" + """Preprocess model and quant config before quantization. + + Args: + model (torch.nn.Module): The model to be quantized. + quant_config (TuningConfig): The quantization configuration to preprocess. + + Returns: + Tuple[torch.nn.Module, TuningConfig]: The preprocessed model and quantization configuration. + """ for config in quant_config.config_set: # handle tokenizer attribute in AutoRoundConfig if isinstance(config, AutoRoundConfig): @@ -88,8 +104,8 @@ def autotune( """The main entry of auto-tune. Args: - model (torch.nn.Module): _description_ - tune_config (TuningConfig): _description_ + model (torch.nn.Module): The model to be quantized. + tune_config (TuningConfig): The configuration for the auto-tuning process. eval_fn (Callable): for evaluation of quantized models. eval_args (tuple, optional): arguments used by eval_fn. Defaults to None. run_fn (Callable, optional): for calibration to quantize model. Defaults to None. diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index b1aeae4bd4e..a85dc17e35b 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1767,7 +1767,7 @@ def __init__( allowlist (dict, optional): Whether to execute fp8 quantization for specific op names or types. Defaults to {"names": [], "types": FP8_WHITE_LIST}. mode (str, optional): Choose the quantization mode. Defaults to "AUTO". scale_method (str or dict, optional): Select method used to generate scale from calibration info. Can be a string or a dict. Defaults to "maxabs_hw". - scale_params (dict, optional): _description_. Defaults to {}. + scale_params (dict, optional): Scaling parameters that override the default ones for specific modules. Defaults to {}. observer (str, optional): Params of scales. Defaults to "maxabs". mod_dict (dict, optional): The dict of modules to quantize. Defaults to {}. measure_exclude (str, optional): Select INPUT/OUTPUT to be exculded by measurement. Defaults to "OUTPUT". diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index d9bad24283b..19ed0b1b3d7 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -50,15 +50,14 @@ def preprocess_quant_config(model, quant_config, mode="prepare", example_inputs= """Preprocess the quantization configuration. Args: - model: a float model to be quantized. - quant_config: a quantization configuration. - mode (str, optional): Which mode is in use currently. Defaults to "prepare". - run_fn: a calibration function for calibrating the model. Defaults to None. - example_inputs: used to trace torch model. + model (torch.nn.Module): Float model to be quantized. + quant_config (BaseConfig | dict): Quantization configuration or configuration dictionary. + mode (str, optional): Quantization mode to prepare for. Defaults to "prepare". + example_inputs (Any, optional): Example inputs for tracing the model. Defaults to None. + run_fn (Callable | None): Calibration function for collecting statistics. Defaults to None. Returns: - model: model to be quantized. - OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: The configuration mapping. + tuple[torch.nn.Module, dict]: Updated model and configuration mapping. """ registered_configs = config_registry.get_cls_configs() if isinstance(quant_config, dict): diff --git a/neural_compressor/torch/utils/llm_utility.py b/neural_compressor/torch/utils/llm_utility.py index e93366b1945..01e671cc076 100644 --- a/neural_compressor/torch/utils/llm_utility.py +++ b/neural_compressor/torch/utils/llm_utility.py @@ -64,7 +64,7 @@ def get_default_llm_dataloader(tokenizer, dataset_name="NeelNanda/pile-10k", bs= Args: tokenizer (obj): tokenizer object. - seq_len (int, optional): _description_. Defaults to 128. + seq_len (int, optional): the sequence length of the input tokens. Defaults to 128. dataset_name (str, optional): dataset name. Defaults to "NeelNanda/pile-10k". seed (int, optional): random seed. Defaults to 42. bs (int, optional): batch size. Defaults to 8.