Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ JAX

## Introduction

`neural_compressor_jax` provides an API for applying quantization on Keras models such as ViT and Gemma3.
`neural_compressor.jax` provides an API for applying quantization to Keras models such as ViT and Gemma3.
Since only JAX is supported as the Keras backend, the environment variable `KERAS_BACKEND` should be set to `jax`.
The following 8-bit floating-point formats are supported: `fp8_e4m3` and `fp8_e5m2`.

Quantized models can be saved and loaded using standard Keras APIs
Expand Down
6 changes: 0 additions & 6 deletions neural_compressor/jax/quantization/layers_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
verify_api,
)

if keras.config.backend() != "jax":
raise ValueError(
f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n"
'Consider setting KERAS_BACKEND env var to "jax".'
)

dynamic_quant_mapping = {}


Expand Down
6 changes: 0 additions & 6 deletions neural_compressor/jax/quantization/layers_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
verify_api,
)

if keras.config.backend() != "jax":
raise ValueError(
f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n"
'Consider setting KERAS_BACKEND env var to "jax".'
)

static_quant_mapping = {}


Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/jax/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
from neural_compressor.common.utils import Mode, log_process
from neural_compressor.jax.utils import algos_mapping
from neural_compressor.jax.utils import algos_mapping, check_backend


def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name):
Expand Down Expand Up @@ -58,6 +58,7 @@ def quantize_model(
keras.Model: The quantized model.
"""
# fmt: on
check_backend()
if not inplace:
raise NotImplementedError("Out of place quantization is not supported yet. "
"Please set parameter inplace=True for quantize_model() to modify the model in-place")
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/jax/quantization/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from neural_compressor.common import logger
from neural_compressor.common.base_config import config_registry
from neural_compressor.jax.quantization.config import FRAMEWORK_NAME, BaseConfig, DynamicQuantConfig, StaticQuantConfig
from neural_compressor.jax.utils.utility import dtype_mapping, iterate_over_layers
from neural_compressor.jax.utils.utility import check_backend, dtype_mapping, iterate_over_layers


def quant_config_to_json_object(quant_config: BaseConfig) -> dict:
Expand Down Expand Up @@ -446,6 +446,7 @@ def prepare_deserialized_quantized_model(
Returns:
Union[KerasQuantizedModelWrapperMixin, KerasQuantizedModelBackboneWrapper]: The transformed quantized model/backbone wrapper.
"""
check_backend()
model_info = quant_config.get_model_info(model)
configs_mapping = quant_config.to_config_mapping(model_info=model_info)

Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/jax/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from neural_compressor.jax.utils.utility import algos_mapping, register_algo
from neural_compressor.jax.utils.utility import algos_mapping, register_algo, check_backend
17 changes: 17 additions & 0 deletions neural_compressor/jax/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@
from neural_compressor.common import logger


def check_backend(raise_error=True):
"""Check if the current Keras backend is JAX and log a warning or error if not."""

if keras.config.backend() != "jax":
message = (
f"neural_compressor.jax only supports JAX backend, but the current Keras backend is {keras.config.backend()}. "
'Consider setting KERAS_BACKEND env var to "jax".'
)
if raise_error:
raise ValueError(message)
else:
logger.warning(message)


check_backend(raise_error=False)


def add_fp8_support(function):
"""Extend a dtype size function to support FP8 dtypes.

Expand Down
Loading