Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ The standard SQL aggregation functions operate 'vertically' down columns. MLDB d
- `vertical_count(<row>)` alias of `count()`, operates on columns.
- `vertical_sum(<row>)` alias of `sum()`, operates on columns.
- `vertical_avg(<row>)` alias of `avg()`, operates on columns.
- `vertical_weighted_avg(<values row>, <weight row>)` alias of `weighted_avg()`, operates on columns.
- `vertical_min(<row>)` alias of `min()`, operates on columns.
- `vertical_max(<row>)` alias of `max()`, operates on columns.
- `vertical_latest(<row>)` alias of `latest()`, operates on columns.
Expand Down
2 changes: 1 addition & 1 deletion plugins/pooling_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ applyT(const ApplierT & applier_, PoolingInput input) const
}

ExcAssertEqual(outputEmbedding.size(), num_embed_cols);
}
}

return {ExpressionValue(std::move(outputEmbedding), outputTs)};
}
Expand Down
81 changes: 60 additions & 21 deletions sql/builtin_aggregators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mldb/base/optimized_path.h"
#include <array>
#include <unordered_set>
#include "mldb/jml/math/xdiv.h"

using namespace std;

Expand Down Expand Up @@ -446,7 +447,7 @@ struct AggregatorT {

if (!state->isDetermined) {
state->isDetermined = true;
checkArgsSize(nargs, 1);
checkArgsSize(nargs, State::nargs);
state->isRow = args[0].isRow();
}

Expand Down Expand Up @@ -504,10 +505,9 @@ struct RegisterAggregatorT: public RegisterAggregator {
}
};

struct AverageAccum {
static constexpr int nargs = 1;

AverageAccum()
struct AverageAccumBase {

AverageAccumBase()
: total(0.0), n(0.0), ts(Date::negativeInfinity())
{
}
Expand All @@ -518,35 +518,74 @@ struct AverageAccum {
return std::make_shared<Float64ValueInfo>();
}

ExpressionValue extract()
{
return ExpressionValue(ML::xdiv(total, n), ts);
}

void merge(AverageAccumBase* from)
{
total += from->total;
n += from->n;
ts.setMax(from->ts);
}

double total;
double n;
Date ts;
};

struct AverageAccum : public AverageAccumBase {
static constexpr int nargs = 1;

AverageAccum()
: AverageAccumBase()
{
}

void process(const ExpressionValue * args, size_t nargs)
{
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;

total += val.toDouble();
n += 1;
ts.setMax(val.getEffectiveTimestamp());
}

ExpressionValue extract()
};

struct WeightedAverageAccum : public AverageAccumBase {
static constexpr int nargs = 2;

WeightedAverageAccum()
: AverageAccumBase()
{
return ExpressionValue(total / n, ts);
}

void merge(AverageAccum* from)
void process(const ExpressionValue * args, size_t nargs)
{
total += from->total;
n += from->n;
ts.setMax(from->ts);
checkArgsSize(nargs, 2);
const ExpressionValue & val = args[0];
if (val.empty())
return;

double weight = 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.0

if(nargs == 2) {
const ExpressionValue & ev_weight = args[1];
if (!val.empty())
weight = ev_weight.toDouble();
}

total += val.toDouble() * weight;
n += weight;
ts.setMax(val.getEffectiveTimestamp());
}

double total;
double n;
Date ts;
};

static RegisterAggregatorT<AverageAccum> registerAvg("avg", "vertical_avg");
static RegisterAggregatorT<WeightedAverageAccum> registerWAvg("weighted_avg", "vertical_weighted_avg");

template<typename Op, int Init>
struct ValueAccum {
Expand All @@ -571,7 +610,7 @@ struct ValueAccum {
value = Op()(value, val.toDouble());
ts.setMax(val.getEffectiveTimestamp());
}

ExpressionValue extract()
{
return ExpressionValue(value, ts);
Expand Down Expand Up @@ -621,12 +660,12 @@ struct StringAggAccum {
value += separator.coerceToString().toUtf8String();
}
first = false;

value += val.coerceToString().toUtf8String();

ts.setMax(val.getEffectiveTimestamp());
}

ExpressionValue extract()
{
return ExpressionValue(value, ts);
Expand Down Expand Up @@ -693,7 +732,7 @@ struct MinMaxAccum {
}
//cerr << "ts now " << ts << endl;
}

ExpressionValue extract()
{
return ExpressionValue(value, ts);
Expand Down
25 changes: 25 additions & 0 deletions testing/MLDB-702-row-aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def recordExample(row, x, y, label, ts):

ds.commit()



ds = mldb.create_dataset({ "id": "test_weight", "type": "sparse.mutable" })

def recordExample(row, x, y, ts):
ds.record_row(row, [ [ "x", x, ts], ["y", y, ts] ]);

recordExample("ex1", 25, 0, 0);
recordExample("ex2", 1, 5, 0);
recordExample("ex3", 10, 2, 0);

ds.commit()


def test_min_max(self):
resp = mldb.get("/v1/query", q = "SELECT min({*}) AS min, max({*}) AS max FROM test GROUP BY label");
Expand Down Expand Up @@ -101,6 +114,18 @@ def test_vertical_avg_is_avg(self):
resp = mldb.get("/v1/query", q = "SELECT avg(x) AS avg FROM test GROUP BY x");
resp2 = mldb.get("/v1/query", q = "SELECT vertical_avg(x) AS avg FROM test GROUP BY x");
self.assertFullResultEquals(resp.json(), resp2.json())

def test_weighted_avg(self):

wavg = mldb.query("SELECT weighted_avg(x, y) AS avg FROM test_weight")

self.assertTableResultEquals(
wavg,
[["_rowName","avg"],
["[]",(1*5 + 10*2) / 7.]])

self.assertTableResultEquals(wavg,
mldb.query("SELECT vertical_weighted_avg(x, y) AS avg FROM test_weight"))

def test_vertical_earliest_is_earliest(self):
resp = mldb.get("/v1/query", q = "SELECT earliest({*}) AS count FROM test GROUP BY x");
Expand Down