diff --git a/gbus/abstractions.go b/gbus/abstractions.go index aa3ebe6..585a08b 100644 --- a/gbus/abstractions.go +++ b/gbus/abstractions.go @@ -6,7 +6,6 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/streadway/amqp" ) @@ -284,3 +283,25 @@ type Logged interface { SetLogger(entry logrus.FieldLogger) Log() logrus.FieldLogger } + +type Transport interface { + Messaging + Logged + Health + + Start() error + Stop() error + + RPCChannel() <-chan *BusMessage + MessageChannel() <-chan *BusMessage + + ErrorChan() <-chan error + BackPressureChannel() <-chan bool + + ListenOnEvent(exchange, topic string) error + + Ack(message *BusMessage) error + Reject(message *BusMessage, requeue bool) error +} + +type NewTransport func(svcName, connString, DLX string, prefetchCount, maxRetryCount uint, purgeOnStartup, withConfirms bool, logger logrus.FieldLogger) Transport diff --git a/gbus/messages.go b/gbus/messages.go index 3bb54ee..7ae244e 100644 --- a/gbus/messages.go +++ b/gbus/messages.go @@ -1,27 +1,26 @@ package gbus import ( - "errors" - "fmt" "strings" "github.com/opentracing/opentracing-go/log" "github.com/rs/xid" - "github.com/sirupsen/logrus" - "github.com/streadway/amqp" ) //BusMessage the structure that gets sent to the underlying transport type BusMessage struct { - ID string - IdempotencyKey string - CorrelationID string - SagaID string - SagaCorrelationID string - Semantics Semantics /*cmd or evt*/ - Payload Message - PayloadFQN string - RPCID string + ID string + IdempotencyKey string + CorrelationID string + SagaID string + SagaCorrelationID string + Semantics Semantics /*cmd or evt*/ + Payload Message + PayloadFQN string + RPCID string + ResurrectedFromDeath bool + RawPayload []byte + ContentType string } //NewBusMessage factory method for creating a BusMessage that wraps the given payload @@ -34,65 +33,6 @@ func NewBusMessage(payload Message) *BusMessage { return bm } -//NewFromDelivery creates a BusMessage from an amqp delivery -func NewFromDelivery(delivery amqp.Delivery) (*BusMessage, error) { - bm := &BusMessage{} - bm.SetFromAMQPHeaders(delivery) - - bm.ID = delivery.MessageId - bm.CorrelationID = delivery.CorrelationId - if delivery.Exchange != "" { - bm.Semantics = EVT - } else { - bm.Semantics = CMD - } - if bm.PayloadFQN == "" || bm.Semantics == "" { - errMsg := fmt.Sprintf("missing critical headers. message_name:%s semantics: %s", bm.PayloadFQN, bm.Semantics) - return nil, errors.New(errMsg) - } - return bm, nil -} - -//GetMessageName extracts the valuee of the custom x-msg-name header from an amq delivery -func GetMessageName(delivery amqp.Delivery) string { - return castToString(delivery.Headers["x-msg-name"]) -} - -func (bm *BusMessage) GetAMQPHeaders() (headers amqp.Table) { - headers = amqp.Table{ - "x-msg-name": bm.Payload.SchemaName(), - "x-idempotency-key": bm.IdempotencyKey, - } - - /* - only set the following headers if they contain a value - https://github.com/wework/grabbit/issues/221 - */ - setNonEmpty(headers, "x-msg-saga-id", bm.SagaID) - setNonEmpty(headers, "x-msg-saga-correlation-id", bm.SagaCorrelationID) - setNonEmpty(headers, "x-grabbit-msg-rpc-id", bm.RPCID) - - return -} - -func setNonEmpty(headers amqp.Table, headerName, headerValue string) { - - if headerValue != "" { - headers[headerName] = headerValue - } -} - -//SetFromAMQPHeaders convert from AMQP headers Table everything but a payload -func (bm *BusMessage) SetFromAMQPHeaders(delivery amqp.Delivery) { - headers := delivery.Headers - bm.IdempotencyKey = castToString(headers["x-idempotency-key"]) - bm.SagaID = castToString(headers["x-msg-saga-id"]) - bm.SagaCorrelationID = castToString(headers["x-msg-saga-correlation-id"]) - bm.RPCID = castToString(headers["x-grabbit-msg-rpc-id"]) - bm.PayloadFQN = GetMessageName(delivery) - -} - //SetPayload sets the payload and makes sure that Name is saved func (bm *BusMessage) SetPayload(payload Message) { bm.PayloadFQN = payload.SchemaName() @@ -121,21 +61,6 @@ func (bm *BusMessage) GetTraceLog() (fields []log.Field) { log.String("RPCID", bm.RPCID), } } - -func GetDeliveryLogEntries(delivery amqp.Delivery) logrus.Fields { - - return logrus.Fields{ - "message_name": castToString(delivery.Headers["x-msg-name"]), - "message_id": delivery.MessageId, - "routing_key": delivery.RoutingKey, - "exchange": delivery.Exchange, - "idempotency_key": castToString(delivery.Headers["x-idempotency-key"]), - "correlation_id": castToString(delivery.CorrelationId), - "rpc_id": castToString(delivery.Headers["x-grabbit-msg-rpc-id"]), - } - -} - func castToString(i interface{}) string { v, ok := i.(string) if !ok { @@ -155,8 +80,3 @@ type SagaTimeoutMessage struct { func (SagaTimeoutMessage) SchemaName() string { return "grabbit.timeout" } - -func isResurrectedMessage(delivery amqp.Delivery) bool { - isResurrected, ok := delivery.Headers[ResurrectedHeaderName].(bool) - return ok && isResurrected -} diff --git a/gbus/transport/rabbitmq/base.go b/gbus/transport/rabbitmq/base.go new file mode 100644 index 0000000..c3740e1 --- /dev/null +++ b/gbus/transport/rabbitmq/base.go @@ -0,0 +1,495 @@ +package rabbitmq + +import ( + "context" + "fmt" + "math/rand" + "time" + + "emperror.dev/emperror" + "emperror.dev/errors" + "github.com/Rican7/retry" + "github.com/Rican7/retry/strategy" + "github.com/opentracing-contrib/go-amqp/amqptracer" + "github.com/opentracing/opentracing-go" + "github.com/rs/xid" + "github.com/sirupsen/logrus" + "github.com/streadway/amqp" + + "github.com/wework/grabbit/gbus" +) + +var _ gbus.Transport = &transport{} + +type queueBinding struct { + topic, exchange string +} + +type transport struct { + *gbus.Glogged + *gbus.Safety + connString, svcName, deadLetterExchange string + prefetchCount, maxRetryCount, instanceId uint + purgeOnStartup, withConfirms, started bool + + ingressConn, egressConn *amqp.Connection + ingressChannel, egressChannel *amqp.Channel + + amqpErrors chan *amqp.Error + amqpBlocks chan amqp.Blocking + + serviceQueue, rpcQueue amqp.Queue + + rawMessages, rawRPCMessages <-chan amqp.Delivery + + // The actual channels used within the gbus/worker + rpcChannel, messageChannel chan *gbus.BusMessage + + // Used to allow registration of subscription before we start + delayedSubscriptions []*queueBinding + + // Monitoring fields we should consider a different path + backpressure bool + amqpConnected bool + healthChan, errorChan chan error + inFlightMsgs map[string]amqp.Delivery + backPressureChannel chan bool +} + +func (t *transport) RPCChannel() <-chan *gbus.BusMessage { + return t.rpcChannel +} + +func (t *transport) MessageChannel() <-chan *gbus.BusMessage { + return t.messageChannel +} + +func (t *transport) Ack(message *gbus.BusMessage) error { + delivery, err := t.popDelivery(message.ID) + if err != nil { + return err + } + return t.ack(delivery) +} + +func (t *transport) Reject(message *gbus.BusMessage, requeue bool) error { + delivery, err := t.popDelivery(message.ID) + if err != nil { + return err + } + return t.reject(requeue, delivery) +} + +func (t *transport) Send(ctx context.Context, toService string, command *gbus.BusMessage, policies ...gbus.MessagePolicy) error { + headers := getAMQPHeaders(command) + span, _ := opentracing.StartSpanFromContext(ctx, "SendMessage") + err := amqptracer.Inject(span, headers) + if err != nil { + t.Log().WithError(err).Error("failed injecting tracer span") + } + msg := amqp.Publishing{ + Type: command.PayloadFQN, + Body: command.RawPayload, + ReplyTo: t.svcName, + MessageId: command.ID, + CorrelationId: command.CorrelationID, + ContentType: command.ContentType, + Headers: headers, + } + + span.LogFields(command.GetTraceLog()...) + + for _, policy := range policies { + policy.Apply(&msg) + } + return t.egressChannel.Publish("", + toService, /*key*/ + false, /*mandatory*/ + false, /*immediate*/ + msg /*msg*/) +} + +func (t *transport) Publish(ctx context.Context, exchange, topic string, event *gbus.BusMessage, policies ...gbus.MessagePolicy) error { + panic("implement me") +} + +func (t *transport) RPC(ctx context.Context, service string, request, reply *gbus.BusMessage, timeout time.Duration) (*gbus.BusMessage, error) { + panic("implement me") +} + +func (t *transport) Start() error { + var err error + t.started = true + + // Open connections and channels + if t.ingressConn, err = t.connect(); err != nil { + return err + } + if t.egressConn, err = t.connect(); err != nil { + return err + } + if t.ingressChannel, err = t.ingressConn.Channel(); err != nil { + return err + } + + if t.egressChannel, err = t.egressConn.Channel(); err != nil { + return err + } + + // register on failure notifications + t.ingressConn.NotifyClose(t.amqpErrors) + t.ingressConn.NotifyBlocked(t.amqpBlocks) + t.egressConn.NotifyClose(t.amqpErrors) + t.egressConn.NotifyBlocked(t.amqpBlocks) + t.egressChannel.NotifyClose(t.amqpErrors) + t.ingressChannel.NotifyClose(t.amqpErrors) + + // declare queue + if err = t.createServiceQueue(); err != nil { + t.Log().WithError(err).Error("failed creating service queue") + return err + } + + // bind queue + if err = t.bindServiceQueue(); err != nil { + t.Log().WithError(err).Error("failed binding service queue") + return err + } + + // declare rpc queue + if t.rpcQueue, err = t.createRPCQueue(); err != nil { + t.Log().WithError(err).Error("failed binding RPC queue") + return err + } + + // start monitoring on amqp related errors + go t.monitorAMQPErrors() + + if t.rawMessages, err = t.createMessageChannel(t.serviceQueue, ""); err != nil { + t.Log().WithError(err).Error("failed creating a message channel") + return err + } + + if t.rawRPCMessages, err = t.createMessageChannel(t.serviceQueue, "_rpc"); err != nil { + t.Log().WithError(err).Error("failed creating a message channel") + return err + } + + go t.consumeMessages() + go t.consumeRPC() + return nil +} + +func (t *transport) Stop() error { + if !t.started { + t.Log().Info("stopping a non started transport") + return nil + } + builder := emperror.NewMultiErrorBuilder() + + builder.Add(t.ingressChannel.Cancel(t.consumerTag(""), false)) + builder.Add(t.ingressChannel.Cancel(t.consumerTag("_rpc"), false)) + builder.Add(t.ingressConn.Close()) + builder.Add(t.egressConn.Close()) + return builder.ErrOrNil() +} + +func (t *transport) ErrorChan() <-chan error { + return t.errorChan +} + +func (t *transport) BackPressureChannel() <-chan bool { + return t.backPressureChannel +} + +func (t *transport) ListenOnEvent(exchange, topic string) error { + + subscription := &queueBinding{ + topic: topic, + exchange: exchange, + } + + if t.started { + return t.addSubscription(subscription) + } + t.delayedSubscriptions = append(t.delayedSubscriptions, subscription) + return nil +} + +func (t *transport) connect() (*amqp.Connection, error) { + var conn *amqp.Connection + err := t.SafeWithRetries(func() error { + var err error + conn, err = amqp.Dial(t.connString) + return err + }, t.maxRetryCount) + return conn, err +} + +func (t *transport) createServiceQueue() error { + if t.purgeOnStartup { + purgedMsgs, err := t.ingressChannel.QueueDelete(t.svcName, false, false, false) + if err != nil { + t.Log().WithError(err).WithField("purged_messages", purgedMsgs) + return err + } + } + + args := amqp.Table{} + if t.deadLetterExchange != "" { + args["x-dead-letter-exchange"] = t.deadLetterExchange + } + var err error + t.serviceQueue, err = t.ingressChannel.QueueDeclare(t.svcName, + true, /*durable*/ + false, /*autoDelete*/ + false, /*exclusive*/ + false, /*noWait*/ + args /*args*/) + + if err != nil { + t.Log().WithError(err).Error("failed to declare queue") + return err + } + return nil +} + +func (t *transport) bindServiceQueue() error { + if t.deadLetterExchange != "" { + err := t.ingressChannel.ExchangeDeclare( + t.deadLetterExchange, + "fanout", + true, + false, + false, + false, + nil) + if err != nil { + t.Log().WithError(err). + WithField("dead_letter_exchange", t.deadLetterExchange). + Error("failed declaring dead letter exchange") + return err + } + + err = t.bindQueue(&queueBinding{topic: "", exchange: t.deadLetterExchange}) + if err != nil { + t.Log().WithError(err). + WithField("dead_letter_exchange", t.deadLetterExchange). + Error("failed binding to dead_letter_exchange queue") + return err + } + } + + for _, subscription := range t.delayedSubscriptions { + err := t.addSubscription(subscription) + if err != nil { + t.Log().WithError(err).Error("failed adding subscription") + return err + } + } + return nil + +} + +func (t *transport) addSubscription(subscription *queueBinding) error { + err := t.ingressChannel.ExchangeDeclare( + subscription.exchange, + "topic", + true, + false, + false, + false, + nil, + ) + if err != nil { + t.Log().WithError(err). + WithField("exchange", subscription.exchange). + WithField("topic", subscription.topic). + Error("failed to declare the proper exchange") + return err + } + err = t.bindQueue(subscription) + if err != nil { + t.Log().WithError(err). + WithField("exchange", subscription.exchange). + WithField("topic", subscription.topic). + Error("failed to bind to queue") + return err + } + return nil +} + +func (t *transport) bindQueue(s *queueBinding) error { + return t.ingressChannel.QueueBind(t.serviceQueue.Name, s.topic, s.exchange, false, nil) +} + +func (t *transport) createRPCQueue() (amqp.Queue, error) { + /* + the RPC queue is a queue per service instance (as opposed to the service queue which + is shared between service instances to allow for round-robin load balancing) in order to + support synchronous RPC style calls.amqpit is not durable and is auto-deleted once the service + instance process terminates + */ + uid := xid.New().String() + qName := t.svcName + "_rpc_" + uid + q, e := t.ingressChannel.QueueDeclare(qName, + false, /*durable*/ + true, /*autoDelete*/ + false, /*exclusive*/ + false, /*noWait*/ + nil /*args*/) + return q, e + +} + +func (t *transport) monitorAMQPErrors() { + // TODO(vlad): refactor this to make sure that this makes + // logic maybe implement https://github.com/AppsFlyer/go-sundheit ? + for t.started { + select { + case blocked := <-t.amqpBlocks: + if blocked.Active { + t.Log().WithField("reason", blocked.Reason).Warn("amqp connection blocked") + } else { + t.Log().WithField("reason", blocked.Reason).Info("amqp connection unblocked") + } + t.backpressure = blocked.Active + t.backPressureChannel <- t.backpressure + case amqpErr := <-t.amqpErrors: + t.amqpConnected = false + t.Log().WithField("amqp_error", amqpErr).Error("amqp error") + if t.healthChan != nil { + t.healthChan <- amqpErr + } + } + } +} + +func (t *transport) NotifyHealth(health chan error) { + if health == nil { + panic("can't pass nil as health channel") + } + t.healthChan = health +} + +func (t *transport) GetHealth() gbus.HealthCard { + // TODO(vlad): refactor health messages in grabbit between components + return gbus.HealthCard{ + RabbitBackPressure: t.backpressure, + RabbitConnected: t.amqpConnected, + } +} + +func (t *transport) createMessageChannel(queue amqp.Queue, suffix string) (<-chan amqp.Delivery, error) { + consumerTag := t.consumerTag(suffix) + msgs, err := t.ingressChannel.Consume( + queue.Name, + consumerTag, + false, + false, + false, + false, + nil, + ) + if err != nil { + t.Log().WithError(err). + WithField("queue_name", queue.Name). + WithField("consumer_tag", consumerTag). + Error("failed to consume from queue") + return nil, err + } + return msgs, nil +} + +func (t *transport) consumerTag(suffix string) string { + return fmt.Sprintf("%s_worker_%d%s", t.svcName, t.instanceId, suffix) +} + +func (t *transport) consumeMessages() { + t.consume(t.rawMessages) +} + +func (t *transport) consumeRPC() { + t.consume(t.rawRPCMessages) +} + +func (t *transport) consume(c <-chan amqp.Delivery) { + for msg := range c { + bm, err := newFromDelivery(msg) + if err != nil { + t.errorChan <- err + if e := t.reject(false, msg); e != nil { + t.errorChan <- e + } + } else { + t.inFlightMsgs[bm.ID] = msg + t.rpcChannel <- bm + } + } +} + +func (t *transport) reject(requeue bool, delivery amqp.Delivery) error { + reject := func(attempts uint) error { return delivery.Reject(requeue) } + l := t.Log().WithField("message_id", delivery.MessageId). + WithField("requeue", requeue) + err := retry.Retry(reject, + strategy.Wait(100*time.Millisecond)) + if err != nil { + l.WithError(err). + Error("failed rejecting message") + return err + } + l.Debug("successfully rejected message") + return nil +} + +func (t *transport) ack(delivery amqp.Delivery) error { + ack := func(attempts uint) error { return delivery.Ack(false /*multiple*/) } + l := t.Log().WithField("message_id", delivery.MessageId) + err := retry.Retry(ack, + strategy.Wait(100*time.Millisecond)) + if err != nil { + l.WithError(err). + Error("failed acking message") + return err + } + l.Debug("successfully acked message") + return nil +} + +func (t *transport) popDelivery(id string) (amqp.Delivery, error) { + delivery, ok := t.inFlightMsgs[id] + if !ok { + t.Log().Error("the message is not in our management we should make sure that we are not crossing channels and go routines somewhere") + return delivery, errors.NewWithDetails("the message we are rejecting is not in our hands", "message_id", message.ID) + } + delete(t.inFlightMsgs, id) + return delivery, nil +} + +var _ gbus.NewTransport = NewTransport + +// NewTransport creates a new AMQP transport +func NewTransport(svcName, connString, DLX string, prefetchCount, maxRetryCount uint, purgeOnStartup, withConfirms bool, logger logrus.FieldLogger) gbus.Transport { + t := &transport{ + connString: connString, + svcName: svcName, + deadLetterExchange: DLX, + prefetchCount: prefetchCount, + purgeOnStartup: purgeOnStartup, + withConfirms: withConfirms, + maxRetryCount: maxRetryCount, + delayedSubscriptions: make([]*queueBinding, 0), + instanceId: uint(rand.Intn(100)), + errorChan: make(chan error), + messageChannel: make(chan *gbus.BusMessage), + rpcChannel: make(chan *gbus.BusMessage), + inFlightMsgs: make(map[string]amqp.Delivery), + amqpBlocks: make(chan amqp.Blocking), + amqpErrors: make(chan *amqp.Error), + backPressureChannel: make(chan bool), + } + t.SetLogger(logger) + + return t +} diff --git a/gbus/transport/rabbitmq/message.go b/gbus/transport/rabbitmq/message.go new file mode 100644 index 0000000..42580ed --- /dev/null +++ b/gbus/transport/rabbitmq/message.go @@ -0,0 +1,87 @@ +package rabbitmq + +import ( + "emperror.dev/errors" + "github.com/sirupsen/logrus" + "github.com/streadway/amqp" + + "github.com/wework/grabbit/gbus" +) + +var ( + ResurrectedHeaderName = "x-resurrected-from-death" +) + +func newFromDelivery(delivery amqp.Delivery) (*gbus.BusMessage, error) { + bm := &gbus.BusMessage{} + setFromAMQPHeaders(delivery, bm) + + bm.ID = delivery.MessageId + bm.CorrelationID = delivery.CorrelationId + if delivery.Exchange != "" { + bm.Semantics = gbus.EVT + } else { + bm.Semantics = gbus.CMD + } + if bm.PayloadFQN == "" || bm.Semantics == "" { + return nil, errors.NewWithDetails("missing critical headers", "message_name", bm.PayloadFQN, "semantics", bm.Semantics) + } + return bm, nil +} + +func setFromAMQPHeaders(delivery amqp.Delivery, bm *gbus.BusMessage) { + headers := delivery.Headers + bm.IdempotencyKey = castToString(headers["x-idempotency-key"]) + bm.SagaID = castToString(headers["x-msg-saga-id"]) + bm.SagaCorrelationID = castToString(headers["x-msg-saga-correlation-id"]) + bm.RPCID = castToString(headers["x-grabbit-msg-rpc-id"]) + bm.PayloadFQN = castToString(delivery.Headers["x-msg-name"]) +} + +func getAMQPHeaders(bm *gbus.BusMessage) (headers amqp.Table) { + headers = amqp.Table{ + "x-msg-name": bm.Payload.SchemaName(), + "x-idempotency-key": bm.IdempotencyKey, + } + + /* + only set the following headers if they contain a value + https://github.com/wework/grabbit/issues/221 + */ + setNonEmpty(headers, "x-msg-saga-id", bm.SagaID) + setNonEmpty(headers, "x-msg-saga-correlation-id", bm.SagaCorrelationID) + setNonEmpty(headers, "x-grabbit-msg-rpc-id", bm.RPCID) + + return +} +func setNonEmpty(headers amqp.Table, headerName, headerValue string) { + if headerValue != "" { + headers[headerName] = headerValue + } +} +func castToString(i interface{}) string { + v, ok := i.(string) + if !ok { + return "" + } + return v +} + +func getDeliveryLogEntries(delivery amqp.Delivery) logrus.Fields { + + return logrus.Fields{ + "message_name": castToString(delivery.Headers["x-msg-name"]), + "message_id": delivery.MessageId, + "routing_key": delivery.RoutingKey, + "exchange": delivery.Exchange, + "idempotency_key": castToString(delivery.Headers["x-idempotency-key"]), + "correlation_id": castToString(delivery.CorrelationId), + "rpc_id": castToString(delivery.Headers["x-grabbit-msg-rpc-id"]), + } + +} + +func isResurrectedMessage(delivery amqp.Delivery) bool { + isResurrected, ok := delivery.Headers[ResurrectedHeaderName].(bool) + return ok && isResurrected +} diff --git a/go.mod b/go.mod index ff5a3f4..8093633 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/wework/grabbit require ( + emperror.dev/emperror v0.21.3 emperror.dev/errors v0.4.3 emperror.dev/handler/logrus v0.1.0 github.com/Rican7/retry v0.1.0 diff --git a/go.sum b/go.sum index f73ddfa..f0e545b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +emperror.dev/emperror v0.21.3 h1:/S3xa/ljmXKTsrxN8ttCE/eq7fmY/4H4xyqbiunObss= +emperror.dev/emperror v0.21.3/go.mod h1:aeDoz3ERR3yJblyjfKojXoFFsXSd6K8Wfd4Zb1eEbZg= emperror.dev/errors v0.4.1/go.mod h1:cA5SMsyzo+KXq997DKGK+lTV1DGx5TXLQUNtYe9p2p0= emperror.dev/errors v0.4.3 h1:yfhVxX1vzHgCDXh0KL+gVKfKhXlJCabmc79jS6QQuus= emperror.dev/errors v0.4.3/go.mod h1:cA5SMsyzo+KXq997DKGK+lTV1DGx5TXLQUNtYe9p2p0=