diff --git a/packages/db-ivm/src/operators/groupBy.ts b/packages/db-ivm/src/operators/groupBy.ts index 344c4b1c..071c098c 100644 --- a/packages/db-ivm/src/operators/groupBy.ts +++ b/packages/db-ivm/src/operators/groupBy.ts @@ -343,6 +343,27 @@ export function mode( } } +/** + * Creates a list aggregate function that collects all values into an array + * @param valueExtractor Function to extract a value from each data entry + */ +export function list( + valueExtractor: (value: T) => V = (v) => v as unknown as V +): AggregateFunction, V> { + return { + preMap: (data: T) => valueExtractor(data), + reduce: (values) => { + const total = [] + + for (const [value, _multiplicity] of values) { + total.push(value) + } + + return total as unknown as V + }, + } +} + export const groupByOperators = { sum, count, @@ -351,4 +372,5 @@ export const groupByOperators = { max, median, mode, + list, } diff --git a/packages/db-ivm/tests/operators/groupBy.test.ts b/packages/db-ivm/tests/operators/groupBy.test.ts index 5f20b8d0..43377a03 100644 --- a/packages/db-ivm/tests/operators/groupBy.test.ts +++ b/packages/db-ivm/tests/operators/groupBy.test.ts @@ -5,6 +5,7 @@ import { avg, count, groupBy, + list, max, median, min, @@ -131,6 +132,66 @@ describe(`Operators`, () => { expect(result).toEqual(expectedResult) }) + test(`with single list aggregate`, () => { + const graph = new D2() + const input = graph.newInput<{ + productId: number + category: string + }>() + let latestMessage: any = null + + input.pipe( + groupBy((data) => ({ productId: data.productId }), { + categories: list((data) => data.category), + }), + output((message) => { + latestMessage = message + }) + ) + + graph.finalize() + + // Initial data + input.sendData( + new MultiSet([ + [{ category: `A`, productId: 1 }, 1], + [{ category: `B`, productId: 1 }, 1], + [{ category: `A`, productId: 2 }, 1], + ]) + ) + graph.run() + + // Verify we have the latest message + expect(latestMessage).not.toBeNull() + + const result = latestMessage.getInner() + + const expectedResult = [ + [ + [ + `{"productId":1}`, + { + productId: 1, + categories: [`A`, `B`], + }, + ], + 1, + ], + [ + [ + `{"productId":2}`, + { + productId: 2, + categories: [`A`], + }, + ], + 1, + ], + ] + + expect(result).toEqual(expectedResult) + }) + test(`with sum and count aggregates`, () => { const graph = new D2() const input = graph.newInput<{ diff --git a/packages/db/src/query/builder/functions.ts b/packages/db/src/query/builder/functions.ts index b5902bd6..6cd6665a 100644 --- a/packages/db/src/query/builder/functions.ts +++ b/packages/db/src/query/builder/functions.ts @@ -1,5 +1,6 @@ import { Aggregate, Func } from "../ir" import { toExpression } from "./ref-proxy.js" +import type { RefProxyFor } from "./types" import type { BasicExpression } from "../ir" import type { RefProxy } from "./ref-proxy.js" @@ -266,6 +267,12 @@ export function max( return new Aggregate(`max`, [toExpression(arg)]) } +export function list( + arg: RefProxy | RefProxyFor +): Aggregate> { + return new Aggregate(`list`, [toExpression(arg)]) +} + /** * List of comparison function names that can be used with indexes */ diff --git a/packages/db/src/query/compiler/group-by.ts b/packages/db/src/query/compiler/group-by.ts index 57de09a9..a8871de8 100644 --- a/packages/db/src/query/compiler/group-by.ts +++ b/packages/db/src/query/compiler/group-by.ts @@ -16,7 +16,7 @@ import type { } from "../ir.js" import type { NamespacedAndKeyedStream, NamespacedRow } from "../../types.js" -const { sum, count, avg, min, max } = groupByOperators +const { sum, count, avg, min, max, list } = groupByOperators /** * Interface for caching the mapping between GROUP BY expressions and SELECT expressions @@ -342,25 +342,32 @@ function getAggregateFunction(aggExpr: Aggregate) { // Pre-compile the value extractor expression const compiledExpr = compileExpression(aggExpr.args[0]!) - // Create a value extractor function for the expression to aggregate - const valueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { + // Create a number only value extractor function for the expression to aggregate + const numberExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { const value = compiledExpr(namespacedRow) // Ensure we return a number for numeric aggregate functions return typeof value === `number` ? value : value != null ? Number(value) : 0 } + // Create a generic value extractor function for non-numeric aggregates + const valueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => { + return compiledExpr(namespacedRow) + } + // Return the appropriate aggregate function switch (aggExpr.name.toLowerCase()) { case `sum`: - return sum(valueExtractor) + return sum(numberExtractor) case `count`: return count() // count() doesn't need a value extractor case `avg`: - return avg(valueExtractor) + return avg(numberExtractor) case `min`: - return min(valueExtractor) + return min(numberExtractor) case `max`: - return max(valueExtractor) + return max(numberExtractor) + case `list`: + return list(valueExtractor) default: throw new UnsupportedAggregateFunctionError(aggExpr.name) } diff --git a/packages/db/src/query/index.ts b/packages/db/src/query/index.ts index 2fb81e6f..cb9ba121 100644 --- a/packages/db/src/query/index.ts +++ b/packages/db/src/query/index.ts @@ -38,6 +38,7 @@ export { sum, min, max, + list, } from "./builder/functions.js" // Ref proxy utilities diff --git a/packages/db/tests/query/group-by.test.ts b/packages/db/tests/query/group-by.test.ts index b9d10312..ac1299b5 100644 --- a/packages/db/tests/query/group-by.test.ts +++ b/packages/db/tests/query/group-by.test.ts @@ -9,6 +9,7 @@ import { eq, gt, gte, + list, lt, max, min, @@ -189,6 +190,7 @@ function createGroupByTests(autoIndex: `off` | `eager`): void { .groupBy(({ orders }) => orders.status) .select(({ orders }) => ({ status: orders.status, + customer_ids: list(orders.customer_id), total_amount: sum(orders.amount), order_count: count(orders.id), avg_amount: avg(orders.amount), @@ -200,6 +202,7 @@ function createGroupByTests(autoIndex: `off` | `eager`): void { // Completed orders: 1, 2, 4, 7 (amounts: 100, 200, 300, 400) const completed = statusSummary.get(`completed`) expect(completed?.status).toBe(`completed`) + expect(completed?.customer_ids).toEqual([1, 1, 2, 1]) expect(completed?.total_amount).toBe(1000) expect(completed?.order_count).toBe(4) expect(completed?.avg_amount).toBe(250) @@ -207,6 +210,7 @@ function createGroupByTests(autoIndex: `off` | `eager`): void { // Pending orders: 3, 5 (amounts: 150, 250) const pending = statusSummary.get(`pending`) expect(pending?.status).toBe(`pending`) + expect(pending?.customer_ids).toEqual([2, 3]) expect(pending?.total_amount).toBe(400) expect(pending?.order_count).toBe(2) expect(pending?.avg_amount).toBe(200) @@ -214,6 +218,7 @@ function createGroupByTests(autoIndex: `off` | `eager`): void { // Cancelled orders: 6 (amount: 75) const cancelled = statusSummary.get(`cancelled`) expect(cancelled?.status).toBe(`cancelled`) + expect(cancelled?.customer_ids).toEqual([3]) expect(cancelled?.total_amount).toBe(75) expect(cancelled?.order_count).toBe(1) expect(cancelled?.avg_amount).toBe(75)