Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/engine/orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ import (

"github.com/pg-sharding/lyx/lyx"
"github.com/pg-sharding/spqr/pkg/catalog"
"github.com/pg-sharding/spqr/pkg/spqrlog"
)

func ProcessOrderBy(data [][][]byte, colOrder map[string]int, order lyx.Node) ([][][]byte, error) {

for _, r := range data {
spqrlog.Zero.Debug().Str("data", string(r[0])).Msg("print row before")
}

switch order.(type) {
case *lyx.SortBy:
ord := order.(*lyx.SortBy)
Expand Down Expand Up @@ -38,5 +43,9 @@ func ProcessOrderBy(data [][][]byte, colOrder map[string]int, order lyx.Node) ([
}
sort.Sort(sortable)
}

for _, r := range data {
spqrlog.Zero.Debug().Str("data", string(r[0])).Msg("print row after")
}
return data, nil
}
1 change: 0 additions & 1 deletion pkg/engine/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type SortableWithContext struct {
func (a SortableWithContext) Len() int { return len(a.Data) }
func (a SortableWithContext) Swap(i, j int) { a.Data[i], a.Data[j] = a.Data[j], a.Data[i] }
func (a SortableWithContext) Less(i, j int) bool {

if a.Order == ASC {
return a.Op.Less(a.Data[i][a.Col_index], a.Data[j][a.Col_index])
} else {
Expand Down
24 changes: 24 additions & 0 deletions pkg/models/distributions/distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,30 @@ func (rel *DistributedRelation) GetDistributionKeyColumns() ([]string, error) {
return res, nil
}

// GetDistributionKeyColumnType returns type of a distributed relation
// column, identified by name.
//
// Returns:
// - string: Column type.
// - bool: flag indicating fact of success.
func (rel *DistributedRelation) GetDistributionKeyColumnType(
d *Distribution,
col string) (string, bool) {

for i, colEntry := range rel.DistributionKey {
if colEntry.Column == col {
return d.ColTypes[i], true
}
for _, tcr := range colEntry.Expr.ColRefs {
if tcr.ColName == col {
return tcr.ColType, true
}
}

}
return "", false
}

// GetDistributionKeyColumnNames returns array of a DistributedRelation column names.
//
// Returns:
Expand Down
156 changes: 154 additions & 2 deletions router/qrouter/proxy_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,153 @@ func (qr *ProxyQrouter) InitExecutionTargets(ctx context.Context,
}
}

func (qr *ProxyQrouter) addSortToPlan(
ctx context.Context,
rm *rmeta.RoutingMetadataContext,
p plan.Plan,
) (plan.Plan, error) {
/* No point in cluster-wide sorting */
if len(p.ExecutionTargets()) == 1 {
return p, nil
}

scatterSlice, ok := p.(*plan.ScatterPlan)

if !ok {
return p, nil
}

spqrlog.Zero.Debug().
Msgf("plan select sort postprocessing %+v", p)

switch stmt := rm.Stmt.(type) {
case *lyx.Select:
/* This currently support sorting for one column. */
for _, n := range stmt.SortClause {
switch sb := n.(type) {
case *lyx.SortBy:
colRef, ok := sb.Node.(*lyx.ColumnRef)

if !ok {
return p, nil
}
/* We can sort by column reference only if we know type of column.
* For now, all we know in advance is type of distribution column. */
rfqn, err := rm.ResolveRelationByAlias(colRef.TableAlias)
if err != nil {
/* We can receive `complex query` error from ResolveRelationByAlias.
* log it and ignore */
spqrlog.Zero.
Error().
Str("alias", colRef.TableAlias).
Err(err).Msg("failed to resolve relation by alias")
return p, nil
}

d, err := rm.GetRelationDistribution(ctx, rfqn)
if err != nil {
return nil, err
}
r, ok := d.TryGetRelation(rfqn)
if !ok {
return p, nil
}
tp, ok := r.GetDistributionKeyColumnType(d, colRef.ColName)
if !ok {
return p, nil
}

/* TODO: refactor this */
if tp != qdb.ColumnTypeVarchar && tp != qdb.ColumnTypeVarcharHashed && tp != qdb.ColumnTypeVarcharDeprecated {
return p, nil
}
columnOff := -1
for i, tle := range stmt.TargetList {
switch cf := tle.(type) {
case *lyx.ColumnRef:
if cf.ColName == colRef.ColName {
columnOff = i
}
}
}

/* XXX: error out here? */
if columnOff == -1 {
return p, nil
}

/* Okay, we are ready for result post-processing sort.*/

retSlice := &plan.VirtualPlan{
TTS: &tupleslot.TupleTableSlot{},
}

retSlice.SubPlan = scatterSlice

scatterSlice.OverwriteQuery = map[string]string{}

for _, sh := range scatterSlice.ExecTargets {
scatterSlice.OverwriteQuery[sh.Name] = rm.Query
}

scatterSlice.RunF = func(serv server.Server) error {
spqrlog.Zero.Debug().Msg("run bottom-level plan slice")
for _, sh := range serv.Datashards() {
if !slices.ContainsFunc(scatterSlice.ExecTargets, func(el kr.ShardKey) bool {
return sh.Name() == el.Name
}) {
continue
}

var errmsg *pgproto3.ErrorResponse
shLoop:
for {
msg, err := serv.ReceiveShard(sh.ID())
if err != nil {
return err
}

switch v := msg.(type) {
case *pgproto3.ReadyForQuery:
if v.TxStatus == byte(txstatus.TXERR) {
return fmt.Errorf("failed to run inner slice, tx status error: %s", errmsg.Message)
}
break shLoop
case *pgproto3.RowDescription:
if len(retSlice.TTS.Desc) == 0 {
retSlice.TTS.Desc = v.Fields
}
case *pgproto3.ErrorResponse:
errmsg = v
case *pgproto3.DataRow:
vals := make([][]byte, len(v.Values))
copy(vals, v.Values)
retSlice.TTS.Raw = append(retSlice.TTS.Raw, vals)
default:
/* All ok? */
}
}
}

retSlice.TTS.Raw, err = engine.ProcessOrderBy(retSlice.TTS.Raw, retSlice.TTS.Desc.GetColumnsMap(), sb)
if err != nil {
return err
}

return nil
}

return retSlice, nil
default:
/* ??? */
}
}

}

return p, nil
}

func (qr *ProxyQrouter) addLimitToPlan(
ctx context.Context,
rm *rmeta.RoutingMetadataContext,
Expand All @@ -935,7 +1082,6 @@ func (qr *ProxyQrouter) addLimitToPlan(
}

spqrlog.Zero.Debug().
Bool("ok", ok).
Msgf("plan select limit postprocessing %+v", p)

switch stmt := rm.Stmt.(type) {
Expand Down Expand Up @@ -1003,7 +1149,9 @@ func (qr *ProxyQrouter) addLimitToPlan(
case *pgproto3.DataRow:

if len(retSlice.TTS.Raw) < limitVal {
retSlice.TTS.Raw = append(retSlice.TTS.Raw, v.Values)
vals := make([][]byte, len(v.Values))
copy(vals, v.Values)
retSlice.TTS.Raw = append(retSlice.TTS.Raw, vals)
}

default:
Expand Down Expand Up @@ -1048,6 +1196,10 @@ func (qr *ProxyQrouter) plannerV1(
* fix bogus limit support, if enabled. */

if config.RouterConfig().Qr.AllowPostProcessing {
p, err = qr.addSortToPlan(ctx, rm, p)
if err != nil {
return nil, err
}
p, err = qr.addLimitToPlan(ctx, rm, p)
if err != nil {
return nil, err
Expand Down
Loading