Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion parser/parser_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package parser

import (
"encoding/binary"
"errors"
"fmt"

"github.com/datastax/go-cassandra-native-protocol/datatype"
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/datastax/go-cassandra-native-protocol/primitive"
)

const (
Expand Down Expand Up @@ -193,3 +194,44 @@ func parseIdentifiers(l *lexer, t token) (err error) {
func isDMLTerminator(t token) bool {
return t == tkEOF || t == tkEOS || t == tkInsert || t == tkUpdate || t == tkDelete || t == tkApply
}

// PatchQueryConsistency modifies the consistency level of a QUERY message in-place
// by locating the consistency field directly in the frame body.
//
// Layout based on the CQL native protocol v4 spec:
// /* <query: long string><consistency: short><flags: byte>... */
func PatchQueryConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
Comment thread
shari0311 marked this conversation as resolved.
Outdated
if len(body) < 6 {
return fmt.Errorf("body too short for QUERY")
}
queryLen := binary.BigEndian.Uint32(body[0:4])
offset := 4 + int(queryLen)
if len(body) < offset+2 {
return fmt.Errorf("not enough bytes to patch QUERY consistency")
}
binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency))
return nil
}

// PatchExecuteConsistency modifies the consistency level of an EXECUTE message in-place
// by locating the consistency field directly after the prepared statement ID.
//
// Layout based on the CQL native protocol v4 spec:
// /* <id: short bytes><consistency: short><flags: byte>... */
func PatchExecuteConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
if len(body) < 2 {
return fmt.Errorf("body too short for EXECUTE")
}
idLen := int(binary.BigEndian.Uint16(body[0:2]))
offset := 2 + idLen
if len(body) < offset+2 {
return fmt.Errorf("not enough bytes to patch EXECUTE consistency")
}
binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency))
return nil
}
Comment thread
shari0311 marked this conversation as resolved.
Outdated

func PatchBatchConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
//TODO: Implement this
return nil
}
Comment thread
shari0311 marked this conversation as resolved.
Outdated
71 changes: 71 additions & 0 deletions parser/parser_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package parser

import (
"encoding/binary"
"testing"

"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/stretchr/testify/assert"
)

func TestPatchQueryConsistency(t *testing.T) {
t.Run("valid QUERY body", func(t *testing.T) {
query := []byte("SELECT * FROM users;")
queryLen := uint32(len(query))

body := make([]byte, 4+len(query)+2)
binary.BigEndian.PutUint32(body[0:4], queryLen)
copy(body[4:], query)
binary.BigEndian.PutUint16(body[4+len(query):], uint16(primitive.ConsistencyLevelOne))

err := PatchQueryConsistency(body, primitive.ConsistencyLevelQuorum)
assert.NoError(t, err)

offset := 4 + len(query)
got := binary.BigEndian.Uint16(body[offset : offset+2])
assert.Equal(t, uint16(primitive.ConsistencyLevelQuorum), got)
})

t.Run("too short body", func(t *testing.T) {
err := PatchQueryConsistency([]byte{0x01, 0x02}, primitive.ConsistencyLevelQuorum)
assert.Error(t, err)
})

t.Run("not enough space for consistency", func(t *testing.T) {
body := make([]byte, 4)
binary.BigEndian.PutUint32(body[0:4], 10)
err := PatchQueryConsistency(body, primitive.ConsistencyLevelQuorum)
assert.Error(t, err)
})
}

func TestPatchExecuteConsistency(t *testing.T) {
t.Run("valid EXECUTE body", func(t *testing.T) {
id := []byte{0xCA, 0xFE, 0xBA, 0xBE}
idLen := len(id)

body := make([]byte, 2+idLen+2)
binary.BigEndian.PutUint16(body[0:2], uint16(idLen))
copy(body[2:], id)
binary.BigEndian.PutUint16(body[2+idLen:], uint16(primitive.ConsistencyLevelLocalQuorum))

err := PatchExecuteConsistency(body, primitive.ConsistencyLevelAll)
assert.NoError(t, err)

offset := 2 + idLen
got := binary.BigEndian.Uint16(body[offset : offset+2])
assert.Equal(t, uint16(primitive.ConsistencyLevelAll), got)
})

t.Run("too short body", func(t *testing.T) {
err := PatchExecuteConsistency([]byte{0x00}, primitive.ConsistencyLevelOne)
assert.Error(t, err)
})

t.Run("not enough space for consistency", func(t *testing.T) {
body := make([]byte, 2)
binary.BigEndian.PutUint16(body[0:2], 4)
err := PatchExecuteConsistency(body, primitive.ConsistencyLevelOne)
assert.Error(t, err)
})
}
25 changes: 9 additions & 16 deletions proxy/codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *partialQueryCodec) Decode(source io.Reader, _ primitive.ProtocolVersion
if query, err := primitive.ReadLongString(source); err != nil {
return nil, err
} else {
return &partialQuery{query, primitive.ConsistencyLevelLocalQuorum}, nil
return &partialQuery{query}, nil
}
}

Expand All @@ -51,8 +51,7 @@ func (c *partialQueryCodec) GetOpCode() primitive.OpCode {
}

type partialQuery struct {
query string
Consistency primitive.ConsistencyLevel
query string
}

func (p *partialQuery) IsResponse() bool {
Expand All @@ -63,17 +62,12 @@ func (p *partialQuery) GetOpCode() primitive.OpCode {
return primitive.OpCodeQuery
}

func (p *partialQuery) Clone() message.Message {
return &partialQuery{p.query, p.Consistency}
}

func (p *partialQuery) DeepCopyMessage() message.Message {
return &partialQuery{p.query, p.Consistency}
return &partialQuery{p.query}
}

type partialExecute struct {
queryId []byte
Consistency primitive.ConsistencyLevel
queryId []byte
}

func (m *partialExecute) IsResponse() bool {
Expand All @@ -87,7 +81,7 @@ func (m *partialExecute) GetOpCode() primitive.OpCode {
func (m *partialExecute) DeepCopyMessage() message.Message {
queryId := make([]byte, len(m.queryId))
copy(queryId, m.queryId)
return &partialExecute{queryId, m.Consistency}
return &partialExecute{queryId}
}

func (m *partialExecute) String() string {
Expand All @@ -105,7 +99,7 @@ func (c *partialExecuteCodec) EncodedLength(_ message.Message, _ primitive.Proto
}

func (c *partialExecuteCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (msg message.Message, err error) {
execute := &partialExecute{Consistency: primitive.ConsistencyLevelLocalQuorum}
execute := &partialExecute{}
if execute.queryId, err = primitive.ReadShortBytes(source); err != nil {
return nil, fmt.Errorf("cannot read EXECUTE query id: %w", err)
} else if len(execute.queryId) == 0 {
Expand All @@ -119,8 +113,7 @@ func (c *partialExecuteCodec) GetOpCode() primitive.OpCode {
}

type partialBatch struct {
queryOrIds []interface{}
Consistency primitive.ConsistencyLevel
queryOrIds []interface{}
}

func (p partialBatch) IsResponse() bool {
Expand All @@ -134,7 +127,7 @@ func (p partialBatch) GetOpCode() primitive.OpCode {
func (p partialBatch) DeepCopyMessage() message.Message {
queryOrIds := make([]interface{}, len(p.queryOrIds))
copy(queryOrIds, p.queryOrIds)
return &partialBatch{queryOrIds, p.Consistency}
return &partialBatch{queryOrIds}
}

type partialBatchCodec struct{}
Expand Down Expand Up @@ -184,7 +177,7 @@ func (p partialBatchCodec) Decode(source io.Reader, version primitive.ProtocolVe
}
queryOrIds[i] = queryOrId
}
return &partialBatch{queryOrIds, primitive.ConsistencyLevelLocalQuorum}, nil
return &partialBatch{queryOrIds}, nil
}

func (p partialBatchCodec) GetOpCode() primitive.OpCode {
Expand Down
14 changes: 11 additions & 3 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type Config struct {
DC string
Tokens []string
Peers []PeerConfig
DatabaseType string
// PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max
// capacity of ~100MB.
PreparedCache proxycore.PreparedCache
Expand Down Expand Up @@ -575,13 +576,20 @@ func (c *client) Receive(reader io.Reader) error {
case *message.Prepare:
c.handlePrepare(raw, msg)
case *partialExecute:
msg.Consistency = primitive.ConsistencyLevelQuorum
if c.proxy.config.DatabaseType == "astra" {
Comment thread
shari0311 marked this conversation as resolved.
Outdated
_ = parser.PatchExecuteConsistency(raw.Body, primitive.ConsistencyLevelLocalQuorum)
}
c.handleExecute(raw, msg, body.CustomPayload)
case *partialQuery:
msg.Consistency = primitive.ConsistencyLevelQuorum
if c.proxy.config.DatabaseType == "astra" {
_ = parser.PatchQueryConsistency(raw.Body, primitive.ConsistencyLevelLocalQuorum)
}
raw.DeepCopy()
Comment thread
shari0311 marked this conversation as resolved.
Outdated
c.handleQuery(raw, msg, body.CustomPayload)
case *partialBatch:
msg.Consistency = primitive.ConsistencyLevelQuorum
if c.proxy.config.DatabaseType == "astra" {
_ = parser.PatchBatchConsistency(raw.Body, primitive.ConsistencyLevelLocalQuorum)
}
c.execute(raw, notDetermined, c.keyspace, msg)
default:
c.send(raw.Header, &message.ProtocolError{ErrorMessage: "Unsupported operation"})
Expand Down
15 changes: 0 additions & 15 deletions proxy/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ func (r *request) executeInternal(next bool) {
r.done = true
r.send(&message.ServerError{ErrorMessage: "Proxy exhausted query plan and there are no more hosts available to try"})
} else {
r.client.proxy.logger.Info("sending request to host", zap.Stringer("host", r.host),
zap.String("consistency", r.getConsistencyLevel(r.msg).String()))
err := r.session.Send(r.host, r)
if err == nil {
break
Expand All @@ -80,19 +78,6 @@ func (r *request) executeInternal(next bool) {
}
}

func (r *request) getConsistencyLevel(message message.Message) primitive.ConsistencyLevel {
switch m := message.(type) {
case *partialQuery:
return m.Consistency
case *partialBatch:
return m.Consistency
case *partialExecute:
return m.Consistency
default:
return primitive.ConsistencyLevelOne
}
}

func (r *request) send(msg message.Message) {
_ = r.client.conn.Write(proxycore.SenderFunc(func(writer io.Writer) error {
return codec.EncodeFrame(frame.NewFrame(r.raw.Header.Version, r.stream, msg), writer)
Expand Down
2 changes: 2 additions & 0 deletions proxy/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type runConfig struct {
DataCenter string `yaml:"data-center" help:"Data center to use in system tables" env:"DATA_CENTER"`
Tokens []string `yaml:"tokens" help:"Tokens to use in the system tables. It's not recommended" env:"TOKENS"`
Peers []PeerConfig `yaml:"peers" kong:"-"` // Not available as a CLI flag
DatabaseType string `yaml:"database-type" help:"Type of database to proxy for (cassandra, dse, astra)" default:"astra" env:"DATABASE_TYPE"`
Comment thread
shari0311 marked this conversation as resolved.
Outdated
}

// Run starts the proxy command. 'args' shouldn't include the executable (i.e. os.Args[1:]). It returns the exit code
Expand Down Expand Up @@ -186,6 +187,7 @@ func Run(ctx context.Context, args []string) int {
Tokens: cfg.Tokens,
Peers: cfg.Peers,
IdempotentGraph: cfg.IdempotentGraph,
DatabaseType: cfg.DatabaseType,
Comment thread
shari0311 marked this conversation as resolved.
Outdated
})

cfg.Bind = maybeAddPort(cfg.Bind, "9042")
Expand Down