Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/pyutils/parallel_tensor.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include <Python.h>

#include <iostream>
#include <map>
#include <vector>

#include <ATen/DLConvertor.h>
#include <ATen/dlpack.h>
#include <ATen/ops/from_blob.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/utils/pybind.h>
Expand Down Expand Up @@ -39,6 +43,13 @@ struct TKParallelTensor {

detail::ipc::flavor ipc_flavor_;

// Build at::Tensor from DLPack capsule (zero-copy) for use in (tensor, ...) ctor
__host__ static at::Tensor from_dlpack_capsule(pybind11::object capsule) {
void* ptr = PyCapsule_GetPointer(capsule.ptr(), "dltensor");
TORCH_CHECK(ptr != nullptr, "Object must be a DLPack capsule (name 'dltensor')");
return at::fromDLPack(static_cast<DLManagedTensor*>(ptr));
}

__host__ inline TKParallelTensor(
const at::Tensor &tensor,
int local_rank,
Expand Down Expand Up @@ -112,6 +123,14 @@ struct TKParallelTensor {
initialize_multicast();
}

// DLPack capsule -> zero-copy at::Tensor then same path as (at::Tensor, ...)
__host__ inline TKParallelTensor(
pybind11::object dlpack_capsule,
int local_rank,
int local_world_size,
bool multicast
) : TKParallelTensor(from_dlpack_capsule(dlpack_capsule), local_rank, local_world_size, multicast) {}

__host__ TKParallelTensor(const TKParallelTensor&) = delete;
__host__ TKParallelTensor& operator=(const TKParallelTensor&) = delete;
__host__ TKParallelTensor& operator=(TKParallelTensor&& other) = delete;
Expand Down Expand Up @@ -330,6 +349,12 @@ struct TKParallelTensor {
pybind11::arg("local_rank"), \
pybind11::arg("local_world_size"), \
pybind11::arg("multicast") = false) \
.def(pybind11::init<pybind11::object, int, int, bool>(), \
pybind11::arg("dlpack_capsule"), \
pybind11::arg("local_rank"), \
pybind11::arg("local_world_size"), \
pybind11::arg("multicast") = false, \
"Construct from DLPack capsule (e.g. tensor.__dlpack__()); zero-copy, then same as tensor ctor.") \
.def("data", &kittens::py::TKParallelTensor::data) \
.def_readonly("data_", &kittens::py::TKParallelTensor::data_) \
.def_readonly("local_rank_", &kittens::py::TKParallelTensor::local_rank_) \
Expand Down
22 changes: 22 additions & 0 deletions include/pyutils/torchutils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/core/Tensor.h>

#include <ATen/dlpack.h>

#include "kittens.cuh"
#include "parallel_tensor.cuh"

Expand Down Expand Up @@ -138,6 +140,26 @@ __host__ static inline GL make_fake_gl(const int batch, const int depth, const i
return ::kittens::make_gl<GL>(reinterpret_cast<uint64_t>(nullptr), batch, depth, rows, cols);
}

// Zero-copy wrap DLPack DLManagedTensor into GL (B,D,R,C); pads leading dims with 1
template <kittens::ducks::gl::all GL>
__host__ static inline GL tensor_to_gl_from_dlpack(const DLManagedTensor* dlm) {
TORCH_CHECK(dlm != nullptr && dlm->dl_tensor.data != nullptr, "DLPack tensor is null");
const int ndim = dlm->dl_tensor.ndim;
TORCH_CHECK(ndim >= 0 && ndim <= 4, "DLPack ndim must be in [0, 4]");
std::array<int, 4> shape = {1, 1, 1, 1};
for (int i = 0; i < ndim && dlm->dl_tensor.shape != nullptr; ++i)
shape[4 - ndim + i] = static_cast<int>(dlm->dl_tensor.shape[i]);
return ::kittens::make_gl<GL>(
reinterpret_cast<uint64_t>(dlm->dl_tensor.data),
shape[0], shape[1], shape[2], shape[3]);
}

// Direct pointer+shape path for C++ or simplified call sites (no DLPack struct)
template <kittens::ducks::gl::all GL>
__host__ static inline GL tensor_to_gl_from_dlpack(void* data, int B, int D, int R, int C) {
return ::kittens::make_gl<GL>(reinterpret_cast<uint64_t>(data), B, D, R, C);
}

__host__ static inline void _device_check(const at::Tensor& first, const at::Tensor& second) {
TORCH_CHECK(first.device() == second.device(), "All tensors must be on the same device");
}
Expand Down