diff --git a/R/pkg-arrow.R b/R/pkg-arrow.R index d38837fa..4b0fb8f7 100644 --- a/R/pkg-arrow.R +++ b/R/pkg-arrow.R @@ -291,6 +291,15 @@ arrow_funs[["grepl"]] <- function(pattern, x, ...) { substrait_call("string.contains", x, pattern) } +arrow_funs[["n"]] <- function() { + substrait_call_agg( + "aggregate_generic.count", + .output_type = substrait_i64(), + .phase = 3L, + .invocation = 1L + ) +} + check_na_rm <- function(na.rm) { if (!na.rm) { warning("Missing value removal from aggregate functions not yet supported, switching to na.rm = TRUE") diff --git a/tests/testthat/test-pkg-arrow.R b/tests/testthat/test-pkg-arrow.R index bb85296f..d40c3548 100644 --- a/tests/testthat/test-pkg-arrow.R +++ b/tests/testthat/test-pkg-arrow.R @@ -662,3 +662,16 @@ test_that("arrow translation for if_else() works", { ) ) }) + +test_that("arrow translation for n() works", { + skip_if_not(has_arrow_with_substrait()) + + expect_identical( + tibble::tibble(lgl = c(NA, TRUE, NA, TRUE, FALSE, FALSE, NA, TRUE, FALSE, TRUE)) %>% + arrow_substrait_compiler() %>% + dplyr::group_by(lgl) %>% + dplyr::summarise(n = n()) %>% + dplyr::collect(), + tibble::tibble(lgl = c(NA, TRUE, FALSE), n = c(3L, 4L, 3L)) + ) +})