Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MIGraphX EP] Add migx ep fp8 support and int4 weights #23534

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

TedThemistokleous
Copy link
Contributor

@TedThemistokleous TedThemistokleous commented Jan 29, 2025

  • Add fp8 and int4 types in supported list for Onnxruntime EP

  • Add support for int4 inputs

Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands

  • Add flag to allow for fp8 quantization through Onnxruntime API

  • Add fp8 quantization to the compile stage of the MIGraphX EP

Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API

  • cleanup logging

  • Cleanup and encapsulate quantization / compile functions

  • Add additional flags for fp8 thats shared for int8

  • Add lockout warning message when int8/fp8 used at the same time

  • Run lintrunner pass

  • Fix session options inputs + add better logging.

Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session

  • Fix naming for save/load path varibles to be consistent with enable.

  • Print only env variables that are set as warnings

need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs.


Description

Changes to cleanup the MIGraphX EP quantization code as well as adding fp8 quantization support along with int4 support.

Cleanup changes handle a few instances of issues seen with the python interface when taking in provider options

Motivation and Context

Required as we fix ignored flags when using provider_options for the MIGraphX EP
Adding fp8 quantization through the MIGraphX API
Adding int4 weight support for packed int4 weights for MIGraphX inference

* Add fp8 and int4 types in supported list for Onnxruntime EP

* Add support for int4 inputs

Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands

* Add flag to allow for fp8 quantization through Onnxruntime API

* Add fp8 quantization to the compile stage of the MIGraphX EP

Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API

* cleanup logging

* Cleanup and encapsulate quantization / compile functions

- Add additional flags for fp8 thats shared for int8

- Add lockout warning message when int8/fp8 used at the same time

* Run lintrunner pass

* Fix session options inputs + add better logging.

Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session

* Fix naming for save/load path varibles to be consistent with  enable.

* Print only env variables that are set as warnings

need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs.

---------

Co-authored-by: Ted Themistokleous <[email protected]>
@TedThemistokleous TedThemistokleous changed the title Add migx ep fp8 support and int4 weights [MIGraphX EP] Add migx ep fp8 support and int4 weights Jan 29, 2025
@TedThemistokleous
Copy link
Contributor Author

ping @tianleiwu required for additional support of quantization through the MIGraphX API

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,CoreML CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline

Copy link

Azure Pipelines successfully started running 7 pipeline(s).

Copy link

Azure Pipelines successfully started running 8 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline,CoreML CI Pipeline,Linux DNNL CI Pipeline,Linux MIGraphX CI Pipeline,Linux ROCm CI Pipeline

Copy link

Azure Pipelines successfully started running 7 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@tianleiwu
Copy link
Contributor

@TedThemistokleous, please take a look at build error:

/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:345:18: error: ‘migraphx_shape_fp8e4m3fn_type’ was not declared in this scope; did you mean ‘migraphx_shape_fp8e4m3fnuz_type’?
345 | mgx_type = migraphx_shape_fp8e4m3fn_type;
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| migraphx_shape_fp8e4m3fnuz_type
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:348:18: error: ‘migraphx_shape_fp8e5m2_type’ was not declared in this scope; did you mean ‘migraphx_shape_uint32_type’?
348 | mgx_type = migraphx_shape_fp8e5m2_type;
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~
| migraphx_shape_uint32_type
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:351:18: error: ‘migraphx_shape_fp8e5m2fnuz_type’ was not declared in this scope; did you mean ‘migraphx_shape_fp8e4m3fnuz_type’?
351 | mgx_type = migraphx_shape_fp8e5m2fnuz_type;
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| migraphx_shape_fp8e4m3fnuz_type
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc: In function ‘void onnxruntime::calibrate_and_quantize(migraphx::api::program&, const migraphx::api::target&, migraphx::api::program_parameters, bool, bool, bool, bool, std::unordered_map<std::__cxx11::basic_string, float>&)’:
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1249:17: error: ‘quantize_fp8_options’ is not a member of ‘migraphx’
1249 | migraphx::quantize_fp8_options quant_opts;
| ^~~~~~~~~~~~~~~~~~~~
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1250:7: error: ‘quant_opts’ was not declared in this scope
1250 | quant_opts.add_calibration_data(quant_params);
| ^~~~~~~~~~
/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1251:17: error: ‘quantize_fp8’ is not a member of ‘migraphx’
1251 | migraphx::quantize_fp8(prog, t, quant_opts);

@TedThemistokleous
Copy link
Contributor Author

@TedThemistokleous, please take a look at build error:

/onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:345:18: error: ‘migraphx_shape_fp8e4m3fn_type’ was not declared in this scope; did you mean ‘migraphx_shape_fp8e4m3fnuz_type’? 345 | mgx_type = migraphx_shape_fp8e4m3fn_type; | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | migraphx_shape_fp8e4m3fnuz_type /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:348:18: error: ‘migraphx_shape_fp8e5m2_type’ was not declared in this scope; did you mean ‘migraphx_shape_uint32_type’? 348 | mgx_type = migraphx_shape_fp8e5m2_type; | ^~~~~~~~~~~~~~~~~~~~~~~~~~~ | migraphx_shape_uint32_type /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:351:18: error: ‘migraphx_shape_fp8e5m2fnuz_type’ was not declared in this scope; did you mean ‘migraphx_shape_fp8e4m3fnuz_type’? 351 | mgx_type = migraphx_shape_fp8e5m2fnuz_type; | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | migraphx_shape_fp8e4m3fnuz_type /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc: In function ‘void onnxruntime::calibrate_and_quantize(migraphx::api::program&, const migraphx::api::target&, migraphx::api::program_parameters, bool, bool, bool, bool, std::unordered_map<std::__cxx11::basic_string, float>&)’: /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1249:17: error: ‘quantize_fp8_options’ is not a member of ‘migraphx’ 1249 | migraphx::quantize_fp8_options quant_opts; | ^~~~~~~~~~~~~~~~~~~~ /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1250:7: error: ‘quant_opts’ was not declared in this scope 1250 | quant_opts.add_calibration_data(quant_params); | ^~~~~~~~~~ /onnxruntime_src/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:1251:17: error: ‘quantize_fp8’ is not a member of ‘migraphx’ 1251 | migraphx::quantize_fp8(prog, t, quant_opts);

Shoot, You'll need to use a later version of MIGraphx in your CI. We've added the quantize_fp8 function as part of our api recently.

Where in your CI does it pull/use MIGraphX or do you pin things to a ROCm version?

@tianleiwu
Copy link
Contributor

Shoot, You'll need to use a later version of MIGraphx in your CI. We've added the quantize_fp8 function as part of our api recently.

Where in your CI does it pull/use MIGraphX or do you pin things to a ROCm version?

This is the docker file used in CI:
https://github.com/microsoft/onnxruntime/blob/2d8b86e09a5fbc9aeec30069c9a1ad75325a5264/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile

Related commit:
28ee049

@TedThemistokleous
Copy link
Contributor Author

Shoot, You'll need to use a later version of MIGraphx in your CI. We've added the quantize_fp8 function as part of our api recently.
Where in your CI does it pull/use MIGraphX or do you pin things to a ROCm version?

This is the docker file used in CI: https://github.com/microsoft/onnxruntime/blob/2d8b86e09a5fbc9aeec30069c9a1ad75325a5264/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile

Related commit: 28ee049

Lets put this on ice then and I can gate this feature based on ROCm 6.4. That's where this change is targeted to for MIGraphX. In the meantime, do you want me to help updating your CI to ROCm 6.3? We just released ROCm 6.3.2 and it looks like you're using a ROCm 6.2.3 build in your CI instead

@TedThemistokleous
Copy link
Contributor Author

related: #23535

@TedThemistokleous
Copy link
Contributor Author

Shoot, You'll need to use a later version of MIGraphx in your CI. We've added the quantize_fp8 function as part of our api recently.
Where in your CI does it pull/use MIGraphX or do you pin things to a ROCm version?

This is the docker file used in CI: https://github.com/microsoft/onnxruntime/blob/2d8b86e09a5fbc9aeec30069c9a1ad75325a5264/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile

Related commit: 28ee049

Actually, ROCm 6.3.2 release is sufficient for this API call as we've added this in. Quick check shows this. CI change to ROCm 6.3.2 should be enough.

~/AMDMIGraphX AMDMIGraphX ((rocm-6.3.2))git log | grep fp8
    fp8 functional not performance (#2845)
    use `optimize_module` for the fp8 JIT compilation  (#2656)
    * use eliminate_convert for the fp8
    Autocast_fp8 pass (#2527)
    * Autocast_fp8 pass
    * Address review: used contains to check fp8_types set
    Add `--fp8` option to quantize models in FP8 inside `migraphx-driver` (#2535)
    Enable simplify qdq to work with FP8 types, add warning for fp8 when loading mxr
    ad fp8 warnings (#2531)
    Update onnx proto and onnx parser to handle fp8 types (#2493)
    Only adding fp8e4m3fnuz in MIGraphX IR for now.
    fp16 and fp8 quantization to include subgraph and parameters

It would pull the existing package found on our public repo.radeon found here: https://repo.radeon.com/rocm/apt/6.3.2/pool/main/m/migraphx/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants