diff --git a/go.mod b/go.mod index e49aafb..a3f4ece 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/datastax/go-cassandra-native-protocol v0.0.0-20211124104234-f6aea54fa801 github.com/hashicorp/golang-lru v0.5.4 github.com/stretchr/testify v1.7.0 + github.com/twmb/murmur3 v1.1.6 // indirect go.uber.org/atomic v1.8.0 go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.17.0 diff --git a/go.sum b/go.sum index 1cc2577..b051834 100644 --- a/go.sum +++ b/go.sum @@ -98,6 +98,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= diff --git a/proxy/proxy.go b/proxy/proxy.go index 2deebec..17ab7ce 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -344,8 +344,8 @@ func (p *Proxy) findSession(version primitive.ProtocolVersion, keyspace string) } } -func (p *Proxy) newQueryPlan() proxycore.QueryPlan { - return p.lb.NewQueryPlan() +func (p *Proxy) newQueryPlan(keyspace string, token proxycore.Token) proxycore.QueryPlan { + return p.lb.NewQueryPlan(keyspace, token) } var ( @@ -575,7 +575,8 @@ func (c *client) Receive(reader io.Reader) error { case *partialQuery: c.handleQuery(raw, msg, body.CustomPayload) case *partialBatch: - c.execute(raw, notDetermined, c.keyspace, msg) + // FIXME: Calculate token for partition key + c.execute(raw, notDetermined, c.keyspace, nil, msg) default: c.send(raw.Header, &message.ProtocolError{ErrorMessage: "Unsupported operation"}) } @@ -583,7 +584,7 @@ func (c *client) Receive(reader io.Reader) error { return nil } -func (c *client) execute(raw *frame.RawFrame, state idempotentState, keyspace string, msg message.Message) { +func (c *client) execute(raw *frame.RawFrame, state idempotentState, keyspace string, token proxycore.Token, msg message.Message) { if sess, err := c.proxy.findSession(raw.Header.Version, c.keyspace); err == nil { req := &request{ client: c, @@ -593,7 +594,7 @@ func (c *client) execute(raw *frame.RawFrame, state idempotentState, keyspace st keyspace: keyspace, done: false, stream: raw.Header.StreamId, - qp: c.proxy.newQueryPlan(), + qp: c.proxy.newQueryPlan(keyspace, token), raw: raw, } req.Execute(true) @@ -649,7 +650,8 @@ func (c *client) handlePrepare(raw *frame.RawFrame, msg *message.Prepare) { } } else { - c.execute(raw, isIdempotent, keyspace, msg) // Prepared statements can be retried themselves + // FIXME: Calculate token for partition key + c.execute(raw, isIdempotent, keyspace, nil, msg) // Prepared statements can be retried themselves } } @@ -658,7 +660,8 @@ func (c *client) handleExecute(raw *frame.RawFrame, msg *partialExecute, customP if stmt, ok := c.preparedSystemQuery[id]; ok { c.interceptSystemQuery(raw.Header, stmt) } else { - c.execute(raw, c.getDefaultIdempotency(customPayload), "", msg) + // FIXME: Calculate token for partition key + c.execute(raw, notDetermined, "", nil, msg) } } @@ -675,7 +678,8 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery, customPaylo c.interceptSystemQuery(raw.Header, stmt) } } else { - c.execute(raw, c.getDefaultIdempotency(customPayload), c.keyspace, msg) + // FIXME: Calculate token for partition key + c.execute(raw, notDetermined, c.keyspace, nil, msg) } } diff --git a/proxycore/cluster.go b/proxycore/cluster.go index 79275ba..dd44e12 100644 --- a/proxycore/cluster.go +++ b/proxycore/cluster.go @@ -64,6 +64,8 @@ func (a UpEvent) isEvent() { type BootstrapEvent struct { Hosts []*Host + Partitioner + Keyspaces map[string]ReplicationStrategy } func (b BootstrapEvent) isEvent() { @@ -110,7 +112,7 @@ type ClusterConfig struct { } type ClusterInfo struct { - Partitioner string + Partitioner Partitioner ReleaseVersion string CQLVersion string LocalDC string @@ -301,7 +303,12 @@ func (c *Cluster) queryHosts(ctx context.Context, conn *ClientConn, version prim row := rs.Row(0) localDC := hosts[0].DC - partitioner, err := row.StringByName("partitioner") + partitionerName, err := row.StringByName("partitioner") + if err != nil { + return nil, ClusterInfo{}, err + } + + partitioner, err := NewPartitionerFromName(partitionerName) if err != nil { return nil, ClusterInfo{}, err } @@ -348,7 +355,7 @@ func (c *Cluster) addHosts(hosts []*Host, rs *ResultSet) []*Host { for i := 0; i < rs.RowCount(); i++ { row := rs.Row(i) if endpoint, err := c.config.Resolver.NewEndpoint(row); err == nil { - if host, err := NewHostFromRow(endpoint, row); err == nil { + if host, err := NewHostFromRow(endpoint, c.Info.Partitioner, row); err == nil { hosts = append(hosts, host) } else { c.logger.Error("unable to create new host", zap.Stringer("endpoint", endpoint), zap.Error(err)) @@ -449,7 +456,7 @@ func (c *Cluster) stayConnected() { continue } } - newListener.OnEvent(&BootstrapEvent{c.hosts}) + newListener.OnEvent(&BootstrapEvent{c.hosts, c.Info.Partitioner, map[string]ReplicationStrategy{}}) // FIXME: Get keyspace info c.listeners = append(c.listeners, newListener) case <-refreshTimer.C: c.refreshHosts() diff --git a/proxycore/host.go b/proxycore/host.go index 155ec14..e52ba11 100644 --- a/proxycore/host.go +++ b/proxycore/host.go @@ -14,23 +14,31 @@ package proxycore +import "fmt" + type Host struct { Endpoint - DC string + DC string + Rack string + Tokens []Token } -func NewHostFromRow(endpoint Endpoint, row Row) (*Host, error) { - dc, err := row.StringByName("data_center") +func NewHostFromRow(endpoint Endpoint, partitioner Partitioner, row Row) (*Host, error) { + dc, err := row.ByName("data_center") if err != nil { - return nil, err + return nil, fmt.Errorf("error attmpting to get 'data_center' column: %v", err) } - return &Host{endpoint, dc}, nil -} - -func (h *Host) Key() string { - return h.Endpoint.Key() -} - -func (h *Host) String() string { - return h.Endpoint.String() + rack, err := row.ByName("rack") + if err != nil { + return nil, fmt.Errorf("error attmpting to get 'rack' column: %v", err) + } + tokensVal, err := row.ByName("tokens") + if err != nil { + return nil, fmt.Errorf("error attmpting to get 'tokens' column: %v", err) + } + tokens := make([]Token, 0, len(tokensVal.([]string))) + for _, token := range tokensVal.([]string) { + tokens = append(tokens, partitioner.FromString(token)) + } + return &Host{endpoint, dc.(string), rack.(string), tokens}, nil } diff --git a/proxycore/lb.go b/proxycore/lb.go index 84cf722..cbb52c8 100644 --- a/proxycore/lb.go +++ b/proxycore/lb.go @@ -25,13 +25,11 @@ type QueryPlan interface { type LoadBalancer interface { ClusterListener - NewQueryPlan() QueryPlan + NewQueryPlan(keyspace string, token Token) QueryPlan } func NewRoundRobinLoadBalancer() LoadBalancer { - lb := &roundRobinLoadBalancer{ - mu: &sync.Mutex{}, - } + lb := &roundRobinLoadBalancer{} lb.hosts.Store(make([]*Host, 0)) return lb } @@ -39,7 +37,7 @@ func NewRoundRobinLoadBalancer() LoadBalancer { type roundRobinLoadBalancer struct { hosts atomic.Value index uint32 - mu *sync.Mutex + mu sync.Mutex } func (l *roundRobinLoadBalancer) OnEvent(event Event) { @@ -69,7 +67,7 @@ func (l *roundRobinLoadBalancer) copy() []*Host { return cpy } -func (l *roundRobinLoadBalancer) NewQueryPlan() QueryPlan { +func (l *roundRobinLoadBalancer) NewQueryPlan(_ string, _ Token) QueryPlan { return &roundRobinQueryPlan{ hosts: l.hosts.Load().([]*Host), offset: atomic.AddUint32(&l.index, 1) - 1, @@ -92,3 +90,52 @@ func (p *roundRobinQueryPlan) Next() *Host { p.index++ return host } + +type tokenAwareLoadBalancer struct { + tokenMap *TokenMap + partitioner Partitioner + mu sync.Mutex +} + +func (l *tokenAwareLoadBalancer) OnEvent(event Event) { + l.mu.Lock() + defer l.mu.Unlock() + + switch evt := event.(type) { + case *BootstrapEvent: + l.tokenMap = NewTokenMap(evt.Hosts, evt.Keyspaces) + l.partitioner = evt.Partitioner + case *AddEvent: + l.tokenMap.AddHost(evt.Host) + case *RemoveEvent: + l.tokenMap.RemoveHost(evt.Host) + } + //TODO implement me + panic("implement me") +} + +func (l *tokenAwareLoadBalancer) NewQueryPlan(keyspace string, token Token) QueryPlan { + if token != nil { + replicas, err := l.tokenMap.GetReplicas(keyspace, token) + if err != nil { + return &tokenAwareQueryPlan{replicas: replicas} + } else { + //TODO implement me + panic("implement me") + } + } else { + //TODO implement me + panic("implement me") + } + return nil +} + +type tokenAwareQueryPlan struct { + replicas []*Host + index int +} + +func (t tokenAwareQueryPlan) Next() *Host { + //TODO implement me + panic("implement me") +} diff --git a/proxycore/token_map.go b/proxycore/token_map.go new file mode 100644 index 0000000..a563093 --- /dev/null +++ b/proxycore/token_map.go @@ -0,0 +1,457 @@ +package proxycore + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "strconv" + "strings" + "sync" + + "github.com/twmb/murmur3" +) + +const NetworkTopologyStrategy = "NetworkTopologyStrategy" +const SimpleStrategy = "SimpleStrategy" + +type Token interface { + fmt.Stringer + LessThan(Token) bool +} + +type Partitioner interface { + fmt.Stringer + Hash(partitionKey []byte) Token + FromString(token string) Token +} + +type TokenHost struct { + Token + *Host +} + +type TokenReplicas struct { + Token + Replicas []*Host +} + +type Datacenter struct { + numNodes int + racks map[string]struct{} +} + +type ReplicationStrategy interface { + BuildTokenMap(tokens []TokenHost, dcs map[string]*Datacenter) []TokenReplicas + Key() string +} + +type TokenMap struct { + hosts map[string]*Host + dcs map[string]*Datacenter + tokens []TokenHost + keyspaces map[string]ReplicationStrategy + tokenReplicas map[string][]TokenReplicas // Uses replication strategy Key() + rwMutex sync.RWMutex + updateMu sync.Mutex // Single updater +} + +func NewTokenMap(hosts []*Host, keyspaces map[string]ReplicationStrategy) *TokenMap { + tokens := make([]TokenHost, 0) + hostsMap := make(map[string]*Host) + + for _, host := range hosts { + hostsMap[host.Key()] = host + for _, token := range host.Tokens { + tokens = append(tokens, TokenHost{ + Token: token, + Host: host, + }) + } + } + + sortTokens(tokens) + + dcs := buildDcs(hostsMap) + + return &TokenMap{ + hosts: hostsMap, + dcs: dcs, + tokens: tokens, + keyspaces: keyspaces, + tokenReplicas: buildTokenReplicas(tokens, dcs, keyspaces), + } +} + +func (t *TokenMap) GetReplicas(keyspace string, token Token) (replicas []*Host, err error) { + t.rwMutex.RLock() + defer t.rwMutex.RUnlock() + if rs, ok := t.keyspaces[keyspace]; ok { + tokenReplicas := t.tokenReplicas[rs.Key()] + index := sort.Search(len(tokenReplicas), func(i int) bool { return token.LessThan(tokenReplicas[i]) }) + if index < 0 { + return tokenReplicas[0].Replicas, nil + } else { + return tokenReplicas[index].Replicas, nil + } + } else { + return nil, fmt.Errorf("'%s' keyspace does not exist in token map", keyspace) + } +} + +func (t *TokenMap) AddHost(host *Host) { + t.updateMu.Lock() + defer t.updateMu.Unlock() + + tokensCopy := make([]TokenHost, len(t.tokens)) + + for _, token := range host.Tokens { + tokensCopy = append(tokensCopy, TokenHost{ + Token: token, + Host: host, + }) + } + + t.hosts[host.Key()] = host + t.dcs = buildDcs(t.hosts) + + sortTokens(tokensCopy) + + tokenReplicasCopy := buildTokenReplicas(t.tokens, t.dcs, t.keyspaces) + + t.rwMutex.Lock() + t.tokens = tokensCopy + t.tokenReplicas = tokenReplicasCopy + t.rwMutex.Unlock() +} + +func (t *TokenMap) RemoveHost(host *Host) { + t.updateMu.Lock() + defer t.updateMu.Unlock() + + tokensCopy := make([]TokenHost, 0) + + for _, tokenHost := range t.tokens { + if tokenHost.Host != host && tokenHost.Host.Key() != host.Key() { + tokensCopy = append(tokensCopy, tokenHost) + } + } + + delete(t.hosts, host.Key()) + t.dcs = buildDcs(t.hosts) + + sortTokens(tokensCopy) + + tokenReplicasCopy := buildTokenReplicas(t.tokens, t.dcs, t.keyspaces) + + t.rwMutex.Lock() + t.tokens = tokensCopy + t.tokenReplicas = tokenReplicasCopy + t.rwMutex.Unlock() +} + +func (t *TokenMap) AddKeyspace(keyspace string, rs ReplicationStrategy) { + t.updateMu.Lock() + defer t.updateMu.Unlock() + + if _, ok := t.tokenReplicas[rs.Key()]; !ok { + tokenMap := rs.BuildTokenMap(t.tokens, t.dcs) + + t.rwMutex.Lock() + t.keyspaces[keyspace] = rs + t.tokenReplicas[rs.Key()] = tokenMap + t.rwMutex.Unlock() + } else { + t.rwMutex.Lock() + t.keyspaces[keyspace] = rs + t.rwMutex.Unlock() + } +} + +func buildTokenReplicas(tokens []TokenHost, dcs map[string]*Datacenter, keyspaces map[string]ReplicationStrategy) map[string][]TokenReplicas { + tokenReplicas := make(map[string][]TokenReplicas) + for _, rs := range keyspaces { + if _, ok := tokenReplicas[rs.Key()]; !ok { + tokenReplicas[rs.Key()] = rs.BuildTokenMap(tokens, dcs) + } + } + return tokenReplicas +} + +func buildDcs(hosts map[string]*Host) map[string]*Datacenter { + dcs := make(map[string]*Datacenter) + for _, host := range hosts { + if dc, ok := dcs[host.DC]; ok { + dc.racks[host.Rack] = struct{}{} + dc.numNodes++ + } else { + dcs[host.DC] = &Datacenter{ + numNodes: 1, + racks: make(map[string]struct{}), + } + } + } + return dcs +} + +func sortTokens(tokens []TokenHost) { + sort.SliceStable(tokens, func(i, j int) bool { + return tokens[i].LessThan(tokens[j]) + }) +} + +type murmur3Token struct { + hash int64 +} + +func (m murmur3Token) String() string { + return strconv.FormatInt(m.hash, 10) +} + +func (m murmur3Token) LessThan(token Token) bool { + if t, ok := token.(*murmur3Token); ok { + return m.hash < t.hash + } else { + panic("tried comparing incompatible token types") + } +} + +func NewPartitionerFromName(name string) (Partitioner, error) { + if strings.EqualFold(name, "Murmur3Partitioner") { + return NewMurmur3Partitioner(), nil + } else { + return nil, fmt.Errorf("'%s' is an unsupported paritioner", name) + } +} + +type murmur3Partitioner struct { +} + +func NewMurmur3Partitioner() Partitioner { + return &murmur3Partitioner{} +} + +func (m murmur3Partitioner) String() string { + return "Murmur3Partitioner" +} + +func (m murmur3Partitioner) Hash(partitionKey []byte) Token { + return &murmur3Token{int64(murmur3.Sum64(partitionKey))} +} + +func (m murmur3Partitioner) FromString(token string) Token { + hash, _ := strconv.ParseInt(token, 10, 64) // TODO: Don't ignore error + return &murmur3Token{hash} +} + +type simpleReplicationStrategy struct { + replicationFactor int + key string +} + +func (s simpleReplicationStrategy) BuildTokenMap(tokens []TokenHost, _ map[string]*Datacenter) []TokenReplicas { + numReplicas := s.replicationFactor + numTokens := len(tokens) + result := make([]TokenReplicas, 0, numTokens) + + if numTokens < numReplicas { + numReplicas = numTokens + } + + for i, token := range tokens { + replicas := make([]*Host, 0, numReplicas) + for j := 0; j < numTokens && len(replicas) < numReplicas; j++ { + replicas = append(replicas, tokens[i].Host) + i++ + if i >= numTokens { + i = 0 + } + } + result = append(result, TokenReplicas{token, replicas}) + } + + return result +} + +func (s simpleReplicationStrategy) Key() string { + return s.key +} + +type networkTopologyReplicationStrategy struct { + dcReplicationFactors map[string]int + key string +} + +type dcState struct { + skippedEndpoints []*Host + racksObserved map[string]struct{} + replicaCount int +} + +type dcInfo struct { + replicationFactor int + numRacks int +} + +func appendReplica(replicas []*Host, replicaCountThisDc int, replicaToAdd *Host) ([]*Host, int) { + for _, replica := range replicas { + if replica == replicaToAdd || replica.Key() == replicaToAdd.Key() { + return replicas, replicaCountThisDc + } + } + replicaCountThisDc++ + return append(replicas, replicaToAdd), replicaCountThisDc +} + +func (n networkTopologyReplicationStrategy) BuildTokenMap(tokens []TokenHost, dcs map[string]*Datacenter) []TokenReplicas { + infos := make(map[string]dcInfo) + + numTokens := len(tokens) + result := make([]TokenReplicas, 0, numTokens) + + numReplicas := 0 + + for dcName, rf := range n.dcReplicationFactors { + if dc, ok := dcs[dcName]; ok { + numReplicas += rf + infos[dcName] = dcInfo{ + replicationFactor: rf, + numRacks: len(dc.racks), + } + } + } + + if numReplicas == 0 { + return result + } + + for i, token := range tokens { + replicas := make([]*Host, 0, numReplicas) + states := make(map[string]*dcState) + + for j := 0; j < numTokens && len(replicas) < numReplicas; j++ { + host := tokens[i].Host + + // Move to the next token, we got the host for the current token in the previous step + i++ + if i >= numTokens { // Wrap to the first token + i = 0 + } + + if info, ok := infos[host.DC]; !ok { + continue // Not a valid datacenter, go to the next token + } else { + var state *dcState + if state, ok = states[host.DC]; !ok { + state = &dcState{ + skippedEndpoints: nil, + racksObserved: make(map[string]struct{}), + replicaCount: 0, + } + states[host.DC] = state + } + + if state.replicaCount >= info.replicationFactor { + continue + } + + if len(host.Rack) == 0 || len(state.racksObserved) == info.numRacks { + replicas, state.replicaCount = appendReplica(replicas, state.replicaCount, host) + } else { + if _, ok = state.racksObserved[host.Rack]; ok { + state.skippedEndpoints = append(state.skippedEndpoints, host) + } else { + replicas, state.replicaCount = appendReplica(replicas, state.replicaCount, host) + state.racksObserved[host.Rack] = struct{}{} // Observe the rack + + if len(state.racksObserved) == info.numRacks { + for len(state.skippedEndpoints) > 0 && state.replicaCount < info.replicationFactor { + replicas, state.replicaCount = appendReplica(replicas, state.replicaCount, host) + state.skippedEndpoints = state.skippedEndpoints[1:] + } + } + } + } + } + } + result = append(result, TokenReplicas{token, replicas}) + } + + return result +} + +func (n networkTopologyReplicationStrategy) Key() string { + return n.key +} + +func NewReplicationFactor(row Row) (ReplicationStrategy, error) { + replicationFactors := make(map[string]int) + replicationColumn, err := row.ByName("replication") + var class string + if err == ColumnNameNotFound { + strategyClass, err := row.ByName("strategy_class") + if err != nil { + return nil, errors.New("couldn't find 'strategy_class' column in keyspace metadata") + } + class = strategyClass.(string) + strategyOptions, err := row.ByName("strategy_options") + if err != nil { + return nil, errors.New("couldn't find 'strategy_options' column in keyspace metadata") + } + options := make(map[string]string) + err = json.Unmarshal([]byte(strategyOptions.(string)), &options) + if err != nil { + return nil, fmt.Errorf("'strategy_options' column is invalid: %v", err) + } + for k, v := range options { + switch k { + case "replication_factor": + rf, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid replication factor: %s. Expected an integer value", v) + } + replicationFactors["rf"] = rf + default: + rf, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid replication factor: %s. Expected an integer value", v) + } + replicationFactors[k] = rf // Key should be a data center + } + } + } else { + replication := replicationColumn.(map[string]string) + for k, v := range replication { + switch k { + case "class": + class = v + case "replication_factor": + rf, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid replication factor: %s. Expected an integer value", v) + } + replicationFactors["rf"] = rf + default: + rf, err := strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("invalid replication factor: %s. Expected an integer value", v) + } + replicationFactors[k] = rf // Key should be a data center + } + } + } + + if strings.EqualFold(class, NetworkTopologyStrategy) { + return &networkTopologyReplicationStrategy{ + dcReplicationFactors: replicationFactors, + key: fmt.Sprintf("%v", replicationFactors), + }, nil + } else if strings.EqualFold(class, SimpleStrategy) { + return &simpleReplicationStrategy{ + replicationFactor: replicationFactors["rf"], + key: fmt.Sprintf("%d", replicationFactors["rf"]), + }, nil + } else { + return nil, fmt.Errorf("invalid replication strategy: '%s'", class) + } +}