-
Notifications
You must be signed in to change notification settings - Fork 86
Dynamic Shapes Design
- dynamic shape: Refers to tensors whose shape information is only known at model runtime. These can occur as user input or from the output of certain operations. Dynamic shape can refer to tensors with dynamic dimension length and dynamic rank. In MIGraphX we are currently only considering dynamic dimension lengths.
- dynamic batch: A shape that has a dynamic dimension for the batch dimension.
- shape function: The function that calculates the output shape of an operation.
-
data-dependent shape function: A shape function whose output shape depends on the input data (not just the input tensor's shape). Examples:
NonZero
,TopK
Static shape information is held within a shape
object as dimensions lengths and strides.
Dynamic shape information is also held within the same shape
object as a vector of dynamic_dimension
objects.
dynamic_dimension
objects are used in MIGraphX to specify a range of dimension values from a minimum value to a maximum value and optimal values that the tensor can be at model evaluation time.
For example, a dynamic_dimension
with {min:1, max:10, optimals:{1, 4, 10}}
means that the dimension can be any value from 1 through 10 with the optimal values being 1, 4, and 10.
Supplied optimal values may allow MIGraphX to optimize the program for those specific shapes.
In this way we have a range of values and possible optimal values for a given dynamic dimension; allowing for future optimizations.
A fixed dynamic_dimension
can be specified by setting the min
and max
to the same value (ex. {min:3, max:3}
).
A dynamic shape specified solely by fixed dynamic_dimension
objects will be converted to a static shape during parsing.
MIGraphX does not support symbolic shapes. In the future we may add named shapes to track dynamic shapes in a model for improved optimizations.
MIGraphX supports parsing dynamic shapes from ONNX models.
There are two onnx_options
where a user can supply dynamic_dimension
data: default_dyn_dim_value
and map_dyn_input_dims
.
default_dyn_dim_value
sets the default dynamic_dimension
value to use for symbolic or missing dimensions in an ONNX node.
This is mainly used for setting the dynamic_dimension
for an ONNX model already set up for dynamic batch sizes.
map_dyn_input_dims
can be used parse any ONNX shape as a dynamic shape.
See the examples folder for usage of the APIs and parsing:
TODO: Make more examples
Dynamic shapes propogate foward through the shape calculations of a model.
If an operator recieves an input with dynamic shape or has a data-dependent shape function, its shape function will return a dynamic output shape.
Each subsequent operation will also have to support dynamic shape input.
This unfortunately means that most operators will have to be updated to support dynamic shapes.
Additionally, there are ONNX operators that can be parsed into other MIGX operators for a static input shape but must use a different calculation/kernel for a dynamic shape input (ex. Resize
).
MIGraphX handles the dynamic batch case on the GPU by creating submodules for each batch size of interest.
MIGX first parses a dynamic model into the ref
versions of the instructions and then the split_single_dyn_dim
compiler pass splits the batch sizes into submodules.
A select_module
operator is inserted into the main module that selects correct submodule to run at module evaluation time.
Within each submodule the shapes are static, allowing us to retain the performance of static kernels and the existing graph compiler optimizations.
Caveat: This method is not effective for operators with data-dependent shape functions.
The major work items for extending support for full dynamic shape support on the GPU include:
- Changing GPU kernels to dynamic shape versions either by using libraries with the support or rewritting them.
- Combing through the current compiler passes to extend them for dynamic shapes if possible or disabling them entirely.
- Creating new compiler passes designed for dynamic shapes to improve performance.