Skip to content

Compatible Implementation on CU118#17

Open
Y-L-LIU wants to merge 3 commits intoJeffreyXiang:mainfrom
Y-L-LIU:main
Open

Compatible Implementation on CU118#17
Y-L-LIU wants to merge 3 commits intoJeffreyXiang:mainfrom
Y-L-LIU:main

Conversation

@Y-L-LIU
Copy link

@Y-L-LIU Y-L-LIU commented Dec 29, 2025

As some issues stated https://github.com/microsoft/TRELLIS.2/issues/34, current implementation uses some new features on CUB on cu124. I changes the implementation so that the code could be correctly built on cu118.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like remove_duplicate_faces() has gone through a significant refactor and now introduces Thrust as a dependency.

Could you share more context on the motivation behind this change? Specifically, which CUB feature is broken or unavailable in CUDA 11.8 that necessitated this switch?

Also, were any alternative approaches considered to avoid adding a new dependency?

@JeffreyXiang
Copy link
Owner

Thanks for your contribution! I have some concerns before merging the change. It would be great if you could add more context.

@Y-L-LIU
Copy link
Author

Y-L-LIU commented Dec 29, 2025

Thanks for your timely response. This PR addresses compatibility issues with CUDA 11.8 by updating two CUB-based operations that rely on newer API features not available in older CUDA toolkits.:

  1. Fix for cub::DeviceScan::ExclusiveSum

Problem: Compilation failed with the error:

no instance of function template "cub::DeviceScan::ExclusiveSum" matches the argument list
argument types are: (std::nullptr_t, size_t, int *, unsigned long)

This occurred because CUDA 11.8’s bundled CUB version requires explicit input and output iterators for ExclusiveSum. The original code used an in-place pattern that relies on overloads only available in CUDA 12.4+.

Solution:
All calls to ExclusiveSum have been updated to explicitly pass the same pointer as both input (d_in) and output (d_out), ensuring valid in-place execution on CUDA 11.8 while preserving semantics.

  1. Replacement of cub::DeviceRadixSort::SortPairs with Thrust

Problem: Compilation failed due to:

no instance of overloaded function "cub::DeviceRadixSort::SortPairs" matches the argument list argument types are: (char *, size_t, int3 *, int3 *, int *, int *, size_t, cumesh::int3_decomposer)

The compilation error occurs because cub::DeviceRadixSort operates on bitwise representations. While newer versions of CUB allow flexible custom Decomposers to handle structs like int3, the version included in CUDA 11.8 is strictly limited to arithmetic types (or types with explicit UnsignedBits traits). The int3_decomposer pattern used in the original code does not match the template signature required by the older CUB API, causing the no instance matches argument list error.

I attempted to stay within CUB by using cub::DeviceMergeSort with a custom comparator (which usually handles structs better than RadixSort). However, this also failed to compile on CUDA 11.8 due to issues resolving the comparator for the int3 type within the device lambda context of that specific CUB version.

Solution:
We replaced the CUB-based sorting with Thrust:

  • No external dependency: Thrust is part of the standard CUDA Toolkit, just like CUB.
  • Cross-version robustness: thrust::sort works reliably with custom structs (e.g., int3) via standard iterator and comparator patterns, avoiding bitwise decomposition constraints.
  • Improved clarity: The new implementation removes complex scattering/masking logic previously needed to fit int3 into CUB’s radix framework, resulting in cleaner, more maintainable code.

I’m more than happy to run additional experiments to confirm the stability and correctness of these changes.

@JeffreyXiang
Copy link
Owner

I see. Then let’s use Thrust for CUDA < 12.4.
Could you use CUDA version macros (see example below) to isolate this code path, so that it does not affect the behavior for CUDA ≥ 12.4 before it is fully tested? Then I can merge it into main.

Example:

#if defined(CUDART_VERSION) && (CUDART_VERSION < 12040)
#include <thrust/...>
#endif

...

#if defined(CUDART_VERSION) && (CUDART_VERSION < 12040)
// CUDA < 12.4: use Thrust implementation
// Thrust-based code here
#else
// CUDA >= 12.4: keep existing behavior
// current implementation
#endif

This way we can keep the new path well isolated and minimize risk for newer CUDA versions.

@Y-L-LIU
Copy link
Author

Y-L-LIU commented Feb 10, 2026

Implemented as requested: the Thrust-based remove_duplicate_faces path is now isolated behind #if defined(CUDART_VERSION) && (CUDART_VERSION < 12040), and CUDA >= 12.4 keeps the original CUB DeviceRadixSort::SortPairs behavior.

Validation:

  • Existing example run completed.
  • Additional duplicate-face regression test passed:
    • before: 4 3
    • after: 4 2
  • Larger randomized stress test (including exact duplicates + permuted duplicates + repeated duplicates) passed with assertion checks.

Please take another look when convenient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants