-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Test only] BFloat16 test for SkipSimplifiedLayerNormalization #22941
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
skip_size); | ||
} | ||
else | ||
{ | ||
LaunchSkipLayerNormKernel<CudaT, Simplified>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_size); | |
} | |
else | |
{ | |
LaunchSkipLayerNormKernel<CudaT, Simplified>( | |
skip_size); | |
} else { | |
LaunchSkipLayerNormKernel<CudaT, Simplified>( |
import tempfile | ||
from typing import Dict | ||
from enum import Enum | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import tempfile | |
from typing import Dict | |
from enum import Enum | |
import tempfile | |
from enum import Enum | |
from typing import Dict | |
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper | ||
from onnx.shape_inference import infer_shapes, infer_shapes_path | ||
from onnx.helper import float32_to_bfloat16 | ||
from packaging import version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper | |
from onnx.shape_inference import infer_shapes, infer_shapes_path | |
from onnx.helper import float32_to_bfloat16 | |
from packaging import version | |
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper | |
from onnx.helper import float32_to_bfloat16 | |
from onnx.shape_inference import infer_shapes, infer_shapes_path | |
from packaging import version |
|
||
|
||
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): | |
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): |
|
||
def convert_tensor_float_to_bfloat16(tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def convert_tensor_float_to_bfloat16(tensor): | |
def convert_tensor_float_to_bfloat16(tensor): |
class NodeValueType(Enum): | ||
FP32 = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class NodeValueType(Enum): | |
FP32 = 1 | |
class NodeValueType(Enum): |
class InitializerTracker: | ||
"""Class for keeping track of initializer.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class InitializerTracker: | |
"""Class for keeping track of initializer.""" | |
class InitializerTracker: |
def convert_float_to_float16( | ||
model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def convert_float_to_float16( | |
model, | |
def convert_float_to_float16( |
|
||
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. | ||
for node in mixed_float_type_node_list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. | |
for node in mixed_float_type_node_list: | |
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
import tempfile | ||
from typing import Dict | ||
from enum import Enum | ||
import ml_dtypes | ||
|
||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import tempfile | |
from typing import Dict | |
from enum import Enum | |
import ml_dtypes | |
import numpy as np | |
import tempfile | |
from enum import Enum | |
from typing import Dict | |
import ml_dtypes | |
import numpy as np |
03bf839
to
09e2cc1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
skip_size); | ||
} | ||
else | ||
{ | ||
LaunchSkipLayerNormKernel<CudaT, Simplified>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_size); | |
} | |
else | |
{ | |
LaunchSkipLayerNormKernel<CudaT, Simplified>( | |
skip_size); | |
} else { | |
LaunchSkipLayerNormKernel<CudaT, Simplified>( |
…lic_shape bugfix (#23558) - Add symbolic shape inference dispatcher for `ReduceMean`. - Reducemean is used in RMSNorm so shape inference fails for llama, phi, etc torch exported models. - Reuse the dispatcher for ReduceSum since ReduceMean 18+ and ReduceSum 13+ have the same specs other than the type of reduction done. - Fix an issue with `quant_pre_process` tool where the external data file is missing if `skip_symbolic_shape=True` and `skip_optimization=False`. - Add `"session.optimized_model_external_initializers_file_name"` to session options so that the external data gets saved in the same temp directory as the optimized model. <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
import tempfile | ||
from typing import Dict | ||
from enum import Enum | ||
import ml_dtypes | ||
|
||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import tempfile | |
from typing import Dict | |
from enum import Enum | |
import ml_dtypes | |
import numpy as np | |
import tempfile | |
from enum import Enum | |
import ml_dtypes | |
import numpy as np |
|
||
import onnxscript | ||
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex | ||
import onnxruntime as rt | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import onnxscript | |
from onnxscript import optimizer, ir | |
import onnxconverter_common | |
from onnxconverter_common.onnx_ex import make_model_ex | |
import onnxruntime as rt | |
|
||
def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): | ||
def convert_tensor_float_to_float16(tensor, is_value_type_bfloat16=False, min_positive_val=5.96e-08, max_finite_val=65504.0): | ||
"""Convert tensor float to float16. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): | |
def convert_tensor_float_to_float16(tensor, is_value_type_bfloat16=False, min_positive_val=5.96e-08, max_finite_val=65504.0): | |
"""Convert tensor float to float16. | |
def convert_tensor_float_to_float16( | |
tensor, is_value_type_bfloat16=False, min_positive_val=5.96e-08, max_finite_val=65504.0 | |
): | |
"""Convert tensor float to float16. |
|
||
class NodeValueType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class NodeValueType(Enum): | |
class NodeValueType(Enum): |
|
||
class InitializerTracker: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class InitializerTracker: | |
class InitializerTracker: |
|
||
|
||
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. | |
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. |
node_name = node.name + "_input_cast" + str(i) | ||
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] | ||
new_node = [helper.make_node("Cast", [input_name], [output_name], to=TensorProto.FLOAT, name=node_name)] | ||
model.graph.node.extend(new_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node_name = node.name + "_input_cast" + str(i) | |
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] | |
new_node = [helper.make_node("Cast", [input_name], [output_name], to=TensorProto.FLOAT, name=node_name)] | |
model.graph.node.extend(new_node) | |
node_name = node.name + "_input_cast" + str(i) | |
new_node = [ | |
helper.make_node("Cast", [input_name], [output_name], to=TensorProto.FLOAT, name=node_name) | |
] | |
model.graph.node.extend(new_node) |
# new_node = [helper.make_node("Cast", [input_name], [output], to=final_value_type, name=node_name)] | ||
new_node = [helper.make_node("Cast", [input_name], [output], to=TensorProto.FLOAT16, name=node_name)] | ||
model.graph.node.extend(new_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# new_node = [helper.make_node("Cast", [input_name], [output], to=final_value_type, name=node_name)] | |
new_node = [helper.make_node("Cast", [input_name], [output], to=TensorProto.FLOAT16, name=node_name)] | |
model.graph.node.extend(new_node) | |
# new_node = [helper.make_node("Cast", [input_name], [output], to=final_value_type, name=node_name)] | |
new_node = [ | |
helper.make_node("Cast", [input_name], [output], to=TensorProto.FLOAT16, name=node_name) | |
] | |
model.graph.node.extend(new_node) |
# model = ir.serde.serialize_model(ir_model) | ||
''' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# model = ir.serde.serialize_model(ir_model) | |
''' | |
# model = ir.serde.serialize_model(ir_model) | |
""" | |
domain=model.domain) | ||
''' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
domain=model.domain) | |
''' | |
domain=model.domain) | |
""" | |
@@ -16,6 +16,9 @@ | |||
import logging | |||
import os | |||
import tempfile | |||
from typing import Dict |
Check warning
Code scanning / lintrunner
RUFF/UP035 Warning
See https://docs.astral.sh/ruff/rules/deprecated-import
@@ -16,6 +16,9 @@ | |||
import logging | |||
import os | |||
import tempfile | |||
from typing import Dict |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript |
Check warning
Code scanning / lintrunner
RUFF/E402 Warning
See https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript |
Check warning
Code scanning / lintrunner
RUFF/I001 Warning
See https://docs.astral.sh/ruff/rules/unsorted-imports
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript | |||
from onnxscript import optimizer, ir | |||
import onnxconverter_common |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
import onnxscript | ||
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex |
Check warning
Code scanning / lintrunner
RUFF/E402 Warning
See https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file
import onnxscript | ||
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex | ||
import onnxruntime as rt |
Check warning
Code scanning / lintrunner
RUFF/E402 Warning
See https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex | ||
import onnxruntime as rt |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
@@ -16,6 +16,9 @@ | |||
import logging | |||
import os | |||
import tempfile | |||
from typing import Dict |
Check notice
Code scanning / CodeQL
Unused import Note
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript |
Check notice
Code scanning / CodeQL
Unused import Note
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript | |||
from onnxscript import optimizer, ir |
Check notice
Code scanning / CodeQL
Unused import Note
Import of 'ir' is not used.
@@ -25,6 +28,12 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
import onnxscript | |||
from onnxscript import optimizer, ir | |||
import onnxconverter_common |
Check notice
Code scanning / CodeQL
Unused import Note
import onnxscript | ||
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex |
Check notice
Code scanning / CodeQL
Unused import Note
from onnxscript import optimizer, ir | ||
import onnxconverter_common | ||
from onnxconverter_common.onnx_ex import make_model_ex | ||
import onnxruntime as rt |
Check notice
Code scanning / CodeQL
Unused import Note
# if func_infer_shape is not None: | ||
# model = func_infer_shape(model) |
Check notice
Code scanning / CodeQL
Commented-out code Note
Description
Motivation and Context