-
Couldn't load subscription status.
- Fork 35
Move to nanobind #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Intron7
wants to merge
54
commits into
main
Choose a base branch
from
move-to-nanobind
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Move to nanobind #455
Changes from 17 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
7d2c62b
add precommit
Intron7 7bc5367
add first implementation
Intron7 dc76b24
clang format
Intron7 dc3648b
format
Intron7 ef8a756
format c++
Intron7 76ae7aa
Merge branch 'main' into move-to-nanobind
Intron7 7ed481d
change clang-format
Intron7 48ba592
fix version
Intron7 1d2f12a
test docs
Intron7 60c4863
fix yml
Intron7 9a2b113
fix sparse to dense kernel launch
Intron7 2d5ea85
test read the docs
Intron7 9878e5d
try env
Intron7 6b46e8a
test cmakeargs
Intron7 55027f7
add system back
Intron7 b780405
add failsafe
Intron7 24104ff
remove print and slim down toml
Intron7 dddd9e8
Add almost unchanged cibw
flying-sheep 5981d50
No macOS
flying-sheep b3c3853
test build wheels
Intron7 b24bf7b
next
Intron7 56aca24
remove wheels workflow
Intron7 7068b19
remove windows
Intron7 9a58ff0
remove optional parts
flying-sheep 7f65657
test publish
Intron7 56f837a
3.12
flying-sheep 3a9a9f1
fix path
Intron7 474de68
remove bad/useless
flying-sheep 646ba23
fix container
Intron7 ae57cb1
try CUDA_PATH
flying-sheep 33ac5af
skip musl again
flying-sheep a529a58
add next kernels
Intron7 0685436
add pca and make safe docs
Intron7 24551bd
Merge branch 'main' into move-to-nanobind
Intron7 5d327bd
make aggr safe
Intron7 30414ab
add harmony
Intron7 d46ab83
make qc smaller
Intron7 d45d6bf
add ligrec
Intron7 20cf11e
move decoupler
Intron7 134d2e0
remove rawkernels
Intron7 a872962
add release note
Intron7 2825de7
fix shape qc
Intron7 66e930f
fix entropy
Intron7 d386000
fix version
Intron7 cfdec19
add streams
Intron7 4876400
Merge branch 'main' into move-to-nanobind
Intron7 3a20dc2
fix pointer
Intron7 948b86a
fix test
Intron7 3fdde98
terse args
flying-sheep 8abaab0
kw-only for aggr.cu
flying-sheep 84a34c4
remaining cleanup
flying-sheep e53c87a
add keywords
Intron7 ad7ed53
fix keywords ligrec
Intron7 a62a596
add 120
Intron7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| BasedOnStyle: Google | ||
| Language: Cpp | ||
|
|
||
| # Make braces stay on the same line (like your diffs) | ||
| BreakBeforeBraces: Attach | ||
| AllowShortFunctionsOnASingleLine: None | ||
|
|
||
| # Compact/“binpack” parameter lists (what produced your earlier diffs) | ||
| BinPackParameters: true | ||
| BinPackArguments: true | ||
|
|
||
| # Typical CUDA/C++ ergonomics | ||
| IndentWidth: 2 | ||
| ColumnLimit: 100 | ||
| PointerAlignment: Left | ||
| DerivePointerAlignment: false | ||
|
|
||
| # Don’t reorder #includes if you don’t want surprise churn | ||
| SortIncludes: false | ||
|
|
||
| # Optional: make templates break more aggressively | ||
| AlwaysBreakTemplateDeclarations: Yes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| cmake_minimum_required(VERSION 3.24) | ||
|
|
||
| project(rapids_singlecell_cuda LANGUAGES CXX) | ||
|
|
||
| # Option to disable building compiled extensions (for docs/RTD) | ||
| option(RSC_BUILD_EXTENSIONS "Build CUDA/C++ extensions" ON) | ||
|
|
||
| set(CMAKE_CXX_STANDARD 17) | ||
| set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
| set(CMAKE_POSITION_INDEPENDENT_CODE ON) | ||
|
|
||
| if (RSC_BUILD_EXTENSIONS) | ||
| enable_language(CUDA) | ||
| find_package(Python REQUIRED COMPONENTS Interpreter Development.Module) | ||
| find_package(nanobind CONFIG REQUIRED) | ||
| find_package(CUDAToolkit REQUIRED) | ||
| else() | ||
| message(STATUS "RSC_BUILD_EXTENSIONS=OFF -> skipping compiled extensions for docs") | ||
| endif() | ||
|
|
||
| # Helper to declare a nanobind CUDA module uniformly | ||
| function(add_nb_cuda_module target src) | ||
| if (RSC_BUILD_EXTENSIONS) | ||
| nanobind_add_module(${target} STABLE_ABI LTO | ||
| ${src} | ||
| ) | ||
| target_link_libraries(${target} PRIVATE CUDA::cudart) | ||
| set_target_properties(${target} PROPERTIES | ||
| CUDA_SEPARABLE_COMPILATION ON | ||
| ) | ||
| install(TARGETS ${target} LIBRARY DESTINATION rapids_singlecell/_cuda) | ||
| # Also copy built module into source tree for editable installs | ||
| add_custom_command(TARGET ${target} POST_BUILD | ||
| COMMAND ${CMAKE_COMMAND} -E copy | ||
| $<TARGET_FILE:${target}> | ||
| ${PROJECT_SOURCE_DIR}/src/rapids_singlecell/_cuda/$<TARGET_FILE_NAME:${target}> | ||
| ) | ||
| endif() | ||
| endfunction() | ||
|
|
||
| if (RSC_BUILD_EXTENSIONS) | ||
| # CUDA modules | ||
| add_nb_cuda_module(_mean_var_cuda src/rapids_singlecell/_cuda/mean_var/mean_var.cu) | ||
| add_nb_cuda_module(_sparse2dense_cuda src/rapids_singlecell/_cuda/sparse2dense/sparse2dense.cu) | ||
| add_nb_cuda_module(_scale_cuda src/rapids_singlecell/_cuda/scale/scale.cu) | ||
| add_nb_cuda_module(_qc_cuda src/rapids_singlecell/_cuda/qc/qc.cu) | ||
| add_nb_cuda_module(_qc_dask_cuda src/rapids_singlecell/_cuda/qc_dask/qc_kernels_dask.cu) | ||
| endif() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from __future__ import annotations | ||
|
|
||
| # Subpackage for CUDA extensions (built via scikit-build-core/nanobind) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| #pragma once | ||
|
|
||
| #include <cuda_runtime.h> | ||
|
|
||
| template <typename T> | ||
| __global__ void mean_var_major_kernel(const int* __restrict__ indptr, | ||
| const int* __restrict__ indices, const T* __restrict__ data, | ||
| double* __restrict__ means, double* __restrict__ vars, | ||
| int major, int /*minor*/) { | ||
| int major_idx = blockIdx.x; | ||
| if (major_idx >= major) return; | ||
|
|
||
| int start_idx = indptr[major_idx]; | ||
| int stop_idx = indptr[major_idx + 1]; | ||
|
|
||
| __shared__ double mean_place[64]; | ||
| __shared__ double var_place[64]; | ||
|
|
||
| mean_place[threadIdx.x] = 0.0; | ||
| var_place[threadIdx.x] = 0.0; | ||
| __syncthreads(); | ||
|
|
||
| for (int minor_idx = start_idx + threadIdx.x; minor_idx < stop_idx; minor_idx += blockDim.x) { | ||
| double value = static_cast<double>(data[minor_idx]); | ||
| mean_place[threadIdx.x] += value; | ||
| var_place[threadIdx.x] += value * value; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { | ||
| if (threadIdx.x < s) { | ||
| mean_place[threadIdx.x] += mean_place[threadIdx.x + s]; | ||
| var_place[threadIdx.x] += var_place[threadIdx.x + s]; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
| if (threadIdx.x == 0) { | ||
| means[major_idx] = mean_place[0]; | ||
| vars[major_idx] = var_place[0]; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void mean_var_minor_kernel(const int* __restrict__ indices, const T* __restrict__ data, | ||
| double* __restrict__ means, double* __restrict__ vars, | ||
| int nnz) { | ||
| int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
| if (idx >= nnz) return; | ||
| double value = static_cast<double>(data[idx]); | ||
| int minor_pos = indices[idx]; | ||
| atomicAdd(&means[minor_pos], value); | ||
| atomicAdd(&vars[minor_pos], value * value); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| #include <cuda_runtime.h> | ||
| #include <nanobind/nanobind.h> | ||
| #include <cstdint> | ||
|
|
||
| #include "kernels.cuh" | ||
|
|
||
| namespace nb = nanobind; | ||
| using nb::handle; | ||
|
|
||
| template <typename T> | ||
| static inline void launch_mean_var_major(std::uintptr_t indptr_ptr, std::uintptr_t indices_ptr, | ||
| std::uintptr_t data_ptr, std::uintptr_t means_ptr, | ||
| std::uintptr_t vars_ptr, int major, int minor) { | ||
| dim3 block(64); | ||
| dim3 grid(major); | ||
| const int* indptr = reinterpret_cast<const int*>(indptr_ptr); | ||
| const int* indices = reinterpret_cast<const int*>(indices_ptr); | ||
| const T* data = reinterpret_cast<const T*>(data_ptr); | ||
| double* means = reinterpret_cast<double*>(means_ptr); | ||
| double* vars = reinterpret_cast<double*>(vars_ptr); | ||
| mean_var_major_kernel<T><<<grid, block>>>(indptr, indices, data, means, vars, major, minor); | ||
| } | ||
|
|
||
| template <typename T> | ||
| static inline void launch_mean_var_minor(std::uintptr_t indices_ptr, std::uintptr_t data_ptr, | ||
| std::uintptr_t means_ptr, std::uintptr_t vars_ptr, | ||
| int nnz) { | ||
| int block = 256; | ||
| int grid = (nnz + block - 1) / block; | ||
| const int* indices = reinterpret_cast<const int*>(indices_ptr); | ||
| const T* data = reinterpret_cast<const T*>(data_ptr); | ||
| double* means = reinterpret_cast<double*>(means_ptr); | ||
| double* vars = reinterpret_cast<double*>(vars_ptr); | ||
| mean_var_minor_kernel<T><<<grid, block>>>(indices, data, means, vars, nnz); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void mean_var_major_api(std::uintptr_t indptr, std::uintptr_t indices, std::uintptr_t data, | ||
| std::uintptr_t means, std::uintptr_t vars, int major, int minor) { | ||
| launch_mean_var_major<T>(indptr, indices, data, means, vars, major, minor); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void mean_var_minor_api(std::uintptr_t indices, std::uintptr_t data, std::uintptr_t means, | ||
| std::uintptr_t vars, int nnz) { | ||
| launch_mean_var_minor<T>(indices, data, means, vars, nnz); | ||
| } | ||
|
|
||
| NB_MODULE(_mean_var_cuda, m) { | ||
| m.def("mean_var_major_f32", &mean_var_major_api<float>); | ||
| m.def("mean_var_major_f64", &mean_var_major_api<double>); | ||
| m.def("mean_var_minor_f32", &mean_var_minor_api<float>); | ||
| m.def("mean_var_minor_f64", &mean_var_minor_api<double>); | ||
|
|
||
| m.def("mean_var_major", | ||
| [](std::uintptr_t indptr, std::uintptr_t indices, std::uintptr_t data, std::uintptr_t means, | ||
| std::uintptr_t vars, int major, int minor, int itemsize) { | ||
| if (itemsize == 4) { | ||
| mean_var_major_api<float>(indptr, indices, data, means, vars, major, minor); | ||
| } else if (itemsize == 8) { | ||
| mean_var_major_api<double>(indptr, indices, data, means, vars, major, minor); | ||
| } else { | ||
| throw nb::value_error("Unsupported itemsize for mean_var_major (expected 4 or 8)"); | ||
| } | ||
| }); | ||
| m.def("mean_var_minor", [](std::uintptr_t indices, std::uintptr_t data, std::uintptr_t means, | ||
| std::uintptr_t vars, int nnz, int itemsize) { | ||
| if (itemsize == 4) { | ||
| mean_var_minor_api<float>(indices, data, means, vars, nnz); | ||
| } else if (itemsize == 8) { | ||
| mean_var_minor_api<double>(indices, data, means, vars, nnz); | ||
| } else { | ||
| throw nb::value_error("Unsupported itemsize for mean_var_minor (expected 4 or 8)"); | ||
| } | ||
| }); | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.