Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/api-doc/apis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ API
tf_quantization_common.rst
tf_quantization_config.rst
tf_quantization_autotune.rst

** JAX Extension API:**

.. toctree::
:maxdepth: 1

jax_quantization_common.rst
jax_quantization_config.rst
6 changes: 6 additions & 0 deletions docs/source/api-doc/jax_quantization_common.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
JAX Quantization Base API
#################################

.. autoapisummary::

neural_compressor.jax.quantization.quantize
6 changes: 6 additions & 0 deletions docs/source/api-doc/jax_quantization_config.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
JAX Quantization Config
============

.. autoapisummary::

neural_compressor.jax.quantization.config
30 changes: 30 additions & 0 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ class ExampleAlgorithmConfig:
"""

def decorator(config_cls):
"""Register the configuration class for the given framework and algorithm.

Args:
config_cls: Configuration class to register.

Returns:
The same configuration class for decorator chaining.
"""
cls.registered_configs.setdefault(framework_name, {})
cls.registered_configs[framework_name][algo_name] = {"priority": priority, "cls": config_cls}
return config_cls
Expand Down Expand Up @@ -207,6 +215,7 @@ def __init__(self, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_
self._white_list = white_list

def _post_init(self):
"""Populate global and local configs based on the whitelist settings."""
if self.white_list == DEFAULT_WHITE_LIST:
global_config = self.get_params_dict()
self._global_config = self.__class__(**global_config, white_list=None)
Expand Down Expand Up @@ -558,6 +567,11 @@ def expand(self) -> List[BaseConfig]:
return config_list

def _get_op_name_op_type_config(self):
"""Split local configs into op-type and op-name mappings.

Returns:
tuple[dict, dict]: Mapping of op types to configs and op names to configs.
"""
op_type_config_dict = dict()
op_name_config_dict = dict()
for name, config in self.local_config.items():
Expand Down Expand Up @@ -604,6 +618,14 @@ def to_config_mapping(

@staticmethod
def _op_type_to_str(op_type: Callable) -> str:
"""Convert an operator type to a string key.

Args:
op_type (Callable): Operator type or callable object.

Returns:
str: String identifier for the operator type.
"""
# * Ort and TF may override this method.
op_type_name = getattr(op_type, "__name__", "")
if op_type_name == "":
Expand All @@ -612,6 +634,14 @@ def _op_type_to_str(op_type: Callable) -> str:

@staticmethod
def _is_op_type(name: str) -> bool:
"""Check whether the identifier represents an operator type.

Args:
name (str): Operator identifier.

Returns:
bool: True if the identifier is an operator type, otherwise False.
"""
# * Ort and TF may override this method.
return not isinstance(name, str)

Expand Down
51 changes: 50 additions & 1 deletion neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Evaluator:
def eval_acc(model):
...

def eval_perf(molde):
def eval_perf(model):
...

# Usage
Expand Down Expand Up @@ -109,6 +109,16 @@ def evaluate(self, model) -> float:
return result

def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> float:
"""Update the aggregated objective score with a weighted evaluation result.

Args:
eval_pair (dict): Evaluation function metadata including weight.
eval_result (float): Result from the evaluation function.
overall_result (float): Current aggregated score.

Returns:
float: Updated aggregated score.
"""
return overall_result + eval_result * eval_pair[self.WEIGHT]

def get_number_of_eval_functions(self) -> int:
Expand All @@ -120,6 +130,11 @@ def get_number_of_eval_functions(self) -> int:
return len(self.eval_fn_registry)

def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None:
"""Normalize and store evaluation function metadata.

Args:
user_eval_fns (List[Dict]): User-provided evaluation function configs.
"""
self.eval_fn_registry = [
{
self.EVAL_FN: user_eval_fn_pair[self.EVAL_FN],
Expand Down Expand Up @@ -202,12 +217,28 @@ def __len__(self) -> int:

@classmethod
def _from_single_config(cls, config: BaseConfig) -> List[BaseConfig]:
"""Expand a single config into a list of configs.

Args:
config (BaseConfig): Configuration to expand.

Returns:
List[BaseConfig]: Expanded configuration list.
"""
config_list = []
config_list = config.expand()
return config_list

@classmethod
def _from_list_of_configs(cls, fwk_configs: List[BaseConfig]) -> List[BaseConfig]:
"""Expand a list of configs into a single list.

Args:
fwk_configs (List[BaseConfig]): Configurations to expand.

Returns:
List[BaseConfig]: Expanded configuration list.
"""
config_list = []
for config in fwk_configs:
config_list += cls._from_single_config(config)
Expand Down Expand Up @@ -378,12 +409,26 @@ def __init__(


class _TrialRecord:
"""Record information for a single tuning trial."""

@staticmethod
def _generate_unique_id():
"""Generate a unique identifier for a trial record.

Returns:
str: Unique identifier string.
"""
unique_id = str(uuid.uuid4())
return unique_id

def __init__(self, trial_index: int, trial_result: Union[int, float], quant_config: BaseConfig):
"""Initialize a trial record.

Args:
trial_index (int): Trial index in the tuning loop.
trial_result (Union[int, float]): Evaluation result for the trial.
quant_config (BaseConfig): Quantization configuration used for the trial.
"""
# The unique id to refer to one trial
self.trial_id = _TrialRecord._generate_unique_id()
self.trial_index = trial_index
Expand Down Expand Up @@ -525,6 +570,10 @@ def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger

Args:
tuning_config (TuningConfig): The configuration for the tuning process.

Returns:
Tuple[ConfigLoader, TuningLogger, TuningMonitor]: A tuple containing the config loader,
tuning logger, and tuning monitor.
"""
config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler)
tuning_logger = TuningLogger()
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/common/tuning_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def create_input_args_model(expect_args_type: Any):
"""

class DynamicInputArgsModel(BaseModel):
"""Pydantic model for validating dynamic input arguments."""

input_args: expect_args_type

return DynamicInputArgsModel
Expand Down
8 changes: 8 additions & 0 deletions neural_compressor/common/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ def warning_once(msg, *args, **kwargs):


def _get_log_msg(mode):
"""Map a Mode enum value to a human-readable log message.

Args:
mode (Mode): Execution mode enum.

Returns:
str | None: Log message string or None when mode is unsupported.
"""
log_msg = None
if mode == Mode.QUANTIZE:
log_msg = "Quantization"
Expand Down
24 changes: 24 additions & 0 deletions neural_compressor/common/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def brand_raw(self, brand_name):

@staticmethod
def _detect_cores():
"""Detect physical CPU core count using psutil.

Returns:
int: Number of physical CPU cores.
"""
physical_cores = psutil.cpu_count(logical=False)
return physical_cores

Expand Down Expand Up @@ -181,6 +186,11 @@ def sockets(self, num_of_sockets):
self._sockets = num_of_sockets

def _get_number_of_sockets(self) -> int:
"""Detect the number of CPU sockets available.

Returns:
int: Number of CPU sockets detected.
"""
if "arch" in self._info and "ARM" in self._info["arch"]: # pragma: no cover
return 1

Expand Down Expand Up @@ -224,7 +234,17 @@ def dump_elapsed_time(customized_msg=""):
"""

def f(func):
"""Decorator factory that times the wrapped function.

Args:
func (Callable): Function to wrap.

Returns:
Callable: Wrapped function that logs elapsed time.
"""

def fi(*args, **kwargs):
"""Execute the function and log elapsed time."""
start = time.time()
res = func(*args, **kwargs)
end = time.time()
Expand Down Expand Up @@ -288,7 +308,10 @@ def log_process(mode=Mode.QUANTIZE):
"""

def log_process_wrapper(func):
"""Wrap a function to log execution start and end."""

def inner_wrapper(*args, **kwargs):
"""Execute the wrapped function with start/end logging."""
start_log = default_tuning_logger.execution_start
end_log = default_tuning_logger.execution_end

Expand Down Expand Up @@ -321,6 +344,7 @@ def call_counter(func):
"""

def wrapper(*args, **kwargs):
"""Increment call count and invoke the wrapped function."""
FUNC_CALL_COUNTS[func.__name__] += 1
return func(*args, **kwargs)

Expand Down
12 changes: 9 additions & 3 deletions neural_compressor/jax/algorithms/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Dynamic quantization algorithm entry point for JAX models."""

# Copyright (c) 2025-2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -39,11 +41,15 @@ def dynamic_quantize(
"""Quantize model using Dynamic quantization algorithm.

Args:
model: a JAX model to be quantized.
configs_mapping: mapping of configurations for the algorithm.
model (keras.Model): JAX model to be quantized.
configs_mapping (Optional[OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]]): Mapping of configurations
for the algorithm.
quant_config (Optional[BaseConfig]): Quantization configuration for wrapper selection.
*args (Any): Additional positional arguments (unused).
**kwargs (Any): Additional keyword arguments (unused).

Returns:
q_model: the quantized model.
keras.Model: The quantized model wrapped for inference.
"""
for _, value in configs_mapping.items():
config = value
Expand Down
11 changes: 8 additions & 3 deletions neural_compressor/jax/algorithms/static.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Static quantization algorithm entry point for JAX models."""

# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -43,11 +45,14 @@ def static_quantize(
"""Quantize model using Static quantization algorithm.

Args:
model: a JAX model to be quantized.
configs_mapping: mapping of configurations for the algorithm.
model (keras.Model): JAX model to be quantized.
configs_mapping (Optional[OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]]): Mapping of configurations
for the algorithm.
quant_config (Optional[BaseConfig]): Quantization configuration for wrapper selection.
calib_function (Optional[Callable]): Calibration function used to collect activation statistics.

Returns:
q_model: the quantized model
keras.Model: The quantized model wrapped for inference.
"""
for _, value in configs_mapping.items():
config = value
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/jax/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Public JAX quantization API exports."""

# Copyright (c) 2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Loading
Loading