Skip to content

Commit 4b1de24

Browse files
committed
Update to include WorkStream proposal, move version to API chain
1 parent a947bef commit 4b1de24

File tree

1 file changed

+85
-21
lines changed

1 file changed

+85
-21
lines changed

include/dlpack/dlpack.h

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ typedef struct DLManagedTensorVersioned {
369369
* \brief A generic C-style allocator that exposes allocation of a Tensor/Array.
370370
*
371371
* This information can then be used to set allocators of a callee to run allocations.
372+
* This information can then be used to set the callee's allocator to perform allocations.
372373
* This function can be exposed by the framework through the DLPackExchangeAPI.
373374
*
374375
* This particular function does not assume a Python environment; as a result,
@@ -394,44 +395,66 @@ typedef int (*DLPackManagedTensorAllocator)( //
394395
* \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
395396
*
396397
* This function is a C-style function pointer to quickly convert a PyObject* Tensor/NDArray
397-
* to a DLManagedTensorVersioned without going through the Python Interpreter.
398+
* to a DLManagedTensorVersioned without going through the Python interpreter.
398399
*
399400
* It also provides an option to query the current context stream of the device provided
400401
* by the tensor.
401402
*
402403
* This function is exposed by the framework through the DLPackExchangeAPI.
403404
*
404-
* This information can then be picked up by importers and libraries to run the speed conversion.
405+
* This information can then be picked up by importers and libraries to perform a fast conversion.
405406
* This function should not throw any exceptions; if it fails, it should return -1 and
406407
* set the error message via PyErr_SetXXX.
407408
*
408409
* \param py_object The Python object to convert; this should be PyObject*.
409410
* We use void* to avoid dependency on Python.h.
410411
*
411-
* \param max_version The maximum version of DLPack support that consumer supports.
412-
* Consumer should fill in their own version here, this parameter is not null.
413-
* Producer can use this information to produce the appropriate
414-
* DLManagedTensorVersioned for maximum compatibility if needed.
415-
* This field is primarily used for future compatibility in case
416-
* of major version bump and ABI-breaking changes.
417-
*
418412
* \param out The output DLManagedTensorVersioned.
419413
*
420-
* \param optional_out_env_stream Outputs the current context stream of the device provided
421-
* by the tensor; it can be NULL, in which case the stream will not be queried.
422-
* optional_out_env_stream should points to cudaStream_t in the case of CUDA.
414+
* \param optional_out_last_active_stream Outputs the current stream the tensor is synced to.
415+
* It can be NULL, in which case the stream will not be queried.
416+
* optional_out_last_active_stream should point to cudaStream_t in the case of CUDA.
417+
* Note that for frameworks that use a stream context manager, optional_out_last_active_stream
418+
* can be the stream that the context manager was most recently active on.
419+
* The stream is owned by the producer, and the consumer cannot retain it.
420+
* Instead, the consumer can record an event or add wait dependencies to it.
421+
* It is the responsibility of the consumer to synchronize with the stream if necessary.
422+
* The producer may output `reinterpret_cast<void*>(-1)` to indicate that the last active stream
423+
* is not available; in such a case, a device sync is needed to ensure data is ready.
423424
*
424425
* \return 0 on success, -1 on failure. PyError should be set if -1 is returned.
425426
* \note We use void* to avoid dependency on Python.h, so this specific type is
426427
* not dependent on Python.h and can be copied to dlpack.h.
427428
*
428-
* \sa DLPackExchangeAPI
429+
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
429430
*/
430431
typedef int (*DLPackManagedTensorFromPyObject)( //
431432
void* py_object, //
432-
const DLPackVersion* max_version, //
433433
DLManagedTensorVersioned** out, //
434-
void** optional_out_env_stream //
434+
void** optional_out_last_active_stream //
435+
);
436+
437+
/*!
438+
* \brief Obtain the current work stream of a device.
439+
*
440+
* This function is a C-style function pointer to obtain the current work stream of a device
441+
* for frameworks that rely on a context manager to manage the stream.
442+
* For example, it should map to torch.cuda.current_stream in PyTorch.
443+
*
444+
* This function can be set to NULL if the framework does not rely on a context manager to
445+
* manage the stream.
446+
*
447+
* \param device_type The device type.
448+
* \param device_id The device id.
449+
* \param optional_out_current_stream The output current work stream.
450+
* \return 0 on success, -1 on failure.
451+
*
452+
* \sa DLPackExchangeAPI
453+
*/
454+
typedef int (*DLPackCurrentWorkStream)( //
455+
DLDevice device_type, //
456+
DLDevice device_id, //
457+
void** optional_out_current_stream //
435458
);
436459

437460
/*!
@@ -457,13 +480,36 @@ typedef int (*DLPackManagedTensorToPyObject)( //
457480
/*!
458481
* \brief Framework-specific function pointers table for DLPack exchange.
459482
*
460-
* Array/Tensor librarie should statically create and initialize this structure
483+
* Guidelines for leveraging DLPackExchangeAPI:
484+
*
485+
* There are generally two kinds of consumer needs for DLPack exchange:
486+
* - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
487+
* with the data from x, y, z. The consumer is also expected to run the kernel with the same
488+
* stream context as the producer. For example, when x, y, z is torch.Tensor,
489+
* consumer should query exchange_api->optional_current_work_stream to get the
490+
* current stream and launch the kernel with the same stream.
491+
* This setup is necessary for no synchronization in kernel launch and maximum compatibility
492+
* with CUDA graph capture in the producer.
493+
* This is the desirable behavior for library extension support for frameworks like PyTorch.
494+
* - N1: data ingestion and retention, in such a case, the consumer is interested in obtaining
495+
* the data from the producer and runs further computation on its own stream.
496+
* In such a case, the consumer can directly query optional_last_active_stream to
497+
* get the last active stream and record a dependency.
498+
*
499+
* Consumer should consider their needs (N0 or N1) and act accordingly based on the
500+
* availability of the function pointer.
501+
*
502+
* Importantly, optional_current_work_stream may be NULL for frameworks that
503+
* do not rely on a context manager to manage the stream, in which case the consumer
504+
* should rely on the information in optional_last_active_stream.
505+
*
506+
* Array/Tensor libraries should statically create and initialize this structure
461507
* then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
462-
* The DLPackExchangeAPI* should stay alive throughout the lifetime of process.
508+
* The DLPackExchangeAPI* should stay alive throughout the lifetime of the process.
463509
*
464510
* One simple way to do so is to create a static instance of DLPackExchangeAPI
465-
* within the framework and return a pointer to it, the following code
466-
* shows an example to do so in c++. It should also be reasonably easy
511+
* within the framework and return a pointer to it. The following code
512+
* shows an example to do so in C++. It should also be reasonably easy
467513
* to do so in other languages.
468514
*
469515
* \code
@@ -474,6 +520,8 @@ typedef int (*DLPackManagedTensorToPyObject)( //
474520
* managed_tensor_allocator = MyDLPackManagedTensorAllocator;
475521
* managed_tensor_from_py_object = MyDLPackManagedTensorFromPyObject;
476522
* managed_tensor_to_py_object = MyDLPackManagedTensorToPyObject
523+
* optional_current_work_stream = MyDLPackCurrentWorkStream;
524+
* prev_version_api = nullptr;
477525
* }
478526
*
479527
* static const DLPackExchangeAPI* Global() {
@@ -484,9 +532,9 @@ typedef int (*DLPackManagedTensorToPyObject)( //
484532
* \endcode
485533
*
486534
* Each framework should attach a dunder `__c_dlpack_exchange_api__` integer
487-
* to point to the pointer of the DLPackExchangeAPI*
535+
* to point to the DLPackExchangeAPI* pointer.
488536
*
489-
* Importantly the attributed should be attached to the class of the Tensor, not the instance.
537+
* Importantly, the attribute should be attached to the class of the Tensor, not the instance.
490538
*
491539
* mypackage.Tensor.__c_dlpack_exchange_api__ = MyPackageDLPackExchangeAPI
492540
*
@@ -499,6 +547,14 @@ struct DLPackExchangeAPI {
499547
* \brief The current DLPack version.
500548
*/
501549
DLPackVersion version;
550+
/*!
551+
* \brief Optional pointer to an older DLPackExchangeAPI in the chain.
552+
*
553+
* It should be set to NULL if the framework does not support older versions.
554+
*
555+
* \sa DLPackExchangeAPI
556+
*/
557+
DLPackExchangeAPI* prev_version_api;
502558
/*!
503559
* \brief Framework-specific function pointer for DLPackManagedTensorAllocator
504560
* \sa DLPackManagedTensorAllocator
@@ -514,6 +570,14 @@ struct DLPackExchangeAPI {
514570
* \sa DLPackManagedTensorToPyObject
515571
*/
516572
DLPackManagedTensorToPyObject managed_tensor_to_py_object;
573+
/*!
574+
* \brief Framework-specific function pointer for DLPackCurrentWorkStream
575+
*
576+
* This function can be set to NULL if the framework does not rely on context manager to manage the stream.
577+
*
578+
* \sa DLPackCurrentWorkStream
579+
*/
580+
DLPackCurrentWorkStream optional_current_work_stream;
517581
};
518582

519583
#ifdef __cplusplus

0 commit comments

Comments
 (0)