diff --git a/container_files/public_html/doc/builtin/sql/ValueExpression.md b/container_files/public_html/doc/builtin/sql/ValueExpression.md
index 1496f737a..8c6a62879 100644
--- a/container_files/public_html/doc/builtin/sql/ValueExpression.md
+++ b/container_files/public_html/doc/builtin/sql/ValueExpression.md
@@ -611,6 +611,7 @@ The standard SQL aggregation functions operate 'vertically' down columns. MLDB d
- `vertical_count()` alias of `count()`, operates on columns.
- `vertical_sum()` alias of `sum()`, operates on columns.
- `vertical_avg()` alias of `avg()`, operates on columns.
+ - `vertical_weighted_avg(, )` alias of `weighted_avg()`, operates on columns.
- `vertical_min()` alias of `min()`, operates on columns.
- `vertical_max()` alias of `max()`, operates on columns.
- `vertical_latest()` alias of `latest()`, operates on columns.
diff --git a/plugins/pooling_function.cc b/plugins/pooling_function.cc
index d55b0bd24..ed1839da4 100644
--- a/plugins/pooling_function.cc
+++ b/plugins/pooling_function.cc
@@ -157,7 +157,7 @@ applyT(const ApplierT & applier_, PoolingInput input) const
}
ExcAssertEqual(outputEmbedding.size(), num_embed_cols);
- }
+ }
return {ExpressionValue(std::move(outputEmbedding), outputTs)};
}
diff --git a/sql/builtin_aggregators.cc b/sql/builtin_aggregators.cc
index 467cec151..7f6569c29 100644
--- a/sql/builtin_aggregators.cc
+++ b/sql/builtin_aggregators.cc
@@ -14,6 +14,7 @@
#include "mldb/base/optimized_path.h"
#include
#include
+#include "mldb/jml/math/xdiv.h"
using namespace std;
@@ -135,14 +136,24 @@ struct AggregatorT {
void process(const ExpressionValue * args, size_t nargs)
{
- checkArgsSize(nargs, 1);
+ checkArgsSize(nargs, State::nargs);
const ExpressionValue & val = args[0];
+
// This must be a row...
auto onColumn = [&] (const PathElement & columnName,
const ExpressionValue & val)
{
- columns[columnName].process(&val, 1);
+ if(State::nargs == 1) {
+ columns[columnName].process(&val, 1);
+ }
+ else if(State::nargs == 2) {
+ const vector pwet = {val, args[1]};
+ columns[columnName].process(pwet.data(), 2);
+ }
+ else {
+ throw ML::Exception("pwet!");
+ }
return true;
};
@@ -363,7 +374,7 @@ struct AggregatorT {
// b) what is the best way to implement the query
// First output: information about the row
// Second output: is it dense (in other words, all rows are the same)?
- checkArgsSize(args.size(), 1, name);
+ checkArgsSize(args.size(), State::nargs, name);
ExcAssert(args[0].info);
// Create a value info object for the output. It has the same
@@ -446,7 +457,7 @@ struct AggregatorT {
if (!state->isDetermined) {
state->isDetermined = true;
- checkArgsSize(nargs, 1);
+ checkArgsSize(nargs, State::nargs);
state->isRow = args[0].isRow();
}
@@ -504,10 +515,9 @@ struct RegisterAggregatorT: public RegisterAggregator {
}
};
-struct AverageAccum {
- static constexpr int nargs = 1;
-
- AverageAccum()
+struct AverageAccumBase {
+
+ AverageAccumBase()
: total(0.0), n(0.0), ts(Date::negativeInfinity())
{
}
@@ -518,35 +528,170 @@ struct AverageAccum {
return std::make_shared();
}
+ ExpressionValue extract()
+ {
+ return ExpressionValue(ML::xdiv(total, n), ts);
+ }
+
+ void merge(AverageAccumBase* from)
+ {
+ total += from->total;
+ n += from->n;
+ ts.setMax(from->ts);
+ }
+
+ double total;
+ double n;
+ Date ts;
+};
+
+struct AverageAccum : public AverageAccumBase {
+ static constexpr int nargs = 1;
+
+ AverageAccum()
+ : AverageAccumBase()
+ {
+ }
+
void process(const ExpressionValue * args, size_t nargs)
{
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;
+
total += val.toDouble();
n += 1;
ts.setMax(val.getEffectiveTimestamp());
}
-
+};
+
+//struct WeightedAverageAccum : public AverageAccumBase {
+ //static constexpr int nargs = 2;
+
+ //WeightedAverageAccum()
+ //: AverageAccumBase()
+ //{
+ //}
+
+ //void process(const ExpressionValue * args, size_t nargs)
+ //{
+ //checkArgsSize(nargs, 2);
+ //const ExpressionValue & val = args[0];
+ //if (val.empty())
+ //return;
+
+ //double weight = 1;
+ //if(nargs == 2) {
+ //const ExpressionValue & ev_weight = args[1];
+ //if (!val.empty())
+ //weight = ev_weight.toDouble();
+ //}
+
+ //total += val.toDouble() * weight;
+ //n += weight;
+ //ts.setMax(val.getEffectiveTimestamp());
+ //}
+//};
+
+static RegisterAggregatorT registerAvg("avg", "vertical_avg");
+//static RegisterAggregatorT registerWAvg("weighted_avg", "vertical_weighted_avg");
+
+
+
+
+struct WeightedAverageAccum {
+ WeightedAverageAccum()
+ : ts(Date::negativeInfinity())
+ {
+ }
+
+ double n;
+ std::unordered_map counts;
+ Date ts;
+};
+
+BoundAggregator wavg(const std::vector & args,
+ const string & name)
+{
+ auto init = [] () -> std::shared_ptr
+ {
+ return std::make_shared();
+ };
+
+ auto process = [name] (const ExpressionValue * args,
+ size_t nargs,
+ void * data)
+ {
+ checkArgsSize(nargs, 2, name);
+ const ExpressionValue & val = args[0];
+ double weight = args[1].toDouble();
+ WeightedAverageAccum & accum = *(WeightedAverageAccum *)data;
+ // This must be a row...
+ auto onAtom = [&] (const Path & columnName,
+ const Path & prefix,
+ const CellValue & val,
+ Date ts)
+ {
+ accum.counts[columnName] += val.toDouble() * weight;
+ accum.ts.setMax(ts);
+ return true;
+ };
+
+ val.forEachAtom(onAtom);
+
+ accum.n += weight;
+ };
+
+ /*
ExpressionValue extract()
{
- return ExpressionValue(total / n, ts);
+ return ExpressionValue(ML::xdiv(total, n), ts);
}
- void merge(AverageAccum* from)
+ void merge(AverageAccumBase* from)
{
total += from->total;
n += from->n;
ts.setMax(from->ts);
- }
-
- double total;
- double n;
- Date ts;
-};
-
-static RegisterAggregatorT registerAvg("avg", "vertical_avg");
+ }*/
+
+ auto extract = [] (void * data) -> ExpressionValue
+ {
+ WeightedAverageAccum & accum = *(WeightedAverageAccum *)data;
+
+ RowValue result;
+ for (auto & v: accum.counts) {
+ result.emplace_back(v.first, ML::xdiv(v.second, accum.n), accum.ts);
+ }
+
+ return ExpressionValue(std::move(result));
+ };
+
+ auto merge = [] (void * data, void* src)
+ {
+ WeightedAverageAccum & accum = *(WeightedAverageAccum *)data;
+ WeightedAverageAccum & srcAccum = *(WeightedAverageAccum *)src;
+
+ for (auto &iter : srcAccum.counts)
+ {
+ accum.counts[iter.first] += iter.second;
+ }
+
+
+ accum.n += srcAccum.n;
+ accum.ts.setMax(srcAccum.ts);
+ };
+
+
+ return { init, process, extract, merge };
+}
+
+static RegisterAggregator registerWavg(wavg, "weighted_avg", "vertical_weighted_avg");
+
+
+
+
template
struct ValueAccum {
@@ -571,7 +716,7 @@ struct ValueAccum {
value = Op()(value, val.toDouble());
ts.setMax(val.getEffectiveTimestamp());
}
-
+
ExpressionValue extract()
{
return ExpressionValue(value, ts);
@@ -621,12 +766,12 @@ struct StringAggAccum {
value += separator.coerceToString().toUtf8String();
}
first = false;
-
+
value += val.coerceToString().toUtf8String();
ts.setMax(val.getEffectiveTimestamp());
}
-
+
ExpressionValue extract()
{
return ExpressionValue(value, ts);
@@ -693,7 +838,7 @@ struct MinMaxAccum {
}
//cerr << "ts now " << ts << endl;
}
-
+
ExpressionValue extract()
{
return ExpressionValue(value, ts);
diff --git a/testing/MLDB-702-row-aggregators.py b/testing/MLDB-702-row-aggregators.py
index f5db67fca..b9494e83a 100644
--- a/testing/MLDB-702-row-aggregators.py
+++ b/testing/MLDB-702-row-aggregators.py
@@ -34,6 +34,19 @@ def recordExample(row, x, y, label, ts):
ds.commit()
+
+
+ ds = mldb.create_dataset({ "id": "test_weight", "type": "sparse.mutable" })
+
+ def recordExample(row, x, z, y, ts):
+ ds.record_row(row, [ [ "x", x, ts], ["z", z, ts], ["y", y, ts] ]);
+
+ recordExample("ex1", 25, 50, 0, 0);
+ recordExample("ex2", 1, 2, 5, 0);
+ recordExample("ex3", 10, 20, 2, 0);
+
+ ds.commit()
+
def test_min_max(self):
resp = mldb.get("/v1/query", q = "SELECT min({*}) AS min, max({*}) AS max FROM test GROUP BY label");
@@ -101,6 +114,32 @@ def test_vertical_avg_is_avg(self):
resp = mldb.get("/v1/query", q = "SELECT avg(x) AS avg FROM test GROUP BY x");
resp2 = mldb.get("/v1/query", q = "SELECT vertical_avg(x) AS avg FROM test GROUP BY x");
self.assertFullResultEquals(resp.json(), resp2.json())
+
+ def test_weighted_avg(self):
+
+ weighted_avg = (1*5 + 10*2) / 7.
+
+ wavg = mldb.query("SELECT weighted_avg(x, y) AS avg FROM test_weight")
+ self.assertTableResultEquals(
+ wavg,
+ [["_rowName","avg"],
+ ["[]",weighted_avg]])
+
+ self.assertTableResultEquals(wavg,
+ mldb.query("SELECT vertical_weighted_avg(x, y) AS avg FROM test_weight"))
+
+ #@unittest.expectedFailure
+ def test_weighted_avg_row(self):
+ weighted_avg = (1*5 + 10*2) / 7.
+
+ wavg = mldb.query("SELECT weighted_avg({* EXCLUDING (y)}, y) AS avg FROM test_weight")
+ mldb.log(wavg)
+
+ self.assertTableResultEquals(
+ wavg,
+ [["_rowName","avg.x", "avg.z"],
+ ["[]",weighted_avg, weighted_avg*2]])
+
def test_vertical_earliest_is_earliest(self):
resp = mldb.get("/v1/query", q = "SELECT earliest({*}) AS count FROM test GROUP BY x");