diff --git a/docs/source/JAX.md b/docs/source/JAX.md index 2ed73916909..50b7b7e030f 100644 --- a/docs/source/JAX.md +++ b/docs/source/JAX.md @@ -10,7 +10,8 @@ JAX ## Introduction -`neural_compressor_jax` provides an API for applying quantization on Keras models such as ViT and Gemma3. +`neural_compressor.jax` provides an API for applying quantization to Keras models such as ViT and Gemma3. +Since only JAX is supported as the Keras backend, the environment variable `KERAS_BACKEND` should be set to `jax`. The following 8-bit floating-point formats are supported: `fp8_e4m3` and `fp8_e5m2`. Quantized models can be saved and loaded using standard Keras APIs diff --git a/neural_compressor/jax/quantization/layers_dynamic.py b/neural_compressor/jax/quantization/layers_dynamic.py index 7b8586802d7..2f78963d4fc 100644 --- a/neural_compressor/jax/quantization/layers_dynamic.py +++ b/neural_compressor/jax/quantization/layers_dynamic.py @@ -37,12 +37,6 @@ verify_api, ) -if keras.config.backend() != "jax": - raise ValueError( - f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n" - 'Consider setting KERAS_BACKEND env var to "jax".' - ) - dynamic_quant_mapping = {} diff --git a/neural_compressor/jax/quantization/layers_static.py b/neural_compressor/jax/quantization/layers_static.py index e9d56e705e5..2791936a0ff 100644 --- a/neural_compressor/jax/quantization/layers_static.py +++ b/neural_compressor/jax/quantization/layers_static.py @@ -37,12 +37,6 @@ verify_api, ) -if keras.config.backend() != "jax": - raise ValueError( - f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n" - 'Consider setting KERAS_BACKEND env var to "jax".' - ) - static_quant_mapping = {} diff --git a/neural_compressor/jax/quantization/quantize.py b/neural_compressor/jax/quantization/quantize.py index 765268362e3..15c28a7f5bb 100644 --- a/neural_compressor/jax/quantization/quantize.py +++ b/neural_compressor/jax/quantization/quantize.py @@ -21,7 +21,7 @@ from neural_compressor.common import logger from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.utils import Mode, log_process -from neural_compressor.jax.utils import algos_mapping +from neural_compressor.jax.utils import algos_mapping, check_backend def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name): @@ -58,6 +58,7 @@ def quantize_model( keras.Model: The quantized model. """ # fmt: on + check_backend() if not inplace: raise NotImplementedError("Out of place quantization is not supported yet. " "Please set parameter inplace=True for quantize_model() to modify the model in-place") diff --git a/neural_compressor/jax/quantization/saving.py b/neural_compressor/jax/quantization/saving.py index ceca1014277..0626258d040 100644 --- a/neural_compressor/jax/quantization/saving.py +++ b/neural_compressor/jax/quantization/saving.py @@ -28,7 +28,7 @@ from neural_compressor.common import logger from neural_compressor.common.base_config import config_registry from neural_compressor.jax.quantization.config import FRAMEWORK_NAME, BaseConfig, DynamicQuantConfig, StaticQuantConfig -from neural_compressor.jax.utils.utility import dtype_mapping, iterate_over_layers +from neural_compressor.jax.utils.utility import check_backend, dtype_mapping, iterate_over_layers def quant_config_to_json_object(quant_config: BaseConfig) -> dict: @@ -446,6 +446,7 @@ def prepare_deserialized_quantized_model( Returns: Union[KerasQuantizedModelWrapperMixin, KerasQuantizedModelBackboneWrapper]: The transformed quantized model/backbone wrapper. """ + check_backend() model_info = quant_config.get_model_info(model) configs_mapping = quant_config.to_config_mapping(model_info=model_info) diff --git a/neural_compressor/jax/utils/__init__.py b/neural_compressor/jax/utils/__init__.py index b53e2c09606..ce4b608fe4f 100644 --- a/neural_compressor/jax/utils/__init__.py +++ b/neural_compressor/jax/utils/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neural_compressor.jax.utils.utility import algos_mapping, register_algo +from neural_compressor.jax.utils.utility import algos_mapping, register_algo, check_backend diff --git a/neural_compressor/jax/utils/utility.py b/neural_compressor/jax/utils/utility.py index e61fe112b32..34e982f153e 100644 --- a/neural_compressor/jax/utils/utility.py +++ b/neural_compressor/jax/utils/utility.py @@ -30,6 +30,23 @@ from neural_compressor.common import logger +def check_backend(raise_error=True): + """Check if the current Keras backend is JAX and log a warning or error if not.""" + + if keras.config.backend() != "jax": + message = ( + f"neural_compressor.jax only supports JAX backend, but the current Keras backend is {keras.config.backend()}. " + 'Consider setting KERAS_BACKEND env var to "jax".' + ) + if raise_error: + raise ValueError(message) + else: + logger.warning(message) + + +check_backend(raise_error=False) + + def add_fp8_support(function): """Extend a dtype size function to support FP8 dtypes.