Skip to content

Commit

Permalink
[signal] Fix context propagation in signal server (#3251)
Browse files Browse the repository at this point in the history
  • Loading branch information
4thel00z authored Feb 7, 2025
1 parent 05415f7 commit 58b2eb4
Showing 1 changed file with 53 additions and 48 deletions.
101 changes: 53 additions & 48 deletions signal/server/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}

dispatcher, err := dispatcher.NewDispatcher(ctx, meter)
d, err := dispatcher.NewDispatcher(ctx, meter)
if err != nil {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}

s := &Server{
dispatcher: dispatcher,
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
}
Expand All @@ -75,7 +75,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
return &proto.EncryptedMessage{}, nil
}

return s.dispatcher.SendMessage(context.Background(), msg)
return s.dispatcher.SendMessage(ctx, msg)
}

// ConnectStream connects to the exchange stream
Expand All @@ -98,76 +98,81 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)

for {
// read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}

log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

_, err = s.dispatcher.SendMessage(stream.Context(), msg)
if err != nil {
log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
select {
case <-stream.Context().Done():
log.Debugf("stream closed for peer [%s] [streamID %d] due to context cancellation", p.Id, p.StreamID)
return stream.Context().Err()
default:
// read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}

log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

_, err = s.dispatcher.SendMessage(stream.Context(), msg)
if err != nil {
log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
}
}
}

<-stream.Context().Done()
return stream.Context().Err()
}

func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
log.Debugf("registering new peer")
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
if id, found := meta[proto.HeaderId]; found {
p := peer.NewPeer(id[0], stream)

s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)

return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId)
}
} else {
meta, hasMeta := metadata.FromIncomingContext(stream.Context())
if !hasMeta {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
}

id, found := meta[proto.HeaderId]
if !found {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
}

p := peer.NewPeer(id[0], stream)
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
return p, nil
}

func (s *Server) DeregisterPeer(p *peer.Peer) {
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
s.registry.Deregister(p)

s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
}

func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

getRegistrationStart := time.Now()

// lookup the target peer where the message is going to
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
// forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil {
log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
// todo respond to the sender?
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
} else {
// in milliseconds
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
}
} else {
dstPeer, found := s.registry.Get(msg.RemoteKey)

if !found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
}

s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()

// forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil {
log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
// todo respond to the sender?
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return
}

// in milliseconds
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
}

0 comments on commit 58b2eb4

Please sign in to comment.