diff --git a/R/slice.R b/R/slice.R index 4f8f4549df..f41609d3f7 100644 --- a/R/slice.R +++ b/R/slice.R @@ -139,16 +139,10 @@ slice_head <- function(.data, ..., n, prop) { #' @export slice_head.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) - idx <- function(n) { - to <- size(n) - if (to > n) { - to <- n - } - seq2(1, to) - } - dplyr_local_error_call() - slice(.data, idx(dplyr::n())) + group_idx <- group_rows(.data) + slice_idx <- lapply(group_idx, function(x) head(x, size(length(x)))) + dplyr_row_slice(.data, unlist(slice_idx)) } #' @export @@ -162,16 +156,10 @@ slice_tail <- function(.data, ..., n, prop) { #' @export slice_tail.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) - idx <- function(n) { - from <- n - size(n) + 1 - if (from < 1L) { - from <- 1L - } - seq2(from, n) - } - dplyr_local_error_call() - slice(.data, idx(dplyr::n())) + group_idx <- group_rows(.data) + slice_idx <- lapply(group_idx, function(x) tail(x, size(length(x)))) + dplyr_row_slice(.data, unlist(slice_idx)) } #' @export @@ -265,16 +253,20 @@ slice_sample <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) slice_sample.data.frame <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) { size <- get_slice_size(n = n, prop = prop, allow_negative = FALSE) - dplyr_local_error_call() - slice(.data, local({ - weight_by <- {{ weight_by }} + if (!missing(weight_by)) { + weight_by <- transmute(.data, ..weight_by = {{ weight_by }})[[1]] + } - n <- dplyr::n() - if (!is.null(weight_by)) { - weight_by <- vec_assert(weight_by, size = n, arg = "weight_by") - } - sample_int(n, size(n), replace = replace, wt = weight_by) - })) + group_idx <- group_rows(.data) + slice_idx <- vector("list", length(group_idx)) + for (i in seq_along(group_idx)) { + idx <- group_idx[[i]] + n <- size(length(idx)) + + slice_idx[[i]] <- sample_int(idx, n, replace = replace, wt = weight_by[idx]) + } + + dplyr_row_slice(.data, unlist(slice_idx)) } # helpers ----------------------------------------------------------------- @@ -466,15 +458,15 @@ get_slice_size <- function(n, prop, allow_negative = TRUE, error_call = caller_e } } -sample_int <- function(n, size, replace = FALSE, wt = NULL, call = caller_env()) { - if (!replace && n < size) { - size <- n +sample_int <- function(x, size, replace = FALSE, wt = NULL, call = caller_env()) { + if (!replace && length(x) < size) { + size <- length(x) } if (size == 0L) { - integer(0) + x[integer(0)] } else { - sample.int(n, size, prob = wt, replace = replace) + x[sample.int(length(x), size, prob = wt, replace = replace)] } } diff --git a/tests/testthat/_snaps/slice.md b/tests/testthat/_snaps/slice.md index 5e246eba73..3cbd141b2c 100644 --- a/tests/testthat/_snaps/slice.md +++ b/tests/testthat/_snaps/slice.md @@ -234,10 +234,9 @@ Code slice_sample(df, n = 2, weight_by = 1:6) Condition - Error in `slice_sample()`: - ! Problem while computing indices. - Caused by error: - ! `weight_by` must have size 10, not size 6. + Error in `transmute()`: + ! Problem while computing `..weight_by = 1:6`. + x `..weight_by` must be size 10 or 1, not 6. # `slice_sample()` validates `replace` diff --git a/tests/testthat/test-slice.r b/tests/testthat/test-slice.r index 6da205ff3c..7e6cb899d3 100644 --- a/tests/testthat/test-slice.r +++ b/tests/testthat/test-slice.r @@ -186,7 +186,7 @@ test_that("slice_*() checks that `n=` is explicitly named and ... is empty", { test_that("slice_helpers do call slice() and benefit from dispatch (#6084)", { local_methods( - slice.noisy = function(.data, ..., .preserve = FALSE) { + dplyr_row_slice.noisy = function(.data, ..., .preserve = FALSE) { warning("noisy") NextMethod() }