1+ #pragma once
2+
3+ #include " nnls_lib.hpp"
4+ #include " utils.hpp"
5+ #include " bppnnls.hpp"
6+
7+ template <typename T, typename eT>
8+ arma::mat planc::nnlslib<T, eT>::runbppnnls(const arma::mat &C, const T &B, const int & ncores) {
9+ arma::uword m_n = B.n_cols ;
10+ arma::uword m_k = C.n_cols ;
11+ arma::mat CtC = C.t () * C;
12+ arma::mat outmat = arma::zeros<arma::mat>(m_k, m_n);
13+ arma::mat *outmatptr;
14+ outmatptr = &outmat;
15+ arma::uword ONE_THREAD_MATRIX_SIZE = chunk_size_dense<double >(m_k);
16+ unsigned int numChunks = m_n / ONE_THREAD_MATRIX_SIZE;
17+ if (numChunks*ONE_THREAD_MATRIX_SIZE < m_n) numChunks++;
18+ #pragma omp parallel for schedule(dynamic) default(none) shared(numChunks, ONE_THREAD_MATRIX_SIZE, m_n, outmatptr, C, B, CtC) num_threads(ncores)
19+ for (unsigned int i = 0 ; i < numChunks; i++) {
20+ unsigned int spanStart = i * ONE_THREAD_MATRIX_SIZE;
21+ unsigned int spanEnd = (i + 1 ) * ONE_THREAD_MATRIX_SIZE - 1 ;
22+ if (spanEnd > m_n - 1 ) spanEnd = m_n - 1 ;
23+ // double start = omp_get_wtime();
24+ arma::mat CtBChunk = C.t () * B.cols (spanStart, spanEnd);
25+ BPPNNLS<arma::mat, arma::vec> solveProblem (CtC, CtBChunk, true );
26+ solveProblem.solveNNLS ();
27+ (*outmatptr).cols (spanStart, spanEnd) = solveProblem.getSolutionMatrix ();
28+ }
29+ return outmat;
30+ }
31+
32+ template <typename T, typename eT>
33+ arma::mat planc::nnlslib<T, eT>::bppnnls_prod(const arma::mat &CtC, const arma::mat &CtB, const int & nCores) {
34+ arma::uword n = CtB.n_cols ;
35+ arma::uword k = CtC.n_cols ;
36+ arma::uword ONE_THREAD_MATRIX_SIZE = chunk_size_dense<double >(k);
37+ arma::mat outmat = arma::zeros<arma::mat>(k, n);
38+ arma::mat* outmatptr = &outmat;
39+ unsigned int numChunks = n / ONE_THREAD_MATRIX_SIZE;
40+ if (numChunks*ONE_THREAD_MATRIX_SIZE < n) numChunks++;
41+ #pragma omp parallel for schedule(dynamic) default(none) shared(numChunks, CtB, ONE_THREAD_MATRIX_SIZE, outmatptr, CtC, n) num_threads(nCores)
42+ for (unsigned int i = 0 ; i < numChunks; i++) {
43+ unsigned int spanStart = i * ONE_THREAD_MATRIX_SIZE;
44+ unsigned int spanEnd = (i + 1 ) * ONE_THREAD_MATRIX_SIZE - 1 ;
45+ if (spanEnd > n - 1 ) spanEnd = n - 1 ;
46+ arma::mat CtBChunk = CtB.cols (spanStart, spanEnd);
47+ BPPNNLS<arma::mat, arma::vec> solveProblem (CtC, CtBChunk, true );
48+ solveProblem.solveNNLS ();
49+ (*outmatptr).cols (spanStart, spanEnd) = solveProblem.getSolutionMatrix ();
50+ }
51+ return outmat;
52+ }
0 commit comments