@@ -369,6 +369,7 @@ typedef struct DLManagedTensorVersioned {
369369 * \brief A generic C-style allocator that exposes allocation of a Tensor/Array.
370370 *
371371 * This information can then be used to set allocators of a callee to run allocations.
372+ * This information can then be used to set the callee's allocator to perform allocations.
372373 * This function can be exposed by the framework through the DLPackExchangeAPI.
373374 *
374375 * This particular function does not assume a Python environment; as a result,
@@ -394,44 +395,66 @@ typedef int (*DLPackManagedTensorAllocator)( //
394395 * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
395396 *
396397 * This function is a C-style function pointer to quickly convert a PyObject* Tensor/NDArray
397- * to a DLManagedTensorVersioned without going through the Python Interpreter .
398+ * to a DLManagedTensorVersioned without going through the Python interpreter .
398399 *
399400 * It also provides an option to query the current context stream of the device provided
400401 * by the tensor.
401402 *
402403 * This function is exposed by the framework through the DLPackExchangeAPI.
403404 *
404- * This information can then be picked up by importers and libraries to run the speed conversion.
405+ * This information can then be picked up by importers and libraries to perform a fast conversion.
405406 * This function should not throw any exceptions; if it fails, it should return -1 and
406407 * set the error message via PyErr_SetXXX.
407408 *
408409 * \param py_object The Python object to convert; this should be PyObject*.
409410 * We use void* to avoid dependency on Python.h.
410411 *
411- * \param max_version The maximum version of DLPack support that consumer supports.
412- * Consumer should fill in their own version here, this parameter is not null.
413- * Producer can use this information to produce the appropriate
414- * DLManagedTensorVersioned for maximum compatibility if needed.
415- * This field is primarily used for future compatibility in case
416- * of major version bump and ABI-breaking changes.
417- *
418412 * \param out The output DLManagedTensorVersioned.
419413 *
420- * \param optional_out_env_stream Outputs the current context stream of the device provided
421- * by the tensor; it can be NULL, in which case the stream will not be queried.
422- * optional_out_env_stream should points to cudaStream_t in the case of CUDA.
414+ * \param optional_out_last_active_stream Outputs the current stream the tensor is synced to.
415+ * It can be NULL, in which case the stream will not be queried.
416+ * optional_out_last_active_stream should point to cudaStream_t in the case of CUDA.
417+ * Note that for frameworks that use a stream context manager, optional_out_last_active_stream
418+ * can be the stream that the context manager was most recently active on.
419+ * The stream is owned by the producer, and the consumer cannot retain it.
420+ * Instead, the consumer can record an event or add wait dependencies to it.
421+ * It is the responsibility of the consumer to synchronize with the stream if necessary.
422+ * The producer may output `reinterpret_cast<void*>(-1)` to indicate that the last active stream
423+ * is not available; in such a case, a device sync is needed to ensure data is ready.
423424 *
424425 * \return 0 on success, -1 on failure. PyError should be set if -1 is returned.
425426 * \note We use void* to avoid dependency on Python.h, so this specific type is
426427 * not dependent on Python.h and can be copied to dlpack.h.
427428 *
428- * \sa DLPackExchangeAPI
429+ * \sa DLPackExchangeAPI, DLPackCurrentWorkStream
429430 */
430431typedef int (*DLPackManagedTensorFromPyObject)( //
431432 void * py_object, //
432- const DLPackVersion* max_version, //
433433 DLManagedTensorVersioned** out, //
434- void ** optional_out_env_stream //
434+ void ** optional_out_last_active_stream //
435+ );
436+
437+ /* !
438+ * \brief Obtain the current work stream of a device.
439+ *
440+ * This function is a C-style function pointer to obtain the current work stream of a device
441+ * for frameworks that rely on a context manager to manage the stream.
442+ * For example, it should map to torch.cuda.current_stream in PyTorch.
443+ *
444+ * This function can be set to NULL if the framework does not rely on a context manager to
445+ * manage the stream.
446+ *
447+ * \param device_type The device type.
448+ * \param device_id The device id.
449+ * \param optional_out_current_stream The output current work stream.
450+ * \return 0 on success, -1 on failure.
451+ *
452+ * \sa DLPackExchangeAPI
453+ */
454+ typedef int (*DLPackCurrentWorkStream)( //
455+ DLDevice device_type, //
456+ DLDevice device_id, //
457+ void ** optional_out_current_stream //
435458);
436459
437460/* !
@@ -457,13 +480,36 @@ typedef int (*DLPackManagedTensorToPyObject)( //
457480/* !
458481 * \brief Framework-specific function pointers table for DLPack exchange.
459482 *
460- * Array/Tensor librarie should statically create and initialize this structure
483+ * Guidelines for leveraging DLPackExchangeAPI:
484+ *
485+ * There are generally two kinds of consumer needs for DLPack exchange:
486+ * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
487+ * with the data from x, y, z. The consumer is also expected to run the kernel with the same
488+ * stream context as the producer. For example, when x, y, z is torch.Tensor,
489+ * consumer should query exchange_api->optional_current_work_stream to get the
490+ * current stream and launch the kernel with the same stream.
491+ * This setup is necessary for no synchronization in kernel launch and maximum compatibility
492+ * with CUDA graph capture in the producer.
493+ * This is the desirable behavior for library extension support for frameworks like PyTorch.
494+ * - N1: data ingestion and retention, in such a case, the consumer is interested in obtaining
495+ * the data from the producer and runs further computation on its own stream.
496+ * In such a case, the consumer can directly query optional_last_active_stream to
497+ * get the last active stream and record a dependency.
498+ *
499+ * Consumer should consider their needs (N0 or N1) and act accordingly based on the
500+ * availability of the function pointer.
501+ *
502+ * Importantly, optional_current_work_stream may be NULL for frameworks that
503+ * do not rely on a context manager to manage the stream, in which case the consumer
504+ * should rely on the information in optional_last_active_stream.
505+ *
506+ * Array/Tensor libraries should statically create and initialize this structure
461507 * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
462- * The DLPackExchangeAPI* should stay alive throughout the lifetime of process.
508+ * The DLPackExchangeAPI* should stay alive throughout the lifetime of the process.
463509 *
464510 * One simple way to do so is to create a static instance of DLPackExchangeAPI
465- * within the framework and return a pointer to it, the following code
466- * shows an example to do so in c ++. It should also be reasonably easy
511+ * within the framework and return a pointer to it. The following code
512+ * shows an example to do so in C ++. It should also be reasonably easy
467513 * to do so in other languages.
468514 *
469515 * \code
@@ -474,6 +520,8 @@ typedef int (*DLPackManagedTensorToPyObject)( //
474520 * managed_tensor_allocator = MyDLPackManagedTensorAllocator;
475521 * managed_tensor_from_py_object = MyDLPackManagedTensorFromPyObject;
476522 * managed_tensor_to_py_object = MyDLPackManagedTensorToPyObject
523+ * optional_current_work_stream = MyDLPackCurrentWorkStream;
524+ * prev_version_api = nullptr;
477525 * }
478526 *
479527 * static const DLPackExchangeAPI* Global() {
@@ -484,9 +532,9 @@ typedef int (*DLPackManagedTensorToPyObject)( //
484532 * \endcode
485533 *
486534 * Each framework should attach a dunder `__c_dlpack_exchange_api__` integer
487- * to point to the pointer of the DLPackExchangeAPI*
535+ * to point to the DLPackExchangeAPI* pointer.
488536 *
489- * Importantly the attributed should be attached to the class of the Tensor, not the instance.
537+ * Importantly, the attribute should be attached to the class of the Tensor, not the instance.
490538 *
491539 * mypackage.Tensor.__c_dlpack_exchange_api__ = MyPackageDLPackExchangeAPI
492540 *
@@ -499,6 +547,14 @@ struct DLPackExchangeAPI {
499547 * \brief The current DLPack version.
500548 */
501549 DLPackVersion version;
550+ /* !
551+ * \brief Optional pointer to an older DLPackExchangeAPI in the chain.
552+ *
553+ * It should be set to NULL if the framework does not support older versions.
554+ *
555+ * \sa DLPackExchangeAPI
556+ */
557+ DLPackExchangeAPI* prev_version_api;
502558 /* !
503559 * \brief Framework-specific function pointer for DLPackManagedTensorAllocator
504560 * \sa DLPackManagedTensorAllocator
@@ -514,6 +570,14 @@ struct DLPackExchangeAPI {
514570 * \sa DLPackManagedTensorToPyObject
515571 */
516572 DLPackManagedTensorToPyObject managed_tensor_to_py_object;
573+ /* !
574+ * \brief Framework-specific function pointer for DLPackCurrentWorkStream
575+ *
576+ * This function can be set to NULL if the framework does not rely on context manager to manage the stream.
577+ *
578+ * \sa DLPackCurrentWorkStream
579+ */
580+ DLPackCurrentWorkStream optional_current_work_stream;
517581};
518582
519583#ifdef __cplusplus
0 commit comments