diff --git a/codec.go b/codec.go index 517effa..8d962b7 100644 --- a/codec.go +++ b/codec.go @@ -128,7 +128,7 @@ func (d *decompressor) passthrough(input []byte, checksum *uint32) ([]byte, erro _, _ = d.checksum.Write(input) // Update checksum (no error possible) if checksum != nil { if curChecksum := d.getChecksum(); curChecksum != *checksum { - return nil, fmt.Errorf("Invalid checksum %x; should be %x", curChecksum, *checksum) + return nil, fmt.Errorf("invalid checksum %x; should be %x", curChecksum, *checksum) } } return input, nil @@ -168,7 +168,7 @@ func (d *decompressor) decompress(input []byte, checksum uint32) ([]byte, error) return nil, err } else if n == 0 { // Nothing more to read; since checksum didn't match (above), fail: - return nil, fmt.Errorf("Invalid checksum %x; should be %x", d.getChecksum(), checksum) + return nil, fmt.Errorf("invalid checksum %x; should be %x", d.getChecksum(), checksum) } _, _ = d.checksum.Write(d.buffer[0:n]) // Update checksum (no error possible) diff --git a/codec_test.go b/codec_test.go index b87a718..e0a86ca 100644 --- a/codec_test.go +++ b/codec_test.go @@ -29,7 +29,7 @@ func init() { randomData = make([]byte, 65536) var b byte var step byte = 1 - for i, _ := range randomData { + for i := range randomData { if rando.Intn(10) == 0 { b = byte(rando.Intn(256)) step = byte(rando.Intn(4)) diff --git a/context.go b/context.go index 6c649eb..10c1f34 100644 --- a/context.go +++ b/context.go @@ -17,7 +17,6 @@ import ( "io" "math/rand" "net/http" - "runtime/debug" "strings" "sync/atomic" "time" @@ -25,17 +24,6 @@ import ( "github.com/coder/websocket" ) -// A function that handles an incoming BLIP request and optionally sends a response. -// A handler is called on a new goroutine so it can take as long as it needs to. -// For example, if it has to send a synchronous network request before it can construct -// a response, that's fine. -type Handler func(request *Message) - -// Utility function that responds to a Message with a 404 error. -func Unhandled(request *Message) { - request.Response().SetError(BLIPErrorDomain, 404, "No handler for BLIP request") -} - // Defines how incoming requests are dispatched to handler functions. type Context struct { @@ -49,14 +37,14 @@ type Context struct { // Patterns that the Origin header must match (if non-empty) origin []string - HandlerForProfile map[string]Handler // Handler function for a request Profile - DefaultHandler Handler // Handler for all otherwise unhandled requests - FatalErrorHandler func(error) // Called when connection has a fatal error - HandlerPanicHandler func(request, response *Message, err interface{}) // Called when a profile handler panics - MaxSendQueueCount int // Max # of messages being sent at once (if >0) - Logger LogFn // Logging callback; defaults to log.Printf - LogMessages bool // If true, will log about messages - LogFrames bool // If true, will log about frames (very verbose) + RequestHandler AsyncHandler // Callback that handles incoming requests + FatalErrorHandler FatalErrorHandler // Called when connection has a fatal error + HandlerPanicHandler HandlerPanicHandler // Called when a profile handler panics + MaxSendQueueCount int // Max # of messages being sent at once (if >0) + MaxDispatchedBytes int // Max total size of incoming requests being dispatched (if >0) + Logger LogFn // Logging callback; defaults to log.Printf + LogMessages bool // If true, will log about messages + LogFrames bool // If true, will log about frames (very verbose) OnExitCallback func() // OnExitCallback callback invoked when the underlying connection closes and the receive loop exits. @@ -65,12 +53,21 @@ type Context struct { // An identifier that uniquely defines the context. NOTE: Random Number Generator not seeded by go-blip. ID string + HandlerForProfile map[string]Handler // deprecated; use RequestHandler & ByProfileDispatcher + DefaultHandler Handler // deprecated; use RequestHandler & ByProfileDispatcher + bytesSent atomic.Uint64 // Number of bytes sent bytesReceived atomic.Uint64 // Number of bytes received cancelCtx context.Context // When cancelled, closes all connections. Terminates receiveLoop(s), which triggers sender and parseLoop stop } +// A function called when a Handler function panics. +type HandlerPanicHandler func(request, response *Message, err interface{}) + +// A function called when the connection closes due to a fatal protocol error. +type FatalErrorHandler func(error) + // Defines a logging interface for use within the blip codebase. Implemented by Context. // Any code that needs to take a Context just for logging purposes should take a Logger instead. type LogContext interface { @@ -116,6 +113,13 @@ func NewContextCustomID(id string, opts ContextOptions) (*Context, error) { } func (blipCtx *Context) start(ws *websocket.Conn) *Sender { + if blipCtx.RequestHandler == nil { + // Compatibility mode: If the app hasn't set a RequestHandler, set one that uses the old + // handlerForProfile and defaultHandler. + blipCtx.RequestHandler = blipCtx.compatibilityHandler + } else if len(blipCtx.HandlerForProfile) > 0 || blipCtx.DefaultHandler != nil { + panic("blip.Context cannot have both a RequestHandler and legacy handlerForProfile or defaultHandler") + } r := newReceiver(blipCtx, ws) r.sender = newSender(blipCtx, ws, r) r.sender.start() @@ -252,9 +256,9 @@ func (bwss *BlipWebsocketServer) handshake(w http.ResponseWriter, r *http.Reques protocolHeader := r.Header.Get("Sec-WebSocket-Protocol") protocol, found := includesProtocol(protocolHeader, bwss.blipCtx.SupportedSubProtocols) if !found { - stringSeperatedProtocols := strings.Join(bwss.blipCtx.SupportedSubProtocols, ",") - bwss.blipCtx.log("Error: Client doesn't support any of WS protocols: %s only %s", stringSeperatedProtocols, protocolHeader) - err := fmt.Errorf("I only speak %s protocols", stringSeperatedProtocols) + stringSeparatedProtocols := strings.Join(bwss.blipCtx.SupportedSubProtocols, ",") + bwss.blipCtx.log("Error: Client doesn't support any of WS protocols: %s only %s", stringSeparatedProtocols, protocolHeader) + err := fmt.Errorf("I only speak %s protocols", stringSeparatedProtocols) http.Error(w, err.Error(), http.StatusInternalServerError) return nil, err } @@ -286,52 +290,6 @@ func (bwss *BlipWebsocketServer) handle(ws *websocket.Conn) { ws.Close(websocket.StatusNormalClosure, "") } -//////// DISPATCHING MESSAGES: - -func (blipCtx *Context) dispatchRequest(request *Message, sender *Sender) { - defer func() { - // On return/panic, send the response: - response := request.Response() - if panicked := recover(); panicked != nil { - if blipCtx.HandlerPanicHandler != nil { - blipCtx.HandlerPanicHandler(request, response, panicked) - } else { - stack := debug.Stack() - blipCtx.log("PANIC handling BLIP request %v: %v:\n%s", request, panicked, stack) - if response != nil { - response.SetError(BLIPErrorDomain, 500, fmt.Sprintf("Panic: %v", panicked)) - } - } - } - if response != nil { - sender.send(response) - } - }() - - blipCtx.logMessage("Incoming BLIP Request: %s", request) - handler := blipCtx.HandlerForProfile[request.Properties["Profile"]] - if handler == nil { - handler = blipCtx.DefaultHandler - if handler == nil { - handler = Unhandled - } - } - handler(request) -} - -func (blipCtx *Context) dispatchResponse(response *Message) { - defer func() { - // On return/panic, log a warning: - if panicked := recover(); panicked != nil { - stack := debug.Stack() - blipCtx.log("PANIC handling BLIP response %v: %v:\n%s", response, panicked, stack) - } - }() - - blipCtx.logMessage("Incoming BLIP Response: %s", response) - //panic("UNIMPLEMENTED") //TODO -} - //////// LOGGING: func (blipCtx *Context) log(format string, params ...interface{}) { @@ -350,6 +308,8 @@ func (blipCtx *Context) logFrame(format string, params ...interface{}) { } } +//////// UTILITIES: + func includesProtocol(header string, protocols []string) (string, bool) { for _, item := range strings.Split(header, ",") { for _, protocol := range protocols { diff --git a/context_test.go b/context_test.go index eb82720..ad3d413 100644 --- a/context_test.go +++ b/context_test.go @@ -82,17 +82,8 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - http.Handle("/TestServerAbruptlyCloseConnectionBehavior", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, nil)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -100,12 +91,8 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/TestServerAbruptlyCloseConnectionBehavior", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() // Create echo request echoRequest := NewRequest() @@ -126,15 +113,7 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { // Read the echo response response := echoRequest.Response() // <--- SG #3268 was causing this to block indefinitely - responseBody, err := response.Body() - - // Assertions about echo response (these might need to be altered, maybe what's expected in this scenario is actually an error) - assert.True(t, err == nil) - assert.True(t, len(responseBody) == 0) - - // TODO: add more assertions about the response. I'm not seeing any errors, or any - // TODO: way to differentiate this response with a normal response other than having an empty body - + requireBLIPError(t, response, BLIPErrorDomain, DisconnectedCode) } /* @@ -204,12 +183,7 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { sent := clientSender.Send(echoAmplifyRequest) assert.True(t, sent) echoAmplifyResponse := echoAmplifyRequest.Response() // <--- SG #3268 was causing this to block indefinitely - echoAmplifyResponseBody, _ := echoAmplifyResponse.Body() - assert.True(t, len(echoAmplifyResponseBody) == 0) - - // TODO: add more assertions about the response. I'm not seeing any errors, or any - // TODO: way to differentiate this response with a normal response other than having an empty body - + requireBLIPError(t, echoAmplifyResponse, BLIPErrorDomain, DisconnectedCode) } // Create a blip profile handler to respond to echo requests and then abruptly close the socket @@ -239,17 +213,8 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - http.Handle("/TestClientAbruptlyCloseConnectionBehavior", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, nil)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -257,13 +222,6 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/TestClientAbruptlyCloseConnectionBehavior", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } - // Handle EchoAmplifyData that should be initiated by server in response to getting incoming echo requests dispatchEchoAmplify := func(request *Message) { _, err := request.Body() @@ -277,6 +235,9 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { } blipContextEchoClient.HandlerForProfile["BLIPTest/EchoAmplifyData"] = dispatchEchoAmplify + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() + // Create echo request echoRequest := NewRequest() echoRequest.SetProfile("BLIPTest/EchoData") @@ -299,7 +260,7 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { responseBody, err := response.Body() // Assertions about echo response (these might need to be altered, maybe what's expected in this scenario is actually an error) - assert.True(t, err == nil) + assert.NoError(t, err) assert.Equal(t, "hello", string(responseBody)) // Wait until the amplify request was received by client (from server), and that the server read the response @@ -395,21 +356,8 @@ func TestUnsupportedSubProtocol(t *testing.T) { serverCtx.LogMessages = true serverCtx.LogFrames = true - server := serverCtx.WebSocketServer() - - mux := http.NewServeMux() - mux.Handle("/someBlip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - go func() { - err := http.Serve(listener, mux) - if err != nil { - panic(err) - } - }() + listener := startTestListener(t, serverCtx) + defer listener.Close() // Client client, err := NewContext(ContextOptions{ProtocolIds: testCase.ClientProtocol}) @@ -417,13 +365,15 @@ func TestUnsupportedSubProtocol(t *testing.T) { t.Fatal(err) } port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/someBlip", port) + destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) s, err := client.Dial(destUrl) if testCase.ExpectError { assert.True(t, err != nil) } else { assert.Equal(t, nil, err) + } + if s != nil { s.Close() } @@ -681,10 +631,9 @@ func TestServerContextClose(t *testing.T) { echoRequest.SetBody(echoResponseBody) receivedRequests.Add(1) sent := sender.Send(echoRequest) - assert.True(t, sent) + require.True(t, sent) - // Read the echo response. Closed connection will result in empty response, as EOF message - // isn't currently returned by blip client + // Read the echo response if we sent something response := echoRequest.Response() responseBody, err := response.Body() assert.True(t, err == nil) @@ -692,7 +641,6 @@ func TestServerContextClose(t *testing.T) { log.Printf("empty response, connection closed") return } - assert.Equal(t, echoResponseBody, responseBody) } } @@ -728,27 +676,6 @@ func assertHandlerNoError(t *testing.T, server *BlipWebsocketServer, wg *sync.Wa } } -// Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout -// TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo. -func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { - - // Create a channel so that a goroutine waiting on the waitgroup can send it's result (if any) - wgFinished := make(chan bool) - - go func() { - wg.Wait() - wgFinished <- true - }() - - select { - case <-wgFinished: - return nil - case <-time.After(timeout): - return fmt.Errorf("Timed out waiting after %v", timeout) - } - -} - // StringPtr returns a pointer to the string value passed in func StringPtr(s string) *string { return &s diff --git a/dispatcher.go b/dispatcher.go new file mode 100644 index 0000000..9915ad9 --- /dev/null +++ b/dispatcher.go @@ -0,0 +1,203 @@ +// Copyright 2023-Present Couchbase, Inc. +// +// Use of this software is governed by the Business Source License included +// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified +// in that file, in accordance with the Business Source License, use of this +// software will be governed by the Apache License, Version 2.0, included in +// the file licenses/APL2.txt. + +package blip + +import ( + "sync" +) + +//////// REQUEST HANDLER FUNCTIONS + +// A function that handles an incoming BLIP request and optionally sends a response. +// The request is not considered handled until this function returns. +// +// After the handler returns, any response you've added to the message (by calling +// `request.Response()“) will be sent, unless the message has the NoReply flag. +// If the message needs a response but none was created, a default empty response will be sent. +// If the handler panics, an error response is sent instead. +type SynchronousHandler = func(request *Message) + +// A function that asynchronously handles an incoming BLIP request and optionally sends a response. +// The handler function may return immediately without handling the request. +// The request is not considered handled until the `onComplete` callback is called, +// from any goroutine. +// +// The callback MUST be called eventually, even if the request has the NoReply flag; +// the caller of this handler may be using the callback to track and limit the number of concurrent +// handlers, so failing to call the callback could eventually block delivery of requests. +// +// The function is allowed to call `onComplete` before it returns, i.e. run synchronously, +// but it should still try to return as quickly as possible to avoid blocking upstream code. +type AsyncHandler = func(request *Message, onComplete RequestCompletedCallback) + +type Handler = SynchronousHandler + +type RequestCompletedCallback = func() + +// Utility SynchronousHandler function that responds to a Message with a 404 error. +func Unhandled(request *Message) { + //log.Printf("Warning: Unhandled BLIP message: %v (Profile=%q)", request, request.Profile()) + request.Response().SetError(BLIPErrorDomain, NotFoundCode, "No handler for BLIP request") +} + +// Utility AsyncHandler function that responds to a Message with a 404 error. +func UnhandledAsync(request *Message, onComplete RequestCompletedCallback) { + defer onComplete() + Unhandled(request) +} + +// A utility that wraps a SynchronousHandler function in an AsyncHandler function. +// When the returned AsyncHandler is called, it will asynchronously run the wrapped handler +// *on a new goroutine* and then call the completion routine. +func AsAsyncHandler(handler SynchronousHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + go func() { + defer onComplete() + handler(request) + }() + } +} + +// AsyncHandler function that uses the Context's old handlerForProfile and defaultHandler fields. +// Used as the RequestHandler when the client hasn't set one. +// For compatibility reasons it does not copy the HandlerForProfile or DefaultHandler, +// rather it accesses the ones in the Context on every call (without any locking!) +// This is because there are tests that change those properties while the connection is open. +func (context *Context) compatibilityHandler(request *Message, onComplete RequestCompletedCallback) { + profile := request.Properties[ProfileProperty] + handler := context.HandlerForProfile[profile] + if handler == nil { + handler = context.DefaultHandler + if handler == nil { + handler = Unhandled + } + } + // Old handlers ran on individual goroutines: + go func() { + defer onComplete() + handler(request) + }() +} + +//////// DISPATCHER INTERFACE: + +// An interface that provides an AsyncHandler function for incoming requests. +// For any Dispatcher instance `d`, `d.Dispatch` without parentheses is an AsyncHandler. +type Dispatcher interface { + Dispatch(request *Message, onComplete RequestCompletedCallback) +} + +//////// BY-PROFILE DISPATCHER + +// A Dispatcher implementation that routes messages to handlers based on the Profile property. +type ByProfileDispatcher struct { + mutex sync.Mutex + profiles map[string]AsyncHandler // Handler for each Profile + defaultHandler AsyncHandler // Fallback handler +} + +// Sets the AsyncHandler function for a specific Profile property value. +// A nil value removes the current handler. +// This method is thread-safe and may be called while the connection is active. +func (d *ByProfileDispatcher) SetHandler(profile string, handler AsyncHandler) { + d.mutex.Lock() + defer d.mutex.Unlock() + if handler != nil { + if d.profiles == nil { + d.profiles = map[string]AsyncHandler{} + } + d.profiles[profile] = handler + } else { + delete(d.profiles, profile) + } +} + +// Sets the default handler, when no Profile matches. +// If this is not set, the default default handler responds synchronously with a BLIP/404 error. +// This method is thread-safe and may be called while the connection is active. +func (d *ByProfileDispatcher) SetDefaultHandler(handler AsyncHandler) { + d.mutex.Lock() + d.defaultHandler = handler + d.mutex.Unlock() +} + +func (d *ByProfileDispatcher) Dispatch(request *Message, onComplete RequestCompletedCallback) { + profile := request.Properties[ProfileProperty] + d.mutex.Lock() + handler := d.profiles[profile] + if handler == nil { + handler = d.defaultHandler + if handler == nil { + handler = UnhandledAsync + } + } + d.mutex.Unlock() + + handler(request, onComplete) +} + +//////// THROTTLING DISPATCHER + +// A Dispatcher implementation that forwards to a given AsyncHandler, but only allows a certain +// number of concurrent calls. +// Excess requests will be queued and later dispatched in the order received. +type ThrottlingDispatcher struct { + Handler AsyncHandler + MaxConcurrency int + mutex sync.Mutex + concurrency int + pending []savedDispatch +} + +type savedDispatch struct { + request *Message + onComplete RequestCompletedCallback +} + +// Initializes a ThrottlingDispatcher. +func (d *ThrottlingDispatcher) Init(maxConcurrency int, handler AsyncHandler) { + d.Handler = handler + d.MaxConcurrency = maxConcurrency +} + +func (d *ThrottlingDispatcher) Dispatch(request *Message, onComplete RequestCompletedCallback) { + d.mutex.Lock() + if d.concurrency >= d.MaxConcurrency { + // Too many running; add this one to the queue: + d.pending = append(d.pending, savedDispatch{request: request, onComplete: onComplete}) + d.mutex.Unlock() + } else { + // Dispatch it now: + d.concurrency++ + d.mutex.Unlock() + d.dispatchNow(request, onComplete) + } +} + +func (d *ThrottlingDispatcher) dispatchNow(request *Message, onComplete RequestCompletedCallback) { + d.Handler(request, func() { + // When the handler finishes, decrement the concurrency count and start a queued request: + var next savedDispatch + d.mutex.Lock() + if len(d.pending) > 0 { + next = d.pending[0] + d.pending = d.pending[1:] + } else { + d.concurrency-- + } + + d.mutex.Unlock() + + onComplete() + + if next.request != nil { + d.dispatchNow(next.request, next.onComplete) + } + }) +} diff --git a/expvar.go b/expvar.go index 0b5ea60..2dfde4a 100644 --- a/expvar.go +++ b/expvar.go @@ -25,22 +25,6 @@ func decrReceiverGoroutines() { goblipExpvar.Add("goroutines_receiver", -1) } -func incrAsyncReadGoroutines() { - goblipExpvar.Add("goroutines_async_read", 1) -} - -func decrAsyncReadGoroutines() { - goblipExpvar.Add("goroutines_async_read", -1) -} - -func incrNextFrameToSendGoroutines() { - goblipExpvar.Add("goroutines_next_frame_to_send", 1) -} - -func decrNextFrameToSendGoroutines() { - goblipExpvar.Add("goroutines_next_frame_to_send", -1) -} - func incrParseLoopGoroutines() { goblipExpvar.Add("goroutines_parse_loop", 1) } diff --git a/functional_test.go b/functional_test.go index 65e0257..165021b 100644 --- a/functional_test.go +++ b/functional_test.go @@ -41,9 +41,8 @@ func TestEchoRoundTrip(t *testing.T) { // ----------------- Setup Echo Server ------------------------- - // Create a blip profile handler to respond to echo requests and then abruptly close the socket + // Create a blip profile handler to respond to echo requests dispatchEcho := func(request *Message) { - defer receivedRequests.Done() body, err := request.Body() if err != nil { log.Printf("ERROR reading body of %s: %s", request, err) @@ -61,22 +60,12 @@ func TestEchoRoundTrip(t *testing.T) { // Blip setup blipContextEchoServer.HandlerForProfile["BLIPTest/EchoData"] = dispatchEcho - blipContextEchoServer.LogMessages = false - blipContextEchoServer.LogFrames = false + blipContextEchoServer.LogMessages = true + blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - mux := http.NewServeMux() - mux.Handle("/blip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, mux)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -84,15 +73,12 @@ func TestEchoRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() numRequests := 100 - receivedRequests.Add(numRequests) + receivedRequests.Add(numRequests * 50) for i := 0; i < numRequests; i++ { @@ -113,12 +99,14 @@ func TestEchoRoundTrip(t *testing.T) { for j := 0; j < 10; j++ { go func(m *Message) { response := m.Response() + defer receivedRequests.Done() if response == nil { t.Errorf("unexpected nil message response") return } responseBody, err := response.Body() - assert.True(t, err == nil) + assert.NoError(t, err) + //log.Printf("Got response: %s", responseBody) assert.Equal(t, "hello", string(responseBody)) }(echoRequest) } @@ -129,7 +117,7 @@ func TestEchoRoundTrip(t *testing.T) { // Wait until all requests were sent and responded to receivedRequests.Wait() - + log.Printf("*** Closing connections ***") } // TestSenderPing ensures a client configured with a WebsocketPingInterval sends ping frames on an otherwise idle connection. @@ -140,17 +128,8 @@ func TestSenderPing(t *testing.T) { if err != nil { t.Fatal(err) } - server := serverCtx.WebSocketServer() - - mux := http.NewServeMux() - mux.Handle("/blip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, mux)) - }() + listener := startTestListener(t, serverCtx) + defer listener.Close() // client clientCtx, err := NewContext(defaultContextOptions) @@ -161,17 +140,11 @@ func TestSenderPing(t *testing.T) { clientCtx.LogFrames = true clientCtx.WebsocketPingInterval = time.Millisecond * 10 - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) - // client hasn't connected yet, stats are uninitialized assert.Equal(t, int64(0), expvarToInt64(goblipExpvar.Get("sender_ping_count"))) assert.Equal(t, int64(0), expvarToInt64(goblipExpvar.Get("goroutines_sender_ping"))) - sender, err := clientCtx.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + sender := startTestClient(t, clientCtx, listener) time.Sleep(time.Millisecond * 50) diff --git a/go.mod b/go.mod index 42b2b8b..566a63f 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.23 require ( github.com/coder/websocket v1.8.12 github.com/klauspost/compress v1.17.11 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.10.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ae9b23c..5060e89 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,14 @@ github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message.go b/message.go index cba5712..bb3bd52 100644 --- a/message.go +++ b/message.go @@ -11,51 +11,61 @@ licenses/APL2.txt. package blip import ( - "bytes" "encoding/json" - "errors" "fmt" "io" - "log" - "runtime/debug" + "strconv" "sync" "sync/atomic" ) +// A Message's serial number. Outgoing requests are numbered consecutively by the sending peer. +// A response has the same number as its request (i.e. the number was created by the other peer.) type MessageNumber uint32 // A BLIP message. It could be a request or response or error, and it could be from me or the peer. type Message struct { - Outgoing bool // Is this a message created locally? - Sender *Sender // The connection that sent this message. - Properties Properties // The message's metadata, similar to HTTP headers. - body []byte // The message body. MIME type is defined by "Content-Type" property - number MessageNumber // The sequence number of the message in the connection. - flags atomic.Pointer[frameFlags] // Message flags as seen on the first frame. - bytesSent uint64 - bytesAcked uint64 - - reader io.ReadCloser // Stream that an incoming message is being read from - encoder io.ReadCloser // Stream that an outgoing message is being written to - readingBody bool // True if reader stream has been accessed by client already - complete bool // Has this message been completely received? - response *Message // Response to this message, if it's a request - inResponseTo *Message // Message this is a response to - cond *sync.Cond // Used to make Response() method block until response arrives -} - -// Closes all resources for the message. + Outgoing bool // Is this a message created locally? + Sender *Sender // The connection that sent this message. + Properties Properties // The message's metadata, similar to HTTP headers. + body []byte // The message body. MIME type is defined by "Content-Type" property + bodyReader *PipeReader // Non-nil if incoming body is being streamed. + bodyWriter *PipeWriter // Non-nil if incoming body is being streamed. + bodyMutex sync.Mutex // Synchronizes access to body properties + response *Message // The response to this outgoing request [must lock 'cond'] + inResponseTo *Message // Outgoing request that this is a response to + cond *sync.Cond // Used to make Response() method block until response arrives + onResponse func(*Message) // Application's callback to receive response + onSent func() // Application's callback when message has been sent + number MessageNumber // The sequence number of the message in the connection. + flags atomic.Pointer[frameFlags] // Message flags as seen on the first frame. + inProgress bool // True while message is being sent or received +} + func (message *Message) Close() (err error) { - if message.reader != nil { - err = message.reader.Close() + if message.bodyReader != nil { + err = message.bodyReader.Close() } - if message.encoder != nil { - err = message.encoder.Close() + if message.bodyWriter != nil { + err = message.bodyWriter.Close() } return err } -// Returns a string describing the message for debugging purposes +// The error info from a BLIP response, as a Go Error value. +type ErrorResponse struct { + Domain string + Code int + Message string +} + +// A callback function that takes a message and returns nothing +type messageCallback func(*Message) + +var ErrConnectionClosed = fmt.Errorf("BLIP connection closed") + +// Returns a string describing the message for debugging purposes. +// It contains the message number and type. A "!" indicates urgent, and "~" indicates compressed. func (message *Message) String() string { return frameString(message.number, *message.flags.Load()) } @@ -85,9 +95,7 @@ func NewRequest() *Message { // The order in which a request message was sent. // A response has the same serial number as its request even though it goes the other direction. func (message *Message) SerialNumber() MessageNumber { - if message.number == 0 { - panic("Unsent message has no serial number yet") - } + message.assertSent() return message.number } @@ -118,11 +126,9 @@ func (message *Message) SetCompressed(compressed bool) { } // Marks an outgoing message as being one-way: no reply will be sent. -func (request *Message) SetNoReply(noReply bool) { - if request.Type() != RequestType { - panic("Can't call SetNoReply on a response") - } - request.setFlag(kNoReply, noReply) +func (message *Message) SetNoReply(noReply bool) { + message.assertIsRequest() + message.setFlag(kNoReply, noReply) } func (message *Message) setFlag(flag frameFlags, value bool) { @@ -136,55 +142,115 @@ func (message *Message) setFlag(flag frameFlags, value bool) { message.flags.Store(&flags) } -func (message *Message) assertMutable() { - if !message.Outgoing || message.encoder != nil { - panic("Message can't be modified") +// Registers a callback that will be called when this message has been written to the socket. +// (This does not mean it's been delivered! If you need delivery confirmation, wait for a reply.) +func (message *Message) OnSent(callback func()) { + message.assertOutgoing() + message.onSent = callback +} + +// The value of the "Profile" property which is used to identify a request's purpose. +func (message *Message) Profile() string { + return message.Properties[ProfileProperty] +} + +// Sets the value of the "Profile" property which is used to identify a request's purpose. +func (message *Message) SetProfile(profile string) { + message.Properties[ProfileProperty] = profile +} + +// True if a message is in its completed form: an incoming message whose entire body has +// arrived, or an outgoing message that's been queued for delivery. +func (message *Message) Complete() bool { + return !message.inProgress +} + +// Writes the encoded form of a Message to a stream. +func (message *Message) WriteEncodedTo(writer io.Writer) error { + if err := message.Properties.WriteEncodedTo(writer); err != nil { + return err + } + var err error + if len(message.body) > 0 { + _, err = writer.Write(message.body) } + return err } -// Reads an incoming message's properties from the reader if necessary -func (m *Message) readProperties() error { - if m.Properties != nil { +//////// ERRORS: + +// The error information in a response whose Type is ErrorType. +// (The return type `*ErrorResponse` implements the `error` interface.) +// If the message is not an error, returns nil. +func (response *Message) Error() *ErrorResponse { + if response.Type() != ErrorType { return nil - } else if m.reader == nil { - panic("Message has no reader") } - m.Properties = Properties{} - return m.Properties.ReadFrom(m.reader) + domain := response.Properties[ErrorDomainProperty] + if domain == "" { + domain = BLIPErrorDomain + } + code, _ := strconv.ParseInt(response.Properties[ErrorCodeProperty], 10, 0) + message, _ := response.Body() + return &ErrorResponse{ + Domain: domain, + Code: int(code), + Message: string(message), + } } -// The value of the "Profile" property which is used to identify a request's purpose. -func (request *Message) Profile() string { - return request.Properties["Profile"] +// ErrorResponse implements the `error` interface. +func (err *ErrorResponse) Error() string { + if err.Message == "" { + return fmt.Sprintf("%s/%d", err.Domain, err.Code) + } else { + return fmt.Sprintf("%s (%s/%d)", err.Message, err.Domain, err.Code) + } } -// Sets the value of the "Profile" property which is used to identify a request's purpose. -func (request *Message) SetProfile(profile string) { - request.Properties["Profile"] = profile +// Changes a pending response into an error. +// It is safe (and a no-op) to call this on a nil Message. +func (response *Message) SetError(errDomain string, errCode int, message string) { + if response != nil { + response.assertMutable() + response.setError(errDomain, errCode, message) + } } +func (response *Message) setError(errDomain string, errCode int, message string) { + response.assertIsResponse() + newFlags := *response.flags.Load()&^kTypeMask | frameFlags(ErrorType) + response.flags.Store(&newFlags) + response.Properties = Properties{ + ErrorDomainProperty: errDomain, + ErrorCodeProperty: fmt.Sprintf("%d", errCode), + } + response.body = []byte(message) +} + +//////// BODY: + // Returns a Reader object from which the message body can be read. // If this is an incoming message the body will be streamed as the message arrives over -// the network (and multiple calls to BodyReader() won't work.) -func (m *Message) BodyReader() (io.Reader, error) { +// the network. +func (m *Message) BodyReader() (*PipeReader, error) { if m.Outgoing || m.body != nil { - return bytes.NewReader(m.body), nil - } - if err := m.readProperties(); err != nil { - return nil, err + r, w := NewPipe() + _, _ = w.Write(m.body) + w.Close() + return r, nil + } else { + return m.bodyReader, nil } - m.readingBody = true - return m.reader, nil } // Returns the entire message body as a byte array. // If the message is incoming, blocks until the entire body is received. func (m *Message) Body() ([]byte, error) { + m.bodyMutex.Lock() + defer m.bodyMutex.Unlock() if m.body == nil && !m.Outgoing { - if m.readingBody { - panic("Already reading body as a stream") - } - body, err := io.ReadAll(m.reader) + body, err := io.ReadAll(m.bodyReader) if err != nil { return nil, err } @@ -193,12 +259,6 @@ func (m *Message) Body() ([]byte, error) { return m.body, nil } -// Sets the entire body of an outgoing message. -func (m *Message) SetBody(body []byte) { - m.assertMutable() - m.body = body -} - // Returns the message body parsed as JSON. func (m *Message) ReadJSONBody(value interface{}) error { if bodyReader, err := m.BodyReader(); err != nil { @@ -210,6 +270,12 @@ func (m *Message) ReadJSONBody(value interface{}) error { } } +// Sets the entire body of an outgoing message. +func (m *Message) SetBody(body []byte) { + m.assertMutable() + m.body = body +} + // Sets the message body to JSON generated from the given JSON-encodable value. // As a convenience this also sets the "Content-Type" property to "application/json". func (m *Message) SetJSONBody(value interface{}) error { @@ -228,22 +294,46 @@ func (m *Message) SetJSONBodyAsBytes(jsonBytes []byte) { m.SetCompressed(true) } -// Returns the response message to this request. Its properties and body are initially empty. -// Multiple calls return the same object. +func (m *Message) bodySize() int { + if m.body != nil { + return len(m.body) + } else { + return m.bodyWriter.bytesPending() + } +} + +// Makes a copy of a Message. Only for tests. +func (message *Message) Clone() *Message { + // Make sure the body is available. This may block. + _, _ = message.Body() + + message.bodyMutex.Lock() + defer message.bodyMutex.Unlock() + m := &Message{ + Outgoing: message.Outgoing, + Properties: message.Properties, + body: message.body, + number: message.number, + } + m.flags.Store(message.flags.Load()) + return m +} + +//////// RESPONSE HANDLING: + +// Returns the response message to this request. Multiple calls return the same object. // If called on a NoReply request, this returns nil. +// - If this is an incoming request, the response is immediately available. Its properties and +// body are initially empty and ready to be filled in. +// - If this is an outgoing request, the function blocks until the response arrives over the +// network. func (request *Message) Response() *Message { + request.assertIsRequest() + request.assertSent() if *request.flags.Load()&kNoReply != 0 { return nil - } - if request.Type() != RequestType { - panic("Can't respond to this message") - } - if request.number == 0 { - panic("Can't get response before message has been sent") - } - - // block until a response has been set by responseComplete - if request.Outgoing { + } else if request.Outgoing { + // Outgoing request: block until a response has been set by responseComplete request.cond.L.Lock() for request.response == nil { request.cond.Wait() @@ -251,184 +341,266 @@ func (request *Message) Response() *Message { response := request.response request.cond.L.Unlock() return response - } - - // request is incoming, so we need to build a response - request.cond.L.Lock() - defer request.cond.L.Unlock() - // if we already have a response, return it - if request.response != nil { - return request.response - } - response := request.createResponse() - newFlags := *response.flags.Load() | *request.flags.Load()&kUrgent - response.flags.Store(&newFlags) - response.Properties = Properties{} - request.response = response - return response -} - -// Changes a pending response into an error. -// It is safe (and a no-op) to call this on a nil Message. -func (response *Message) SetError(errDomain string, errCode int, message string) { - if response != nil { - response.assertMutable() - if response.Type() == RequestType { - panic("Can't call SetError on a request") + } else { + // Incoming request: create a response for the caller to fill in + request.cond.L.Lock() + defer request.cond.L.Unlock() + // if we already have a response, return it + if request.response != nil { + return request.response } - newFlags := *response.flags.Load()&^kTypeMask | frameFlags(ErrorType) + response := request.createResponse() + newFlags := *response.flags.Load() | *request.flags.Load()&kUrgent response.flags.Store(&newFlags) - response.Properties = Properties{ - "Error-Domain": errDomain, - "Error-Code": fmt.Sprintf("%d", errCode), - } - if message != "" { - response.body = []byte(message) - } + response.Properties = Properties{} + request.response = response + return response } } -//////// INTERNALS: +// Registers a function to be called when a response to this outgoing Message arrives. +// This Message must be an outgoing request. +// Only one function can be registered; registering another causes a panic. +func (request *Message) OnResponse(callback func(*Message)) { + request.assertIsRequest() + request.assertOutgoing() + precondition(*request.flags.Load()&kNoReply == 0, "OnResponse: Message %s was sent NoReply", request) -func newIncomingMessage(sender *Sender, number MessageNumber, flags frameFlags, reader io.ReadCloser) *Message { - m := &Message{ - Sender: sender, - number: number, - reader: reader, - cond: sync.NewCond(&sync.Mutex{}), + request.cond.L.Lock() + response := request.response + if response == nil { + if request.onResponse != nil { + request.cond.L.Unlock() + panic("Message already has an OnResponse callback") + } + request.onResponse = callback } - m.flags.Store(ptr(flags | kMoreComing)) - return m -} + request.cond.L.Unlock() -// Creates an incoming message given properties and body; exposed only for testing. -func NewParsedIncomingMessage(sender *Sender, msgType MessageType, properties Properties, body []byte) *Message { - if properties == nil { - properties = Properties{} - } - if body == nil { - body = []byte{} + if response != nil { + // Response has already arrived: + callback(response) } - msg := newIncomingMessage(sender, 1, frameFlags(msgType), nil) - msg.Properties = properties - msg.body = body - return msg } +// Creates a response Message for this Message. func (request *Message) createResponse() *Message { flags := frameFlags(ResponseType) | (*request.flags.Load() & kUrgent) response := &Message{ number: request.number, Outgoing: !request.Outgoing, inResponseTo: request, + inProgress: request.Outgoing, cond: sync.NewCond(&sync.Mutex{}), } if !response.Outgoing { flags |= kMoreComing + response.bodyReader, response.bodyWriter = NewPipe() } response.flags.Store(&flags) return response } -func (request *Message) responseComplete(response *Message) { +// Notifies an outgoing Message that its response is available. +// - This unblocks any waiting `Response` methods, which will now return the response. +// - It calls any registered "OnResponse" callback. +// - It sets the `response` field so subsequent calls to `Response` will return it. +func (request *Message) responseAvailable(response *Message) { + var callback func(*Message) + request.cond.L.Lock() - defer request.cond.L.Unlock() - if request.response != nil { - panic(fmt.Sprintf("Multiple responses to %s", request)) + if existing := request.response; existing != nil { + request.cond.L.Unlock() + if existing != response { + panic(fmt.Sprintf("Multiple responses to %s", request)) + } + return } request.response = response + callback = request.onResponse + request.onResponse = nil request.cond.Broadcast() + request.cond.L.Unlock() + + if callback != nil { + callback(response) // Calling on receiver thread; callback should exit quickly + } } -//////// I/O: +//////// SENDING MESSAGES: -func (m *Message) WriteTo(writer io.Writer) error { - if err := m.Properties.WriteTo(writer); err != nil { - return err +// A wrapper around an outgoing Message that generates frame bodies to send. +type msgSender struct { + *Message + bytesSent uint64 + bytesAcked uint64 + remainingProperties []byte + remainingBody []byte +} + +func (m *msgSender) nextFrameToSend(maxSize int) ([]byte, frameFlags) { + m.assertOutgoing() + m.assertSent() + + if m.remainingProperties == nil { + // On the first call, encode the properties: + m.inProgress = true + m.remainingProperties = m.Properties.Encode() + m.remainingBody = m.body } - var err error - if len(m.body) > 0 { - _, err = writer.Write(m.body) + + var frame []byte + if plen := min(maxSize, len(m.remainingProperties)); plen > 0 { + // Send properties first: + frame = m.remainingProperties[0:plen] + m.remainingProperties = m.remainingProperties[plen:] + maxSize -= plen } - return err -} -func (m *Message) ReadFrom(reader io.Reader) error { - if err := m.Properties.ReadFrom(reader); err != nil { - return err + if blen := min(maxSize, len(m.remainingBody)); blen > 0 { + // Then the body: + frame = append(frame, m.remainingBody[0:blen]...) + m.remainingBody = m.remainingBody[blen:] } - var err error - m.body, err = io.ReadAll(reader) - return err + + flags := *m.flags.Load() + if len(m.remainingProperties) > 0 || len(m.remainingBody) > 0 { + flags |= kMoreComing + } + return frame, flags } -// Returns a write stream to write the incoming message's content into. When the stream is closed, -// the message will deliver itself. -func (m *Message) asyncRead(onComplete func(error)) io.WriteCloser { +func (m *msgSender) addBytesSent(bytesSent uint64) { + m.bytesSent += bytesSent +} - reader, writer := io.Pipe() - m.reader = reader +func (m *msgSender) needsAck() bool { + return m.bytesSent > m.bytesAcked+kMaxUnackedBytes +} - // Start a goroutine to read off the read-end of the io.Pipe until it's read everything, or the - // write end of the io.Pipe was closed, which can happen if the peer closes the connection. - go func() { - defer func() { - if p := recover(); p != nil { - err := fmt.Sprintf("PANIC in BLIP asyncRead: %v", p) - log.Printf(err+"\n%s", debug.Stack()) - reader.CloseWithError(errors.New(err)) - } - }() +func (m *msgSender) receivedAck(bytesAcked uint64) { + if bytesAcked > m.bytesAcked { + m.bytesAcked = bytesAcked + } +} - // Update Expvar stats for number of outstanding goroutines - incrAsyncReadGoroutines() - defer decrAsyncReadGoroutines() +func (m *msgSender) sent() { + if callback := m.onSent; callback != nil { + m.onSent = nil + callback() + } +} - err := m.ReadFrom(reader) - onComplete(err) - }() - return writer +// Informs an outgoing message that the connection has closed +func (m *msgSender) cancelOutgoing() { } -func (m *Message) nextFrameToSend(maxSize int) ([]byte, frameFlags) { - if m.number == 0 || !m.Outgoing { - panic("Can't send this message") - } +//////// RECEIVING MESSAGES: - if m.encoder == nil { - // Start the encoder goroutine: - var writer io.WriteCloser - m.encoder, writer = io.Pipe() - go func() { - defer func() { - if p := recover(); p != nil { - log.Printf("PANIC in BLIP nextFrameToSend: %v\n%s", p, debug.Stack()) - } - }() - defer writer.Close() +// A wrapper around an incoming Message while it's being received. +// Called only on the receiver's parseLoop goroutine. +type msgReceiver struct { + *Message + bytesWritten uint64 + propertiesBuffer []byte + dispatched bool +} - // Update Expvar stats for number of outstanding goroutines - incrNextFrameToSendGoroutines() - defer decrNextFrameToSendGoroutines() +func newIncomingMessage(sender *Sender, number MessageNumber, flags frameFlags) *msgReceiver { + m := &msgReceiver{ + Message: &Message{ + Sender: sender, + number: number, + inProgress: true, + cond: sync.NewCond(&sync.Mutex{}), + }, + } + m.flags.Store(ptr(flags | kMoreComing)) + m.bodyReader, m.bodyWriter = NewPipe() + return m +} - _ = m.WriteTo(writer) +type dispatchState struct { + atStart bool // The beginning of the message has arrived + atEnd bool // The end of the message has arrived +} - }() +// Appends a frame's data to an incoming message. +func (m *msgReceiver) addIncomingBytes(bytes []byte, complete bool) (dispatchState, error) { + state := dispatchState{atEnd: complete} + if m.Properties == nil { + // First read the message properties: + m.propertiesBuffer = append(m.propertiesBuffer, bytes...) + props, bytesRead, err := ReadProperties(m.propertiesBuffer) + if err != nil { + return state, err + } else if props == nil { + if complete { + return state, fmt.Errorf("incomplete properties in BLIP message") + } + return state, nil // incomplete properties; wait for more + } + // Got the complete properties: + m.Properties = props + bytes = m.propertiesBuffer[bytesRead:] + m.propertiesBuffer = nil + state.atStart = true } - frame := make([]byte, maxSize) - flags := *m.flags.Load() - size, err := io.ReadFull(m.encoder, frame) - if err == nil { - flags |= kMoreComing + // Now add to the body: + if complete { + m.inProgress = false + m.flags.Store(ptr(*m.flags.Load() &^ kMoreComing)) + } + if m.body != nil { + m.body = append(m.body, bytes...) } else { - frame = frame[0:size] + _, _ = m.bodyWriter.Write(bytes) + if complete { + m.bodyWriter.Close() + } } - return frame, flags + return state, nil } -// A callback function that takes a message and returns nothing -type MessageCallback func(*Message) +// Add `frameSize` to my `bytesWritten` and send an ACK message every `kAckInterval` bytes +func (m *msgReceiver) maybeSendAck(frameSize int) { + oldWritten := m.bytesWritten + m.bytesWritten += uint64(frameSize) + if oldWritten > 0 && (oldWritten/kAckInterval) < (m.bytesWritten/kAckInterval) { + m.Sender.sendAck(m.number, m.Type(), m.bytesWritten) + } +} + +// Informs an incoming message that the connection has closed +func (m *msgReceiver) cancelIncoming() { + if m.bodyWriter != nil { + _ = m.bodyWriter.CloseWithError(ErrConnectionClosed) + } + if !m.dispatched && m.inResponseTo != nil { + m.setError(BLIPErrorDomain, DisconnectedCode, "") + m.dispatched = true + m.inResponseTo.responseAvailable(m.Message) + } +} + +//////// UTILITIES + +func (message *Message) assertMutable() { + precondition(message.Outgoing && !message.inProgress, "Message %s is not modifiable", message) +} +func (message *Message) assertOutgoing() { + precondition(message.Outgoing, "Message %s is not outgoing", message) +} +func (message *Message) assertIsRequest() { + precondition(message.Type() == RequestType, "Message %s is not a request", message) +} +func (message *Message) assertIsResponse() { + precondition(message.Type() != RequestType, "Message %s is not a response", message) +} +func (message *Message) assertSent() { + precondition(message.number != 0, "Message %s has not been sent", message) +} func ptr[T any](v T) *T { return &v diff --git a/message_test.go b/message_test.go index 0bb6a3f..7540fb8 100644 --- a/message_test.go +++ b/message_test.go @@ -11,12 +11,11 @@ licenses/APL2.txt. package blip import ( - "bytes" + "fmt" "io" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func init() { @@ -27,49 +26,82 @@ func makeTestRequest() *Message { m := NewRequest() m.Properties["Content-Type"] = "ham/rye" m.Properties["X-Weather"] = "rainy" + m.SetProfile("test") m.SetBody([]byte("The white knight is sliding down the poker. He balances very badly.")) return m } func TestMessageEncoding(t *testing.T) { m := makeTestRequest() - var writer bytes.Buffer - err := m.WriteTo(&writer) - assert.Equal(t, nil, err) - serialized := writer.Bytes() - assert.Equal(t, "\x25Content-Type\x00ham/rye\x00X-Weather\x00rainy\x00The white knight is sliding down the poker. He balances very badly.", string(serialized)) - t.Logf("Encoded as %d bytes", len(serialized)) - - m2 := newIncomingMessage(nil, 1, *m.flags.Load(), nil) - reader := bytes.NewReader(serialized) - err = m2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, m.Properties, m2.Properties) - mbody, _ := m.Body() - m2body, _ := m2.Body() - assert.Equal(t, mbody, m2body) + assert.Equal(t, "test", m.Profile()) + serialized := serializeMessage(t, m) + assert.Equal(t, "\x32Content-Type\x00ham/rye\x00Profile\x00test\x00X-Weather\x00rainy\x00The white knight is sliding down the poker. He balances very badly.", string(serialized)) +} + +func TestMessageStreaming(t *testing.T) { + m := makeTestRequest() + serialized := serializeMessage(t, m) + body, _ := m.Body() + + for breakAt := 1; breakAt <= len(serialized); breakAt++ { + t.Run(fmt.Sprintf("Frame size %d", breakAt), func(t *testing.T) { + // "Receive" message m2 in two pieces, separated at offset `breakAt`. + frame1 := serialized[0:breakAt] + frame2 := serialized[breakAt:] + complete := len(frame2) == 0 + expectedState := dispatchState{ + atStart: len(frame1) >= len(serialized)-len(body), + atEnd: complete, + } + + m2 := newIncomingMessage(nil, 1, *m.flags.Load()) + state, err := m2.addIncomingBytes(frame1, complete) + assert.NoError(t, err) + assert.Equal(t, state, expectedState) + if state.atStart { + // Frame 1 completes the properties, so the reader should be available: + assert.Equal(t, m.Properties, m2.Properties) + reader, err := m2.BodyReader() + assert.NoError(t, err) + buf := make([]byte, 500) + n, err := reader.TryRead(buf) + assert.NoError(t, err) + assert.Equal(t, n, breakAt-51) + assert.Equal(t, body[0:n], buf[0:n]) + + if !expectedState.atEnd { + state, err = m2.addIncomingBytes(frame2, true) + assert.NoError(t, err) + assert.Equal(t, state, dispatchState{false, true}) + n2, err := reader.TryRead(buf[n:]) + assert.NoError(t, err) + assert.Equal(t, len(body)-n, n2) + assert.Equal(t, body, buf[0:(n+n2)]) + } + + n, err = reader.TryRead(buf) + assert.ErrorIs(t, err, io.EOF) + assert.Equal(t, n, 0) + } + }) + } } func TestMessageEncodingCompressed(t *testing.T) { m := makeTestRequest() m.SetCompressed(true) - var writer bytes.Buffer - err := m.WriteTo(&writer) - assert.Equal(t, nil, err) - serialized := writer.Bytes() + serialized := serializeMessage(t, m) // Commented due to test failure: // http://drone.couchbase.io/couchbase/go-blip/4 (test logs: https://gist.github.com/tleyden/ae2aa71978cd11ca5d9d1f6878593cdb) // goassert.Equals(t, string(serialized), "\x1a\x04\x00ham/rye\x00X-Weather\x00rainy\x00\x1f\x8b\b\x00\x00\tn\x88\x00\xff\f\xca\xd1\t\xc5 \f\x05\xd0U\xee\x04\xce\xf1\x06x\v\xd8z\xd1`\x88ń\x8a\xdb\xd7\xcf\x03\xe7߈\xd5$\x88nR[@\x1c\xaeR\xc4*\xcaX\x868\xe1\x19\x9d3\xe1G\\Y\xb3\xddt\xbc\x9c\xfb\xa8\xe8N_\x00\x00\x00\xff\xffs*\xa1\xa6C\x00\x00\x00") // log.Printf("Encoded compressed as %d bytes", len(serialized)) - m2 := newIncomingMessage(nil, 1, *m.flags.Load(), nil) - reader := bytes.NewReader(serialized) - err = m2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, m.Properties, m2.Properties) - assert.Equal(t, m.body, m2.body) - + m2 := newIncomingMessage(nil, 1, *m.flags.Load()) + state, err := m2.addIncomingBytes(serialized, true) + assert.NoError(t, err) + assert.Equal(t, state, dispatchState{true, true}) + assertEqualMessages(t, m, m2.Message) } func BenchmarkMessageEncoding(b *testing.B) { @@ -85,33 +117,13 @@ func BenchmarkMessageEncoding(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - msg.encoder = nil + msg.inProgress = false + sender := &msgSender{Message: msg} for { - _, flags := msg.nextFrameToSend(4090) + _, flags := sender.nextFrameToSend(4090) if flags&kMoreComing == 0 { break } } } } - -func TestMessageDecoding(t *testing.T) { - original := makeTestRequest() - reader, writer := io.Pipe() - go func() { - require.NoError(t, original.WriteTo(writer)) - writer.Close() - }() - - incoming := newIncomingMessage(nil, original.number, *original.flags.Load(), reader) - err := incoming.readProperties() - assert.Equal(t, nil, err) - assert.Equal(t, original.Properties, incoming.Properties) - err = incoming.readProperties() - assert.Equal(t, nil, err) - - body, err := incoming.Body() - assert.Equal(t, nil, err) - assert.Equal(t, original.body, body) - assert.Equal(t, incoming.body, body) -} diff --git a/messagequeue.go b/messagequeue.go index d2ace1b..da5ebb8 100644 --- a/messagequeue.go +++ b/messagequeue.go @@ -20,21 +20,26 @@ const kInitialQueueCapacity = 10 type messageQueue struct { logContext LogContext maxCount int - queue []*Message + queue []*msgSender numRequestsSent MessageNumber cond *sync.Cond + /* (instrumentation for perf testing) + totalSize int + highestCount int + highestTotalSize int + */ } func newMessageQueue(logContext LogContext, maxCount int) *messageQueue { return &messageQueue{ logContext: logContext, - queue: make([]*Message, 0, kInitialQueueCapacity), + queue: make([]*msgSender, 0, kInitialQueueCapacity), cond: sync.NewCond(&sync.Mutex{}), maxCount: maxCount, } } -func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock +func (q *messageQueue) _push(msg *msgSender, new bool) bool { // requires lock if !msg.Outgoing { panic("Not an outgoing message") } @@ -53,7 +58,7 @@ func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock if q.queue[index].Urgent() { index += 2 break - } else if new && q.queue[index].encoder == nil { + } else if new && !q.queue[index].inProgress { // But have to keep message starts in order index += 1 break @@ -74,20 +79,32 @@ func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock copy(q.queue[index+1:n+1], q.queue[index:n]) q.queue[index] = msg + /* (instrumentation for perf testing) + q.totalSize += len(msg.body) + if n+1 > q.highestCount { + q.highestCount = n + 1 + q.logContext.log("messageQueue total size = %d (%d messages)", q.totalSize, n+1) + } + if q.totalSize > q.highestTotalSize { + q.highestTotalSize = q.totalSize + q.logContext.log("messageQueue total size = %d (%d messages)", q.totalSize, n+1) + } + */ if len(q.queue) == 1 { q.cond.Signal() // It's non-empty now, so unblock a waiting pop() } + return true } // Push an item into the queue -func (q *messageQueue) push(msg *Message) bool { +func (q *messageQueue) push(msg *msgSender) bool { return q.pushWithCallback(msg, nil) } // Push an item into the queue, also providing a callback function that will be invoked // after the number is assigned to the message, but before pushing into the queue. -func (q *messageQueue) pushWithCallback(msg *Message, prepushCallback MessageCallback) bool { +func (q *messageQueue) pushWithCallback(msg *msgSender, prepushCallback messageCallback) bool { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -112,13 +129,13 @@ func (q *messageQueue) pushWithCallback(msg *Message, prepushCallback MessageCal } if prepushCallback != nil { - prepushCallback(msg) + prepushCallback(msg.Message) } return q._push(msg, isNew) } -func (q *messageQueue) _maybePop(actuallyPop bool) *Message { +func (q *messageQueue) _maybePop(actuallyPop bool) *msgSender { q.cond.L.Lock() defer q.cond.L.Unlock() for len(q.queue) == 0 && q.queue != nil { @@ -132,7 +149,9 @@ func (q *messageQueue) _maybePop(actuallyPop bool) *Message { msg := q.queue[0] if actuallyPop { q.queue = q.queue[1:] - + /* (instrumentation for perf testing) + q.totalSize -= len(msg.body) + */ if len(q.queue) == q.maxCount-1 { q.cond.Signal() } @@ -140,10 +159,10 @@ func (q *messageQueue) _maybePop(actuallyPop bool) *Message { return msg } -func (q *messageQueue) first() *Message { return q._maybePop(false) } -func (q *messageQueue) pop() *Message { return q._maybePop(true) } +func (q *messageQueue) first() *msgSender { return q._maybePop(false) } +func (q *messageQueue) pop() *msgSender { return q._maybePop(true) } -func (q *messageQueue) find(msgNo MessageNumber, msgType MessageType) *Message { +func (q *messageQueue) find(msgNo MessageNumber, msgType MessageType) *msgSender { q.cond.L.Lock() defer q.cond.L.Unlock() for _, message := range q.queue { @@ -159,14 +178,8 @@ func (q *messageQueue) stop() { q.cond.L.Lock() defer q.cond.L.Unlock() - // Iterate over messages and call close on every message's readcloser, since it's possible that - // a goroutine may be blocked on the reader, thus causing a resource leak. Added during SG #3268 - // diagnosis, but does not fix any reproducible issues. for _, message := range q.queue { - err := message.Close() - if err != nil { - q.logContext.logMessage("Warning: messageQueue encountered error closing message while stopping. Error: %v", err) - } + message.cancelOutgoing() } q.queue = nil diff --git a/messagequeue_test.go b/messagequeue_test.go index 54f6268..c3a74b8 100644 --- a/messagequeue_test.go +++ b/messagequeue_test.go @@ -11,8 +11,6 @@ licenses/APL2.txt. package blip import ( - "bytes" - "io" "log" "sync" "testing" @@ -28,7 +26,7 @@ func TestMessagePushPop(t *testing.T) { // Push a non-urgent message into the queue for i := 0; i < 2; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) @@ -38,7 +36,7 @@ func TestMessagePushPop(t *testing.T) { } // Push an urgent message into the queue - urgentMsg := NewRequest() + urgentMsg := &msgSender{Message: NewRequest()} urgentMsg.SetUrgent(true) pushed := mq.push(urgentMsg) assert.True(t, pushed) @@ -78,7 +76,7 @@ func TestConcurrentAccess(t *testing.T) { // Fill it up to capacity w/ normal messages for i := 0; i < maxSendQueueCount; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) @@ -88,7 +86,7 @@ func TestConcurrentAccess(t *testing.T) { doneWg.Add(2) pusher := func() { for i := 0; i < 100; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) } @@ -149,15 +147,15 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio maxSendQueueCount := 25 mq := newMessageQueue(&TestLogContext{silent: true}, maxSendQueueCount) - // Add normal messages that are "in-progress" since they have a non-nil msg.encoder + // Add normal messages that are "in-progress" for i := 0; i < 5; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) - // set the msg.encoder to something so that the next urgent message will go to the head of the line - msg.encoder = io.NopCloser(&bytes.Buffer{}) + // set inProgress so that the next urgent message will go to the head of the line + msg.inProgress = true } @@ -165,7 +163,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio assert.True(t, mq.first().SerialNumber() == MessageNumber(1)) // Push an urgent message into the queue - urgentMsg := NewRequest() + urgentMsg := &msgSender{Message: NewRequest()} urgentMsg.SetUrgent(true) pushed := mq.push(urgentMsg) assert.True(t, pushed) @@ -181,7 +179,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio mq.pop() // T6: [n5] [n4] [n3] [n2] [u6] - // Since all the normal messages have had frames sent (faked, via non-nil msg.encoder), then the + // Since all the normal messages have had frames sent (faked, via setting inProgress), then the // urgent message should have skipped to the head of the line // assert.True(t, mq.nextMessageIsUrgent()) headOfLine := mq.first() @@ -189,7 +187,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio // Push another urgent message // T7: [n5] [n4] [n3] [u7] [n2] [u6] - anotherUrgentMsg := NewRequest() + anotherUrgentMsg := &msgSender{Message: NewRequest()} anotherUrgentMsg.SetUrgent(true) pushed = mq.push(anotherUrgentMsg) assert.True(t, pushed) diff --git a/pipe.go b/pipe.go new file mode 100644 index 0000000..8435add --- /dev/null +++ b/pipe.go @@ -0,0 +1,180 @@ +/* +Copyright 2023-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package blip + +import ( + "io" + "sync" +) + +// Creates a new pipe, a pair of bound streams. +// Similar to io.Pipe, except the stream is buffered so writes don't block. +func NewPipe() (*PipeReader, *PipeWriter) { + p := &pipe_shared{ + chunks: make([][]byte, 0, 50), + cond: sync.NewCond(&sync.Mutex{}), + } + return &PipeReader{shared: p}, &PipeWriter{shared: p} +} + +// The private state shared between a PipeReader and PipeWriter. +type pipe_shared struct { + chunks [][]byte // ordered list of unread byte-arrays written to the Pipe + err error // Set when closed, to io.EOF or some other error. + cond *sync.Cond // Synchronizes writer & reader +} + +// -------- PIPEWRITER + +// The write end of a pipe. Implements io.WriteCloser. +// Unlike io.Pipe, writes do not block; instead the unread data is buffered in memory. +type PipeWriter struct { + shared *pipe_shared +} + +// Standard Writer method. Does not block. +func (w *PipeWriter) Write(chunk []byte) (n int, err error) { + if len(chunk) == 0 { + // The Writer interface forbids retaining the input, so we must copy it: + copied := make([]byte, len(chunk)) + copy(copied, chunk) + chunk = copied + } + if err = w._add(chunk, nil); err == nil { + n = len(chunk) + } + return +} + +// Closes the pipe. +// The associated PipeReader can still read any remaining data, then it will get an EOF error. +func (w *PipeWriter) Close() error { + return w.CloseWithError(io.EOF) +} + +// Closes the pipe with a custom error. +func (w *PipeWriter) CloseWithError(err error) error { + if err == nil { + err = io.EOF + } + _ = w._add(nil, err) + return nil +} + +func (w *PipeWriter) _add(chunk []byte, err error) error { + // adds a chunk or sets an error; or if there's already an error, returns it. + w.shared.cond.L.Lock() + defer w.shared.cond.L.Unlock() + + if w.shared.err != nil { + return w.shared.err + } + + if err == nil || err == io.EOF { + if len(chunk) > 0 { + w.shared.chunks = append(w.shared.chunks, chunk) + } + } else { + w.shared.chunks = nil // make sure reader sees the custom error ASAP + } + w.shared.err = err + w.shared.cond.Signal() + return nil +} + +// The number of bytes written but not yet read +func (w *PipeWriter) bytesPending() int { + total := 0 + w.shared.cond.L.Lock() + for _, chunk := range w.shared.chunks { + total += len(chunk) + } + w.shared.cond.L.Unlock() + return total +} + +// -------- PIPEREADER + +// The read end of a pipe. Implements io.ReadCloser. +type PipeReader struct { + shared *pipe_shared // Shared state + chunk []byte // The data chunk currently being read from +} + +// Standard Reader method. +func (r *PipeReader) Read(p []byte) (n int, err error) { return r._read(p, true) } + +// Non-blocking read: similar to Read, but if no data is available returns (0, nil). +func (r *PipeReader) TryRead(p []byte) (n int, err error) { return r._read(p, false) } + +func (r *PipeReader) _read(p []byte, wait bool) (n int, err error) { + if len(r.chunk) == 0 { + // Current chunk is exhausted; wait for a new one: + r.chunk, err = r._nextChunk(wait) + } + if len(r.chunk) > 0 { + // Read bytes out of the current chunk: + n = copy(p, r.chunk) + r.chunk = r.chunk[n:] + } + return +} + +// Returns true if a Read call will not block, whether because there's data or EOF or an error. +func (r *PipeReader) CanRead() (ok bool) { + if len(r.chunk) > 0 { + ok = true + } else { + r.shared.cond.L.Lock() + ok = len(r.shared.chunks) > 0 || r.shared.err != nil + r.shared.cond.L.Unlock() + } + return +} + +// Closes the reader. Subsequent PipeReader.Write calls will return io.ErrClosedPipe. +func (r *PipeReader) Close() error { + return r.CloseWithError(io.ErrClosedPipe) +} + +// Closes the reader with a custom error. Subsequent PipeReader.Write calls will return this error. +func (r *PipeReader) CloseWithError(err error) error { + r.shared.cond.L.Lock() + if r.shared.err == nil { + r.shared.err = err + r.shared.chunks = nil + } + r.shared.cond.L.Unlock() + return nil +} + +func (r *PipeReader) _nextChunk(wait bool) ([]byte, error) { + // Returns the next chunk added by the writer, or the error if any. + // If neither is available and `wait` is true, it blocks. + r.shared.cond.L.Lock() + defer r.shared.cond.L.Unlock() + for { + if len(r.shared.chunks) > 0 { + // If there's a chunk in the queue, return it: + chunk := r.shared.chunks[0] + r.shared.chunks = r.shared.chunks[1:] + return chunk, nil + } else if r.shared.err != nil { + // Else if an error is set, return that: + return nil, r.shared.err + } else if !wait { + // Non-blocking call with nothing to read: + return nil, nil + } + // None of the above -- wait for something to happen: + r.shared.cond.Wait() + } +} diff --git a/properties.go b/properties.go index 21b85d9..638cc3f 100644 --- a/properties.go +++ b/properties.go @@ -13,6 +13,7 @@ package blip import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "sort" @@ -24,53 +25,55 @@ type Properties map[string]string // For testing purposes, clients can set this to true to write properties sorted by key var SortProperties = false -// Adapter for io.Reader to io.ByteReader -type byteReader struct { - reader io.Reader -} - -func (br byteReader) ReadByte() (byte, error) { - var p [1]byte - _, err := br.reader.Read(p[:]) - return p[0], err -} +// Implementation-imposed max encoded size of message properties (not part of protocol) +const maxPropertiesLength = 100 * 1024 -// Reads encoded Properties from a stream. -func (properties *Properties) ReadFrom(reader io.Reader) error { - length, err := binary.ReadUvarint(byteReader{reader}) - if err != nil { - return err +// Reads encoded Properties from a byte array. +// On success, returns the Properties map and the number of bytes read. +// If the array doesn't contain the complete properties, returns (nil, 0, nil). +// On failure, returns nil Properties and an error. +func ReadProperties(body []byte) (properties Properties, bytesRead int, err error) { + length, bytesRead := binary.Uvarint(body) + if bytesRead == 0 { + // Not enough bytes to read the varint + return nil, 0, nil + } else if bytesRead < 0 || length > maxPropertiesLength { + return nil, bytesRead, errors.New("invalid properties length in BLIP message") + } else if bytesRead+int(length) > len(body) { + // Incomplete + return nil, 0, nil } + if length == 0 { - return nil - } - body := make([]byte, length) - if _, err := io.ReadFull(reader, body); err != nil { - return err + // Empty properties + return Properties{}, bytesRead, nil } + body = body[bytesRead:] + bytesRead += int(length) + if body[length-1] != 0 { - return fmt.Errorf("Invalid properties (not NUL-terminated)") + return nil, 0, errors.New("invalid properties (not NUL-terminated)") } eachProp := bytes.Split(body[0:length-1], []byte{0}) nProps := len(eachProp) / 2 if nProps*2 != len(eachProp) { - return fmt.Errorf("Odd number of strings in properties") + return nil, bytesRead, errors.New("odd number of strings in properties") } - *properties = Properties{} + properties = Properties{} for i := 0; i < len(eachProp); i += 2 { key := string(eachProp[i]) value := string(eachProp[i+1]) - if _, exists := (*properties)[key]; exists { - return fmt.Errorf("Duplicate property name %q", key) + if _, exists := (properties)[key]; exists { + return nil, bytesRead, fmt.Errorf("duplicate property name %q", key) } - (*properties)[key] = value + properties[key] = value } - return nil + return properties, bytesRead, nil } // Writes Properties to a stream. -func (properties Properties) WriteTo(writer io.Writer) error { +func (properties Properties) WriteEncodedTo(writer io.Writer) error { // First convert the property strings into byte arrays, and add up their sizes: var strings propertyList = make(propertyList, 2*len(properties)) i := 0 @@ -103,6 +106,13 @@ func (properties Properties) WriteTo(writer io.Writer) error { return nil } +// Writes Properties to a byte array. +func (properties Properties) Encode() []byte { + var out bytes.Buffer + _ = properties.WriteEncodedTo(&out) + return out.Bytes() +} + // Properties stored as alternating keys / values type propertyList [][]byte diff --git a/properties_test.go b/properties_test.go index 5751ca6..675740f 100644 --- a/properties_test.go +++ b/properties_test.go @@ -12,6 +12,7 @@ package blip import ( "bytes" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -24,31 +25,37 @@ func init() { func TestReadWriteProperties(t *testing.T) { p := Properties{"Content-Type": "application/octet-stream", "Foo": "Bar"} var writer bytes.Buffer - err := p.WriteTo(&writer) + err := p.WriteEncodedTo(&writer) assert.Equal(t, nil, err) + writer.Write([]byte("FOOBAR")) serialized := writer.Bytes() - assert.Equal(t, "\x2EContent-Type\x00application/octet-stream\x00Foo\x00Bar\x00", string(serialized)) + propertiesLength := len(serialized) - len("FOOBAR") + assert.Equal(t, "\x2EContent-Type\x00application/octet-stream\x00Foo\x00Bar\x00FOOBAR", string(serialized)) - var p2 Properties - reader := bytes.NewReader(serialized) - err = p2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, p, p2) + for dataLen := 0; dataLen <= len(serialized); dataLen++ { + p2, bytesRead, err := ReadProperties(serialized[0:dataLen]) + assert.NoError(t, err) + if dataLen < propertiesLength { + assert.Nil(t, p2) + } else { + assert.Equal(t, p, p2) + assert.Equal(t, propertiesLength, bytesRead) + } + } } func TestReadWriteEmptyProperties(t *testing.T) { var p Properties var writer bytes.Buffer - err := p.WriteTo(&writer) + err := p.WriteEncodedTo(&writer) assert.Equal(t, nil, err) serialized := writer.Bytes() assert.Equal(t, "\x00", string(serialized)) - var p2 Properties - reader := bytes.NewReader(serialized) - err = p2.ReadFrom(reader) + p2, bytesRead, err := ReadProperties(serialized) assert.Equal(t, nil, err) - assert.Equal(t, p, p2) + assert.Equal(t, Properties{}, p2) + assert.Equal(t, len(serialized), bytesRead) } func TestReadBadProperties(t *testing.T) { @@ -56,27 +63,32 @@ func TestReadBadProperties(t *testing.T) { {"", "EOF"}, {"\x00", ""}, {"\x0C", "EOF"}, - {"\x0CX\x00Y\x00Foo\x00Ba", "unexpected EOF"}, + {"\x0CX\x00Y\x00Foo\x00Ba", "EOF"}, {"\x0CX\x00Y\x00Foo\x00Bar\x00", ""}, - {"\x14X\x00Y\x00Foo\x00Bar\x00Foo\x00Zog\x00", "Duplicate property name \"Foo\""}, + {"\x14X\x00Y\x00Foo\x00Bar\x00Foo\x00Zog\x00", "duplicate property name \"Foo\""}, - {"\x02hi", "Invalid properties (not NUL-terminated)"}, - {"\x02h\x00", "Odd number of strings in properties"}, + {"\x02hi", "invalid properties (not NUL-terminated)"}, + {"\x02h\x00", "odd number of strings in properties"}, } - var p2 Properties for i, pair := range bad { - reader := bytes.NewReader([]byte(pair[0])) - err := p2.ReadFrom(reader) - var errStr string - if err == nil { - if reader.Len() != 0 { - t.Errorf("Error decoding #%d %q: No error, but left %d bytes unread", i, pair[0], reader.Len()) + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + serialized := []byte(pair[0]) + p2, bytesRead, err := ReadProperties(serialized) + var errStr string + if err == nil { + if bytesRead == 0 { + errStr = "EOF" + } else if bytesRead != len(serialized) { + t.Errorf("Error decoding #%d %q: No error, but left %d bytes unread", i, pair[0], len(serialized)-bytesRead) + } else { + assert.NotNil(t, p2) + } + } else { + errStr = err.Error() } - } else { - errStr = err.Error() - } - if errStr != pair[1] { - t.Errorf("Error decoding #%d %q: %q (expected %q)", i, pair[0], errStr, pair[1]) - } + if errStr != pair[1] { + t.Errorf("Error decoding #%d %q: %q (expected %q)", i, pair[0], errStr, pair[1]) + } + }) } } diff --git a/protocol.go b/protocol.go index 435927b..0e67a55 100644 --- a/protocol.go +++ b/protocol.go @@ -22,9 +22,6 @@ import ( // Every sub-protocol used by a caller should begin with this string. const WebSocketSubProtocolPrefix = "BLIP_3" -// Domain used in errors returned by BLIP itself. -const BLIPErrorDomain = "BLIP" - //////// MESSAGE TYPE: // Enumeration of the different types of messages in the BLIP protocol. @@ -66,6 +63,31 @@ func (t MessageType) ackSourceType() MessageType { return t - 4 } +//////// PROPERTIES & ERRORS: + +// Message property that indicates what type of message it is +const ProfileProperty = "Profile" + +// Property of an error response that indicates the type of error; +// application-defined, except for the value "BLIP" (BLIPErrorDomain). +const ErrorDomainProperty = "Error-Domain" + +// Property of an error response containing a numeric error code, +// to be interpreted in the context of the error domain; application-defined. +const ErrorCodeProperty = "Error-Code" + +// Domain used in errors returned by BLIP itself. +const BLIPErrorDomain = "BLIP" + +// Standard error codes in the BLIP domain: +const ( + BadRequestCode = 400 // Something's invalid with the request properties or body + ForbiddenCode = 403 // You're not allowed to make this request + NotFoundCode = 404 // No handler for this Profile + HandlerFailedCode = 501 // A handler failed unexpectedly (panic, exception...) + DisconnectedCode = 503 // Fake response delivered if connection closes unexpectedly +) + //////// FRAME FLAGS: type frameFlags uint8 diff --git a/receiver.go b/receiver.go index 1bc9ea3..af7c429 100644 --- a/receiver.go +++ b/receiver.go @@ -14,24 +14,17 @@ import ( "bytes" "encoding/binary" "fmt" - "io" - "log" "runtime/debug" "sync" "sync/atomic" + "time" "github.com/coder/websocket" ) const checksumLength = 4 -type msgStreamer struct { - message *Message - writer io.WriteCloser - bytesWritten uint64 -} - -type msgStreamerMap map[MessageNumber]*msgStreamer +type msgStreamerMap map[MessageNumber]*msgReceiver // The receiving side of a BLIP connection. // Handles receiving WebSocket messages as frames and assembling them into BLIP messages. @@ -50,20 +43,31 @@ type receiver struct { pendingRequests msgStreamerMap // Unfinished REQ messages being assembled pendingResponses msgStreamerMap // Unfinished RES messages being assembled maxPendingResponseNumber MessageNumber // Largest RES # I've seen + + dispatchMutex sync.Mutex // For thread-safe access to the fields below + dispatchCond sync.Cond // Used when receiver stops reading + maxDispatchedBytes int // above this value, receiver stops reading the WebSocket + dispatchedBytes int // Size of dispatched but unhandled incoming requests + dispatchedMessageCount int // Number of dispatched but unhandled incoming requests } func newReceiver(context *Context, conn *websocket.Conn) *receiver { - return &receiver{ - conn: conn, - context: context, - channel: make(chan []byte, 10), - parseError: make(chan error, 1), - frameDecoder: getDecompressor(context), - pendingRequests: msgStreamerMap{}, - pendingResponses: msgStreamerMap{}, + rcvr := &receiver{ + conn: conn, + context: context, + channel: make(chan []byte, 10), + parseError: make(chan error, 1), + frameDecoder: getDecompressor(context), + pendingRequests: msgStreamerMap{}, + pendingResponses: msgStreamerMap{}, + maxDispatchedBytes: context.MaxDispatchedBytes, } + rcvr.dispatchCond = sync.Cond{L: &rcvr.dispatchMutex} + return rcvr } +// Reads WebSocket messages, not returning until the socket is closed. +// Spawns parseLoop in a goroutine and pushes the messages to it. func (r *receiver) receiveLoop() error { defer atomic.AddInt32(&r.activeGoroutines, -1) atomic.AddInt32(&r.activeGoroutines, 1) @@ -73,7 +77,7 @@ func (r *receiver) receiveLoop() error { for { // Receive the next raw WebSocket frame: - _, frame, err := r.conn.Read(r.context.GetCancelCtx()) + msgType, frame, err := r.conn.Read(r.context.GetCancelCtx()) if err != nil { if isCloseError(err) { // lower log level for close @@ -87,18 +91,24 @@ func (r *receiver) receiveLoop() error { return err } - r.channel <- frame + switch msgType { + case websocket.MessageBinary: + r.channel <- frame + default: + r.context.log("Warning: received WebSocket message of type %v", msgType) + } } } +// Goroutine created by receiveLoop that parses BLIP frames and dispatches messages. func (r *receiver) parseLoop() { defer func() { // Panic handler: atomic.AddInt32(&r.activeGoroutines, -1) if p := recover(); p != nil { - log.Printf("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack()) + r.context.log("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack()) err, _ := p.(error) if err == nil { - err = fmt.Errorf("Panic: %v", p) + err = fmt.Errorf("panic: %v", p) } r.fatalError(err) } @@ -119,47 +129,38 @@ func (r *receiver) parseLoop() { } r.context.logFrame("parseLoop stopped") + r.closePendingResponses() + returnDecompressor(r.frameDecoder) r.frameDecoder = nil } +// called on the parseLoop goroutine func (r *receiver) fatalError(err error) { r.context.log("Error: parseLoop closing socket due to error: %v", err) r.parseError <- err - r.stop() + r.conn.Close(websocket.StatusAbnormalClosure, "") } +// called by the sender func (r *receiver) stop() { - - r.closePendingResponses() - r.conn.Close(websocket.StatusNormalClosure, "") - waitForZeroActiveGoroutines(r.context, &r.activeGoroutines) } func (r *receiver) closePendingResponses() { - r.pendingMutex.Lock() defer r.pendingMutex.Unlock() - // There can be goroutines spawned by message.asyncRead() that are blocked waiting to - // read off their end of an io.Pipe, and if the peer abruptly closes a connection which causes - // the sender to stop(), the other side of that io.Pipe must be closed to avoid the goroutine's - // call to unblock on the read() call. This loops through any io.Pipewriters in pendingResponses and - // close them, unblocking the readers and letting the message.asyncRead() goroutines proceed. for _, msgStreamer := range r.pendingResponses { - err := msgStreamer.writer.Close() - if err != nil { - r.context.logMessage("Warning: error closing msgStreamer writer in pending responses while stopping receiver: %v", err) - } + msgStreamer.cancelIncoming() } } func (r *receiver) handleIncomingFrame(frame []byte) error { // Parse BLIP header: if len(frame) < 2 { - return fmt.Errorf("Illegally short frame") + return fmt.Errorf("illegally short frame") } r.frameBuffer.Reset() r.frameBuffer.Write(frame) @@ -191,14 +192,14 @@ func (r *receiver) handleIncomingFrame(frame []byte) error { bufferedFrame := r.frameBuffer.Bytes() frameSize := len(bufferedFrame) if len(frame) < checksumLength { - return fmt.Errorf("Illegally short frame") + return fmt.Errorf("illegally short frame") } - checksumSlice := bufferedFrame[len(bufferedFrame)-checksumLength : len(bufferedFrame)] + checksumSlice := bufferedFrame[len(bufferedFrame)-checksumLength:] checksum := binary.BigEndian.Uint32(checksumSlice) r.frameBuffer.Truncate(r.frameBuffer.Len() - checksumLength) if r.context.LogFrames { - r.context.logFrame("Received frame: %s (flags=%8b, length=%d)", + r.context.logFrame("received frame: %s (flags=%8b, length=%d)", frameString(requestNumber, flags), flags, r.frameBuffer.Len()) } @@ -215,14 +216,14 @@ func (r *receiver) handleIncomingFrame(frame []byte) error { return err } - return r.processFrame(requestNumber, flags, body, frameSize) + return r.processIncomingFrame(requestNumber, flags, body, frameSize) } } -func (r *receiver) processFrame(requestNumber MessageNumber, flags frameFlags, frame []byte, frameSize int) error { +func (r *receiver) processIncomingFrame(requestNumber MessageNumber, flags frameFlags, frame []byte, frameSize int) error { // Look up or create the writer stream for this message: complete := (flags & kMoreComing) == 0 - var msgStream *msgStreamer + var msgStream *msgReceiver var err error switch flags.messageType() { case RequestType: @@ -234,30 +235,42 @@ func (r *receiver) processFrame(requestNumber MessageNumber, flags frameFlags, f default: r.context.log("Warning: Ignoring incoming message type, with flags 0x%x", flags) } + if msgStream == nil { + return err + } // Write the decoded frame body to the stream: - if msgStream != nil { - if _, err := writeFull(frame, msgStream.writer); err != nil { - return err - } else if complete { - if err = msgStream.writer.Close(); err != nil { - r.context.log("Warning: message writer closed with error %v", err) - } - } else { - //FIX: This isn't the right place to do this, because this goroutine doesn't block even - // if the client can't read the message fast enough. The right place to send the ACK is - // in the goroutine that's running msgStream.writer. (Somehow...) - oldWritten := msgStream.bytesWritten - msgStream.bytesWritten += uint64(frameSize) - if oldWritten > 0 && (oldWritten/kAckInterval) < (msgStream.bytesWritten/kAckInterval) { - r.sender.sendAck(requestNumber, flags.messageType(), msgStream.bytesWritten) - } + state, err := msgStream.addIncomingBytes(frame, complete) + if err != nil { + return err + } + + if !complete { + // Not complete yet; send an ACK message every `kAckInterval` bytes: + //FIX: This isn't the right place to do this, because this goroutine doesn't block even + // if the client can't read the message fast enough. + msgStream.maybeSendAck(frameSize) + } + + // Dispatch at first or last frame: + if request := msgStream.inResponseTo; request != nil { + if state.atStart { + // Dispatch response to its request message as soon as properties are available: + r.context.logMessage("Received response %s", msgStream.Message) + msgStream.dispatched = true + request.responseAvailable(msgStream.Message) // Response to outgoing request + } + } else { + if /*state.atStart ||*/ state.atEnd { + // Dispatch request to the dispatcher: + msgStream.dispatched = true + r.dispatch(msgStream.Message) } } return err } -func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) { +func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgReceiver, err error) { r.pendingMutex.Lock() defer r.pendingMutex.Unlock() msgStream = r.pendingRequests[requestNumber] @@ -267,31 +280,23 @@ func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFla } } else if requestNumber == r.numRequestsReceived+1 { r.numRequestsReceived++ - request := newIncomingMessage(r.sender, requestNumber, flags, nil) - atomic.AddInt32(&r.activeGoroutines, 1) - msgStream = &msgStreamer{ - message: request, - writer: request.asyncRead(func(err error) { - r.context.dispatchRequest(request, r.sender) - atomic.AddInt32(&r.activeGoroutines, -1) - }), - } + msgStream = newIncomingMessage(r.sender, requestNumber, flags) if !complete { r.pendingRequests[requestNumber] = msgStream } } else { - return nil, fmt.Errorf("Bad incoming request number %d", requestNumber) + return nil, fmt.Errorf("bad incoming request number %d", requestNumber) } return msgStream, nil } -func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) { +func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgReceiver, err error) { r.pendingMutex.Lock() defer r.pendingMutex.Unlock() msgStream = r.pendingResponses[requestNumber] if msgStream != nil { if msgStream.bytesWritten == 0 { - msgStream.message.flags.Store(&flags) // set flags based on 1st frame of response + msgStream.flags.Store(&flags) // set flags based on 1st frame of response } if complete { delete(r.pendingResponses, requestNumber) @@ -301,20 +306,19 @@ func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFl r.context.log("Warning: Unexpected response frame to my msg #%d", requestNumber) // benign } else { // processing a response frame with a message number higher than any requests I've sent - err = fmt.Errorf("Bogus message number %d in response. Expected to be less than max pending response number (%d)", requestNumber, r.maxPendingResponseNumber) + err = fmt.Errorf("bogus message number %d in response. Expected to be less than max pending response number (%d)", requestNumber, r.maxPendingResponseNumber) } return } -// pendingResponses is accessed from both the receiveLoop goroutine and the sender's goroutine, -// so it needs synchronization. -func (r *receiver) awaitResponse(request *Message, writer io.WriteCloser) { +func (r *receiver) awaitResponse(response *Message) { + // pendingResponses is accessed from both the receiveLoop goroutine and the sender's goroutine, + // so it needs synchronization. r.pendingMutex.Lock() defer r.pendingMutex.Unlock() - number := request.number - r.pendingResponses[number] = &msgStreamer{ - message: request, - writer: writer, + number := response.number + r.pendingResponses[number] = &msgReceiver{ + Message: response, } if number > r.maxPendingResponseNumber { r.maxPendingResponseNumber = number @@ -327,16 +331,72 @@ func (r *receiver) backlog() (pendingRequest, pendingResponses int) { return len(r.pendingRequests), len(r.pendingResponses) } -// Why isn't this in the io package already, when ReadFull is? -func writeFull(buf []byte, writer io.Writer) (nWritten int, err error) { - for len(buf) > 0 { - var n int - n, err = writer.Write(buf) - if err != nil { - break +//////// REQUEST DISPATCHING & FLOW CONTROL + +func (r *receiver) dispatch(request *Message) { + sender := r.sender + requestSize := request.bodySize() + onComplete := func() { + response := request.Response() + if panicked := recover(); panicked != nil { + if handler := sender.context.HandlerPanicHandler; handler != nil { + handler(request, response, panicked) + } else { + stack := debug.Stack() + sender.context.log("PANIC handling BLIP request %v: %v:\n%s", request, panicked, stack) + if response != nil { + // (It is generally considered bad security to reveal internal state like error + // messages or stack dumps in a network response.) + response.SetError(BLIPErrorDomain, HandlerFailedCode, "Internal Error") + } + } + } + + r.subDispatchedBytes(requestSize) + + if response != nil { + sender.send(response) } - nWritten += n - buf = buf[n:] } - return + + r.addDispatchedBytes(requestSize) + r.context.RequestHandler(request, onComplete) + r.waitOnDispatchedBytes() +} + +func (r *receiver) addDispatchedBytes(n int) { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + r.dispatchedBytes += n + r.dispatchedMessageCount++ + r.dispatchMutex.Unlock() + } +} + +func (r *receiver) subDispatchedBytes(n int) { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + prevBytes := r.dispatchedBytes + r.dispatchedBytes = prevBytes - n + r.dispatchedMessageCount-- + if prevBytes > r.maxDispatchedBytes && r.dispatchedBytes <= r.maxDispatchedBytes { + r.dispatchCond.Signal() + } + r.dispatchMutex.Unlock() + } +} + +func (r *receiver) waitOnDispatchedBytes() { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + if r.dispatchedBytes > r.maxDispatchedBytes { + start := time.Now() + r.context.log("WebSocket receiver paused (%d requests being handled, %d bytes)", r.dispatchedMessageCount, r.dispatchedBytes) + for r.dispatchedBytes > r.maxDispatchedBytes { + r.dispatchCond.Wait() + } + r.context.log("...WebSocket receiver resuming after %v (now %d requests being handled, %d bytes)", time.Since(start), r.dispatchedMessageCount, r.dispatchedBytes) + } + r.dispatchMutex.Unlock() + } } diff --git a/sender.go b/sender.go index 362a5bf..22f0c53 100644 --- a/sender.go +++ b/sender.go @@ -14,11 +14,9 @@ import ( "bytes" "context" "encoding/binary" - "log" "runtime/debug" "strings" "sync" - "sync/atomic" "time" "github.com/coder/websocket" @@ -42,9 +40,8 @@ type Sender struct { conn *websocket.Conn receiver *receiver queue *messageQueue - icebox map[msgKey]*Message - curMsg *Message - numRequestsSent MessageNumber + icebox map[msgKey]*msgSender + curMsg *msgSender requeueLock sync.Mutex activeGoroutines int32 websocketPingInterval time.Duration @@ -59,7 +56,7 @@ func newSender(c *Context, conn *websocket.Conn, receiver *receiver) *Sender { conn: conn, receiver: receiver, queue: newMessageQueue(c, c.MaxSendQueueCount), - icebox: map[msgKey]*Message{}, + icebox: map[msgKey]*msgSender{}, websocketPingInterval: c.WebsocketPingInterval, ctx: ctx, ctxCancel: ctxCancel, @@ -80,7 +77,7 @@ func (sender *Sender) Send(msg *Message) bool { // Posts a request or response to be delivered asynchronously. // Returns false if the message can't be queued because the Sender has stopped. func (sender *Sender) send(msg *Message) bool { - if msg.Sender != nil || msg.encoder != nil { + if msg.Sender != nil { panic("Message is already enqueued") } msg.Sender = sender @@ -93,17 +90,12 @@ func (sender *Sender) send(msg *Message) bool { prePushCallback := func(prePushMsg *Message) { if prePushMsg.Type() == RequestType && !prePushMsg.NoReply() { response := prePushMsg.createResponse() - atomic.AddInt32(&sender.activeGoroutines, 1) - writer := response.asyncRead(func(err error) { - // TODO: the error passed into this callback is currently being ignored. Calling response.SetError() causes: "panic: Message can't be modified" - prePushMsg.responseComplete(response) - atomic.AddInt32(&sender.activeGoroutines, -1) - }) - sender.receiver.awaitResponse(response, writer) + sender.receiver.awaitResponse(response) } } - return sender.queue.pushWithCallback(msg, prePushCallback) + msgSender := &msgSender{Message: msg} + return sender.queue.pushWithCallback(msgSender, prePushCallback) } @@ -149,7 +141,7 @@ func (sender *Sender) start() { go func() { defer func() { if panicked := recover(); panicked != nil { - log.Printf("PANIC in BLIP sender: %v\n%s", panicked, debug.Stack()) + sender.context.log("PANIC in BLIP sender: %v\n%s", panicked, debug.Stack()) } }() @@ -200,9 +192,6 @@ func (sender *Sender) start() { err := sender.conn.Write(sender.ctx, websocket.MessageBinary, frameBuffer.Bytes()) if err != nil { sender.context.logFrame("Sender error writing framebuffer (len=%d). Error: %v", len(frameBuffer.Bytes()), err) - if err := msg.Close(); err != nil { - sender.context.logFrame("Sender error closing message. Error: %v", err) - } } frameBuffer.Reset() @@ -212,6 +201,8 @@ func (sender *Sender) start() { panic("empty frame should not have moreComing") } sender.requeue(msg, uint64(bytesSent)) + } else { + msg.sent() } } returnCompressor(frameEncoder) @@ -253,7 +244,7 @@ func (sender *Sender) start() { //////// FLOW CONTROL: -func (sender *Sender) popNextMessage() *Message { +func (sender *Sender) popNextMessage() *msgSender { sender.requeueLock.Lock() sender.curMsg = nil sender.requeueLock.Unlock() @@ -270,11 +261,11 @@ func (sender *Sender) popNextMessage() *Message { return msg } -func (sender *Sender) requeue(msg *Message, bytesSent uint64) { +func (sender *Sender) requeue(msg *msgSender, bytesSent uint64) { sender.requeueLock.Lock() defer sender.requeueLock.Unlock() - msg.bytesSent += bytesSent - if msg.bytesSent <= msg.bytesAcked+kMaxUnackedBytes { + msg.addBytesSent(bytesSent) + if !msg.needsAck() { // requeue it so it can send its next frame later sender.queue.push(msg) } else { @@ -289,14 +280,14 @@ func (sender *Sender) receivedAck(requestNumber MessageNumber, msgType MessageTy sender.requeueLock.Lock() defer sender.requeueLock.Unlock() if msg := sender.queue.find(requestNumber, msgType); msg != nil { - msg.bytesAcked = bytesReceived + msg.receivedAck(bytesReceived) } else if msg := sender.curMsg; msg != nil && msg.number == requestNumber && msg.Type() == msgType { - msg.bytesAcked = bytesReceived + msg.receivedAck(bytesReceived) } else { key := msgKey{msgNo: requestNumber, msgType: msgType} if msg := sender.icebox[key]; msg != nil { - msg.bytesAcked = bytesReceived - if msg.bytesSent <= msg.bytesAcked+kMaxUnackedBytes { + msg.receivedAck(bytesReceived) + if !msg.needsAck() { sender.context.logFrame("Resuming %v", msg) delete(sender.icebox, key) sender.queue.push(msg) @@ -316,5 +307,4 @@ func (sender *Sender) sendAck(msgNo MessageNumber, msgType MessageType, bytesRec if err != nil { sender.context.logFrame("Sender error writing ack. Error: %v", err) } - } diff --git a/sender_test.go b/sender_test.go index 77f814f..d7b8b56 100644 --- a/sender_test.go +++ b/sender_test.go @@ -53,7 +53,8 @@ func TestStopSenderClearsAllMessageQueues(t *testing.T) { "id": fmt.Sprint(i), } msg.Properties = msgProp - sender.queue.push(msg) + msgSender := &msgSender{Message: msg} + sender.queue.push(msgSender) } for i := 10; i < 15; i++ { msg := NewRequest() @@ -61,7 +62,8 @@ func TestStopSenderClearsAllMessageQueues(t *testing.T) { "id": fmt.Sprint(i), } msg.Properties = msgProp - sender.icebox[msgKey{msgNo: MessageNumber(i)}] = msg + msgSender := &msgSender{Message: msg} + sender.icebox[msgKey{msgNo: MessageNumber(i)}] = msgSender } // close sender diff --git a/thread_pool.go b/thread_pool.go new file mode 100644 index 0000000..0db311c --- /dev/null +++ b/thread_pool.go @@ -0,0 +1,121 @@ +/* +Copyright 2023-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package blip + +import ( + "log" + "runtime" + "runtime/debug" +) + +// Runs functions asynchronously using a fixed-size pool of goroutines. +// Intended for use by Dispatchers and Handlers, though it's not limited to that. +type ThreadPool struct { + Concurrency int // The number of goroutines. 0 means GOMAXPROCS + PanicHandler func(err interface{}) // Called if a function panics + + concurrency int // Actual concurrency + channel chan<- func() // Queue of calls to be dispatched + terminator chan bool // Goroutines stop when this is closed +} + +// Starts a ThreadPool. +func (p *ThreadPool) Start() { + p.concurrency = p.Concurrency + if p.concurrency <= 0 { + p.concurrency = runtime.GOMAXPROCS(0) + } + ch := make(chan func(), 1000) //TODO: What real limit? + p.channel = ch + p.terminator = make(chan bool) + + handlerLoop := func(i int) { + // Each goroutine in the pool runs this loop, calling functions from the channel + // until it either reads a nil function, or something happens to the terminator. + for { + select { + case fn := <-ch: + if fn != nil { + p.perform(i, fn) + } else { + return + } + case <-p.terminator: + return + } + } + } + + for i := 0; i < p.concurrency; i++ { + go handlerLoop(i + 1) + } +} + +// Stops a ThreadPool after it completes all currently running or queued tasks. +func (p *ThreadPool) Stop() { + if p.channel != nil { + for i := 0; i < p.concurrency; i++ { + p.channel <- nil // each nil stops one goroutine + } + } +} + +// Stops a ThreadPool ASAP. Currently running tasks will complete; queued tasks will be dropped. +func (p *ThreadPool) StopImmediately() { + if p.terminator != nil { + close(p.terminator) + } +} + +// Schedules the function to be called on ThreadPool's next available goroutine. +func (p *ThreadPool) Go(fn func()) { + if p.channel == nil { + panic("ThreadPool has not been started") + } else if fn == nil { + panic("Invalid nil function") + } + p.channel <- fn +} + +// Given a SynchronousHandler function, returns an AsyncHandler function that will call the +// wrapped handler on one of the ThreadPool's goroutines. +func (p *ThreadPool) WrapSynchronousHandler(handler SynchronousHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + p.Go(func() { + defer onComplete() + handler(request) + }) + } +} + +// Given an AsyncHandler function, returns an AsyncHandler function that will call the +// wrapped handler on one of the ThreadPool's goroutines. +func (p *ThreadPool) WrapAsyncHandler(handler AsyncHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + p.Go(func() { + handler(request, onComplete) + }) + } +} + +func (p *ThreadPool) perform(i int, fn func()) { + defer func() { + if panicked := recover(); panicked != nil { + if p.PanicHandler != nil { + p.PanicHandler(panicked) + } else { + log.Printf("PANIC in ThreadPool[%d] function: %v\n%s", i, panicked, debug.Stack()) + } + } + }() + + fn() +} diff --git a/util.go b/util.go index 164ec95..d9ded3a 100644 --- a/util.go +++ b/util.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package blip import ( + "fmt" "sync/atomic" "time" ) @@ -46,3 +47,18 @@ func errorFromChannel(c chan error) error { } return nil } + +func min(a, b int) int { + if a < b { + return a + } else { + return b + } +} + +// Simple assertion that panics if the condition isn't met. +func precondition(condition bool, panicMessage string, args ...interface{}) { + if !condition { + panic(fmt.Sprintf("Precondition failed! "+panicMessage, args...)) + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..5611230 --- /dev/null +++ b/util_test.go @@ -0,0 +1,97 @@ +// Copyright 2023-Present Couchbase, Inc. +// +// Use of this software is governed by the Business Source License included +// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified +// in that file, in accordance with the Business Source License, use of this +// software will be governed by the Apache License, Version 2.0, included in +// the file licenses/APL2.txt. + +package blip + +import ( + "bytes" + "fmt" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Starts a WebSocket server on the Context, returning its net.Listener. +// The server runs on a background goroutine. +// Close the Listener when the test finishes. +func startTestListener(t *testing.T, serverContext *Context) net.Listener { + mux := http.NewServeMux() + mux.Handle("/blip", serverContext.WebSocketServer()) + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Error opening WebSocket listener: %v", err) + } + go func() { + _ = http.Serve(listener, mux) + }() + return listener +} + +// Connects the Context to a Listener created by `startTestListener`. +// Close the Sender when the test finishes. +func startTestClient(t *testing.T, clientContext *Context, listener net.Listener) *Sender { + port := listener.Addr().(*net.TCPAddr).Port + destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) + sender, err := clientContext.Dial(destUrl) + if err != nil { + t.Fatalf("Error opening WebSocket client: %v", err) + } + return sender +} + +// Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout +// TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo. +func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { + + // Create a channel so that a goroutine waiting on the waitgroup can send it's result (if any) + wgFinished := make(chan bool) + + go func() { + wg.Wait() + wgFinished <- true + }() + + select { + case <-wgFinished: + return nil + case <-time.After(timeout): + return fmt.Errorf("timed out waiting after %v", timeout) + } +} + +// Returns the serialized form (properties + body) of a Message. +func serializeMessage(t *testing.T, m *Message) []byte { + var writer bytes.Buffer + err := m.WriteEncodedTo(&writer) + assert.NoError(t, err) + return writer.Bytes() +} + +// Asserts that two Messages are identical +func assertEqualMessages(t *testing.T, m, m2 *Message) bool { + if !assert.Equal(t, m.flags.Load(), m2.flags.Load()) || !assert.Equal(t, m.Properties, m2.Properties) { + return false + } + mb, err := m.Body() + m2b, err2 := m2.Body() + return assert.NoError(t, err) && assert.NoError(t, err2) && assert.Equal(t, mb, m2b) + +} + +// requireBLIPError requires that `response` contains a BLIP error with the given domain and code. +func requireBLIPError(t *testing.T, response *Message, domain string, code int) { + responseError := response.Error() + require.Errorf(t, responseError, "Expected a BLIP error") + require.Equal(t, domain, responseError.Domain) + require.Equal(t, code, responseError.Code) +}