diff --git a/datasource/context.go b/datasource/context.go index 7b2bf902..ea9e4f1d 100644 --- a/datasource/context.go +++ b/datasource/context.go @@ -116,6 +116,7 @@ func (m *SqlDriverMessageMap) Values() []driver.Value { return m.Vals } func (m *SqlDriverMessageMap) SetRow(row []driver.Value) { m.Vals = row } func (m *SqlDriverMessageMap) Ts() time.Time { return time.Time{} } func (m *SqlDriverMessageMap) Get(key string) (value.Value, bool) { + key = strings.ToLower(key) if idx, ok := m.ColIndex[key]; ok { return value.NewValue(m.Vals[idx]), true } @@ -229,6 +230,9 @@ func NewNestedContextReadWriter(readers []expr.ContextReader, writer expr.Contex func (n *NestedContextReader) Get(key string) (value.Value, bool) { for _, r := range n.readers { + if r == nil { + continue + } val, ok := r.Get(key) if ok && val != nil { return val, ok diff --git a/datasource/datatypes.go b/datasource/datatypes.go index b43d77d9..ed7eb499 100644 --- a/datasource/datatypes.go +++ b/datasource/datatypes.go @@ -63,6 +63,7 @@ func (m *TimeValue) Time() time.Time { func (m *TimeValue) Scan(src interface{}) error { + u.Debugf("time %T %v", src, src) var t time.Time var dstr string switch val := src.(type) { diff --git a/datasource/files/filesource_test.go b/datasource/files/filesource_test.go index 6c5a6edd..ae5d3f71 100644 --- a/datasource/files/filesource_test.go +++ b/datasource/files/filesource_test.go @@ -70,6 +70,7 @@ func TestFileList(t *testing.T) { {"testjson"}, }, ) + return testutil.TestSqlSelect(t, "testcsvs", `show tables;`, [][]driver.Value{ {"appearances"}, diff --git a/datasource/mockcsv/mockcsv.go b/datasource/mockcsv/mockcsv.go index e31ede88..bb6f40ce 100644 --- a/datasource/mockcsv/mockcsv.go +++ b/datasource/mockcsv/mockcsv.go @@ -150,6 +150,7 @@ func (m *Source) loadTable(tableName string) error { return fmt.Errorf("No csv-source created for %q", tableName) } ds := membtree.NewStaticData(tableName) + //u.Infof("loaded columns table=%q cols=%v", tableName, csvSource.Columns()) ds.SetColumns(csvSource.Columns()) m.tables[tableName] = ds diff --git a/datasource/schemadb.go b/datasource/schemadb.go index 4aa34968..71eaa40a 100644 --- a/datasource/schemadb.go +++ b/datasource/schemadb.go @@ -132,6 +132,7 @@ func (m *SchemaDb) Open(schemaObjectName string) (schema.Conn, error) { case "engines", "procedures", "functions", "indexes": return &SchemaSource{db: m, tbl: tbl, rows: nil}, nil default: + u.Warnf("here") return &SchemaSource{db: m, tbl: tbl, rows: tbl.AsRows()}, nil } diff --git a/datasource/sqlite/conn.go b/datasource/sqlite/conn.go index 8b79cffd..6665a15a 100644 --- a/datasource/sqlite/conn.go +++ b/datasource/sqlite/conn.go @@ -234,11 +234,11 @@ func (m *qryconn) WalkSourceSelect(planner plan.Planner, p *plan.Source) (plan.T sqlSelect := p.Stmt.Source u.Infof("original %s", sqlSelect.String()) - p.Stmt.Source = nil - p.Stmt.Rewrite(sqlSelect) - sqlSelect = p.Stmt.Source - u.Infof("original after From(source) rewrite %s", sqlSelect.String()) - sqlSelect.RewriteAsRawSelect() + //p.Stmt.Source = nil + //p.Stmt.Rewrite(sqlSelect) + //sqlSelect = p.Stmt.Source + //u.Infof("original after From(source) rewrite %s", sqlSelect.String()) + //sqlSelect.RewriteAsRawSelect() m.cols = sqlSelect.Columns.UnAliasedFieldNames() m.colidx = sqlSelect.ColIndexes() diff --git a/exec/join.go b/exec/join.go index 7b7bdba4..14358b36 100644 --- a/exec/join.go +++ b/exec/join.go @@ -242,7 +242,7 @@ func (m *JoinMerge) Run() error { //u.Debugf("msgsct: %v msgs:%#v", len(msgs), msgs) for _, msg := range msgs { //outCh <- datasource.NewUrlValuesMsg(i, msg) - //u.Debugf("i:%d msg:%#v", i, msg) + u.Warnf("i:%d msg:%#v", i, msg) msg.IdVal = i i++ outCh <- msg @@ -289,8 +289,8 @@ func (m *JoinMerge) valIndexing(valOut, valSource []driver.Value, cols []*rel.Co if col.Index < 0 || col.Index >= len(valSource) { u.Errorf("source index out of range? idx:%v of %d source: %#v \n\tcol=%#v", col.Index, len(valSource), valSource, col) } - //u.Infof("found: si=%v pi:%v idx:%d as=%v vals:%v len(out):%v", col.SourceIndex, col.ParentIndex, col.Index, col.As, valSource, len(valOut)) - valOut[col.ParentIndex] = valSource[col.Index] + //u.Infof("found: si=%v pi:%v idx:%d as=%v val:%v len(out):%v", col.SourceIndex, col.ParentIndex, col.Index, col.As, valSource[col.Index], len(valOut)) + valOut[col.ParentIndex] = valSource[col.SourceIndex] } return valOut } diff --git a/exec/projection.go b/exec/projection.go index 90048399..e30ab6b7 100644 --- a/exec/projection.go +++ b/exec/projection.go @@ -122,7 +122,21 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { colCt := len(columns) // If we have a projection, use that as col count if m.p.Proj != nil { - colCt = len(m.p.Proj.Columns) + if len(m.p.Proj.Columns) > colCt { + colCt = len(m.p.Proj.Columns) + } else if len(m.p.Proj.Columns) != colCt { + u.Warnf("wtf less? %v vs %v", colCt, len(m.p.Proj.Columns)) + } + + if len(m.p.Proj.Columns) == 0 { + u.Errorf("crap %+v", m.p.Proj) + } + for i, col := range m.p.Proj.Columns { + u.Debugf("%d %#v", i, col) + } + for i, col := range columns { + u.Debugf("%d %#v", i, col) + } } rowCt := 0 @@ -139,17 +153,27 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { var outMsg schema.Message switch mt := msg.(type) { case *datasource.SqlDriverMessageMap: + var rdr expr.ContextReader // use our custom write context for example purposes row := make([]driver.Value, colCt) - rdr := datasource.NewNestedContextReader([]expr.ContextReader{ - mt, - ctx.Session, - }, mt.Ts()) - //u.Debugf("about to project: %#v", mt) + if ctx.Session == nil { + rdr = mt + } else { + rdr = datasource.NewNestedContextReader([]expr.ContextReader{ + mt, + ctx.Session, + }, mt.Ts()) + } + + u.Debugf("about to project: colCt:%d message:%#v", colCt, mt) colIdx := -1 for _, col := range columns { colIdx += 1 - //u.Debugf("%d colidx:%v sidx: %v pidx:%v key:%q Expr:%v", colIdx, col.Index, col.SourceIndex, col.ParentIndex, col.Key(), col.Expr) + u.Debugf("%d colidx:%v sidx: %v pidx:%v star=%v key:%q Expr:%v", colIdx, col.Index, col.SourceIndex, col.ParentIndex, col.Star, col.Key(), col.Expr) + if len(row) <= colIdx { + row = append(row, nil) + u.Warnf("wtf wrong count %v %v", colIdx, len(row)) + } if isFinal && col.ParentIndex < 0 { continue @@ -175,6 +199,9 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { } if col.Star { starRow := mt.Values() + if colCt != len(starRow) { + u.Warnf("wtf wrong count %v %v", colCt, len(starRow)) + } //u.Infof("star row: %#v", starRow) if len(columns) > 1 { // select *, myvar, 1 @@ -217,14 +244,15 @@ func (m *Projection) projectionEvaluator(isFinal bool) MessageHandler { //u.Infof("mt: %T mt %#v", mt, mt) row[colIdx] = nil //v.Value() } else { - //u.Debugf("%d:%d row:%d evaled: %v val=%v", colIdx, colCt, len(row), col, v.Value()) + u.Debugf("%d:%d row:%d evaled: %v val=%v", colIdx, colCt, len(row), col, v.Value()) //writeContext.Put(col, mt, v) row[colIdx] = v.Value() + } } } - //u.Infof("row: %#v", row) - //u.Infof("row cols: %v", colIndex) + u.Infof("row: %#v", row) + u.Infof("row cols: %v", colIndex) outMsg = datasource.NewSqlDriverMessageMap(0, row, colIndex) case expr.ContextReader: diff --git a/exec/source.go b/exec/source.go index 01d3e314..451f50fc 100644 --- a/exec/source.go +++ b/exec/source.go @@ -133,6 +133,7 @@ func (m *Source) Run() error { for item := m.Scanner.Next(); item != nil; item = m.Scanner.Next() { + u.Debugf("source msg %#v", item) select { case <-sigChan: return nil diff --git a/exec/where.go b/exec/where.go index 3da1577d..2965d4f0 100644 --- a/exec/where.go +++ b/exec/where.go @@ -87,7 +87,7 @@ func NewHaving(ctx *plan.Context, p *plan.Having) *Where { func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) MessageHandler { out := task.MessageOut() - //u.Debugf("prepare filter %s", filter) + u.Debugf("WHERE prepare filter %s", filter) return func(ctx *plan.Context, msg schema.Message) bool { var filterValue value.Value @@ -102,7 +102,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message case *datasource.SqlDriverMessageMap: filterValue, ok = vm.Eval(mt, filter) if !ok { - u.Warnf("wtf %s %#v", filter, mt) + //u.Warnf("wtf %s %#v", filter, mt) } //u.Debugf("WHERE: result:%v T:%T \n\trow:%#v \n\tvals:%#v", filterValue, msg, mt, mt.Values()) //u.Debugf("cols: %#v", cols) @@ -125,7 +125,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message switch valTyped := filterValue.(type) { case value.BoolValue: if valTyped.Val() == false { - //u.Debugf("Filtering out: T:%T v:%#v", valTyped, valTyped) + u.Debugf("Filtering out: T:%T v:%#v \n\t%#v", valTyped, valTyped, msg) return true } case nil: @@ -136,7 +136,7 @@ func whereFilter(filter expr.Node, task TaskRunner, cols map[string]int) Message } } - //u.Debugf("about to send from where to forward: %#v", msg) + u.Debugf("about to send from where to forward: %#v", msg) select { case out <- msg: return true diff --git a/plan/plan.go b/plan/plan.go index 54789c20..87659383 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -749,7 +749,7 @@ func (m *Source) serializeToPb() error { return nil } func (m *Source) load() error { - // u.Debugf("source load schema=%s from=%s %#v", m.ctx.Schema.Name, m.Stmt.SourceName(), m.Stmt) + u.Debugf("source load schema=%s from=%s %#v", m.ctx.Schema.Name, m.Stmt.SourceName(), m.Stmt) if m.Stmt == nil { return nil } @@ -821,6 +821,9 @@ func (m *Projection) ToPb() (*PlanPb, error) { if err != nil { return nil, err } + if m.Proj == nil { + u.WarnT(10) + } ppbptr := m.Proj.ToPB() ppcpy := *ppbptr ppcpy.Final = m.Final @@ -864,15 +867,32 @@ func NewJoinMerge(l, r Task, lf, rf *rel.SqlSource) *JoinMerge { // Build an index of source to destination column indexing for _, col := range lf.Source.Columns { //u.Debugf("left col: idx=%d key=%q as=%q col=%v parentidx=%v", len(m.colIndex), col.Key(), col.As, col.String(), col.ParentIndex) - m.ColIndex[lf.Alias+"."+col.Key()] = col.ParentIndex - //u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", m.leftStmt.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) + if col.ParentIndex >= 0 { + m.ColIndex[lf.Alias+"."+col.Key()] = col.ParentIndex + } + u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", lf.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) } for _, col := range rf.Source.Columns { //u.Debugf("right col: idx=%d key=%q as=%q col=%v", len(m.colIndex), col.Key(), col.As, col.String()) - m.ColIndex[rf.Alias+"."+col.Key()] = col.ParentIndex - //u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", m.rightStmt.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) + if col.ParentIndex >= 0 { + m.ColIndex[rf.Alias+"."+col.Key()] = col.ParentIndex + } + u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", rf.Alias+"."+col.Key(), col.Index, col.SourceIndex, col.ParentIndex) + } + for _, col := range lf.Source.Columns { + //u.Debugf("left col: idx=%d key=%q as=%q col=%v parentidx=%v", len(m.colIndex), col.Key(), col.As, col.String(), col.ParentIndex) + if col.ParentIndex < 0 { + m.ColIndex[lf.Alias+"."+col.Key()] = len(m.ColIndex) + } + u.Debugf("left colIndex: %15q : idx:%d sidx:%d pidx:%d", lf.Alias+"."+col.Key(), col.Index, col.SourceIndex, len(m.ColIndex)-1) + } + for _, col := range rf.Source.Columns { + //u.Debugf("right col: idx=%d key=%q as=%q col=%v", len(m.colIndex), col.Key(), col.As, col.String()) + if col.ParentIndex < 0 { + m.ColIndex[rf.Alias+"."+col.Key()] = len(m.ColIndex) + } + u.Debugf("right colIndex: %15q : idx:%d sidx:%d pidx:%d", rf.Alias+"."+col.Key(), col.Index, col.SourceIndex, len(m.ColIndex)-1) } - return m } diff --git a/plan/planner_select.go b/plan/planner_select.go index 8b0a0035..960c3a6d 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -34,7 +34,8 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { return m.WalkLiteralQuery(p) - } else if len(p.Stmt.From) == 1 { + } + /*else if len(p.Stmt.From) == 1 { p.Stmt.From[0].Source = p.Stmt // TODO: move to a Finalize() in query parser/planner @@ -55,40 +56,58 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { } } else { - - var prevSource *Source - var prevTask Task - - for i, from := range p.Stmt.From { - - // Need to rewrite the From statement to ensure all fields necessary to support - // joins, wheres, etc exist but is standalone query + */ + var prevSource *Source + var rootTask Task + isFinal := len(p.Stmt.From) == 1 + for i, from := range p.Stmt.From { + + // Need to rewrite the From statement to ensure all fields necessary to support + // joins, wheres, etc exist but is standalone query + if len(p.Stmt.From) == 1 { + from.Source = p.Stmt + } else { + u.Debugf("from.Source: %s", p.Stmt) + u.Debugf("from: %s", from.String()) from.Rewrite(p.Stmt) - srcPlan, err := NewSource(m.Ctx, from, false) - if err != nil { - return nil - } - err = m.Planner.WalkSourceSelect(srcPlan) - if err != nil { - u.Errorf("Could not visitsubselect %v %s", err, from) - return err - } + u.Debugf("from-rewrite: %s", from.String()) + u.Debugf("from.Source: %s", from.Source.String()) + } - // now fold into previous task - if i != 0 { - from.Seekable = true - // fold this source into previous - curMergeTask := NewJoinMerge(prevTask, srcPlan, prevSource.Stmt, srcPlan.Stmt) - prevTask = curMergeTask - } else { - prevTask = srcPlan - } - prevSource = srcPlan - //u.Debugf("got task: %T", lastSource) + sourceTask, err := NewSource(m.Ctx, from, isFinal) + if err != nil { + return nil + } + // if len(p.Stmt.From) == 1 { + // p.From = []*Source{sourceTask} + // } + p.From = append(p.From, sourceTask) + err = m.Planner.WalkSourceSelect(sourceTask) + if err != nil { + u.Errorf("Could not visitsubselect %v %s", err, from) + return err + } + + var curSource Task = sourceTask + if from.Source.Where != nil { + u.Errorf("got a WHERE") + curSource.Add(NewWhere(from.Source)) } - p.Add(prevTask) + // now fold into previous task + if i != 0 { + from.Seekable = true + // fold this source into previous + rootTask = NewJoinMerge(rootTask, sourceTask, prevSource.Stmt, sourceTask.Stmt) + //rootTask = curMergeTask + } else { + rootTask = sourceTask + } + prevSource = sourceTask + //u.Debugf("got task: %T", lastSource) } + p.Add(rootTask) + u.Infof("did we accidentally mutate the original statement? \n\t%s", p.Stmt) if p.Stmt.Where != nil { switch { @@ -118,6 +137,7 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { p.Add(NewOrder(p.Stmt)) } + u.Debugf("needsFinalProject?%v", needsFinalProject) if needsFinalProject { err := m.WalkProjectionFinal(p) if err != nil { @@ -125,7 +145,6 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { } } -finalProjection: if m.Ctx.Projection == nil { proj, err := NewProjectionFinal(m.Ctx, p) //u.Infof("Projection: %T:%p %T:%p", proj, proj, proj.Proj, proj.Proj) @@ -146,10 +165,12 @@ func (m *PlannerDefault) WalkProjectionFinal(p *Select) error { proj, err := NewProjectionFinal(m.Ctx, p) //u.Infof("Projection: %T:%p %T:%p", proj, proj, proj.Proj, proj.Proj) if err != nil { + u.Warnf("could not build projection err=%v for %s", err, p.Stmt) return err } p.Add(proj) if m.Ctx.Projection == nil { + u.Infof("set projection") m.Ctx.Projection = proj } else { // Not entirely sure we should be over-writing the projection? @@ -209,6 +230,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { // Can do our own planning t, err := sourcePlanner.WalkSourceSelect(m.Planner, p) if err != nil { + u.Warnf("could not source plan %v", err) return err } if t != nil { @@ -218,7 +240,9 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { } else { if schemaCols, ok := p.Conn.(schema.ConnColumns); ok { + u.Debugf("schemaCols: %#v cols=%v", p.Conn, schemaCols.Columns()) if err := buildColIndex(schemaCols, p); err != nil { + u.Warnf("could not build index %v", err) return err } } else { @@ -236,12 +260,15 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { } // Add a Non-Final Projection to choose the columns for results - if !p.Final { - err := m.WalkProjectionSource(p) - if err != nil { - return err + /* + if !p.Final { + u.Warnf("!final wtf %s", p.Stmt.String()) + err := m.WalkProjectionSource(p) + if err != nil { + return err + } } - } + */ } if needsJoinKey { @@ -256,6 +283,7 @@ func (m *PlannerDefault) WalkSourceSelect(p *Source) error { func (m *PlannerDefault) WalkProjectionSource(p *Source) error { // Add a Non-Final Projection to choose the columns for results //u.Debugf("exec.projection: %p job.proj: %p added %s", p, m.Ctx.Projection, p.Stmt.String()) + u.Infof("------------- Source Projection") proj := NewProjectionInProcess(p.Stmt.Source) //u.Debugf("source projection: %p added %s", proj, p.Stmt.Source.String()) p.Add(proj) diff --git a/plan/projection.go b/plan/projection.go index e76bbb60..03db14f8 100644 --- a/plan/projection.go +++ b/plan/projection.go @@ -12,13 +12,14 @@ import ( "github.com/araddon/qlbridge/value" ) -// A static projection has already had its column/types defined -// and doesn't need to use internal schema to find it, often internal SHOW/DESCRIBE +// NewProjectionStatic create A static projection for literal query. +// IT has already had its column/types defined and doesn't need to use internal +// schema to find it, often internal SHOW/DESCRIBE. func NewProjectionStatic(proj *rel.Projection) *Projection { return &Projection{Proj: proj, PlanBase: NewPlanBase(false)} } -// Final Projections project final select columns for result-writing +// NewProjectionFinal project final select columns for result-writing func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { s := &Projection{ P: p, @@ -27,10 +28,13 @@ func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { Final: true, } var err error + u.Debugf("NewProjectionFinal") if len(p.Stmt.From) == 0 { + u.Warnf("literal projection") err = s.loadLiteralProjection(ctx) } else if len(p.From) == 1 && p.From[0].Proj != nil { s.Proj = p.From[0].Proj + u.Warnf("used the projection from From[0] %#v", s.Proj.Columns) } else { err = s.loadFinal(ctx, true) } @@ -39,6 +43,9 @@ func NewProjectionFinal(ctx *Context, p *Select) (*Projection, error) { } return s, nil } + +// NewProjectionInProcess create a projection for a non-final +// projection for source. func NewProjectionInProcess(stmt *rel.SqlSelect) *Projection { s := &Projection{ Stmt: stmt, @@ -67,14 +74,12 @@ func (m *Projection) loadLiteralProjection(ctx *Context) error { case *expr.NumberNode: // number? if et.IsInt { - proj.AddColumnShort(as, value.IntType) + proj.AddColumnShort(as, value.IntType, true) } else { - proj.AddColumnShort(as, value.NumberType) + proj.AddColumnShort(as, value.NumberType, true) } - //u.Infof("number? %#v", et) default: - //u.Infof("type? %#v", et) - proj.AddColumnShort(as, value.StringType) + proj.AddColumnShort(as, value.StringType, true) } } @@ -89,11 +94,11 @@ func (m *Projection) loadLiteralProjection(ctx *Context) error { func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { - //u.Debugf("creating plan.Projection final %s", m.Stmt.String()) + u.Debugf("creating plan.Projection final %s", m.Stmt.String()) m.Proj = rel.NewProjection() - for _, from := range m.Stmt.From { + for fromi, from := range m.Stmt.From { fromName := strings.ToLower(from.SourceName()) tbl, err := ctx.Schema.Table(fromName) @@ -108,90 +113,64 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { //u.Debugf("getting cols? %v cols=%v", from.ColumnPositions()) for _, col := range from.Source.Columns { //_, right, _ := col.LeftRight() - //u.Infof("col %s", col) + u.Infof("%d from:%s col %s", fromi, from.Name, col) if col.Star { for _, f := range tbl.Fields { - m.Proj.AddColumnShort(f.Name, f.ValueType()) + m.Proj.AddColumnShort(f.Name, f.ValueType(), true) } } else { if schemaCol, ok := tbl.FieldMap[col.SourceField]; ok { - if isFinal { - if col.InFinalProjection() { - //u.Debugf("in plan final %s", col.As) - m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) - } - } else { - //u.Debugf("not final %s", col.As) - m.Proj.AddColumnShort(col.As, schemaCol.ValueType()) - } - //u.Debugf("projection: %p add col: %v %v", m.Proj, col.As, schemaCol.Type.String()) + m.Proj.AddColumnShort(col.As, schemaCol.ValueType(), col.InFinalProjection()) } else { - //u.Infof("schema col not found: final?%v col: %#v InFinal?%v", isFinal, col, col.InFinalProjection()) - if isFinal { - if col.InFinalProjection() { - m.Proj.AddColumnShort(col.As, value.StringType) - } else { - u.Warnf("not adding to projection? %s", col) - } - } else { - m.Proj.AddColumnShort(col.As, value.StringType) - } + u.Infof("schema col not found: final?%v col: %#v InFinal?%v", isFinal, col, col.InFinalProjection()) + m.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) } } - } } } + + for i, col := range m.Proj.Columns { + u.Debugf("%d %#v", i, col) + } return nil } func projectionForSourcePlan(plan *Source) error { plan.Proj = rel.NewProjection() + u.Infof("projection. tbl?%v plan.Final?%v source: %s", plan.Tbl != nil, plan.Final, plan.Stmt.Source) - // u.Debugf("created plan.Proj *rel.Projection %p", plan.Proj) // Not all Execution run-times support schema. ie, csv files and other "ad-hoc" structures // do not have to have pre-defined data in advance, in which case the schema output // will not be deterministic on the sql []driver.values for _, col := range plan.Stmt.Source.Columns { - //u.Debugf("col: %v star?%v", col, col.Star) + u.Debugf("%2d col: %#v star?%v inFinal?%v", len(plan.Proj.Columns), col, col.Star, col.InFinalProjection()) if plan.Tbl == nil { - if plan.Final { - if col.InFinalProjection() { - plan.Proj.AddColumn(col, value.StringType) - } - } else { - plan.Proj.AddColumn(col, value.StringType) - } + plan.Proj.AddColumn(col, value.StringType, col.InFinalProjection()) + } else if schemaCol, ok := plan.Tbl.FieldMap[col.SourceField]; ok { - if plan.Final { - if col.InFinalProjection() { - //u.Infof("col add %v for %s", schemaCol.Type.String(), col) - plan.Proj.AddColumn(col, schemaCol.ValueType()) - } else { - //u.Infof("not in final? %#v", col) - } - } else { - plan.Proj.AddColumn(col, schemaCol.ValueType()) - } - //u.Debugf("projection: %p add col: %v %v", plan.Proj, col.As, schemaCol.Type.String()) + + plan.Proj.AddColumn(col, schemaCol.ValueType(), col.InFinalProjection()) + } else if col.Star { if plan.Tbl == nil { u.Warnf("no table?? %v", plan) } else { - //u.Infof("star cols? %v fields: %v", plan.Tbl.FieldPositions, plan.Tbl.Fields) + u.Infof("star cols? %v fields: %v", plan.Tbl.FieldPositions, plan.Tbl.Fields) for _, f := range plan.Tbl.Fields { //u.Infof(" add col %v %+v", f.Name, f) - plan.Proj.AddColumnShort(f.Name, f.ValueType()) + plan.Proj.AddColumnShort(f.Name, f.ValueType(), true) } } } else { + u.Warnf("WTF %#v", plan.Tbl.FieldMap) if col.Expr != nil && strings.ToLower(col.Expr.String()) == "count(*)" { //u.Warnf("count(*) as=%v", col.As) - plan.Proj.AddColumn(col, value.IntType) + plan.Proj.AddColumn(col, value.IntType, true) } else if col.Expr != nil { // A column was included in projection that does not exist in source. // TODO: Should we allow sources to have settings that specify wether @@ -199,16 +178,16 @@ func projectionForSourcePlan(plan *Source) error { // this is fine switch nt := col.Expr.(type) { case *expr.IdentityNode, *expr.StringNode: - plan.Proj.AddColumnShort(col.As, value.StringType) + plan.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) case *expr.NumberNode: if nt.IsInt { - plan.Proj.AddColumnShort(col.As, value.IntType) + plan.Proj.AddColumnShort(col.As, value.IntType, col.InFinalProjection()) } else { - plan.Proj.AddColumnShort(col.As, value.NumberType) + plan.Proj.AddColumnShort(col.As, value.NumberType, col.InFinalProjection()) } case *expr.FuncNode, *expr.BinaryNode: // Probably not string? - plan.Proj.AddColumnShort(col.As, value.StringType) + plan.Proj.AddColumnShort(col.As, value.StringType, col.InFinalProjection()) default: u.Warnf("schema col not found: SourceField=%q vals=%#v", col.SourceField, col) } @@ -219,6 +198,9 @@ func projectionForSourcePlan(plan *Source) error { } } + for _, c := range plan.Proj.Columns { + u.Debugf("col %+v", c) + } //u.Infof("plan.Projection %p cols: %d", plan.Proj, len(plan.Proj.Columns)) return nil } diff --git a/rel/sql.go b/rel/sql.go index 881cb175..a74e6278 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -236,7 +236,7 @@ type ( } // Columns List of Columns in SELECT [columns] Columns []*Column - // Column represents the Column as expressed in a [SELECT] + // Column represents the Column(s) as expressed in a [SELECT COLUMNS] // expression Column struct { sourceQuoteByte byte // quote mark? [ or ` etc @@ -321,13 +321,16 @@ func NewSqlDialect() expr.DialectWriter { func NewProjection() *Projection { return &Projection{Columns: make(ResultColumns, 0), colNames: make(map[string]struct{})} } -func NewResultColumn(as string, ordinal int, col *Column, valtype value.ValueType) *ResultColumn { - rc := ResultColumn{Name: as, As: as, ColPos: ordinal, Col: col, Type: valtype} + +// NewResultColumn create a new column describing a result column, may be final or intermediate. +func NewResultColumn(as string, ordinal int, col *Column, valtype value.ValueType, final bool) *ResultColumn { + rc := ResultColumn{Name: as, As: as, ColPos: ordinal, Col: col, Type: valtype, Final: final} if col != nil { rc.Name = col.SourceField } return &rc } + func NewSqlSelect() *SqlSelect { req := &SqlSelect{} req.Columns = make(Columns, 0) @@ -390,10 +393,13 @@ func NewColumnValue(tok lex.Token) *Column { } } func NewColumn(col string) *Column { + l, r, _ := expr.LeftRight(col) + u.Debugf("col=%q l=%q r=%q", col, l, r) return &Column{ - As: col, - SourceField: col, - Expr: &expr.IdentityNode{Text: col}, + As: col, + SourceField: col, + SourceOriginal: col, + Expr: &expr.IdentityNode{Text: col}, } } @@ -470,22 +476,23 @@ func resultColumnToPb(m *ResultColumn) *ResultColumnPb { return s } -func (m *Projection) AddColumnShort(colName string, vt value.ValueType) { +func (m *Projection) AddColumnShort(colName string, vt value.ValueType, final bool) { //colName = strings.ToLower(colName) // if _, exists := m.colNames[colName]; exists { // return // } //u.Infof("adding column %s to %v", colName, m.colNames) //m.colNames[colName] = struct{}{} - m.Columns = append(m.Columns, NewResultColumn(colName, len(m.Columns), nil, vt)) + m.Columns = append(m.Columns, NewResultColumn(colName, len(m.Columns), nil, vt, final)) } -func (m *Projection) AddColumn(col *Column, vt value.ValueType) { +func (m *Projection) AddColumn(col *Column, vt value.ValueType, final bool) { //colName := strings.ToLower(col.As) // if _, exists := m.colNames[colName]; exists { // return // } //m.colNames[colName] = struct{}{} - m.Columns = append(m.Columns, NewResultColumn(col.As, len(m.Columns), col, vt)) + u.Debugf("AddColumn %#v", col) + m.Columns = append(m.Columns, NewResultColumn(col.As, len(m.Columns), col, vt, final)) } func (m *Projection) Equal(s *Projection) bool { if m == nil && s == nil { @@ -633,6 +640,11 @@ func (m Columns) Equal(cols Columns) bool { } func (m *Column) Key() string { + // if m.right == "" && m.As == "" { + // u.Warnf("WTF no col info %#v", m) + // u.WarnT(10) + // } + // u.Debugf("Key(): left=%q right=%q As=%q", m.left, m.right, m.As) if m.left != "" { return m.right } @@ -1252,7 +1264,7 @@ func (m *SqlSelect) ColIndexes() map[string]int { cols := make(map[string]int, len(m.Columns)) for i, col := range m.Columns { //u.Debugf("aliasing: key():%-15q As:%-15q %-15q", col.Key(), col.As, col.String()) - cols[col.Key()] = i + cols[strings.ToLower(col.Key())] = i } return cols } @@ -1260,7 +1272,6 @@ func (m *SqlSelect) ColIndexes() map[string]int { func (m *SqlSelect) AddColumn(colArg Column) error { col := &colArg col.Index = len(m.Columns) - m.Columns = append(m.Columns, col) if col.Star { m.Star = true } @@ -1271,10 +1282,14 @@ func (m *SqlSelect) AddColumn(colArg Column) error { if col.Agg && !m.isAgg { m.isAgg = true } + // SELECT USER.FirstName AS fname FROM user + // col{SourceField:"FirstName"} + col.SourceField = strings.ToLower(col.SourceField) + m.Columns = append(m.Columns, col) return nil } -// Is this a select count(*) FROM ... query? +// CountStar Is this a select count(*) FROM ... query? func (m *SqlSelect) CountStar() bool { if len(m.Columns) != 1 { return false @@ -1295,20 +1310,27 @@ func (m *SqlSelect) CountStar() bool { } // Rewrite take current SqlSelect statement and re-write it -func (m *SqlSelect) Rewrite() { +func (m *SqlSelect) Rewrite() error { for _, f := range m.From { - f.Rewrite(m) + if _, err := f.Rewrite(m); err != nil { + return err + } } + return nil } // RewriteAsRawSelect We are removing Column Aliases "user_id as uid" // as well as functions - used when we are going to defer projection, aggs func (m *SqlSelect) RewriteAsRawSelect() { - RewriteSelect(m) + rewriteSelectStatement(m) } func (m *SqlSource) IsLiteral() bool { return len(m.Name) == 0 } func (m *SqlSource) Keyword() lex.TokenType { return m.Op } + +// SourceName return the sourcename for this sqlselect source, if sub-query +// get name of FROM [name] else if join get name. Corrects for namespacing +// to only get non-namedspaced name. func (m *SqlSource) SourceName() string { if m == nil { return "" @@ -1398,17 +1420,20 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } starDelta := 0 // how many columns were added due to * for _, col := range m.Source.Columns { + // if col.Key() == "" { + // u.Errorf("WTF no key? %#v", col) + // } if col.Star { starStart := len(m.colIndex) - for colIdx := range colNames { - m.colIndex[col.Key()] = colIdx + starStart + for colIdx, colName := range colNames { + m.colIndex[strings.ToLower(colName)] = colIdx + starStart } starDelta = len(colNames) } else { found := false for colIdx, colName := range colNames { _, colName, _ = expr.LeftRight(colName) - //u.Debugf("col.Key():%v sourceField:%v colName:%v", col.Key(), col.SourceField, colName) + u.Debugf("col.Key():%v sourceField:%v colName:%v", col.Key(), col.SourceField, colName) if colName == col.Key() || col.SourceField == colName { //&& //u.Debugf("build col: idx=%d key=%-15q as=%-15q col=%-15s sourcidx:%d", len(m.colIndex), col.Key(), col.As, col.String(), colIdx) m.colIndex[col.Key()] = colIdx + starDelta @@ -1418,6 +1443,8 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } } if !found && !col.IsLiteralOrFunc() { + u.Errorf("could not find col? %s", col) + u.WarnT(10) return fmt.Errorf("Missing Column in source: %q", col.String()) } } @@ -1426,9 +1453,17 @@ func (m *SqlSource) BuildColIndex(colNames []string) error { } // Rewrite this Source to act as a stand-alone query to backend -// @parentStmt = the parent statement that this a partial source to -func (m *SqlSource) Rewrite(parentStmt *SqlSelect) *SqlSelect { - return RewriteSqlSource(m, parentStmt) +// @parentStmt = the parent statement. IE, source is a partial (join, from where-in) source in a +// multi-source SELECT statement. We are re-writing to allow the sources to be independent. +func (m *SqlSource) Rewrite(parentStmt *SqlSelect) (*SqlSelect, error) { + sql2, err := RewriteSqlSource(m, parentStmt) + if err != nil { + return nil, err + } + m.Source = sql2 + u.Debugf("rewritten source: %s", sql2) + m.cols = sql2.UnAliasedColumns() + return sql2, err } func (m *SqlSource) findFromAliases() (string, string) { @@ -1453,8 +1488,8 @@ func (m *SqlSource) findFromAliases() (string, string) { return from1, from2 } -// Get a list of Un-Aliased Columns, ie columns with column -// names that have NOT yet been aliased +// UnAliasedColumns Get a list of Un-Aliased Columns, ie columns with column +// names that have NOT yet been aliased func (m *SqlSource) UnAliasedColumns() map[string]*Column { //u.Warnf("un-aliased %d", len(m.Source.Columns)) if len(m.cols) > 0 || m.Source != nil && len(m.Source.Columns) == 0 { @@ -1474,7 +1509,7 @@ func (m *SqlSource) UnAliasedColumns() map[string]*Column { return cols } -// Get a list of Column names to position +// ColumnPositions Get a list of Column names to position in array of columns. func (m *SqlSource) ColumnPositions() map[string]int { if len(m.colIndex) > 0 { return m.colIndex @@ -1496,7 +1531,7 @@ func (m *SqlSource) ColumnPositions() map[string]int { return m.colIndex } -// We need to be able to rewrite statements to convert a stmt such as: +// JoinNodes We need to be able to rewrite statements to convert a stmt such as: // // FROM users AS u // INNER JOIN orders AS o @@ -1518,6 +1553,8 @@ func (m *SqlSource) ColumnPositions() map[string]int { func (m *SqlSource) JoinNodes() []expr.Node { return m.joinNodes } + +// Finalize the source. func (m *SqlSource) Finalize() error { if m.final { return nil diff --git a/rel/sql_rewrite.go b/rel/sql_rewrite.go index c78d6ea2..6215e8f9 100644 --- a/rel/sql_rewrite.go +++ b/rel/sql_rewrite.go @@ -1,33 +1,83 @@ package rel import ( + "fmt" "strings" u "github.com/araddon/gou" + "github.com/araddon/qlbridge/expr" "github.com/araddon/qlbridge/lex" + "github.com/araddon/qlbridge/schema" +) + +type ( + rewriteSelect struct { + sel *SqlSelect + cols map[string]bool + matchSource string + features *schema.DataSourceFeatures + result *RewriteSelectResult + } + // RewriteSelectResult describes the result of a re-write statement to + // tell the planner which poly-fill features are needed based on re-write. + RewriteSelectResult struct { + NeedsProjection bool + NeedsWhere bool + NeedsGroupBy bool + } ) -// RewriteSelect We are removing Column Aliases "user_id as uid" +func newRewriteSelect(sel *SqlSelect) *rewriteSelect { + rw := &rewriteSelect{ + sel: sel, + cols: make(map[string]bool), + features: schema.FeaturesDefault(), + result: &RewriteSelectResult{}, + } + return rw +} + +// ReWriteStatement given SqlStatement +func ReWriteStatement(input SqlStatement) error { + switch stmt := input.(type) { + case *SqlSelect: + return rewriteSelectStatement(stmt) + default: + return fmt.Errorf("Rewrite not implemented for %T", input) + } +} + +// rewriteSelectStatement We are removing Column Aliases "user_id as uid" // as well as functions - used when we are going to defer projection, aggs -func RewriteSelect(m *SqlSelect) { - originalCols := m.Columns - m.Columns = make(Columns, 0, len(originalCols)+5) - rewriteIntoProjection(m, originalCols) - rewriteIntoProjection(m, m.GroupBy) - if m.Where != nil { - colsToAdd := expr.FindAllIdentityField(m.Where.Expr) - addIntoProjection(m, colsToAdd) +func rewriteSelectStatement(sel *SqlSelect) error { + rw := newRewriteSelect(sel) + + originalCols := sel.Columns + sel.Columns = make(Columns, 0, len(originalCols)+5) + if err := rw.intoProjection(sel, originalCols, true); err != nil { + return err + } + if err := rw.intoProjection(sel, sel.GroupBy, false); err != nil { + return err } - rewriteIntoProjection(m, m.OrderBy) + if sel.Where != nil { + cols := expr.FindAllIdentityField(sel.Where.Expr) + for _, col := range cols { + nc := NewColumn(col) + nc.ParentIndex = -1 // ie, NOT in final + rw.addColumn(*nc) + } + } + return rw.intoProjection(sel, sel.OrderBy, false) } -// RewriteSqlSource this Source to act as a stand-alone query to backend +// RewriteSqlSource this SqlSource to act as a stand-alone query to backend // @parentStmt = the parent statement that this a partial source to -func RewriteSqlSource(m *SqlSource, parentStmt *SqlSelect) *SqlSelect { +func RewriteSqlSource(source *SqlSource, parentStmt *SqlSelect) (*SqlSelect, error) { - if m.Source != nil { - return m.Source + if source.Source != nil { + return source.Source, nil } // Rewrite this SqlSource for the given parent, ie // 1) find the column names we need to request from source including those used in join/where @@ -36,172 +86,223 @@ func RewriteSqlSource(m *SqlSource, parentStmt *SqlSelect) *SqlSelect { // sides should be aliased towards the left-hand join portion // 4) if we need different sort for our join algo? - newCols := make(Columns, 0) - if !parentStmt.Star { - for idx, col := range parentStmt.Columns { - left, _, hasLeft := col.LeftRight() - if !hasLeft { - // Was not left/right qualified, so use as is? or is this an error? - // what is official sql grammar on this? - newCol := col.Copy() - newCol.ParentIndex = idx - newCol.Index = len(newCols) - newCols = append(newCols, newCol) - - } else if hasLeft && left == m.Alias { - newCol := col.CopyRewrite(m.Alias) - newCol.ParentIndex = idx - newCol.SourceIndex = len(newCols) - newCol.Index = len(newCols) - newCols = append(newCols, newCol) - } - } - } + sql2 := &SqlSelect{Columns: make(Columns, 0), Star: parentStmt.Star} + rw := newRewriteSelect(sql2) + rw.matchSource = source.Alias + originalCols := parentStmt.Columns + if err := rw.intoProjection(sql2, originalCols, true); err != nil { + return nil, err + } + //u.Debugf("after into projection: %s", sql2.Columns) // TODO: // - rewrite the Sort // - rewrite the group-by - sql2 := &SqlSelect{Columns: newCols, Star: parentStmt.Star} - m.joinNodes = make([]expr.Node, 0) - if m.SubQuery != nil { - if len(m.SubQuery.From) != 1 { - u.Errorf("Not supported, nested subQuery %v", m.SubQuery.String()) + + source.joinNodes = make([]expr.Node, 0) + if source.SubQuery != nil { + if len(source.SubQuery.From) != 1 { + u.Errorf("Not supported, nested subQuery %v", source.SubQuery.String()) } else { - sql2.From = append(sql2.From, &SqlSource{Name: m.SubQuery.From[0].Name}) + sql2.From = append(sql2.From, &SqlSource{Name: source.SubQuery.From[0].Name}) } } else { - sql2.From = append(sql2.From, &SqlSource{Name: m.Name}) + sql2.From = append(sql2.From, &SqlSource{Name: source.Name}) } for _, from := range parentStmt.From { // We need to check each participant in the Join for possible // columns which need to be re-written - sql2.Columns = columnsFromJoin(m, from.JoinExpr, sql2.Columns) + rw.columnsFromExpression(source, from.JoinExpr) // We also need to create an expression used for evaluating // the values of Join "Keys" if from.JoinExpr != nil { - joinNodesForFrom(parentStmt, m, from.JoinExpr, 0) + rw.joinNodesForFrom(parentStmt, source, from.JoinExpr, 0) } } + //u.Debugf("after FROM: %s", sql2.Columns) if parentStmt.Where != nil { - node, cols := rewriteWhere(parentStmt, m, parentStmt.Where.Expr, make(Columns, 0)) + node := rw.rewriteWhere(parentStmt, source, parentStmt.Where.Expr) if node != nil { sql2.Where = &SqlWhere{Expr: node} } - if len(cols) > 0 { - parentIdx := len(parentStmt.Columns) - for _, col := range cols { - col.Index = len(sql2.Columns) - col.ParentIndex = parentIdx - parentIdx++ - sql2.Columns = append(sql2.Columns, col) + /* + if len(cols) > 0 { + parentIdx := len(parentStmt.Columns) + for _, col := range cols { + col.Index = len(sql2.Columns) + col.ParentIndex = parentIdx + parentIdx++ + sql2.Columns = append(sql2.Columns, col) + } } - } + */ } - m.Source = sql2 - m.cols = sql2.UnAliasedColumns() - return sql2 + //u.Debugf("after WHERE: %s", sql2.Columns) + return sql2, nil } -func rewriteIntoProjection(sel *SqlSelect, m Columns) { - if len(m) == 0 { +func (m *rewriteSelect) addColumn(col Column) { + col.Index = len(m.sel.Columns) + if col.Star { + if _, found := m.cols["*"]; found { + //u.Debugf("dupe %+v", col) + return + } + m.cols["*"] = true + m.sel.AddColumn(col) + return + } + if _, found := m.cols[col.SourceField]; found { + //u.Debugf("dupe %+v", col) return } - colsToAdd := make([]string, 0) - for _, c := range m { - // u.Infof("source=%-15s as=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) + + u.Infof("adding col %#v", col) + m.cols[col.SourceField] = true + m.sel.AddColumn(col) +} +func (m *rewriteSelect) intoProjection(sel *SqlSelect, cols Columns, final bool) error { + if len(cols) == 0 { + return nil + } + /* + if !parentStmt.Star { + for idx, col := range parentStmt.Columns { + left, _, hasLeft := col.LeftRight() + if !hasLeft { + // Was not left/right qualified, so use as is? or is this an error? + // what is official sql grammar on this? + newCol := col.Copy() + newCol.ParentIndex = idx + newCol.Index = len(newCols) + newCols = append(newCols, newCol) + + } else if hasLeft && left == m.Alias { + newCol := col.CopyRewrite(m.Alias) + newCol.ParentIndex = idx + newCol.SourceIndex = len(newCols) + newCol.Index = len(newCols) + newCols = append(newCols, newCol) + } + } + } + */ + for i, c := range cols { + left, _, hasLeft := c.LeftRight() + if !hasLeft { + // ?? + u.Warnf("is this possible no left? %#v", c) + } else if hasLeft && left == m.matchSource { + // ok + c = c.CopyRewrite(m.matchSource) + } else { + //u.Warnf("no.... %v", c) + continue + } + + parentIndex := i + if !final { + parentIndex = -1 + } + + u.Infof("as=%-15s source=%-15s exprT:%T expr=%s star:%v", c.As, c.SourceField, c.Expr, c.Expr, c.Star) + + var nc *Column switch n := c.Expr.(type) { case *expr.IdentityNode: - colsToAdd = append(colsToAdd, c.SourceField) + nc = NewColumn(strings.ToLower(c.SourceField)) + nc.Expr = &expr.IdentityNode{Text: strings.ToLower(n.Text)} case *expr.FuncNode: - + // TODO: use features to rewrite this. ie, idents := expr.FindAllIdentities(n) for _, in := range idents { - _, r, _ := in.LeftRight() - colsToAdd = append(colsToAdd, r) + _, right, _ := in.LeftRight() + nc := NewColumn(strings.ToLower(right)) + nc.ParentIndex = parentIndex + nc.Expr = in + m.addColumn(*nc) } - + case *expr.NumberNode, *expr.NullNode, *expr.StringNode: + // literals + nc := NewColumn(strings.ToLower(n.String())) + nc.Expr = n case nil: if c.Star { - colsToAdd = append(colsToAdd, "*") + nc = c.Copy() } else { u.Warnf("unhandled column? %T %s", n, n) } - default: u.Warnf("unhandled column? %T %s", n, n) } - } - addIntoProjection(sel, colsToAdd) -} -func addIntoProjection(sel *SqlSelect, newCols []string) { - notExists := make(map[string]bool) - for _, colName := range newCols { - colName = strings.ToLower(colName) - found := false - for _, c := range sel.Columns { - if c.SourceField == colName { - // already in projection - found = true - break - } - } - if !found { - notExists[colName] = true - if colName == "*" { - sel.AddColumn(Column{Star: true}) - } else { - nc := NewColumn(colName) - sel.AddColumn(*nc) - } + + if nc != nil { + nc.ParentIndex = parentIndex + m.addColumn(*nc) } } + return nil } -func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns) (expr.Node, Columns) { - //u.Debugf("rewrite where %s", node) + +// func (m *rewriteSelect) addIntoProjection(sel *SqlSelect, colsToAdd map[string]int) { +// for colName, idx := range colsToAdd { +// colName = strings.ToLower(colName) +// if colName == "*" { +// m.addColumn(Column{Star: true, ParentIndex: idx}) +// } else { +// nc := NewColumn(colName) +// nc.ParentIndex = idx +// m.addColumn(*nc) +// } +// } +// } +func (m *rewriteSelect) rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node) expr.Node { + u.Debugf("rewrite where %s", node) switch nt := node.(type) { case *expr.IdentityNode: if left, right, hasLeft := nt.LeftRight(); hasLeft { //u.Debugf("rewriteWhere from.Name:%v l:%v r:%v", from.alias, left, right) if left == from.alias { in := expr.IdentityNode{Text: right} - cols = append(cols, NewColumn(right)) - //u.Warnf("nice, found it! in = %v cols:%d", in, len(cols)) - return &in, cols + nc := *NewColumn(right) + nc.ParentIndex = -1 + m.addColumn(nc) + return &in } else { //u.Warnf("what to do? source:%v %v", from.alias, nt.String()) } } else { //u.Debugf("returning original: %s", nt) - return node, cols + return node } case *expr.NumberNode, *expr.NullNode, *expr.StringNode: - return nt, cols + return nt case *expr.BinaryNode: //u.Infof("binaryNode T:%v", nt.Operator.T.String()) switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: var n1, n2 expr.Node - n1, cols = rewriteWhere(stmt, from, nt.Args[0], cols) - n2, cols = rewriteWhere(stmt, from, nt.Args[1], cols) + n1 = m.rewriteWhere(stmt, from, nt.Args[0]) + n2 = m.rewriteWhere(stmt, from, nt.Args[1]) if n1 != nil && n2 != nil { - return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}}, cols + return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}} } else if n1 != nil { - return n1, cols + return n1 } else if n2 != nil { - return n2, cols + return n2 } else { //u.Warnf("n1=%#v n2=%#v %#v", n1, n2, nt) } case lex.TokenEqual, lex.TokenEqualEqual, lex.TokenGT, lex.TokenGE, lex.TokenLE, lex.TokenNE: var n1, n2 expr.Node - n1, cols = rewriteWhere(stmt, from, nt.Args[0], cols) - n2, cols = rewriteWhere(stmt, from, nt.Args[1], cols) + n1 = m.rewriteWhere(stmt, from, nt.Args[0]) + n2 = m.rewriteWhere(stmt, from, nt.Args[1]) //u.Debugf("n1=%#v n2=%#v %#v", n1, n2, nt) if n1 != nil && n2 != nil { - return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}}, cols + return &expr.BinaryNode{Operator: nt.Operator, Args: []expr.Node{n1, n2}} // } else if n1 != nil { // return n1 // } else if n2 != nil { @@ -212,14 +313,25 @@ func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns default: //u.Warnf("un-implemented op: %#v", nt) } + case *expr.FuncNode: + // TODO: use features. + idents := expr.FindAllIdentities(nt) + for _, in := range idents { + _, right, _ := in.LeftRight() + nc := *NewColumn(right) + nc.ParentIndex = -1 + nc.Expr = in + m.addColumn(nc) + } + default: u.Warnf("%T node types are not suppored yet for where rewrite", node) } //u.Warnf("nil?? %T %s %#v", node, node, node) - return nil, cols + return nil } -func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth int) expr.Node { +func (m *rewriteSelect) joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth int) expr.Node { switch nt := node.(type) { case *expr.IdentityNode: @@ -266,8 +378,8 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in //u.Infof("%v binaryNode %v", depth, nt.String()) switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: - n1 := joinNodesForFrom(stmt, from, nt.Args[0], depth+1) - n2 := joinNodesForFrom(stmt, from, nt.Args[1], depth+1) + n1 := m.joinNodesForFrom(stmt, from, nt.Args[0], depth+1) + n2 := m.joinNodesForFrom(stmt, from, nt.Args[1], depth+1) if n1 != nil && n2 != nil { //u.Debugf("%d neither nil: n1=%v n2=%v %q", depth, n1, n2, nt.String()) @@ -282,8 +394,8 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in //u.Warnf("%d n1=%#v n2=%#v %#v", depth, n1, n2, nt) } case lex.TokenEqual, lex.TokenEqualEqual, lex.TokenGT, lex.TokenGE, lex.TokenLE, lex.TokenNE: - n1 := joinNodesForFrom(stmt, from, nt.Args[0], depth+1) - n2 := joinNodesForFrom(stmt, from, nt.Args[1], depth+1) + n1 := m.joinNodesForFrom(stmt, from, nt.Args[0], depth+1) + n2 := m.joinNodesForFrom(stmt, from, nt.Args[1], depth+1) if n1 != nil && n2 != nil { //u.Debugf("%d neither nil: n1=%v n2=%v %q", depth, n1, n2, nt.String()) @@ -316,54 +428,42 @@ func joinNodesForFrom(stmt *SqlSelect, from *SqlSource, node expr.Node, depth in return nil } -// We need to find all columns used in the given Node (where/join expression) -// to ensure we have those columns in projection for sub-queries -func columnsFromJoin(from *SqlSource, node expr.Node, cols Columns) Columns { +// We need to find all columns used in the given Node (where or join expression) +// to ensure we have those columns in projection. +func (m *rewriteSelect) columnsFromExpression(from *SqlSource, node expr.Node) error { if node == nil { - return cols + return nil } //u.Debugf("columnsFromJoin() T:%T node=%q", node, node.String()) switch nt := node.(type) { case *expr.IdentityNode: if left, right, ok := nt.LeftRight(); ok { - //u.Debugf("from.Name:%v AS %v Joinnode l:%v r:%v %#v", from.Name, from.alias, left, right, nt) - //u.Warnf("check cols against join expr arg: %#v", nt) - if left == from.alias { - found := false - for _, col := range cols { - colLeft, colRight, _ := col.LeftRight() - //u.Debugf("left='%s' colLeft='%s' right='%s' %#v", left, colLeft, colRight, col) - //u.Debugf("col: From %s AS '%s' '%s'.'%s' JoinExpr: '%v'.'%v' col:%#v", from.Name, from.alias, colLeft, colRight, left, right, col) - if left == colLeft || colRight == right { - found = true - //u.Infof("columnsFromJoin from.Name:%v l:%v r:%v", from.alias, left, right) - } else { - //u.Warnf("not? from.Name:%v l:%v r:%v col: P:%p %#v", from.alias, left, right, col, col) - } - } - if !found { - //u.Debugf("columnsFromJoin from.Name:%v l:%v r:%v", from.alias, left, right) - newCol := &Column{As: right, SourceField: right, Expr: &expr.IdentityNode{Text: right}} - newCol.Index = len(cols) - newCol.ParentIndex = -1 // if -1, we don't need in parent index - cols = append(cols, newCol) - //u.Warnf("added col %s idx:%d pidx:%v", right, newCol.Index, newCol.Index) - } + if left != from.alias { + return nil + } + if _, found := m.cols[strings.ToLower(right)]; found { + return nil } + + newCol := Column{As: right, SourceField: right, Expr: &expr.IdentityNode{Text: right}} + newCol.ParentIndex = -1 // if -1, we don't need in parent projection + m.addColumn(newCol) + //u.Warnf("added col %s idx:%d pidx:%v", right, newCol.Index, newCol.Index) } + case *expr.FuncNode: //u.Warnf("columnsFromJoin func node: %s", nt.String()) for _, arg := range nt.Args { - cols = columnsFromJoin(from, arg, cols) + m.columnsFromExpression(from, arg) } case *expr.BinaryNode: switch nt.Operator.T { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: - cols = columnsFromJoin(from, nt.Args[0], cols) - cols = columnsFromJoin(from, nt.Args[1], cols) + m.columnsFromExpression(from, nt.Args[0]) + m.columnsFromExpression(from, nt.Args[1]) case lex.TokenEqual, lex.TokenEqualEqual: - cols = columnsFromJoin(from, nt.Args[0], cols) - cols = columnsFromJoin(from, nt.Args[1], cols) + m.columnsFromExpression(from, nt.Args[0]) + m.columnsFromExpression(from, nt.Args[1]) default: u.Warnf("un-implemented op: %v", nt.Operator) } @@ -371,7 +471,7 @@ func columnsFromJoin(from *SqlSource, node expr.Node, cols Columns) Columns { u.LogTracef(u.INFO, "whoops") u.Warnf("%T node types are not suppored yet for join rewrite %s", node, from.String()) } - return cols + return nil } // Remove any aliases diff --git a/rel/sql_rewrite_test.go b/rel/sql_rewrite_test.go index 0e2a2bd7..dad896c8 100644 --- a/rel/sql_rewrite_test.go +++ b/rel/sql_rewrite_test.go @@ -1 +1,247 @@ package rel_test + +import ( + "strings" + "testing" + + u "github.com/araddon/gou" + "github.com/stretchr/testify/assert" + + "github.com/araddon/qlbridge/rel" + "github.com/araddon/qlbridge/schema" +) + +func parseFeatures(t testing.TB, f *schema.DataSourceFeatures, q string) *rel.SqlSelect { + stmt, err := rel.ParseSqlSelect(q) + assert.Equal(t, nil, err, "expected no error but got %v for %s", err, q) + assert.NotEqual(t, nil, stmt) + err = stmt.Rewrite() + assert.Equal(t, nil, err) + return stmt +} +func parse(t testing.TB, q string) *rel.SqlSelect { + return parseFeatures(t, schema.FeaturesDefault(), q) +} + +func TestSqlSelectReWrite(t *testing.T) { + ss := parse(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)") + assert.Equal(t, 1, len(ss.From[0].Source.Columns)) + ss = parse(t, `select exists(email), email FROM users WHERE yy(reg_date) > 10;`) + assert.Equal(t, 2, len(ss.From[0].Source.Columns)) +} + +func TestSqlRewriteTemp(t *testing.T) { + + s := `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id from ORDERS + WHERE user_id IS NOT NULL AND price > 10 + ) AS o + ON u.user_id = o.user_id + ` + sql := parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 + ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) +} + +func TestSqlRewrite(t *testing.T) { + t.Parallel() + /* + SQL Re-writing is to take select statement with multiple sources (joins, sub-select) + and rewrite these sub-statements/sources into standalone statements + and prepare the column name, index mapping + + - Do we want to send the columns fully aliased? ie + SELECT name AS u.name, email as u.email, user_id as u.user_id FROM users + */ + s := `SELECT u.name, o.item_id, u.email, o.price + FROM users AS u INNER JOIN orders AS o + ON u.user_id = o.user_id;` + sql := parseOrPanic(t, s).(*rel.SqlSelect) + err := sql.Finalize() + assert.True(t, err == nil, "no error: %v", err) + assert.True(t, len(sql.Columns) == 4, "has 4 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + // Test the Left/Right column level parsing + // TODO: This field should not be u.name? sourcefield should be name right? as = u.name? + col, _ := sql.Columns.ByName("u.name") + assert.True(t, col.As == "u.name", "col.As=%s", col.As) + left, right, ok := col.LeftRight() + //u.Debugf("left=%v right=%v ok%v", left, right, ok) + assert.True(t, left == "u" && right == "name" && ok == true) + + rw1, _ := sql.From[0].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.Equal(t, rw1.String(), "SELECT name, email, user_id FROM users", "%v", rw1.String()) + + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT item_id, price, user_id FROM orders", "%v", rw1.String()) + + // Do we change? + //assert.Equal(t, sql.Columns.FieldNames(), []string{"user_id", "email", "item_id", "price"}) + + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON u.name = b.author;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + rw1, _ = sql.From[0].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT name, email FROM users", "%v", rw1.String()) + jn := sql.From[0].JoinNodes() + assert.True(t, len(jn) == 1, "%v", jn) + assert.True(t, jn[0].String() == "name", "wanted 1 node %v", jn[0].String()) + cols := sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 2, "Should have 2: %#v", cols) + //u.Infof("cols: %#v", cols) + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) + // TODO: verify that we can rewrite sql for aliases + // jn, _ = sql.From[1].JoinValueExpr() + // assert.True(t, jn.String() == "name", "%v", jn.String()) + // u.Infof("SQL?: '%v'", rw1.String()) + // assert.True(t, rw1.String() == "SELECT title, author as name FROM blog", "%v", rw1.String()) + + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON tolower(u.author) = b.author;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.Rewrite() + selu := sql.From[0].Source + assert.True(t, len(selu.Columns) == 3, "user 3 cols: %v", selu.Columns.String()) + assert.True(t, selu.String() == "SELECT name, email, author FROM users", "%v", selu.String()) + jn = sql.From[0].JoinNodes() + assert.True(t, len(jn) == 1, "wanted 1 node but got fromP: %p %v", sql.From[0], jn) + assert.True(t, jn[0].String() == "tolower(author)", "wanted 1 node %v", jn[0].String()) + cols = sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) + + // Now lets try compound join keys + s = `SELECT u.name, u.email, b.title + FROM users AS u INNER JOIN blog AS b + ON u.name = b.author and tolower(u.alias) = b.alias;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.Rewrite() + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + rw1 = sql.From[0].Source + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + //u.Infof("SQL?: '%v'", rw1.String()) + assert.True(t, rw1.String() == "SELECT name, email, alias FROM users", "%v", rw1.String()) + jn = sql.From[0].JoinNodes() + assert.True(t, len(jn) == 2, "wanted 2 join nodes but %v", len(jn)) + assert.True(t, jn[0].String() == "name", `want "name" %v`, jn[0].String()) + assert.True(t, jn[1].String() == "tolower(alias)", `want "tolower(alias)" but got %q`, jn[1].String()) + cols = sql.From[0].UnAliasedColumns() + assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) + //u.Infof("cols: %#v", cols) + rw1 = sql.From[1].Source + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + + // This test, is looking at these aspects of rewrite + // 1 the dotted notation of 'repostory.name' ensuring we have removed the p. + // 2 where clause + s = ` + SELECT + p.actor, p.repository.name, a.title + FROM article AS a + INNER JOIN github_push AS p + ON p.actor = a.author + WHERE p.follow_ct > 20 AND a.email IS NOT NULL + ` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + rw0, _ := sql.From[0].Rewrite(sql) + rw1, _ = sql.From[1].Rewrite(sql) + assert.True(t, rw0 != nil, "should not be nil:") + assert.True(t, len(rw0.Columns) == 3, "has 3 cols: %v", rw0.String()) + assert.True(t, len(sql.From[0].Source.Columns) == 3, "has 3 cols? %s", sql.From[0].Source) + assert.True(t, rw0.String() == "SELECT title, author, email FROM article WHERE email != NULL", "Wrong SQL 0: %v", rw0.String()) + assert.True(t, rw1 != nil, "should not be nil:") + assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) + assert.True(t, len(sql.From[1].Source.Columns) == 3, "has 3 cols? %s", sql.From[1].Source) + assert.True(t, rw1.String() == "SELECT actor, `repository.name`, follow_ct FROM github_push WHERE follow_ct > 20", "Wrong SQL 1: %v", rw1.String()) + + // Original should still be the same + parts := strings.Split(sql.String(), "\n") + for _, p := range parts { + u.Debugf("----%v----", p) + } + assert.True(t, parts[0] == "SELECT p.actor, p.`repository.name`, a.title FROM article AS a", "Wrong Full SQL?: '%v'", parts[0]) + assert.True(t, parts[1] == ` INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", parts[1]) + assert.True(t, sql.String() == `SELECT p.actor, p.`+"`repository.name`"+`, a.title FROM article AS a + INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", sql.String()) + + s = `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id from ORDERS + WHERE user_id IS NOT NULL AND price > 10 + ) AS o + ON u.user_id = o.user_id + ` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) + assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) + + assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u + INNER JOIN ( + SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 + ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) + + // Rewrite to remove functions, and aliasing to send all fields needed down to source + // used when we are going to poly-fill + s = `SELECT count AS ct, name as nm, todate(myfield) AS mydate FROM user` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.RewriteAsRawSelect() + assert.True(t, sql.String() == `SELECT count, name, myfield FROM user`, "Wrong rewrite SQL?: '%v'", sql.String()) + + // Now ensure a group by, and where columns + s = `SELECT name as nm, todate(myfield) AS mydate FROM user WHERE created > todate("2016-01-01") GROUP BY referral;` + sql = parseOrPanic(t, s).(*rel.SqlSelect) + sql.RewriteAsRawSelect() + assert.True(t, sql.String() == `SELECT name, myfield, referral, created FROM user WHERE created > todate("2016-01-01") GROUP BY referral`, "Wrong rewrite SQL?: '%v'", sql.String()) + + //assert.True(t, sql.From[1].Name == "ORDERS", "orders? %q", sql.From[1].Name) + // sql.From[0].Rewrite(sql) + // sql.From[1].Rewrite(sql) + // assert.True(t, sql.From[0].Source.String() == `SELECT user_id, reg_date, email FROM users`, "Wrong Full SQL?: '%v'", sql.From[0].Source.String()) + // assert.True(t, sql.From[1].Source.String() == `SELECT item_id, price, order_date, user_id FROM ORDERS`, "Wrong Full SQL?: '%v'", sql.From[1].Source.String()) + + // s = `SELECT aa.*, + // bb.meal + // FROM table1 aa + // INNER JOIN table2 bb + // ON aa.tableseat = bb.tableseat AND + // aa.weddingtable = bb.weddingtable + // INNER JOIN + // ( + // SELECT a.tableSeat + // FROM table1 a + // INNER JOIN table2 b + // ON a.tableseat = b.tableseat AND + // a.weddingtable = b.weddingtable + // WHERE b.meal IN ('chicken', 'steak') + // GROUP by a.tableSeat + // HAVING COUNT(DISTINCT b.Meal) = 2 + // ) c ON aa.tableseat = c.tableSeat + // ` +} diff --git a/rel/sql_test.go b/rel/sql_test.go index e6db444c..f6de2b9e 100644 --- a/rel/sql_test.go +++ b/rel/sql_test.go @@ -3,7 +3,6 @@ package rel_test import ( "fmt" "reflect" - "strings" "testing" u "github.com/araddon/gou" @@ -146,222 +145,6 @@ func compareNode(t *testing.T, n1, n2 expr.Node) { assert.True(t, rv1.Kind() == rv2.Kind(), "kinds match: %T %T", n1, n2) } -func TestSqlRewriteTemp(t *testing.T) { - - s := `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id from ORDERS - WHERE user_id IS NOT NULL AND price > 10 - ) AS o - ON u.user_id = o.user_id - ` - sql := parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 - ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) -} - -func TestSqlRewrite(t *testing.T) { - t.Parallel() - /* - SQL Re-writing is to take select statement with multiple sources (joins, sub-select) - and rewrite these sub-statements/sources into standalone statements - and prepare the column name, index mapping - - - Do we want to send the columns fully aliased? ie - SELECT name AS u.name, email as u.email, user_id as u.user_id FROM users - */ - s := `SELECT u.name, o.item_id, u.email, o.price - FROM users AS u INNER JOIN orders AS o - ON u.user_id = o.user_id;` - sql := parseOrPanic(t, s).(*rel.SqlSelect) - err := sql.Finalize() - assert.True(t, err == nil, "no error: %v", err) - assert.True(t, len(sql.Columns) == 4, "has 4 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - // Test the Left/Right column level parsing - // TODO: This field should not be u.name? sourcefield should be name right? as = u.name? - col, _ := sql.Columns.ByName("u.name") - assert.True(t, col.As == "u.name", "col.As=%s", col.As) - left, right, ok := col.LeftRight() - //u.Debugf("left=%v right=%v ok%v", left, right, ok) - assert.True(t, left == "u" && right == "name" && ok == true) - - rw1 := sql.From[0].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.Equal(t, rw1.String(), "SELECT name, email, user_id FROM users", "%v", rw1.String()) - - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT item_id, price, user_id FROM orders", "%v", rw1.String()) - - // Do we change? - //assert.Equal(t, sql.Columns.FieldNames(), []string{"user_id", "email", "item_id", "price"}) - - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON u.name = b.author;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - rw1 = sql.From[0].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT name, email FROM users", "%v", rw1.String()) - jn := sql.From[0].JoinNodes() - assert.True(t, len(jn) == 1, "%v", jn) - assert.True(t, jn[0].String() == "name", "wanted 1 node %v", jn[0].String()) - cols := sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 2, "Should have 2: %#v", cols) - //u.Infof("cols: %#v", cols) - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 2, "has 2 cols: %v", rw1.Columns.String()) - // TODO: verify that we can rewrite sql for aliases - // jn, _ = sql.From[1].JoinValueExpr() - // assert.True(t, jn.String() == "name", "%v", jn.String()) - // u.Infof("SQL?: '%v'", rw1.String()) - // assert.True(t, rw1.String() == "SELECT title, author as name FROM blog", "%v", rw1.String()) - - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON tolower(u.author) = b.author;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.Rewrite() - selu := sql.From[0].Source - assert.True(t, len(selu.Columns) == 3, "user 3 cols: %v", selu.Columns.String()) - assert.True(t, selu.String() == "SELECT name, email, author FROM users", "%v", selu.String()) - jn = sql.From[0].JoinNodes() - assert.True(t, len(jn) == 1, "wanted 1 node but got fromP: %p %v", sql.From[0], jn) - assert.True(t, jn[0].String() == "tolower(author)", "wanted 1 node %v", jn[0].String()) - cols = sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) - - // Now lets try compound join keys - s = `SELECT u.name, u.email, b.title - FROM users AS u INNER JOIN blog AS b - ON u.name = b.author and tolower(u.alias) = b.alias;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.Rewrite() - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - rw1 = sql.From[0].Source - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - //u.Infof("SQL?: '%v'", rw1.String()) - assert.True(t, rw1.String() == "SELECT name, email, alias FROM users", "%v", rw1.String()) - jn = sql.From[0].JoinNodes() - assert.True(t, len(jn) == 2, "wanted 2 join nodes but %v", len(jn)) - assert.True(t, jn[0].String() == "name", `want "name" %v`, jn[0].String()) - assert.True(t, jn[1].String() == "tolower(alias)", `want "tolower(alias)" but got %q`, jn[1].String()) - cols = sql.From[0].UnAliasedColumns() - assert.True(t, len(cols) == 3, "Should have 3: %#v", cols) - //u.Infof("cols: %#v", cols) - rw1 = sql.From[1].Source - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - - // This test, is looking at these aspects of rewrite - // 1 the dotted notation of 'repostory.name' ensuring we have removed the p. - // 2 where clause - s = ` - SELECT - p.actor, p.repository.name, a.title - FROM article AS a - INNER JOIN github_push AS p - ON p.actor = a.author - WHERE p.follow_ct > 20 AND a.email IS NOT NULL - ` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 3, "has 3 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - rw0 := sql.From[0].Rewrite(sql) - rw1 = sql.From[1].Rewrite(sql) - assert.True(t, rw0 != nil, "should not be nil:") - assert.True(t, len(rw0.Columns) == 3, "has 3 cols: %v", rw0.String()) - assert.True(t, len(sql.From[0].Source.Columns) == 3, "has 3 cols? %s", sql.From[0].Source) - assert.True(t, rw0.String() == "SELECT title, author, email FROM article WHERE email != NULL", "Wrong SQL 0: %v", rw0.String()) - assert.True(t, rw1 != nil, "should not be nil:") - assert.True(t, len(rw1.Columns) == 3, "has 3 cols: %v", rw1.Columns.String()) - assert.True(t, len(sql.From[1].Source.Columns) == 3, "has 3 cols? %s", sql.From[1].Source) - assert.True(t, rw1.String() == "SELECT actor, `repository.name`, follow_ct FROM github_push WHERE follow_ct > 20", "Wrong SQL 1: %v", rw1.String()) - - // Original should still be the same - parts := strings.Split(sql.String(), "\n") - for _, p := range parts { - u.Debugf("----%v----", p) - } - assert.True(t, parts[0] == "SELECT p.actor, p.`repository.name`, a.title FROM article AS a", "Wrong Full SQL?: '%v'", parts[0]) - assert.True(t, parts[1] == ` INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", parts[1]) - assert.True(t, sql.String() == `SELECT p.actor, p.`+"`repository.name`"+`, a.title FROM article AS a - INNER JOIN github_push AS p ON p.actor = a.author WHERE p.follow_ct > 20 AND a.email != NULL`, "Wrong Full SQL?: '%v'", sql.String()) - - s = `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id from ORDERS - WHERE user_id IS NOT NULL AND price > 10 - ) AS o - ON u.user_id = o.user_id - ` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - assert.True(t, len(sql.Columns) == 6, "has 6 cols: %v", len(sql.Columns)) - assert.True(t, len(sql.From) == 2, "has 2 sources: %v", len(sql.From)) - - assert.True(t, sql.String() == `SELECT u.user_id, o.item_id, u.reg_date, u.email, o.price, o.order_date FROM users AS u - INNER JOIN ( - SELECT price, order_date, user_id FROM ORDERS WHERE user_id != NULL AND price > 10 - ) AS o ON u.user_id = o.user_id`, "Wrong Full SQL?: '%v'", sql.String()) - - // Rewrite to remove functions, and aliasing to send all fields needed down to source - // used when we are going to poly-fill - s = `SELECT count AS ct, name as nm, todate(myfield) AS mydate FROM user` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.RewriteAsRawSelect() - assert.True(t, sql.String() == `SELECT count, name, myfield FROM user`, "Wrong rewrite SQL?: '%v'", sql.String()) - - // Now ensure a group by, and where columns - s = `SELECT name as nm, todate(myfield) AS mydate FROM user WHERE created > todate("2016-01-01") GROUP BY referral;` - sql = parseOrPanic(t, s).(*rel.SqlSelect) - sql.RewriteAsRawSelect() - assert.True(t, sql.String() == `SELECT name, myfield, referral, created FROM user WHERE created > todate("2016-01-01") GROUP BY referral`, "Wrong rewrite SQL?: '%v'", sql.String()) - - //assert.True(t, sql.From[1].Name == "ORDERS", "orders? %q", sql.From[1].Name) - // sql.From[0].Rewrite(sql) - // sql.From[1].Rewrite(sql) - // assert.True(t, sql.From[0].Source.String() == `SELECT user_id, reg_date, email FROM users`, "Wrong Full SQL?: '%v'", sql.From[0].Source.String()) - // assert.True(t, sql.From[1].Source.String() == `SELECT item_id, price, order_date, user_id FROM ORDERS`, "Wrong Full SQL?: '%v'", sql.From[1].Source.String()) - - // s = `SELECT aa.*, - // bb.meal - // FROM table1 aa - // INNER JOIN table2 bb - // ON aa.tableseat = bb.tableseat AND - // aa.weddingtable = bb.weddingtable - // INNER JOIN - // ( - // SELECT a.tableSeat - // FROM table1 a - // INNER JOIN table2 b - // ON a.tableseat = b.tableseat AND - // a.weddingtable = b.weddingtable - // WHERE b.meal IN ('chicken', 'steak') - // GROUP by a.tableSeat - // HAVING COUNT(DISTINCT b.Meal) = 2 - // ) c ON aa.tableseat = c.tableSeat - // ` -} - func TestSqlFingerPrinting(t *testing.T) { t.Parallel() // Fingerprinting allows the select statement to have a cached plan regardless diff --git a/schema/apply_schema.go b/schema/apply_schema.go index 66f36afb..b8f6827b 100644 --- a/schema/apply_schema.go +++ b/schema/apply_schema.go @@ -64,7 +64,7 @@ func (m *InMemApplyer) AddOrUpdateOnSchema(s *Schema, v interface{}) error { // Find the type of operation being updated. switch v := v.(type) { case *Table: - u.Debugf("%p:%s InfoSchema P:%p adding table %q", s, s.Name, s.InfoSchema, v.Name) + //u.Debugf("%p:%s InfoSchema P:%p adding table %q", s, s.Name, s.InfoSchema, v.Name) s.InfoSchema.DS.Init() // Wipe out cache, it is invalid s.mu.Lock() s.addTable(v) @@ -72,7 +72,7 @@ func (m *InMemApplyer) AddOrUpdateOnSchema(s *Schema, v interface{}) error { s.InfoSchema.refreshSchemaUnlocked() case *Schema: - u.Debugf("%p:%s InfoSchema P:%p adding schema %q s==v?%v", s, s.Name, s.InfoSchema, v.Name, s == v) + //u.Debugf("%p:%s InfoSchema P:%p adding schema %q s==v?%v", s, s.Name, s.InfoSchema, v.Name, s == v) if s == v { // s==v means schema has been updated m.reg.mu.Lock() diff --git a/schema/datasource.go b/schema/datasource.go index 6437e5d8..002a6b7a 100644 --- a/schema/datasource.go +++ b/schema/datasource.go @@ -32,15 +32,15 @@ type ( // Close() // Source interface { - // Init provides opportunity for those sources that require/ no configuration and - // introspect schema from their environment time to load pre-schema discovery + // Init provides opportunity for those sources that require no configuration and + // introspect schema from their environment time to load pre-schema discovery. Init() // Setup optional interface for getting the Schema injected during creation/starup. // Since the Source is a singleton, stateful manager, it has a startup/shutdown process. Setup(*Schema) error // Close this source, ensure connections, underlying resources are closed. Close() error - // Open create a connection (not thread safe) to this source. + // Open create a connection to this source (the connection is not thread safe). Open(source string) (Conn, error) // Tables is a list of table names provided by this source. Tables() []string @@ -67,6 +67,13 @@ type ( // Underlying data type of column Column(col string) (value.ValueType, bool) } + + // SourceFeatures is optional interface allowing a source to declare its features so the + // planner can be more accurate. + SourceFeatures interface { + // Features describes the features of a datasource. + Features() *DataSourceFeatures + } ) type ( diff --git a/schema/schema.go b/schema/schema.go index b6adb0eb..3f5a2a20 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -275,6 +275,7 @@ func (m *Schema) SchemaForTable(tableName string) (*Schema, error) { // We always lower-case table names tableName = strings.ToLower(tableName) + u.Warnf("Schema for table schema.Name=%q", m.Name) if m.Name == "schema" { return m, nil } @@ -525,7 +526,7 @@ func (m *Table) AddField(fld *Field) { fld.idx = uint64(len(m.Fields)) m.Fields = append(m.Fields, fld) } - m.FieldMap[fld.Name] = fld + m.FieldMap[strings.ToLower(fld.Name)] = fld } // AddFieldType describe and register a new column @@ -550,7 +551,7 @@ func (m *Table) Column(col string) (value.ValueType, bool) { func (m *Table) SetColumns(cols []string) { m.FieldPositions = make(map[string]int, len(cols)) for idx, col := range cols { - //col = strings.ToLower(col) + col = strings.ToLower(col) m.FieldPositions[col] = idx cols[idx] = col } diff --git a/schema/source_features.go b/schema/source_features.go new file mode 100644 index 00000000..eaeea685 --- /dev/null +++ b/schema/source_features.go @@ -0,0 +1,37 @@ +package schema + +type ( + // DataSourceFeatures describes the features of a datasource. + DataSourceFeatures struct { + aggFuncs map[string]struct{} + projectionFuncs map[string]*FuncFeature + GroupBy bool + Having bool + Partitionable bool + } + + // FuncFeature describes the features of a function from datasource. + FuncFeature struct { + // Name of the function in underlying source. + Name string + // QLBName is the QLBridge name + QLBName string + } +) + +// FeaturesDefault is list of datasource features. +func FeaturesDefault() *DataSourceFeatures { + return &DataSourceFeatures{} +} + +// HasAgg does this datasource support Agg function (count(*), sum(*)) etc, these func's +// can be pushed down to underlying engine as part of GroupBy query. +func (m *DataSourceFeatures) HasAgg(name string) bool { + return false +} + +// HasProjectionFunc does this datasource support projection function tolower(field) +// can be pushed down to underlying engine as part of projection. +func (m *DataSourceFeatures) HasProjectionFunc(name string) (string, bool) { + return "", false +} diff --git a/testutil/testsuite.go b/testutil/testsuite.go index c22393ad..10c7d7dc 100644 --- a/testutil/testsuite.go +++ b/testutil/testsuite.go @@ -33,6 +33,10 @@ func RunDDLTests(t TestingT) { // RunTestSuite run the normal DML SQL test suite. func RunTestSuite(t TestingT) { + // TestSelect(t, `select exists(email), email FROM users WHERE yy(reg_date) > 10;`, + // [][]driver.Value{{true, "aaron@email.com"}}, + // ) + // return // Literal Queries TestSelect(t, `select 1;`, [][]driver.Value{{int64(1)}}, @@ -60,7 +64,7 @@ func RunTestSuite(t TestingT) { TestSelect(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)", [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, ) - TestSelect(t, "SELECT email FROM users WHERE interests != NULL)", + TestSelect(t, "SELECT email FROM users WHERE interests != NULL", [][]driver.Value{{"aaron@email.com"}, {"bob@email.com"}}, ) TestSelect(t, "SELECT email FROM users WHERE (`users`.`email` like \"%aaron%\");", @@ -123,6 +127,13 @@ func RunTestSuite(t TestingT) { // RunSimpleSuite run the normal DML SQL test suite. func RunSimpleSuite(t TestingT) { + TestSelect(t, "SELECT user_id FROM users WHERE (`users.user_id` != NULL)", + [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, + ) + // TestSelect(t, "SELECT *, user_id as uid FROM users WHERE (`users.user_id` != NULL)", + // [][]driver.Value{{"hT2impsabc345c"}, {"9Ip1aKbeZe2njCDM"}, {"hT2impsOPUREcVPc"}}, + // ) + return // // Function in select projected columns that needs to be late evaluated. // // "select json.jmespath(body,\"name\") AS name FROM article WHERE `author` = \"aaron\";", // TestSelect(t, "select json.jmespath(json_data,\"name\") AS name FROM users WHERE `email` = \"aaron@email.com\";", diff --git a/vm/vm.go b/vm/vm.go index b6a7b406..c408aeb8 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -574,7 +574,7 @@ func evalBinary(ctx expr.EvalContext, node *expr.BinaryNode, depth int) (value.V return value.NewBoolValue(false), true } // Should we evaluate strings that are non-nil to be = true? - u.Debugf("not handled: boolean %v %T=%v expr: %s", node.Operator, at.Value(), at.Val(), node.String()) + //u.Debugf("not handled: boolean %v %T=%v expr: %s", node.Operator, at.Value(), at.Val(), node.String()) return nil, false case value.Map: switch node.Operator.T {