Skip to content

Commit 873db84

Browse files
authored
re-arrange omp parallel region to make more efficient memory allocattions (#75)
* re-arrange omp parallel region to make more efficient memory allocations. Related to #72 * optimize R code, avoid double work in transform * ignore bench files * update github actions * fix accidentally introduced segfault * run CI only for master * - update readme - update NEWS * simplify r cmd check options
1 parent aa9cf58 commit 873db84

File tree

9 files changed

+132
-111
lines changed

9 files changed

+132
-111
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ docs/
1414
extradata/
1515
revdep/
1616
^CRAN-SUBMISSION$
17+
bench/

.github/FUNDING.yml

Lines changed: 0 additions & 2 deletions
This file was deleted.

.github/workflows/R-CMD-check.yaml

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,45 @@
1-
# For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag.
2-
# https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
2+
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
on:
44
push:
5-
branches:
6-
- master
5+
branches: [master]
76
pull_request:
8-
branches:
9-
- master
7+
branches: [master]
108

119
name: R-CMD-check
1210

1311
jobs:
1412
R-CMD-check:
15-
runs-on: macOS-latest
13+
runs-on: ubuntu-latest
14+
15+
name: (${{ matrix.config.r }})
16+
17+
strategy:
18+
fail-fast: false
19+
matrix:
20+
config:
21+
- {r: 'devel'}
22+
# minimal required R version
23+
- {r: '3.6.0'}
1624
env:
1725
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
26+
R_KEEP_PKG_SOURCE: yes
27+
1828
steps:
19-
- uses: actions/checkout@v2
20-
- uses: r-lib/actions/setup-r@v1
21-
- name: Install dependencies
22-
run: |
23-
install.packages(c("remotes", "rcmdcheck", "Matrix"))
24-
remotes::install_deps(dependencies = TRUE)
25-
shell: Rscript {0}
26-
- name: Check
27-
run: rcmdcheck::rcmdcheck(args = "--no-manual", error_on = "error")
28-
shell: Rscript {0}
29+
- uses: actions/checkout@v3
30+
31+
- uses: r-lib/actions/setup-pandoc@v2
32+
33+
- uses: r-lib/actions/setup-r@v2
34+
with:
35+
r-version: ${{ matrix.config.r }}
36+
use-public-rspm: true
37+
38+
- uses: r-lib/actions/setup-r-dependencies@v2
39+
with:
40+
extra-packages: any::rcmdcheck, any::Matrix
41+
needs: check
42+
43+
- uses: r-lib/actions/check-r-package@v2
44+
with:
45+
upload-snapshots: true
Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,31 @@
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
2+
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
13
on:
24
push:
3-
branches:
4-
- master
5+
branches: [master]
56
pull_request:
6-
branches:
7-
- master
7+
branches: [master]
88

99
name: test-coverage
1010

1111
jobs:
1212
test-coverage:
13-
runs-on: macOS-latest
13+
runs-on: ubuntu-latest
1414
env:
1515
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
16-
steps:
17-
18-
- uses: actions/checkout@v2
19-
20-
- uses: r-lib/actions/setup-r@master
21-
22-
- uses: r-lib/actions/setup-pandoc@master
2316

24-
- name: Query dependencies
25-
run: |
26-
install.packages('remotes')
27-
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
28-
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
29-
shell: Rscript {0}
17+
steps:
18+
- uses: actions/checkout@v3
3019

31-
- name: Cache R packages
32-
uses: actions/cache@v1
20+
- uses: r-lib/actions/setup-r@v2
3321
with:
34-
path: ${{ env.R_LIBS_USER }}
35-
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
36-
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
22+
use-public-rspm: true
3723

38-
- name: Install dependencies
39-
run: |
40-
install.packages(c("remotes", "Matrix"))
41-
remotes::install_deps(dependencies = TRUE)
42-
remotes::install_cran("covr")
43-
shell: Rscript {0}
24+
- uses: r-lib/actions/setup-r-dependencies@v2
25+
with:
26+
extra-packages: any::covr, any::Matrix
27+
needs: coverage
4428

4529
- name: Test coverage
46-
run: covr::codecov()
30+
run: covr::codecov(quiet = FALSE)
4731
shell: Rscript {0}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ autom4te.cache
1313
src/Makevars
1414
revdep
1515
.Rprofile
16+
bench/

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# rsparse dev
2+
- faster WRMF solver see #72, #75
3+
- updated github actions
4+
15
# rsparse 0.5.1 (2022-09-11)
26
- update `configure` script, thanks to @david-cortes, see #73
37
- minor fixes in WRMF

R/model_WRMF.R

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,15 @@ WRMF = R6::R6Class(
180180
RhpcBLASctl::blas_set_num_threads(blas_threads_keep)
181181
})
182182
}
183-
183+
logger$debug("converting input user-item matrix")
184184
c_ui = MatrixExtra::as.csc.matrix(x)
185+
# c_ui = as(x, "CsparseMatrix")
186+
logger$debug("pre-processing input")
185187
c_ui = private$preprocess(c_ui)
186-
c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(x))
188+
logger$debug("creating item-user matrix")
189+
c_iu = MatrixExtra::t_shallow(MatrixExtra::as.csr.matrix(c_ui))
190+
# c_iu = t(c_ui)
191+
logger$debug("created item-user matrix")
187192
# store item_ids in order to use them in predict method
188193
private$item_ids = colnames(c_ui)
189194

@@ -195,7 +200,7 @@ WRMF = R6::R6Class(
195200
n_user = nrow(c_ui)
196201
n_item = ncol(c_ui)
197202

198-
logger$trace("initializing U")
203+
logger$debug("initializing U")
199204
if (private$precision == "double") {
200205
private$U = large_rand_matrix(private$rank, n_user)
201206
# for item biases
@@ -210,7 +215,7 @@ WRMF = R6::R6Class(
210215
}
211216

212217
if (is.null(self$components)) {
213-
218+
logger$debug("initializing components")
214219
if (private$solver_code == 1L) { ### <- cholesky
215220
if (private$precision == "double") {
216221
self$components = matrix(0, private$rank, n_item)
@@ -331,6 +336,7 @@ WRMF = R6::R6Class(
331336

332337
loss_prev_iter = loss
333338
}
339+
logger$debug("solver finished")
334340

335341
if (private$precision == "double")
336342
data.table::setattr(self$components, "dimnames", list(NULL, colnames(x)))
@@ -341,12 +347,16 @@ WRMF = R6::R6Class(
341347
rank_ = ifelse(private$with_user_item_bias, private$rank - 1L, private$rank)
342348
ridge = fl(diag(x = private$lambda, nrow = rank_, ncol = rank_))
343349
XX = if (private$with_user_item_bias) self$components[-1L, , drop = FALSE] else self$components
350+
351+
RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
344352
private$XtX = tcrossprod(XX) + ridge
353+
RhpcBLASctl::blas_set_num_threads(1)
345354

346355
# call extra transform to ensure results from transform() and fit_transform()
347356
# are the same (due to avoid_cg, etc)
348357
# this adds some extra computation, but not a big deal though
349-
self$transform(x)
358+
# self$transform(x)
359+
private$transform_(c_iu, ...)
350360
},
351361
# project new users into latent user space - just make ALS step given fixed items matrix
352362
#' @description create user embeddings for new input
@@ -366,6 +376,41 @@ WRMF = R6::R6Class(
366376
x = MatrixExtra::t_shallow(x)
367377
}
368378

379+
x = private$preprocess(x)
380+
381+
if (self$global_bias != 0. && private$feedback == "explicit")
382+
x@x = x@x - self$global_bias
383+
384+
private$transform_(x, ...)
385+
}
386+
),
387+
#### private -----
388+
private = list(
389+
solver_code = NULL,
390+
cg_steps = NULL,
391+
scorers = NULL,
392+
lambda = NULL,
393+
dynamic_lambda = FALSE,
394+
rank = NULL,
395+
non_negative = NULL,
396+
cnt_u = NULL,
397+
# user factor matrix = rank * n_users
398+
U = NULL,
399+
# item factor matrix = rank * n_items
400+
I = NULL,
401+
# preprocess - transformation of input matrix before passing it to ALS
402+
# for example we can scale each row or apply log() to values
403+
# this is essentially "confidence" transformation from WRMF article
404+
preprocess = NULL,
405+
feedback = NULL,
406+
precision = NULL,
407+
XtX = NULL,
408+
solver = NULL,
409+
with_user_item_bias = NULL,
410+
with_global_bias = NULL,
411+
init_user_item_bias = NULL,
412+
transform_ = function(x, ...) {
413+
logger$debug('starting transform')
369414
if (private$feedback == "implicit" ) {
370415
logger$trace("WRMF$transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)")
371416
blas_threads_keep = RhpcBLASctl::blas_get_num_procs()
@@ -375,11 +420,6 @@ WRMF = R6::R6Class(
375420
RhpcBLASctl::blas_set_num_threads(blas_threads_keep)
376421
})
377422
}
378-
379-
x = private$preprocess(x)
380-
if (self$global_bias != 0. && private$feedback == "explicit")
381-
x@x = x@x - self$global_bias
382-
383423
if (private$precision == "double") {
384424
res = matrix(0, nrow = private$rank, ncol = ncol(x))
385425
} else {
@@ -389,7 +429,7 @@ WRMF = R6::R6Class(
389429
if (private$with_user_item_bias) {
390430
res[1, ] = if(private$precision == "double") 1.0 else float::fl(1.0)
391431
}
392-
432+
logger$debug('starting transform solver')
393433
loss = private$solver(
394434
x,
395435
self$components,
@@ -399,42 +439,17 @@ WRMF = R6::R6Class(
399439
cnt_X = private$cnt_u,
400440
avoid_cg = TRUE
401441
)
442+
logger$debug('finished transform solver')
402443

403444
res = t(res)
404445

405446
if (private$precision == "double")
406447
setattr(res, "dimnames", list(colnames(x), NULL))
407448
else
408449
setattr(res@Data, "dimnames", list(colnames(x), NULL))
409-
450+
logger$debug('finished transform')
410451
res
411452
}
412-
),
413-
#### private -----
414-
private = list(
415-
solver_code = NULL,
416-
cg_steps = NULL,
417-
scorers = NULL,
418-
lambda = NULL,
419-
dynamic_lambda = FALSE,
420-
rank = NULL,
421-
non_negative = NULL,
422-
cnt_u = NULL,
423-
# user factor matrix = rank * n_users
424-
U = NULL,
425-
# item factor matrix = rank * n_items
426-
I = NULL,
427-
# preprocess - transformation of input matrix before passing it to ALS
428-
# for example we can scale each row or apply log() to values
429-
# this is essentially "confidence" transformation from WRMF article
430-
preprocess = NULL,
431-
feedback = NULL,
432-
precision = NULL,
433-
XtX = NULL,
434-
solver = NULL,
435-
with_user_item_bias = NULL,
436-
with_global_bias = NULL,
437-
init_user_item_bias = NULL
438453
)
439454
)
440455

@@ -465,7 +480,9 @@ als_implicit = function(
465480
} else {
466481
XX = X
467482
}
483+
RhpcBLASctl::blas_set_num_threads(RhpcBLASctl::get_num_cores())
468484
XtX = tcrossprod(XX) + ridge
485+
RhpcBLASctl::blas_set_num_threads(1)
469486
}
470487
if (is.null(global_bias_base)) {
471488
global_bias_base = numeric()

README.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@
1111

1212
We've paid some attention to the implementation details - we try to avoid data copies, utilize multiple threads via OpenMP and use SIMD where appropriate. Package **allows to work on datasets with millions of rows and millions of columns**.
1313

14-
15-
### Support
16-
17-
Please reach us if you need **commercial support** - [hello@rexy.ai](mailto:hello@rexy.ai).
18-
19-
20-
2114
# Features
2215

2316
### Classification/Regression

0 commit comments

Comments
 (0)