Skip to content

Support Metal DLPack ndarray imports and byte offsets#1338

Open
XXXXRT666 wants to merge 4 commits into
wjakob:masterfrom
XXXXRT666:metal-dlpack-cast
Open

Support Metal DLPack ndarray imports and byte offsets#1338
XXXXRT666 wants to merge 4 commits into
wjakob:masterfrom
XXXXRT666:metal-dlpack-cast

Conversation

@XXXXRT666
Copy link
Copy Markdown

@XXXXRT666 XXXXRT666 commented May 8, 2026

Summary

This adds test coverage for importing Metal-backed DLPack arrays into nb::ndarray, including PyTorch MPS tensors.

It also exposes raw DLPack storage metadata for device-backed arrays:

  • data_handle()
  • data_offset()

nb::ndarray exports can now set DLTensor::byte_offset explicitly via a new trailing data_offset constructor argument. This lets producers represent views into opaque device allocations, such as Metal id<MTLBuffer> handles, without copying when the logical array starts after the beginning of the allocation.

Tests

  • Added a mocked Metal DLPack import test.
  • Added a real PyTorch MPS DLPack import test guarded by needs_torch_mps.
  • Added CPU and Metal DLPack export coverage for nonzero byte_offset.

Comment thread src/nb_ndarray.cpp Outdated
@XXXXRT666 XXXXRT666 changed the title Support Metal DLPack ndarray imports Support Metal DLPack ndarray imports and byte offsets May 10, 2026
@XXXXRT666
Copy link
Copy Markdown
Author

XXXXRT666 commented May 10, 2026

Added explicit DLPack byte_offset export support: nb::ndarray constructors now accept a trailing data_offset, with CPU/Metal tests covering nonzero offsets.

Comment thread src/nb_ndarray.cpp Outdated
Comment thread src/nb_ndarray.cpp Outdated
Comment thread src/nb_ndarray.cpp Outdated
Comment thread src/nb_ndarray.cpp Outdated
Comment thread src/nb_ndarray.cpp Outdated
Comment thread include/nanobind/ndarray.h Outdated
@XXXXRT666
Copy link
Copy Markdown
Author

XXXXRT666 commented May 11, 2026

Addressed the latest review

@hpkfft
Copy link
Copy Markdown
Contributor

hpkfft commented May 11, 2026

Thanks.

I think it's helpful for module developers that this PR now uses the name byte_offset since DLPack is inconsistent--strides are measured in terms of the dtype but the initial offset is measured in terms of bytes.
Twenty years ago, the buffer protocol made the wiser choice of measuring strides in bytes.
Note, for example, in C++ (assuming typical IEEE floating-point types):

  • sizeof(float) == 4
  • sizeof(std::complex<float>) == 8
  • alignof(std::complex<float>) == 4

so it's unnecessarily limiting to measure strides in terms of element size.
See test41_noninteger_stride() for an example of a slice that doesn't work with DLPack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants