diff --git a/datasource/context.go b/datasource/context.go index 7b2bf902..fa0ff062 100644 --- a/datasource/context.go +++ b/datasource/context.go @@ -209,6 +209,9 @@ func (m *ContextSimple) Commit(rowInfo []expr.SchemaInfo, row expr.RowWriter) er return nil } func (m *ContextSimple) Delete(row map[string]value.Value) error { + for k, _ := range row { + delete(m.Data, k) + } return nil } diff --git a/datasource/files/filesource.go b/datasource/files/filesource.go index 54f72b15..29d421d8 100644 --- a/datasource/files/filesource.go +++ b/datasource/files/filesource.go @@ -10,7 +10,8 @@ import ( "time" u "github.com/araddon/gou" - "github.com/dchest/siphash" + + hash "github.com/aviddiviner/go-murmur" "github.com/lytics/cloudstorage" "golang.org/x/net/context" "google.golang.org/api/iterator" @@ -46,7 +47,7 @@ type FileReaderIterator interface { type Partitioner func(uint64, *FileInfo) int func SipPartitioner(partitionCt uint64, fi *FileInfo) int { - hashU64 := siphash.Hash(0, 1, []byte(fi.Name)) + hashU64 := hash.MurmurHash64A([]byte(fi.Name), 1) return int(hashU64 % partitionCt) } diff --git a/datasource/membtree/btree.go b/datasource/membtree/btree.go index a85b1ee0..321325d8 100644 --- a/datasource/membtree/btree.go +++ b/datasource/membtree/btree.go @@ -7,7 +7,7 @@ import ( "fmt" u "github.com/araddon/gou" - "github.com/dchest/siphash" + hash "github.com/aviddiviner/go-murmur" "github.com/google/btree" "golang.org/x/net/context" @@ -78,14 +78,14 @@ func makeId(dv driver.Value) uint64 { case int64: return uint64(vt) case []byte: - return siphash.Hash(0, 1, vt) + return hash.MurmurHash64A(vt, 1) // iv, err := strconv.ParseUint(string(vt), 10, 64) // if err != nil { // u.Warnf("could not create id: %v for %v", err, dv) // } // return iv case string: - return siphash.Hash(0, 1, []byte(vt)) + return hash.MurmurHash64A([]byte(vt), 1) // iv, err := strconv.ParseUint(vt, 10, 64) // if err != nil { // u.Warnf("could not create id: %v for %v", err, dv) diff --git a/datasource/memdb/index.go b/datasource/memdb/index.go index bd0c6f27..d4a8e985 100644 --- a/datasource/memdb/index.go +++ b/datasource/memdb/index.go @@ -5,7 +5,7 @@ import ( "fmt" u "github.com/araddon/gou" - "github.com/dchest/siphash" + hash "github.com/aviddiviner/go-murmur" "github.com/hashicorp/go-memdb" "github.com/araddon/qlbridge/datasource" @@ -25,9 +25,9 @@ func makeId(dv driver.Value) uint64 { case int64: return uint64(vt) case []byte: - return siphash.Hash(456729, 1111581582, vt) + return hash.MurmurHash64A(vt, 1111581582) case string: - return siphash.Hash(456729, 1111581582, []byte(vt)) + return hash.MurmurHash64A([]byte(vt), 1111581582) //by := append(make([]byte,0,8), byte(r), byte(r>>8), byte(r>>16), byte(r>>24), byte(r>>32), byte(r>>40), byte(r>>48), byte(r>>56)) case datasource.KeyCol: return makeId(vt.Val) diff --git a/datasource/sqlite/conn.go b/datasource/sqlite/conn.go index 8b79cffd..b03526bc 100644 --- a/datasource/sqlite/conn.go +++ b/datasource/sqlite/conn.go @@ -9,7 +9,7 @@ import ( "strings" u "github.com/araddon/gou" - "github.com/dchest/siphash" + "github.com/aviddiviner/go-murmur" "github.com/google/btree" "golang.org/x/net/context" // Import driver for sqlite @@ -323,9 +323,9 @@ func MakeId(dv driver.Value) uint64 { case int64: return uint64(vt) case []byte: - return siphash.Hash(456729, 1111581582, vt) + return go-murmur.MurmurHash64A(vt, 1111581582) case string: - return siphash.Hash(456729, 1111581582, []byte(vt)) + return go-murmur.MurmurHash64A([]byte(vt), 1111581582) //by := append(make([]byte,0,8), byte(r), byte(r>>8), byte(r>>16), byte(r>>24), byte(r>>32), byte(r>>40), byte(r>>48), byte(r>>56)) case datasource.KeyCol: return MakeId(vt.Val) diff --git a/exec/exec.go b/exec/exec.go index a71fd142..62f4fa7f 100644 --- a/exec/exec.go +++ b/exec/exec.go @@ -8,11 +8,18 @@ package exec import ( "fmt" - + "database/sql/driver" "github.com/araddon/qlbridge/plan" "github.com/araddon/qlbridge/schema" ) +const ( + JOINMERGE_MAKER = "UseJoinMerge" + WHERE_MAKER = "UseWhere" + GROUPBY_MAKER = "UseGroupBy" + PROJECTION_MAKER = "UseProjection" +) + var ( // ErrShuttingDown already shutting down error ErrShuttingDown = fmt.Errorf("Received Shutdown Signal") @@ -105,6 +112,7 @@ type ( WalkHaving(p *plan.Having) (Task, error) WalkGroupBy(p *plan.GroupBy) (Task, error) WalkOrder(p *plan.Order) (Task, error) + WalkInto(p *plan.Into) (Task, error) WalkProjection(p *plan.Projection) (Task, error) // Other Statements WalkCommand(p *plan.Command) (Task, error) @@ -122,4 +130,18 @@ type ( // WalkExecSource given our plan, turn that into a Task. WalkExecSource(p *plan.Source) (Task, error) } + + // SinkMaker Sink Factory + SinkMaker func(ctx *plan.Context, dest string, params map[string]interface{}) (Sink, error) + + // Sinks are execution tasks used to direct query result set output to a destination. + Sink interface { + Open(ctx *plan.Context, destination string, params map[string]interface{}) error + Next(dest []driver.Value, colIndex map[string]int) error + Cleanup() error + Close() error + } + + // JoinMergeMaker Factory + JoinMergeMaker func(ctx *plan.Context, l, r TaskRunner, p *plan.JoinMerge) TaskRunner ) diff --git a/exec/executor.go b/exec/executor.go index a691a021..2c30bcf9 100644 --- a/exec/executor.go +++ b/exec/executor.go @@ -65,17 +65,21 @@ func BuildSqlJobPlanned(planner plan.Planner, executor Executor, ctx *plan.Conte if ctx.Raw == "" { return nil, fmt.Errorf("no sql provided") } - stmt, err := rel.ParseSql(ctx.Raw) - if err != nil { - u.Debugf("could not parse sql : %v", err) - return nil, err - } - if stmt == nil { - return nil, fmt.Errorf("Not statement for parse? %v", ctx.Raw) + var err error + var pln plan.Task + var stmt rel.SqlStatement + if ctx.Stmt == nil { // Prepared statement + stmt, err = rel.ParseSql(ctx.Raw) + if err != nil { + u.Debugf("could not parse sql : %v", err) + return nil, err + } + if stmt == nil { + return nil, fmt.Errorf("Not statement for parse? %v", ctx.Raw) + } + ctx.Stmt = stmt } - ctx.Stmt = stmt - - pln, err := plan.WalkStmt(ctx, stmt, planner) + pln, err = plan.WalkStmt(ctx, ctx.Stmt, planner) if err != nil { return nil, err @@ -197,6 +201,7 @@ func (m *JobExecutor) WalkSource(p *plan.Source) (Task, error) { } return NewSource(m.Ctx, p) } + func (m *JobExecutor) WalkSourceExec(p *plan.Source) (Task, error) { if p.Conn == nil { @@ -219,22 +224,71 @@ func (m *JobExecutor) WalkSourceExec(p *plan.Source) (Task, error) { return nil, fmt.Errorf("%T Must Implement Scanner for %q", p.Conn, p.Stmt.String()) } func (m *JobExecutor) WalkWhere(p *plan.Where) (Task, error) { - return NewWhere(m.Ctx, p), nil + + var tr TaskRunner + tr = NewWhere(m.Ctx, p) + if m.Ctx.Session != nil { + if v, ok := m.Ctx.Session.Get(WHERE_MAKER); ok { + //if factory, ok2 := v.Value().(JoinMergeMaker); !ok2 { + if factory, ok2 := v.Value().(func(ctx *plan.Context, p *plan.Where) TaskRunner); !ok2 { + return nil, fmt.Errorf("Cannot cast [%T] to WhereMaker factory.", v.Value) + } else { + tr = factory(m.Ctx, p) + } + } + } + return tr, nil } func (m *JobExecutor) WalkHaving(p *plan.Having) (Task, error) { return NewHaving(m.Ctx, p), nil } func (m *JobExecutor) WalkGroupBy(p *plan.GroupBy) (Task, error) { - return NewGroupBy(m.Ctx, p), nil + + var tr TaskRunner + tr = NewGroupBy(m.Ctx, p) + if m.Ctx.Session != nil { + if v, ok := m.Ctx.Session.Get(GROUPBY_MAKER); ok { + //if factory, ok2 := v.Value().(JoinMergeMaker); !ok2 { + if factory, ok2 := v.Value().(func(ctx *plan.Context, p *plan.GroupBy) TaskRunner); !ok2 { + return nil, fmt.Errorf("Cannot cast [%T] to GroupByMaker factory.", v.Value) + } else { + tr = factory(m.Ctx, p) + } + } + } + return tr, nil } func (m *JobExecutor) WalkOrder(p *plan.Order) (Task, error) { return NewOrder(m.Ctx, p), nil } +func (m *JobExecutor) WalkInto(p *plan.Into) (Task, error) { + return NewInto(m.Ctx, p), nil +} func (m *JobExecutor) WalkProjection(p *plan.Projection) (Task, error) { - return NewProjection(m.Ctx, p), nil + var tr TaskRunner + tr = NewProjection(m.Ctx, p) + if m.Ctx.Session != nil { + if v, ok := m.Ctx.Session.Get(PROJECTION_MAKER); ok { + //if factory, ok2 := v.Value().(JoinMergeMaker); !ok2 { + if factory, ok2 := v.Value().(func(ctx *plan.Context, p *plan.Projection) TaskRunner); !ok2 { + return nil, fmt.Errorf("Cannot cast [%T] to ProjectionMaker factory.", v.Value) + } else { + tr = factory(m.Ctx, p) + } + } + } + return tr, nil } func (m *JobExecutor) WalkJoin(p *plan.JoinMerge) (Task, error) { - execTask := NewTaskParallel(m.Ctx) + + // If the left task is already parallelized then must be a multi table join. + // No need to parallelize subsequent join tasks. + var execTask TaskRunner + if p.Left.IsParallel() { + execTask = NewTaskSequential(m.Ctx) + } else { + execTask = NewTaskParallel(m.Ctx) + } //u.Debugf("join.Left: %#v \nright:%#v", p.Left, p.Right) l, err := m.WalkPlanAll(p.Left) if err != nil { @@ -255,7 +309,21 @@ func (m *JobExecutor) WalkJoin(p *plan.JoinMerge) (Task, error) { return nil, err } - jm := NewJoinNaiveMerge(m.Ctx, l.(TaskRunner), r.(TaskRunner), p) + + var jm TaskRunner + jm = NewJoinNaiveMerge(m.Ctx, l.(TaskRunner), r.(TaskRunner), p) + if m.Ctx.Session != nil { + if v, ok := m.Ctx.Session.Get(JOINMERGE_MAKER); ok { + //if factory, ok2 := v.Value().(JoinMergeMaker); !ok2 { + if factory, ok2 := v.Value().(func(ctx *plan.Context, l, r TaskRunner, + p *plan.JoinMerge) TaskRunner); !ok2 { + return nil, fmt.Errorf("Cannot cast [%T] to JoinMergeMaker factory.", v.Value) + } else { + jm = factory(m.Ctx, l.(TaskRunner), r.(TaskRunner), p) + } + } + } + err = execTask.Add(jm) if err != nil { return nil, err @@ -273,7 +341,7 @@ func (m *JobExecutor) WalkPlanAll(p plan.Task) (Task, error) { } if len(p.Children()) > 0 { dagRoot := m.NewTask(p) - //u.Debugf("sequential?%v parallel?%v", p.IsSequential(), p.IsParallel()) + //u.Debugf("%p sequential?%v parallel?%v", p, p.IsSequential(), p.IsParallel()) err = dagRoot.Add(root) if err != nil { u.Errorf("Could not add root: %v", err) @@ -296,6 +364,8 @@ func (m *JobExecutor) WalkPlanTask(p plan.Task) (Task, error) { return m.Executor.WalkGroupBy(p) case *plan.Order: return m.Executor.WalkOrder(p) + case *plan.Into: + return m.Executor.WalkInto(p) case *plan.Projection: return m.Executor.WalkProjection(p) case *plan.JoinMerge: diff --git a/exec/into.go b/exec/into.go new file mode 100644 index 00000000..802a0810 --- /dev/null +++ b/exec/into.go @@ -0,0 +1,211 @@ +package exec + +import ( + "database/sql/driver" + "fmt" + "net/url" + "time" + u "github.com/araddon/gou" + + "github.com/araddon/qlbridge/datasource" + "github.com/araddon/qlbridge/expr" + "github.com/araddon/qlbridge/plan" + "github.com/araddon/qlbridge/rel" +) + +var ( + sinkFactories = make(map[string]SinkMaker) +) + +// Into - Write to output sink +type Into struct { + *TaskBase + p *plan.Into + complete chan bool + Closed bool + isComplete bool + colIndexes map[string]int + sink Sink +} + +// NewInto create new into exec task +func NewInto(ctx *plan.Context, p *plan.Into) *Into { + o := &Into{ + TaskBase: NewTaskBase(ctx), + p: p, + complete: make(chan bool), + } + return o +} + +// Registry for sinks +func Register(name string, factory SinkMaker) { + if factory == nil { + panic(fmt.Sprintf("SinkMaker factory %s does not exist.", name)) + } + _, registered := sinkFactories[name] + if registered { + return + } + sinkFactories[name] = factory +} + + +func (m *Into) Open(ctx *plan.Context, destination string) (err error) { + + params := make(map[string]interface{}, 0) + if m.TaskBase.Ctx.Stmt.(*rel.SqlSelect).With != nil { + params = m.TaskBase.Ctx.Stmt.(*rel.SqlSelect).With + } + + if url, err := url.Parse(destination); err == nil { + if newSink, ok := sinkFactories[url.Scheme]; !ok { + m := fmt.Sprintf("scheme [%s] not registered!", url.Scheme) + panic(m) + } else { + m.sink, err = newSink(ctx, destination, params) + } + } else { // First treat this as a output Table + if newSink, ok := sinkFactories["table"]; !ok { + m := fmt.Sprintf("INTO sink factory not found!") + panic(m) + } else { + m.sink, err = newSink(ctx, destination, params) + } + } + return +} + + +func (m *Into) Close() error { + m.Lock() + if m.Closed { + m.Unlock() + return nil + } + m.Closed = true + m.sink.Close() //FIX: handle error on close + m.Unlock() + + // what should this be? + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + //u.Infof("%p into sink final Close() waiting for complete", m) + select { + case <-ticker.C: + u.Warnf("into sink timeout???? ") + case <-m.complete: + //u.Warnf("%p got into sink complete", m) + } + + return m.TaskBase.Close() +} + +func (m *Into) Run() error { + defer m.Ctx.Recover() + defer close(m.msgOutCh) + + //outCh := m.MessageOut() + inCh := m.MessageIn() + + projCols := m.TaskBase.Ctx.Projection.Proj.Columns + cols := make(map[string]int, len(projCols)) + for i, col := range projCols { + //u.Debugf("aliasing: key():%-15q As:%-15q %-15q", col.Key(), col.As, col.String()) + cols[col.As] = i + } + + //m.colIndexes = m.TaskBase.Ctx.Stmt.(*rel.SqlSelect).ColIndexes() + m.colIndexes = cols + if m.colIndexes == nil { + u.Errorf("Cannot get column indexes for output !") + return nil + } + + // Open the output file sink + if err := m.Open(m.Ctx, m.p.Stmt.Table); err != nil { + u.Errorf("Open output sink failed! - %v", err) + return err + } + + var rowCount, lastMsgId int64 + +msgReadLoop: + for { + select { + case <-m.SigChan(): + //u.Warnf("got signal quit") + return nil + case <-m.ErrChan(): + //u.Warnf("got err signal") + m.sink.Cleanup() + return nil + case msg, ok := <-inCh: + if !ok { + //u.Warnf("NICE, got closed channel shutdown") + //close(m.TaskBase.sigCh) + break msgReadLoop + } else { + var sdm *datasource.SqlDriverMessageMap + + switch mt := msg.(type) { + case *datasource.SqlDriverMessageMap: + sdm = mt + m.sink.Next(sdm.Values(), m.colIndexes) // FIX: handle error return from Next() + rowCount++ + lastMsgId = int64(mt.Id()) + default: + + msgReader, isContextReader := msg.(expr.ContextReader) + if !isContextReader { + err := fmt.Errorf("To use Into must use SqlDriverMessageMap but got %T", msg) + u.Errorf("unrecognized msg %T", msg) + close(m.TaskBase.sigCh) + return err + } + + sdm = datasource.NewSqlDriverMessageMapCtx(msg.Id(), msgReader, m.colIndexes) + m.sink.Next(sdm.Values(), m.colIndexes) // FIX: handle error return from Next() + rowCount++ + lastMsgId = int64(msg.Id()) + } + } + } + } +//u.Warnf("HERE 1 %#v, %p, LEN = %d", m.ErrChan(), m.ErrChan(), len(m.ErrChan())) +errLoop: + for { + select { + case <-m.ErrChan(): +//u.Warnf("HERE ERR") + m.sink.Cleanup() + break errLoop + default: + } + select { + case <-m.ErrChan(): +//u.Warnf("HERE 3") + m.sink.Cleanup() + break errLoop + case <-m.SigChan(): +//u.Warnf("HERE 2") + break errLoop + case _, ok := <-inCh: +//u.Warnf("HERE 4") + if !ok { + break errLoop + } + } + } + vals := make([]driver.Value, 2) + vals[0] = lastMsgId + vals[1] = rowCount + m.msgOutCh <- datasource.NewSqlDriverMessage(0, vals) + m.isComplete = true + close(m.complete) + + return nil +} + + diff --git a/exec/join.go b/exec/join.go index 7b7bdba4..f4ac8db0 100644 --- a/exec/join.go +++ b/exec/join.go @@ -55,6 +55,7 @@ func (m *JoinKey) Run() error { defer m.Ctx.Recover() defer close(m.msgOutCh) + outCh := m.MessageOut() inCh := m.MessageIn() joinNodes := m.p.Source.Stmt.JoinNodes() @@ -89,6 +90,9 @@ func (m *JoinKey) Run() error { key := strings.Join(vals, string(byte(0))) mt.SetKeyHashed(key) outCh <- mt + case *datasource.ContextSimple: + // Just pass it along to the JoinMerge task + outCh <- mt default: return fmt.Errorf("To use JoinKey must use SqlDriverMessageMap but got %T", msg) } @@ -182,6 +186,8 @@ func (m *JoinMerge) Run() error { return } lh[key] = append(lh[key], mt) + case *datasource.ContextSimple: + // Process driver table input variables default: fatalErr = fmt.Errorf("To use Join must use SqlDriverMessageMap but got %T", msg) u.Errorf("unrecognized msg %T", msg) diff --git a/exec/mutations.go b/exec/mutations.go index 9ec4b155..eb5be553 100644 --- a/exec/mutations.go +++ b/exec/mutations.go @@ -10,6 +10,7 @@ import ( "github.com/araddon/qlbridge/plan" "github.com/araddon/qlbridge/rel" "github.com/araddon/qlbridge/schema" + "github.com/araddon/qlbridge/value" "github.com/araddon/qlbridge/vm" ) @@ -103,11 +104,12 @@ func (m *Upsert) Run() error { var err error var affectedCt int64 + var rownum int64 switch { case m.insert != nil: - affectedCt, err = m.insertRows(m.insert.Rows) + rownum, err = m.insertRows(m.insert.Rows) case m.upsert != nil && len(m.upsert.Rows) > 0: - affectedCt, err = m.insertRows(m.upsert.Rows) + rownum, err = m.insertRows(m.upsert.Rows) case m.update != nil: affectedCt, err = m.updateValues() default: @@ -117,14 +119,13 @@ func (m *Upsert) Run() error { vals := make([]driver.Value, 2) if err != nil { u.Warnf("errored, should not complete %v", err) - vals[0] = err.Error() - vals[1] = -1 + vals[0] = int64(0) + vals[1] = int64(0) m.msgOutCh <- &datasource.SqlDriverMessage{Vals: vals, IdVal: 1} return err } - vals[0] = int64(0) // status? + vals[0] = rownum vals[1] = affectedCt - u.Infof("affected? %v", affectedCt) m.msgOutCh <- &datasource.SqlDriverMessage{Vals: vals, IdVal: 1} return nil } @@ -149,6 +150,7 @@ func (m *Upsert) updateValues() (int64, error) { u.Errorf("Could not evaluate: %s", valcol.Expr) return 0, fmt.Errorf("Could not evaluate expression: %v", valcol.Expr) } + valcol.Value = value.NewValue(exprVal.Value()) valmap[key] = exprVal.Value() } else { u.Debugf("%T %v", valcol.Value.Value(), valcol.Value.Value()) @@ -182,6 +184,7 @@ func (m *Upsert) updateValues() (int64, error) { } func (m *Upsert) insertRows(rows [][]*rel.ValueColumn) (int64, error) { + for i, row := range rows { select { case <-m.SigChan(): @@ -204,9 +207,16 @@ func (m *Upsert) insertRows(rows [][]*rel.ValueColumn) (int64, error) { } } - if _, err := m.db.Put(m.Ctx.Context, nil, vals); err != nil { + if key, err := m.db.Put(m.Ctx.Context, nil, vals); err != nil { u.Errorf("Could not put values: fordb T:%T %v", m.db, err) return 0, err + } else { + rownum, ok := key.Key().(driver.Value) + if ok { + v, _ := rownum.(value.IntValue) + return v.Val(), nil + } + return 0, fmt.Errorf("cannot cast rownum to int64 it is a %T", key.Key()) } } } diff --git a/exec/order.go b/exec/order.go index 67f4b252..60d1289b 100644 --- a/exec/order.go +++ b/exec/order.go @@ -170,10 +170,11 @@ func (m *OrderMessages) Less(i, j int) bool { return false } return true - } else { + } else if key > m.l[j].keys[ki] { if m.invert[ki] { return true } + return false } } return false diff --git a/exec/sqldriver.go b/exec/sqldriver.go index 679994c0..0ee1fcee 100644 --- a/exec/sqldriver.go +++ b/exec/sqldriver.go @@ -4,7 +4,6 @@ import ( "bytes" "database/sql" "database/sql/driver" - "errors" "fmt" "io" "strconv" @@ -14,25 +13,27 @@ import ( u "github.com/araddon/gou" + "github.com/araddon/qlbridge/datasource" "github.com/araddon/qlbridge/expr" "github.com/araddon/qlbridge/plan" "github.com/araddon/qlbridge/rel" "github.com/araddon/qlbridge/schema" + "github.com/araddon/qlbridge/value" ) var ( // Ensure our driver implements appropriate database/sql interfaces - _ driver.Conn = (*qlbConn)(nil) + _ driver.Conn = (*qlbConn)(nil) _ driver.Driver = (*qlbdriver)(nil) _ driver.Execer = (*qlbConn)(nil) _ driver.Queryer = (*qlbConn)(nil) _ driver.Result = (*qlbResult)(nil) - _ driver.Rows = (*qlbRows)(nil) - _ driver.Stmt = (*qlbStmt)(nil) - //_ driver.Tx = (*driverConn)(nil) + _ driver.Rows = (*qlbRows)(nil) + _ driver.Stmt = (*qlbStmt)(nil) + //_ driver.Tx = (*driverConn)(nil) // Create an instance of our driver - qlbd = &qlbdriver{} + qlbd = &qlbdriver{} qlbDriverOnce sync.Once // Runtime Schema Config as in in-mem data structure of the @@ -60,15 +61,15 @@ func DisableRecover() { // sql.Driver Interface implementation. // // Notes about Value return types: -// Value is a value that drivers must be able to handle. -// It is either nil or an instance of one of these types: +// Value is a value that drivers must be able to handle. +// It is either nil or an instance of one of these types: // -// int64 -// float64 -// bool -// []byte -// string [*] everywhere except from Rows.Next. -// time.Time +// int64 +// float64 +// bool +// []byte +// string [*] everywhere except from Rows.Next. +// time.Time type qlbdriver struct{} // Open returns a new connection to the database. @@ -87,25 +88,27 @@ func (m *qlbdriver) Open(connInfo string) (driver.Conn, error) { if !ok || s == nil { return nil, fmt.Errorf("No schema was found for %q", connInfo) } - return &qlbConn{schema: s}, nil + return &qlbConn{schema: s, session: datasource.NewMySqlSessionVars(), stmts: make(map[*qlbStmt]struct{})}, nil } // A stateful connection to database/source // // // Execer is an optional interface that may be implemented by a Conn. -// If a Conn does not implement Execer, the sql package's DB.Exec will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn does not implement Execer, the sql package's DB.Exec will +// first prepare a query, execute the statement, and then close the +// statement. // // Queryer is an optional interface that may be implemented by a Conn. -// If a Conn does not implement Queryer, the sql package's DB.Query will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn does not implement Queryer, the sql package's DB.Query will +// first prepare a query, execute the statement, and then close the +// statement. type qlbConn struct { parallel bool // Do we Run In Background Mode? Default = true connInfo string // schema *schema.Schema + session expr.ContextReadWriter + stmts map[*qlbStmt]struct{} } // Exec may return ErrSkip. @@ -113,7 +116,10 @@ type qlbConn struct { // Execer implementation. To be used for queries that do not return any rows // such as Create Index, Insert, Upset, Delete etc func (m *qlbConn) Exec(query string, args []driver.Value) (driver.Result, error) { + stmt := &qlbStmt{conn: m, query: query} + defer stmt.Close() + stmt.numInput = strings.Count(query, "?") return stmt.Exec(args) } @@ -121,13 +127,29 @@ func (m *qlbConn) Exec(query string, args []driver.Value) (driver.Result, error) // Query may return ErrSkip // func (m *qlbConn) Query(query string, args []driver.Value) (driver.Rows, error) { + stmt := &qlbStmt{conn: m, query: query} + stmt.numInput = strings.Count(query, "?") return stmt.Query(args) } // Prepare returns a prepared statement, bound to this connection. func (m *qlbConn) Prepare(query string) (driver.Stmt, error) { - return nil, expr.ErrNotImplemented + + query = strings.TrimSpace(query) + s := strings.Split(strings.ToLower(query), " ") + stmt := &qlbStmt{conn: m, query: query} + stmt.numInput = strings.Count(query, "?") + var err error + if s[0] == "insert" { + stmt.job, err = createExecJob(strings.ReplaceAll(query, "?", "0"), m, nil, nil) + if err != nil { + return nil, err + } + stmt.sqlStmt = stmt.job.Ctx.Stmt + } + m.stmts[stmt] = struct{}{} + return stmt, nil } // Close invalidates and potentially stops any current @@ -139,7 +161,14 @@ func (m *qlbConn) Prepare(query string) (driver.Stmt, error) { // idle connections, it shouldn't be necessary for drivers to // do their own connection caching. func (m *qlbConn) Close() error { - //u.Debugf("sqlbConn.Close() do we need to do anything here?") + + if m.stmts != nil { + for k, _ := range m.stmts { + k.Close() + delete(m.stmts, k) + } + m.stmts = nil + } return nil } @@ -162,9 +191,11 @@ func (conn *qlbTx) Rollback() error { return expr.ErrNotImplemented } // used by multiple goroutines concurrently. // type qlbStmt struct { - job *JobExecutor - query string - conn *qlbConn + job *JobExecutor + query string + numInput int + conn *qlbConn + sqlStmt rel.SqlStatement } // Close closes the statement. @@ -172,9 +203,13 @@ type qlbStmt struct { // As of Go 1.1, a Stmt will not be closed if it's in use // by any queries. func (m *qlbStmt) Close() error { + if m.job != nil { m.job.Close() } + if m.conn.stmts != nil { + delete (m.conn.stmts, m) + } return nil } @@ -187,67 +222,88 @@ func (m *qlbStmt) Close() error { // NumInput may also return -1, if the driver doesn't know // its number of placeholders. In that case, the sql package // will not sanity check Exec or Query argument counts. -func (m *qlbStmt) NumInput() int { return 0 } +func (m *qlbStmt) NumInput() int { + return m.numInput +} // Exec executes a query that doesn't return rows, such // as an INSERT, UPDATE, DELETE func (m *qlbStmt) Exec(args []driver.Value) (driver.Result, error) { + + if m.query == "" { + return nil, fmt.Errorf("No query in stmt.Exec() %#p", m) + } var err error - if len(args) > 0 { - m.query, err = queryArgsConvert(m.query, args) + prepared := false + if m.conn.stmts != nil { + _, prepared = m.conn.stmts[m] // in list of prepared + if prepared { + prepared = m.sqlStmt != nil // has parsed sql + } + } + if !prepared { + m.job, err = createExecJob(m.query, m.conn, args, nil) if err != nil { return nil, err } + } else { // Previously prepared + m.job, err = createExecJob(m.query, m.conn, args, m.sqlStmt) + if err != nil { + return nil, err + } + rows := make([][]*rel.ValueColumn, 0) + rows = append(rows, argsToValueColumns(args)) + switch p := m.job.Ctx.Stmt.(type) { + case *rel.SqlInsert: + p.Rows = rows + default: + return nil, fmt.Errorf("sqldriver Exec prepared stmt type %T not implemented.", p) + } } - // Create a Job, which is Dag of Tasks that Run() - ctx := plan.NewContext(m.query) - ctx.Schema = m.conn.schema - job, err := BuildSqlJob(ctx) - if err != nil { - return nil, err - } - m.job = job - - resultWriter := NewResultExecWriter(ctx) - job.RootTask.Add(resultWriter) + resultWriter := NewResultExecWriter(m.job.Ctx) + m.job.RootTask.Add(resultWriter) + m.job.Setup() - job.Setup() //u.Infof("in qlbdriver.Exec about to run") - err = job.Run() + err = m.job.Run() //u.Debugf("After qlb driver.Run() in Exec()") if err != nil { u.Errorf("error on Query.Run(): %v", err) //resultWriter.ErrChan() <- err //job.Close() } - return resultWriter.Result(), nil + return resultWriter.Result(), err } // Query executes a query that may return rows, such as a SELECT func (m *qlbStmt) Query(args []driver.Value) (driver.Rows, error) { + var err error + qry := m.query if len(args) > 0 { - m.query, err = queryArgsConvert(m.query, args) + qry, err = queryArgsConvert(qry, args) if err != nil { return nil, err } } - u.Debugf("query: %v", m.query) + u.Debugf("stmt.query: %v", qry) // Create a Job, which is Dag of Tasks that Run() - ctx := plan.NewContext(m.query) + ctx := plan.NewContext(qry) ctx.Schema = m.conn.schema + ctx.Session = m.conn.session job, err := BuildSqlJob(ctx) if err != nil { - u.Warnf("return error? %v", err) + u.Errorf("return error? %v", err) return nil, err } m.job = job // The only type of stmt that makes sense for Query is SELECT // and we need list of columns that requires casing - sqlSelect, ok := job.Ctx.Stmt.(*rel.SqlSelect) + //sqlSelect, ok := job.Ctx.Stmt.(*rel.SqlSelect) + _, ok := job.Ctx.Stmt.(*rel.SqlSelect) if !ok { u.Warnf("ctx? %v", job.Ctx) return nil, fmt.Errorf("We could not recognize that as a select query: %T", job.Ctx.Stmt) @@ -255,10 +311,15 @@ func (m *qlbStmt) Query(args []driver.Value) (driver.Rows, error) { // Prepare a result writer, we manually append this task to end // of job? - resultWriter := NewResultRows(ctx, sqlSelect.Columns.AliasedFieldNames()) + //resultWriter := NewResultRows(ctx, sqlSelect.Columns.AliasedFieldNames()) + projCols := job.Ctx.Projection.Proj.Columns + cols := make([]string, len(projCols)) + for i, col := range projCols { + cols[i] = col.As + } + resultWriter := NewResultRows(ctx, cols) job.RootTask.Add(resultWriter) - job.Setup() // TODO: this can't run in parallel-buffered mode? @@ -289,7 +350,9 @@ func (m *qlbStmt) Query(args []driver.Value) (driver.Rows, error) { // column index. If the type of a specific column isn't known // or shouldn't be handled specially, DefaultValueConverter // can be returned. -func (conn *qlbStmt) ColumnConverter(idx int) driver.ValueConverter { return nil } +func (conn *qlbStmt) ColumnConverter(idx int) driver.ValueConverter { + return driver.DefaultParameterConverter +} // driver.Rows Interface implementation. // @@ -324,7 +387,7 @@ func (conn *qlbRows) Next(dest []driver.Value) error { return expr.ErrNotImpleme type qlbResult struct { lastId int64 affected int64 - err error + err error } // LastInsertId returns the database's auto-generated ID @@ -354,15 +417,17 @@ func queryArgsConvert(query string, args []driver.Value) (string, error) { return query, nil } // a tiny, tiny, tiny bit of string sanitization +/* if strings.ContainsAny(query, `'"`) { return "", nil } +*/ q := make([]string, 2*len(args)+1) n := 0 for _, a := range args { i := strings.IndexRune(query, '?') if i == -1 { - return "", errors.New("number of parameters doesn't match number of placeholders") + return "", fmt.Errorf("number of parameters doesn't match number of placeholders for query %s", query) } var s string switch v := a.(type) { @@ -442,3 +507,55 @@ func escapeQuotes(txt string) string { io.WriteString(&buf, txt[last:]) return buf.String() } + +func createExecJob(query string, conn *qlbConn, args []driver.Value, + stmt rel.SqlStatement) (*JobExecutor, error) { + + if query == "" { + return nil, fmt.Errorf("createExecJob no sql provided") + } + var err error + if args != nil && len(args) > 0 { + query, err = queryArgsConvert(query, args) + if err != nil { + return nil, err + } + } + + // Create a Job, which is Dag of Tasks that Run() + ctx := plan.NewContext(query) + ctx.Schema = conn.schema + ctx.Session = conn.session + ctx.Stmt = stmt + job, err := BuildSqlJob(ctx) + if err != nil { + return nil, err + } + return job, nil +} + +func argsToValueColumns(vals []driver.Value) []*rel.ValueColumn { + + row := make([]*rel.ValueColumn, len(vals)) + for i, x := range vals { + switch v := x.(type) { + case nil: + row[i] = &rel.ValueColumn{Value: value.NewNilValue()} + case float64: + row[i] = &rel.ValueColumn{Value: value.NewNumberValue(x.(float64))} + case string: + row[i] = &rel.ValueColumn{Value: value.NewStringValue(x.(string))} + case []byte: + row[i] = &rel.ValueColumn{Value: value.NewStringValue(string(x.([]byte)))} + case int64: + row[i] = &rel.ValueColumn{Value: value.NewIntValue(x.(int64))} + case time.Time: + row[i] = &rel.ValueColumn{Value: value.NewStringValue(x.(time.Time).String())} + case bool: + row[i] = &rel.ValueColumn{Value: value.NewBoolValue(x.(bool))} + default: + panic(fmt.Sprintf("%v (%T) argument can't be handled by prepared insert", v, v)) + } + } + return row +} diff --git a/exec/task_sequential.go b/exec/task_sequential.go index c17c3e25..fe5305a4 100644 --- a/exec/task_sequential.go +++ b/exec/task_sequential.go @@ -26,7 +26,7 @@ type TaskSequential struct { func NewTaskSequential(ctx *plan.Context) *TaskSequential { st := &TaskSequential{ TaskBase: NewTaskBase(ctx), - tasks: make([]Task, 0), + tasks: make([]Task, 0), runners: make([]TaskRunner, 0), } return st @@ -152,7 +152,11 @@ func (m *TaskSequential) Run() (err error) { u.Errorf("%T.Run() errored %v", task, taskErr) // TODO: what do we do with this error? send to error channel? err = taskErr + //m.ErrChan() <- taskErr m.errors = append(m.errors, taskErr) + for i := 0; i < len(m.runners); i++ { + m.runners[i].ErrChan() <- taskErr + } } //u.Debugf("%p %q exiting taskId: %p %v %T", m, m.Name, task, taskId, task) wg.Done() @@ -161,9 +165,9 @@ func (m *TaskSequential) Run() (err error) { if len(m.runners)-1 == taskId { //u.Warnf("%p got shutdown on last one, lets shutdown them all", m) for i := len(m.runners) - 2; i >= 0; i-- { - //u.Debugf("%p sending close??: %v %T", m, i, m.runners[i]) + //u.Warnf("%p sending close??: %v %T", m, i, m.runners[i]) m.runners[i].Close() - //u.Debugf("%p after close??: %v %T", m, i, m.runners[i]) + //u.Warnf("%p after close??: %v %T", m, i, m.runners[i]) } } }(i) diff --git a/expr/builtins/cast.go b/expr/builtins/cast.go index c7a0237d..15c5173a 100644 --- a/expr/builtins/cast.go +++ b/expr/builtins/cast.go @@ -59,11 +59,18 @@ func castEvalNoAs(ctx expr.EvalContext, vals []value.Value) (value.Value, bool) // http://www.cheatography.com/davechild/cheat-sheets/mysql/ if vt == value.UnknownType { - switch strings.ToLower(vals[1].ToString()) { - case "char": + x := strings.ToLower(vals[1].ToString()) + if strings.HasPrefix(x, "char") { vt = value.ByteSliceType - default: - return nil, false + } else { + switch x { + case "char": + vt = value.ByteSliceType + case "date": + vt = value.TimeType + default: + return nil, false + } } } val, err := value.Cast(vt, vals[0]) @@ -87,11 +94,18 @@ func castEval(ctx expr.EvalContext, vals []value.Value) (value.Value, bool) { // http://www.cheatography.com/davechild/cheat-sheets/mysql/ if vt == value.UnknownType { - switch strings.ToLower(vals[2].ToString()) { - case "char": + x := strings.ToLower(vals[1].ToString()) + if strings.HasPrefix(x, "char") { vt = value.ByteSliceType - default: - return nil, false + } else { + switch x { + case "char": + vt = value.ByteSliceType + case "date": + vt = value.TimeType + default: + return nil, false + } } } val, err := value.Cast(vt, vals[0]) diff --git a/expr/builtins/hash_and_encode.go b/expr/builtins/hash_and_encode.go index 641c017f..c5072106 100644 --- a/expr/builtins/hash_and_encode.go +++ b/expr/builtins/hash_and_encode.go @@ -10,8 +10,8 @@ import ( "fmt" u "github.com/araddon/gou" - "github.com/dchest/siphash" + hash "github.com/aviddiviner/go-murmur" "github.com/araddon/qlbridge/expr" "github.com/araddon/qlbridge/value" ) @@ -49,7 +49,7 @@ func hashSipEval(ctx expr.EvalContext, args []value.Value) (value.Value, bool) { return value.NewIntValue(0), false } - hash := siphash.Hash(0, 1, []byte(val)) + hash := hash.MurmurHash64A([]byte(val), 1) return value.NewIntValue(int64(hash)), true } diff --git a/expr/parse.go b/expr/parse.go index 63069fd9..d97fface 100644 --- a/expr/parse.go +++ b/expr/parse.go @@ -807,7 +807,7 @@ func (t *tree) ArrayNode(depth int) Node { if n != nil { an.Append(n) } else { - u.Warnf("invalid? %v", t.Cur()) + u.Debugf("invalid? %v", t.Cur()) return an } } diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..fe18bcd1 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/guymolinari/qlbridge + +go 1.13 + +require github.com/go-sql-driver/mysql v1.6.0 diff --git a/lex/dialect_sql.go b/lex/dialect_sql.go index 9012c143..a4bcf8a4 100644 --- a/lex/dialect_sql.go +++ b/lex/dialect_sql.go @@ -354,6 +354,8 @@ func LexInto(l *Lexer) StateFn { l.Emit(TokenTable) return nil } + // Must have been a quoted string value + return LexValue } return nil } diff --git a/lex/lexer.go b/lex/lexer.go index ee2c0930..d34e4e6f 100644 --- a/lex/lexer.go +++ b/lex/lexer.go @@ -1504,7 +1504,7 @@ func lexIdentifierOfTypeNoWs(l *Lexer, shouldIgnore bool, forToken TokenType) St // content.`field name` if lastRune == '.' { p := l.Peek() - if p == '`' || p == '[' { + if p == '`' || p == '[' || p == '@' { return lexIdentifierOfTypeNoWs(l, false, forToken) } } @@ -2360,7 +2360,7 @@ func LexExpression(l *Lexer) StateFn { l.Push("LexExpression", l.clauseState()) return LexIdentifier } - u.Warnf("un-handled? ") + u.Debug("un-handled? ") case '(': // this is a logical Grouping/Ordering and must be a single // logically valid expression l.Push("LexParenRight", LexParenRight) diff --git a/plan/plan.go b/plan/plan.go index 54789c20..33495eea 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -906,6 +906,11 @@ func NewOrder(stmt *rel.SqlSelect) *Order { return &Order{Stmt: stmt, PlanBase: NewPlanBase(false)} } +// NewInto from SqlSelect statement. +func NewInto(stmt *rel.SqlInto) *Into { + return &Into{Stmt: stmt, PlanBase: NewPlanBase(false)} +} + // Equal compares equality of two tasks. func (m *Into) Equal(t Task) bool { if m == nil && t == nil { diff --git a/plan/planner_select.go b/plan/planner_select.go index 8b0a0035..cec0196c 100644 --- a/plan/planner_select.go +++ b/plan/planner_select.go @@ -7,12 +7,17 @@ import ( "github.com/araddon/qlbridge/rel" "github.com/araddon/qlbridge/schema" + "github.com/araddon/qlbridge/expr" + "github.com/araddon/qlbridge/lex" ) func needsFinalProjection(s *rel.SqlSelect) bool { if s.Having != nil { return true } + if s.Where != nil { + return true + } // Where? if len(s.OrderBy) > 0 { return true @@ -38,16 +43,79 @@ func (m *PlannerDefault) WalkSelect(p *Select) error { p.Stmt.From[0].Source = p.Stmt // TODO: move to a Finalize() in query parser/planner - srcPlan, err := NewSource(m.Ctx, p.Stmt.From[0], true) - if err != nil { - return err - } - p.From = append(p.From, srcPlan) - p.Add(srcPlan) + var srcPlan *Source - err = m.Planner.WalkSourceSelect(srcPlan) - if err != nil { - return err + if p.Stmt.Where != nil && p.Stmt.Where.Source != nil { // Where subquery + negate := false + var parentJoin expr.Node + if n, ok := p.Stmt.Where.Expr.(*expr.BinaryNode); ok { + parentJoin = n.Args[0] + } else if n2, ok2 := p.Stmt.Where.Expr.(*expr.UnaryNode); ok2 { + parentJoin = n2.Arg + negate = true + } + p.Stmt.From[0].AddJoin(parentJoin) + + var err error + srcPlan, err = NewSource(m.Ctx, p.Stmt.From[0], false) + if err != nil { + return nil + } + //p.From = append(p.From, srcPlan) + sub := p.Stmt.Where.Source + // Inject join criteria (JoinNodes, JoinExpr) on source for subquery (back to parent) + subSqlSrc := sub.From[0] + err = m.Planner.WalkSourceSelect(srcPlan) + if err != nil { + return err + } + subSrc := rel.NewSqlSource(subSqlSrc.Name) + subSrc.Rewrite(sub) + cols := subSrc.UnAliasedColumns() + var childJoin expr.Node + if len(cols) > 1 { + return fmt.Errorf("subquery must contain only 1 select column for join") + } + for _, v := range cols { + childJoin = v.Expr + break + } + if childJoin == nil { + return fmt.Errorf("subquery must contain at least 1 select column for join") + } + p.Stmt.From[0].AddJoin(childJoin) + subSrc.AddJoin(childJoin) + subSrcPlan, err := NewSource(m.Ctx, subSrc, false) + if err != nil { + return nil + } + subSrc.AddJoin(childJoin) + if negate { + subSrc.JoinExpr = expr.NewBinaryNode(lex.TokenFromOp("!="), parentJoin, childJoin) + p.Stmt.From[0].JoinExpr = expr.NewBinaryNode(lex.TokenFromOp("!="), parentJoin, childJoin) + } else { + subSrc.JoinExpr = expr.NewBinaryNode(lex.TokenFromOp("="), parentJoin, childJoin) + p.Stmt.From[0].JoinExpr = expr.NewBinaryNode(lex.TokenFromOp("="), parentJoin, childJoin) + } + err = m.Planner.WalkSourceSelect(subSrcPlan) + if err != nil { + u.Errorf("Could not visitsubselect %v %s", err, subSrcPlan) + return err + } + subQueryTask := NewJoinMerge(srcPlan, subSrcPlan, srcPlan.Stmt, subSrcPlan.Stmt) + p.Add(subQueryTask) + } else { + var err error + srcPlan, err = NewSource(m.Ctx, p.Stmt.From[0], true) + if err != nil { + return err + } + p.From = append(p.From, srcPlan) + p.Add(srcPlan) + err = m.Planner.WalkSourceSelect(srcPlan) + if err != nil { + return err + } } if srcPlan.Complete && !needsFinalProjection(p.Stmt) { @@ -137,6 +205,10 @@ finalProjection: //u.Debugf("m.Ctx: %p m.Ctx.Projection: %T:%p", m.Ctx, m.Ctx.Projection, m.Ctx.Projection) } + if p.Stmt.Into != nil{ + p.Add(NewInto(p.Stmt.Into)) + } + return nil } diff --git a/plan/projection.go b/plan/projection.go index e76bbb60..f6091b58 100644 --- a/plan/projection.go +++ b/plan/projection.go @@ -131,7 +131,9 @@ func (m *Projection) loadFinal(ctx *Context, isFinal bool) error { if col.InFinalProjection() { m.Proj.AddColumnShort(col.As, value.StringType) } else { - u.Warnf("not adding to projection? %s", col) + if !strings.HasSuffix(col.As, "@rownum") { + u.Warnf("not adding to projection? %s", col.String()) + } } } else { m.Proj.AddColumnShort(col.As, value.StringType) diff --git a/rel/parse_sql.go b/rel/parse_sql.go index 62b4a68b..a2a08200 100644 --- a/rel/parse_sql.go +++ b/rel/parse_sql.go @@ -29,6 +29,7 @@ type ParseError struct { func ParseSql(sqlQuery string) (SqlStatement, error) { return parseSqlResolver(sqlQuery, nil) } + func parseSqlResolver(sqlQuery string, fr expr.FuncResolver) (SqlStatement, error) { l := lex.NewSqlLexer(sqlQuery) m := Sqlbridge{l: l, SqlTokenPager: NewSqlTokenPager(l), funcs: fr} @@ -647,7 +648,11 @@ func (m *Sqlbridge) parseShow() (*SqlShow, error) { case lex.TokenLike: // SHOW TABLES LIKE '%' m.Next() // Consume Like - ex, err := expr.ParseExpression(fmt.Sprintf("%s LIKE %q", likeLhs, m.Cur().V)) + vx := m.Cur().V + if len(vx) > 0 && vx[0] != '%' { + vx = "%" + vx + } + ex, err := expr.ParseExpression(fmt.Sprintf("%s LIKE %q", likeLhs, vx)) m.Next() if err != nil { u.Errorf("Error parsing fake expression: %v", err) @@ -1067,6 +1072,7 @@ func (m *Sqlbridge) parseUpdateList() (map[string]*ValueColumn, error) { return nil, err } cols[lastColName] = &ValueColumn{Expr: exprNode} + m.Backup() default: u.Warnf("don't know how to handle ? %v", m.Cur()) return nil, m.ErrMsg("expected column") @@ -1097,11 +1103,8 @@ func (m *Sqlbridge) parseValueList() ([][]*ValueColumn, error) { case lex.TokenRightParenthesis: values = append(values, row) case lex.TokenFrom, lex.TokenInto, lex.TokenLimit, lex.TokenEOS, lex.TokenEOF: - if len(row) > 0 { - values = append(values, row) - } return values, nil - case lex.TokenValue: + case lex.TokenValue, lex.TokenValueEscaped: row = append(row, &ValueColumn{Value: value.NewStringValue(m.Cur().V)}) case lex.TokenInteger: iv, err := strconv.ParseInt(m.Cur().V, 10, 64) @@ -1304,14 +1307,16 @@ func (m *Sqlbridge) parseInto(req *SqlSelect) error { if m.Cur().T != lex.TokenInto { return nil } - m.Next() // Consume Into token - if m.Cur().T != lex.TokenTable { - return m.ErrMsg("expected table") + m.Next() //Consume INTO + if strings.ToUpper(m.Cur().V) == "FROM" { + return m.ErrMsg("expected 'TABLE' got 'FROM'") } - if strings.ToLower(m.Cur().V) == "FROM" { - return m.ErrMsg("expected table") + switch m.Cur().T { + case lex.TokenTable, lex.TokenValue: + req.Into = &SqlInto{Table: m.Cur().V} + default: + return m.ErrMsg("expected TABLE name or URI") } - req.Into = &SqlInto{Table: m.Cur().V} m.Next() return nil } @@ -1373,6 +1378,12 @@ func (m *Sqlbridge) parseWhere() (*SqlWhere, error) { // to determine which type of where clause m.Next() // x t2 := m.Cur().T + negate := false + if t2 == lex.TokenNegate { + negate = true + m.Next() + t2 = m.Cur().T + } m.Next() t3 := m.Cur().T m.Next() @@ -1380,6 +1391,9 @@ func (m *Sqlbridge) parseWhere() (*SqlWhere, error) { m.Backup() m.Backup() m.Backup() + if negate { + m.Backup() + } // Check for Types of Where // t1 T2 T3 T4 @@ -1392,15 +1406,22 @@ func (m *Sqlbridge) parseWhere() (*SqlWhere, error) { // TODO: // SELECT * FROM t3 WHERE ROW(5*t2.s1,77) = ( SELECT 50,11*s1 FROM t4) switch { - case (t2 == lex.TokenIN || t2 == lex.TokenEqual) && t3 == lex.TokenLeftParenthesis && t4 == lex.TokenSelect: + case (t2 == lex.TokenIN || t2 == lex.TokenEqual || t2 == lex.TokenNE) && t3 == lex.TokenLeftParenthesis && + t4 == lex.TokenSelect: //u.Infof("in parseWhere: %v", m.Cur()) + exprNode, err := expr.ParseExprWithFuncs(m, m.funcs) + if err != nil { + return &where, err + } + where.Expr = exprNode +/* m.Next() // T1 ?? this might be udf? m.Next() // t2 (IN | =) m.Next() // t3 = ( - //m.Next() // t4 = SELECT +*/ where.Op = t2 - where.Source = &SqlSelect{} - return &where, m.parseWhereSubSelect(where.Source) + where.Source, err = m.parseSqlSelect() + return &where, err } exprNode, err := expr.ParseExprWithFuncs(m, m.funcs) if err != nil { diff --git a/rel/sql.go b/rel/sql.go index 881cb175..08420561 100644 --- a/rel/sql.go +++ b/rel/sql.go @@ -1389,6 +1389,14 @@ func (m *SqlSource) writeDialectDepth(depth int, w expr.DialectWriter) { } } +func (m *SqlSource) AddJoin(join expr.Node) { + + if m.joinNodes == nil { + m.joinNodes = make([]expr.Node, 0) + } + m.joinNodes = append(m.joinNodes, join) +} + func (m *SqlSource) BuildColIndex(colNames []string) error { if len(m.colIndex) == 0 { m.colIndex = make(map[string]int, len(colNames)) @@ -1857,7 +1865,9 @@ func (m *SqlInsert) RewriteAsPrepareable(maxRows int, mark byte) string { func (m *SqlInsert) ColumnNames() []string { cols := make([]string, 0) for _, col := range m.Columns { - cols = append(cols, col.Key()) + if col != nil { + cols = append(cols, col.Key()) + } } return cols } diff --git a/rel/sql_rewrite.go b/rel/sql_rewrite.go index c78d6ea2..f22dbade 100644 --- a/rel/sql_rewrite.go +++ b/rel/sql_rewrite.go @@ -176,7 +176,7 @@ func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns //u.Debugf("returning original: %s", nt) return node, cols } - case *expr.NumberNode, *expr.NullNode, *expr.StringNode: + case *expr.NumberNode, *expr.NullNode, *expr.StringNode, *expr.FuncNode: return nt, cols case *expr.BinaryNode: //u.Infof("binaryNode T:%v", nt.Operator.T.String()) @@ -212,6 +212,19 @@ func rewriteWhere(stmt *SqlSelect, from *SqlSource, node expr.Node, cols Columns default: //u.Warnf("un-implemented op: %#v", nt) } + case *expr.TriNode: + switch nt.Operator.T { + case lex.TokenBetween: + var n1, n2, n3 expr.Node + n1, cols = rewriteWhere(stmt, from, nt.Args[0], cols) + n2, cols = rewriteWhere(stmt, from, nt.Args[1], cols) + n3, cols = rewriteWhere(stmt, from, nt.Args[2], cols) + if n1 != nil && n2 != nil && n3 != nil { + return &expr.TriNode{Operator: nt.Operator, Args: []expr.Node{n1, n2, n3}}, cols + } + default: + u.Warnf("un-implemented op: %#v", nt) + } default: u.Warnf("%T node types are not suppored yet for where rewrite", node) } @@ -361,7 +374,7 @@ func columnsFromJoin(from *SqlSource, node expr.Node, cols Columns) Columns { case lex.TokenAnd, lex.TokenLogicAnd, lex.TokenLogicOr: cols = columnsFromJoin(from, nt.Args[0], cols) cols = columnsFromJoin(from, nt.Args[1], cols) - case lex.TokenEqual, lex.TokenEqualEqual: + case lex.TokenEqual, lex.TokenEqualEqual, lex.TokenNE: cols = columnsFromJoin(from, nt.Args[0], cols) cols = columnsFromJoin(from, nt.Args[1], cols) default: