diff --git a/container_files/public_html/doc/builtin/sql/ValueExpression.md b/container_files/public_html/doc/builtin/sql/ValueExpression.md index 1496f737a..8c6a62879 100644 --- a/container_files/public_html/doc/builtin/sql/ValueExpression.md +++ b/container_files/public_html/doc/builtin/sql/ValueExpression.md @@ -611,6 +611,7 @@ The standard SQL aggregation functions operate 'vertically' down columns. MLDB d - `vertical_count()` alias of `count()`, operates on columns. - `vertical_sum()` alias of `sum()`, operates on columns. - `vertical_avg()` alias of `avg()`, operates on columns. + - `vertical_weighted_avg(, )` alias of `weighted_avg()`, operates on columns. - `vertical_min()` alias of `min()`, operates on columns. - `vertical_max()` alias of `max()`, operates on columns. - `vertical_latest()` alias of `latest()`, operates on columns. diff --git a/plugins/pooling_function.cc b/plugins/pooling_function.cc index d55b0bd24..ed1839da4 100644 --- a/plugins/pooling_function.cc +++ b/plugins/pooling_function.cc @@ -157,7 +157,7 @@ applyT(const ApplierT & applier_, PoolingInput input) const } ExcAssertEqual(outputEmbedding.size(), num_embed_cols); - } + } return {ExpressionValue(std::move(outputEmbedding), outputTs)}; } diff --git a/sql/builtin_aggregators.cc b/sql/builtin_aggregators.cc index 467cec151..7f6569c29 100644 --- a/sql/builtin_aggregators.cc +++ b/sql/builtin_aggregators.cc @@ -14,6 +14,7 @@ #include "mldb/base/optimized_path.h" #include #include +#include "mldb/jml/math/xdiv.h" using namespace std; @@ -135,14 +136,24 @@ struct AggregatorT { void process(const ExpressionValue * args, size_t nargs) { - checkArgsSize(nargs, 1); + checkArgsSize(nargs, State::nargs); const ExpressionValue & val = args[0]; + // This must be a row... auto onColumn = [&] (const PathElement & columnName, const ExpressionValue & val) { - columns[columnName].process(&val, 1); + if(State::nargs == 1) { + columns[columnName].process(&val, 1); + } + else if(State::nargs == 2) { + const vector pwet = {val, args[1]}; + columns[columnName].process(pwet.data(), 2); + } + else { + throw ML::Exception("pwet!"); + } return true; }; @@ -363,7 +374,7 @@ struct AggregatorT { // b) what is the best way to implement the query // First output: information about the row // Second output: is it dense (in other words, all rows are the same)? - checkArgsSize(args.size(), 1, name); + checkArgsSize(args.size(), State::nargs, name); ExcAssert(args[0].info); // Create a value info object for the output. It has the same @@ -446,7 +457,7 @@ struct AggregatorT { if (!state->isDetermined) { state->isDetermined = true; - checkArgsSize(nargs, 1); + checkArgsSize(nargs, State::nargs); state->isRow = args[0].isRow(); } @@ -504,10 +515,9 @@ struct RegisterAggregatorT: public RegisterAggregator { } }; -struct AverageAccum { - static constexpr int nargs = 1; - - AverageAccum() +struct AverageAccumBase { + + AverageAccumBase() : total(0.0), n(0.0), ts(Date::negativeInfinity()) { } @@ -518,35 +528,170 @@ struct AverageAccum { return std::make_shared(); } + ExpressionValue extract() + { + return ExpressionValue(ML::xdiv(total, n), ts); + } + + void merge(AverageAccumBase* from) + { + total += from->total; + n += from->n; + ts.setMax(from->ts); + } + + double total; + double n; + Date ts; +}; + +struct AverageAccum : public AverageAccumBase { + static constexpr int nargs = 1; + + AverageAccum() + : AverageAccumBase() + { + } + void process(const ExpressionValue * args, size_t nargs) { checkArgsSize(nargs, 1); const ExpressionValue & val = args[0]; if (val.empty()) return; + total += val.toDouble(); n += 1; ts.setMax(val.getEffectiveTimestamp()); } - +}; + +//struct WeightedAverageAccum : public AverageAccumBase { + //static constexpr int nargs = 2; + + //WeightedAverageAccum() + //: AverageAccumBase() + //{ + //} + + //void process(const ExpressionValue * args, size_t nargs) + //{ + //checkArgsSize(nargs, 2); + //const ExpressionValue & val = args[0]; + //if (val.empty()) + //return; + + //double weight = 1; + //if(nargs == 2) { + //const ExpressionValue & ev_weight = args[1]; + //if (!val.empty()) + //weight = ev_weight.toDouble(); + //} + + //total += val.toDouble() * weight; + //n += weight; + //ts.setMax(val.getEffectiveTimestamp()); + //} +//}; + +static RegisterAggregatorT registerAvg("avg", "vertical_avg"); +//static RegisterAggregatorT registerWAvg("weighted_avg", "vertical_weighted_avg"); + + + + +struct WeightedAverageAccum { + WeightedAverageAccum() + : ts(Date::negativeInfinity()) + { + } + + double n; + std::unordered_map counts; + Date ts; +}; + +BoundAggregator wavg(const std::vector & args, + const string & name) +{ + auto init = [] () -> std::shared_ptr + { + return std::make_shared(); + }; + + auto process = [name] (const ExpressionValue * args, + size_t nargs, + void * data) + { + checkArgsSize(nargs, 2, name); + const ExpressionValue & val = args[0]; + double weight = args[1].toDouble(); + WeightedAverageAccum & accum = *(WeightedAverageAccum *)data; + // This must be a row... + auto onAtom = [&] (const Path & columnName, + const Path & prefix, + const CellValue & val, + Date ts) + { + accum.counts[columnName] += val.toDouble() * weight; + accum.ts.setMax(ts); + return true; + }; + + val.forEachAtom(onAtom); + + accum.n += weight; + }; + + /* ExpressionValue extract() { - return ExpressionValue(total / n, ts); + return ExpressionValue(ML::xdiv(total, n), ts); } - void merge(AverageAccum* from) + void merge(AverageAccumBase* from) { total += from->total; n += from->n; ts.setMax(from->ts); - } - - double total; - double n; - Date ts; -}; - -static RegisterAggregatorT registerAvg("avg", "vertical_avg"); + }*/ + + auto extract = [] (void * data) -> ExpressionValue + { + WeightedAverageAccum & accum = *(WeightedAverageAccum *)data; + + RowValue result; + for (auto & v: accum.counts) { + result.emplace_back(v.first, ML::xdiv(v.second, accum.n), accum.ts); + } + + return ExpressionValue(std::move(result)); + }; + + auto merge = [] (void * data, void* src) + { + WeightedAverageAccum & accum = *(WeightedAverageAccum *)data; + WeightedAverageAccum & srcAccum = *(WeightedAverageAccum *)src; + + for (auto &iter : srcAccum.counts) + { + accum.counts[iter.first] += iter.second; + } + + + accum.n += srcAccum.n; + accum.ts.setMax(srcAccum.ts); + }; + + + return { init, process, extract, merge }; +} + +static RegisterAggregator registerWavg(wavg, "weighted_avg", "vertical_weighted_avg"); + + + + template struct ValueAccum { @@ -571,7 +716,7 @@ struct ValueAccum { value = Op()(value, val.toDouble()); ts.setMax(val.getEffectiveTimestamp()); } - + ExpressionValue extract() { return ExpressionValue(value, ts); @@ -621,12 +766,12 @@ struct StringAggAccum { value += separator.coerceToString().toUtf8String(); } first = false; - + value += val.coerceToString().toUtf8String(); ts.setMax(val.getEffectiveTimestamp()); } - + ExpressionValue extract() { return ExpressionValue(value, ts); @@ -693,7 +838,7 @@ struct MinMaxAccum { } //cerr << "ts now " << ts << endl; } - + ExpressionValue extract() { return ExpressionValue(value, ts); diff --git a/testing/MLDB-702-row-aggregators.py b/testing/MLDB-702-row-aggregators.py index f5db67fca..b9494e83a 100644 --- a/testing/MLDB-702-row-aggregators.py +++ b/testing/MLDB-702-row-aggregators.py @@ -34,6 +34,19 @@ def recordExample(row, x, y, label, ts): ds.commit() + + + ds = mldb.create_dataset({ "id": "test_weight", "type": "sparse.mutable" }) + + def recordExample(row, x, z, y, ts): + ds.record_row(row, [ [ "x", x, ts], ["z", z, ts], ["y", y, ts] ]); + + recordExample("ex1", 25, 50, 0, 0); + recordExample("ex2", 1, 2, 5, 0); + recordExample("ex3", 10, 20, 2, 0); + + ds.commit() + def test_min_max(self): resp = mldb.get("/v1/query", q = "SELECT min({*}) AS min, max({*}) AS max FROM test GROUP BY label"); @@ -101,6 +114,32 @@ def test_vertical_avg_is_avg(self): resp = mldb.get("/v1/query", q = "SELECT avg(x) AS avg FROM test GROUP BY x"); resp2 = mldb.get("/v1/query", q = "SELECT vertical_avg(x) AS avg FROM test GROUP BY x"); self.assertFullResultEquals(resp.json(), resp2.json()) + + def test_weighted_avg(self): + + weighted_avg = (1*5 + 10*2) / 7. + + wavg = mldb.query("SELECT weighted_avg(x, y) AS avg FROM test_weight") + self.assertTableResultEquals( + wavg, + [["_rowName","avg"], + ["[]",weighted_avg]]) + + self.assertTableResultEquals(wavg, + mldb.query("SELECT vertical_weighted_avg(x, y) AS avg FROM test_weight")) + + #@unittest.expectedFailure + def test_weighted_avg_row(self): + weighted_avg = (1*5 + 10*2) / 7. + + wavg = mldb.query("SELECT weighted_avg({* EXCLUDING (y)}, y) AS avg FROM test_weight") + mldb.log(wavg) + + self.assertTableResultEquals( + wavg, + [["_rowName","avg.x", "avg.z"], + ["[]",weighted_avg, weighted_avg*2]]) + def test_vertical_earliest_is_earliest(self): resp = mldb.get("/v1/query", q = "SELECT earliest({*}) AS count FROM test GROUP BY x");