Skip to content

Commit dcf0f88

Browse files
committed
re-add nnls
1 parent 44dd3ba commit dcf0f88

4 files changed

Lines changed: 96 additions & 1 deletion

File tree

nmf/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include(FetchContent)
99
#add_library(sparse_nmf OBJECT SparseNMFDriver.cpp)
1010
#add_library(dense_nmf OBJECT NMFDriver.cpp)
1111
add_library(nmflib
12-
nmf_lib.cpp bppnmf.cpp
12+
nmf_lib.cpp bppnmf.cpp ../nnls/nnls_lib.cpp
1313
../common/utils.cc ../common/data.cpp)
1414
# add_executable(${DENSE_OR_SPARSE}_inmf inmf.cpp)
1515
# target_compile_features(${DENSE_OR_SPARSE}_inmf PRIVATE cxx_std_17)

nnls/nnls_lib.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//
2+
// Created by andrew on 3/21/2025.
3+
//
4+
5+
#include "nnls_lib.inl"
6+
7+
#define X(T) \
8+
template arma::mat planc::nnlslib<T, double>::runbppnnls(const arma::mat &C, const T &B, const int &ncores);
9+
#include "../nmf/nmf_types.inc"
10+
#undef X
11+
template arma::mat planc::nnlslib<arma::mat>::bppnnls_prod(const arma::mat &CtC, const arma::mat &CtB, const int &ncores);

nnls/nnls_lib.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//
2+
// Created by andrew on 3/21/2025.
3+
//
4+
5+
#ifndef NNLS_LIB_H
6+
#define NNLS_LIB_H
7+
8+
#include "../nmf/nmf_lib.hpp"
9+
10+
namespace planc {
11+
12+
template<typename T, typename eT = typename T::elem_type>
13+
class NMFLIB_EXPORT nnlslib {
14+
15+
public:
16+
nnlslib() {
17+
openblas_pthread_off(get_openblas_handle());
18+
}
19+
20+
~nnlslib() = default;
21+
22+
23+
static arma::mat runbppnnls(const arma::mat &C, const T &B, const int &ncores);
24+
25+
static arma::mat bppnnls_prod(const arma::mat &CtC, const arma::mat &CtB, const int& nCores = 2);
26+
27+
28+
};
29+
30+
} // planc
31+
32+
#endif //NNLS_LIB_H

nnls/nnls_lib.inl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include "nnls_lib.hpp"
4+
#include "utils.hpp"
5+
#include "bppnnls.hpp"
6+
7+
template<typename T, typename eT>
8+
arma::mat planc::nnlslib<T, eT>::runbppnnls(const arma::mat &C, const T &B, const int& ncores) {
9+
arma::uword m_n = B.n_cols;
10+
arma::uword m_k = C.n_cols;
11+
arma::mat CtC = C.t() * C;
12+
arma::mat outmat = arma::zeros<arma::mat>(m_k, m_n);
13+
arma::mat *outmatptr;
14+
outmatptr = &outmat;
15+
arma::uword ONE_THREAD_MATRIX_SIZE = chunk_size_dense<double>(m_k);
16+
unsigned int numChunks = m_n / ONE_THREAD_MATRIX_SIZE;
17+
if (numChunks*ONE_THREAD_MATRIX_SIZE < m_n) numChunks++;
18+
#pragma omp parallel for schedule(dynamic) default(none) shared(numChunks, ONE_THREAD_MATRIX_SIZE, m_n, outmatptr, C, B, CtC) num_threads(ncores)
19+
for (unsigned int i = 0; i < numChunks; i++) {
20+
unsigned int spanStart = i * ONE_THREAD_MATRIX_SIZE;
21+
unsigned int spanEnd = (i + 1) * ONE_THREAD_MATRIX_SIZE - 1;
22+
if (spanEnd > m_n - 1) spanEnd = m_n - 1;
23+
// double start = omp_get_wtime();
24+
arma::mat CtBChunk = C.t() * B.cols(spanStart, spanEnd);
25+
BPPNNLS<arma::mat, arma::vec> solveProblem(CtC, CtBChunk, true);
26+
solveProblem.solveNNLS();
27+
(*outmatptr).cols(spanStart, spanEnd) = solveProblem.getSolutionMatrix();
28+
}
29+
return outmat;
30+
}
31+
32+
template<typename T, typename eT>
33+
arma::mat planc::nnlslib<T, eT>::bppnnls_prod(const arma::mat &CtC, const arma::mat &CtB, const int& nCores) {
34+
arma::uword n = CtB.n_cols;
35+
arma::uword k = CtC.n_cols;
36+
arma::uword ONE_THREAD_MATRIX_SIZE = chunk_size_dense<double>(k);
37+
arma::mat outmat = arma::zeros<arma::mat>(k, n);
38+
arma::mat* outmatptr = &outmat;
39+
unsigned int numChunks = n / ONE_THREAD_MATRIX_SIZE;
40+
if (numChunks*ONE_THREAD_MATRIX_SIZE < n) numChunks++;
41+
#pragma omp parallel for schedule(dynamic) default(none) shared(numChunks, CtB, ONE_THREAD_MATRIX_SIZE, outmatptr, CtC, n) num_threads(nCores)
42+
for (unsigned int i = 0; i < numChunks; i++) {
43+
unsigned int spanStart = i * ONE_THREAD_MATRIX_SIZE;
44+
unsigned int spanEnd = (i + 1) * ONE_THREAD_MATRIX_SIZE - 1;
45+
if (spanEnd > n - 1) spanEnd = n - 1;
46+
arma::mat CtBChunk = CtB.cols(spanStart, spanEnd);
47+
BPPNNLS<arma::mat, arma::vec> solveProblem(CtC, CtBChunk, true);
48+
solveProblem.solveNNLS();
49+
(*outmatptr).cols(spanStart, spanEnd) = solveProblem.getSolutionMatrix();
50+
}
51+
return outmat;
52+
}

0 commit comments

Comments
 (0)