-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[runtime][python] Add IRPA entry conversion to/from numpy #19492
[runtime][python] Add IRPA entry conversion to/from numpy #19492
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me but I'm afk and didn't do a detailed review. Would you move getting a second review (maybe from Scott)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drive-by since I saw my name. A few minor comments but otherwise LGTM. Thanks!
Please resolve the DCO check: https://github.com/iree-org/iree/pull/19492/checks?check_run_id=34489850995
https://iree.dev/developers/general/contributing/#developer-certificate-of-origin
def parameter_index_add_numpy_ndarray( | ||
index: ParameterIndex, name: str, array: np.ndarray | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file already has import array
, so this function shadows array
with a new local variable.
We have similar code that doesn't also import array
:
iree/runtime/bindings/python/iree/runtime/system_api.py
Lines 80 to 88 in a31da1f
def _bool_to_int8(array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: | |
if not isinstance(array, np.ndarray): | |
return array | |
# IREE models booleans as i8s. | |
# TODO(#5359): This cast should be moved into the function abi. | |
if array.dtype == bool: | |
array = array.astype(np.int8) | |
return array |
I guess this is fine, just be careful :)
raise KeyError(f"Numpy dtype {dtype} not found.") | ||
|
||
|
||
_metadata_prefix = "PYTORCH:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be changed to NUMPY:
? This file doesn't depend on torch.
_metadata_prefix = "PYTORCH:" | |
_metadata_prefix = "NUMPY:" |
You could also move this up in the file, ahead of its first uses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is essentially our format magic number. Unfortunately, this can not be change if we want files that were already created with IREE Turbine to be compatible.
At some point we could start writing files with a new prefix, but we would have to maintain support for the old one.
I think we would want to have some versioning scheme as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the format while maintaining support for loading IREE Turbine tensors.
flat_array = array.copy() | ||
else: | ||
flat_array = np.ascontiguousarray(array).view(np.uint8) | ||
index.add_buffer(name, flat_array, metadata=metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I'd nearly forgotten that we have support for arbitrary metadata in parameter archives. So this lets the C runtime load the data as dense bytes without needing to care about the original types, but then Python code can read the metadata to unpack.
On that note, do you want to also write some version information into the metadata, to be able to handle encoding changes more gracefully? Maybe as part of the _metadata_prefix
, like NUMPY_v0:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I answered on the other question.
_metadata_iree_turbine_version = "PYTORCH" | ||
"""There are files created with IREE Turbine that use this prefix. | ||
This is here to maintain the ability to load such files.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Feel free to merge without this, but it would be nice)
We can probably trust this code... but if you want to go a bit further and guard against compatibility breaking changes, you could add a unit test with some binary data exported from iree-turbine to ensure that it is supported here:
def testParameterIndexEntryFromIreeTurbine(self):
index = rt.ParameterIndex()
index.load("test_iree_turbine_index.irpa") # load from a testdata file
index_entry_as_array = rt.parameter_index_entry_as_numpy_ndarray(...)
# then do some sanity check on the returned ndarray
(avoiding the temptation to put logic in the test and call the _make_tensor_metadata
helper function: https://testing.googleblog.com/2014/07/testing-on-toilet-dont-put-logic-in.html)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once DCO check is resolved (highly recommend setting up SSH commit signature verification on each machine you work from, that shows your commits as verified across all GitHub projects and appeases the DCO check: https://iree.dev/developers/general/contributing/#crypographically-signing-commits)
Add iterop between numpy ndarray and parameter index. This is an adaptation of the original in IREE Turbine https://github.com/iree-org/iree-turbine/blob/142c8a5044a4fedb43a11229f462363b05743b23/iree/turbine/aot/params.py The goal is to maintain compatibility with IRPA files that were already generated with IREE Turbine. At some point we can refactor the IREE Turbine side to use this implementation. Signed-off-by: Boian Petkantchin <[email protected]>
c10eca9
to
98f7be5
Compare
Add iterop between numpy ndarray and parameter index. This is an adaptation of the original in IREE Turbine https://github.com/iree-org/iree-turbine/blob/142c8a5044a4fedb43a11229f462363b05743b23/iree/turbine/aot/params.py
The goal is to maintain compatibility with IRPA files that were already generated with IREE Turbine.
At some point we can refactor the IREE Turbine side to use this implementation.