Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions packages/db-ivm/src/operators/groupBy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,27 @@ export function mode<T>(
}
}

/**
* Creates a list aggregate function that collects all values into an array
* @param valueExtractor Function to extract a value from each data entry
*/
export function list<T, V>(
valueExtractor: (value: T) => V = (v) => v as unknown as V
): AggregateFunction<T, Array<V>, V> {
return {
preMap: (data: T) => valueExtractor(data),
reduce: (values) => {
const total = []

for (const [value, _multiplicity] of values) {
total.push(value)
}

return total as unknown as V
},
}
}

export const groupByOperators = {
sum,
count,
Expand All @@ -351,4 +372,5 @@ export const groupByOperators = {
max,
median,
mode,
list,
}
61 changes: 61 additions & 0 deletions packages/db-ivm/tests/operators/groupBy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
avg,
count,
groupBy,
list,
max,
median,
min,
Expand Down Expand Up @@ -131,6 +132,66 @@ describe(`Operators`, () => {
expect(result).toEqual(expectedResult)
})

test(`with single list aggregate`, () => {
const graph = new D2()
const input = graph.newInput<{
productId: number
category: string
}>()
let latestMessage: any = null

input.pipe(
groupBy((data) => ({ productId: data.productId }), {
categories: list((data) => data.category),
}),
output((message) => {
latestMessage = message
})
)

graph.finalize()

// Initial data
input.sendData(
new MultiSet([
[{ category: `A`, productId: 1 }, 1],
[{ category: `B`, productId: 1 }, 1],
[{ category: `A`, productId: 2 }, 1],
])
)
graph.run()

// Verify we have the latest message
expect(latestMessage).not.toBeNull()

const result = latestMessage.getInner()

const expectedResult = [
[
[
`{"productId":1}`,
{
productId: 1,
categories: [`A`, `B`],
},
],
1,
],
[
[
`{"productId":2}`,
{
productId: 2,
categories: [`A`],
},
],
1,
],
]

expect(result).toEqual(expectedResult)
})

test(`with sum and count aggregates`, () => {
const graph = new D2()
const input = graph.newInput<{
Expand Down
7 changes: 7 additions & 0 deletions packages/db/src/query/builder/functions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Aggregate, Func } from "../ir"
import { toExpression } from "./ref-proxy.js"
import type { RefProxyFor } from "./types"
import type { BasicExpression } from "../ir"
import type { RefProxy } from "./ref-proxy.js"

Expand Down Expand Up @@ -266,6 +267,12 @@ export function max(
return new Aggregate(`max`, [toExpression(arg)])
}

export function list<T>(
arg: RefProxy<T> | RefProxyFor<T>
): Aggregate<Array<T>> {
return new Aggregate(`list`, [toExpression(arg)])
}

/**
* List of comparison function names that can be used with indexes
*/
Expand Down
21 changes: 14 additions & 7 deletions packages/db/src/query/compiler/group-by.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import type {
} from "../ir.js"
import type { NamespacedAndKeyedStream, NamespacedRow } from "../../types.js"

const { sum, count, avg, min, max } = groupByOperators
const { sum, count, avg, min, max, list } = groupByOperators

/**
* Interface for caching the mapping between GROUP BY expressions and SELECT expressions
Expand Down Expand Up @@ -342,25 +342,32 @@ function getAggregateFunction(aggExpr: Aggregate) {
// Pre-compile the value extractor expression
const compiledExpr = compileExpression(aggExpr.args[0]!)

// Create a value extractor function for the expression to aggregate
const valueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => {
// Create a number only value extractor function for the expression to aggregate
const numberExtractor = ([, namespacedRow]: [string, NamespacedRow]) => {
const value = compiledExpr(namespacedRow)
// Ensure we return a number for numeric aggregate functions
return typeof value === `number` ? value : value != null ? Number(value) : 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is a good change, since the other function rely on the value being a number.

Might be a good idea to add another value extractor specific to your need. At least that's what I've done in my PR for min and max to support dates.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that a workaround for now I didn't have the time to do it and I needed it to work fast for my project

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this should be a separate extractor

}

// Create a generic value extractor function for non-numeric aggregates
const valueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => {
return compiledExpr(namespacedRow)
}

// Return the appropriate aggregate function
switch (aggExpr.name.toLowerCase()) {
case `sum`:
return sum(valueExtractor)
return sum(numberExtractor)
case `count`:
return count() // count() doesn't need a value extractor
case `avg`:
return avg(valueExtractor)
return avg(numberExtractor)
case `min`:
return min(valueExtractor)
return min(numberExtractor)
case `max`:
return max(valueExtractor)
return max(numberExtractor)
case `list`:
return list(valueExtractor)
default:
throw new UnsupportedAggregateFunctionError(aggExpr.name)
}
Expand Down
1 change: 1 addition & 0 deletions packages/db/src/query/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export {
sum,
min,
max,
list,
} from "./builder/functions.js"

// Ref proxy utilities
Expand Down
5 changes: 5 additions & 0 deletions packages/db/tests/query/group-by.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
eq,
gt,
gte,
list,
lt,
max,
min,
Expand Down Expand Up @@ -189,6 +190,7 @@ function createGroupByTests(autoIndex: `off` | `eager`): void {
.groupBy(({ orders }) => orders.status)
.select(({ orders }) => ({
status: orders.status,
customer_ids: list(orders.customer_id),
total_amount: sum(orders.amount),
order_count: count(orders.id),
avg_amount: avg(orders.amount),
Expand All @@ -200,20 +202,23 @@ function createGroupByTests(autoIndex: `off` | `eager`): void {
// Completed orders: 1, 2, 4, 7 (amounts: 100, 200, 300, 400)
const completed = statusSummary.get(`completed`)
expect(completed?.status).toBe(`completed`)
expect(completed?.customer_ids).toEqual([1, 1, 2, 1])
expect(completed?.total_amount).toBe(1000)
expect(completed?.order_count).toBe(4)
expect(completed?.avg_amount).toBe(250)

// Pending orders: 3, 5 (amounts: 150, 250)
const pending = statusSummary.get(`pending`)
expect(pending?.status).toBe(`pending`)
expect(pending?.customer_ids).toEqual([2, 3])
expect(pending?.total_amount).toBe(400)
expect(pending?.order_count).toBe(2)
expect(pending?.avg_amount).toBe(200)

// Cancelled orders: 6 (amount: 75)
const cancelled = statusSummary.get(`cancelled`)
expect(cancelled?.status).toBe(`cancelled`)
expect(cancelled?.customer_ids).toEqual([3])
expect(cancelled?.total_amount).toBe(75)
expect(cancelled?.order_count).toBe(1)
expect(cancelled?.avg_amount).toBe(75)
Expand Down