From b36efc8524f7adb28747a7de2a6bf6d1270bdbbb Mon Sep 17 00:00:00 2001 From: Jens Alfke Date: Tue, 4 Apr 2023 16:08:19 -0700 Subject: [PATCH 1/6] Improved concurrency; fewer goroutines; new handler API - Don't create new goroutines to encode/decode Message bodies. - Message state that's only needed while receiving or sending is broken out into new structs msgSender and msgStreamer. - New API for handling requests. Added AsyncHandler type, Dispatcher interface, and some Dispatcher implementations for typical tasks like handler-per-Profile and limited concurrency. - New ThreadPool class for handling requests asynchronously without spawning more goroutines. - receiver goroutine blocks and stops reading from the socket when the total size of received-but-unhandled requests exceeds a limit. (This exerts backpressure on the client.) CBG-2952 --- context.go | 94 +++---- context_test.go | 106 ++------ dispatcher.go | 203 +++++++++++++++ expvar.go | 16 -- functional_test.go | 62 ++--- go.mod | 3 +- go.sum | 10 +- message.go | 576 +++++++++++++++++++++++++++---------------- message_test.go | 114 +++++---- messagequeue.go | 51 ++-- messagequeue_test.go | 24 +- pipe.go | 180 ++++++++++++++ properties.go | 66 +++-- properties_test.go | 66 +++-- protocol.go | 28 ++- receiver.go | 212 ++++++++++------ sender.go | 46 ++-- thread_pool.go | 121 +++++++++ util.go | 16 ++ util_test.go | 101 ++++++++ 20 files changed, 1422 insertions(+), 673 deletions(-) create mode 100644 dispatcher.go create mode 100644 pipe.go create mode 100644 thread_pool.go create mode 100644 util_test.go diff --git a/context.go b/context.go index ed085a3..f97cbd3 100644 --- a/context.go +++ b/context.go @@ -16,7 +16,6 @@ import ( "io" "math/rand" "net/http" - "runtime/debug" "strings" "sync/atomic" "time" @@ -24,17 +23,6 @@ import ( "nhooyr.io/websocket" ) -// A function that handles an incoming BLIP request and optionally sends a response. -// A handler is called on a new goroutine so it can take as long as it needs to. -// For example, if it has to send a synchronous network request before it can construct -// a response, that's fine. -type Handler func(request *Message) - -// Utility function that responds to a Message with a 404 error. -func Unhandled(request *Message) { - request.Response().SetError(BLIPErrorDomain, 404, "No handler for BLIP request") -} - // Defines how incoming requests are dispatched to handler functions. type Context struct { @@ -45,14 +33,14 @@ type Context struct { // The currently used WebSocket subprotocol by the client, set on a successful handshake. activeSubProtocol string - HandlerForProfile map[string]Handler // Handler function for a request Profile - DefaultHandler Handler // Handler for all otherwise unhandled requests - FatalErrorHandler func(error) // Called when connection has a fatal error - HandlerPanicHandler func(request, response *Message, err interface{}) // Called when a profile handler panics - MaxSendQueueCount int // Max # of messages being sent at once (if >0) - Logger LogFn // Logging callback; defaults to log.Printf - LogMessages bool // If true, will log about messages - LogFrames bool // If true, will log about frames (very verbose) + RequestHandler AsyncHandler // Callback that handles incoming requests + FatalErrorHandler FatalErrorHandler // Called when connection has a fatal error + HandlerPanicHandler HandlerPanicHandler // Called when a profile handler panics + MaxSendQueueCount int // Max # of messages being sent at once (if >0) + MaxDispatchedBytes int // Max total size of incoming requests being dispatched (if >0) + Logger LogFn // Logging callback; defaults to log.Printf + LogMessages bool // If true, will log about messages + LogFrames bool // If true, will log about frames (very verbose) OnExitCallback func() // OnExitCallback callback invoked when the underlying connection closes and the receive loop exits. @@ -61,10 +49,19 @@ type Context struct { // An identifier that uniquely defines the context. NOTE: Random Number Generator not seeded by go-blip. ID string + HandlerForProfile map[string]Handler // deprecated; use RequestHandler & ByProfileDispatcher + DefaultHandler Handler // deprecated; use RequestHandler & ByProfileDispatcher + bytesSent atomic.Uint64 // Number of bytes sent bytesReceived atomic.Uint64 // Number of bytes received } +// A function called when a Handler function panics. +type HandlerPanicHandler func(request, response *Message, err interface{}) + +// A function called when the connection closes due to a fatal protocol error. +type FatalErrorHandler func(error) + // Defines a logging interface for use within the blip codebase. Implemented by Context. // Any code that needs to take a Context just for logging purposes should take a Logger instead. type LogContext interface { @@ -98,6 +95,13 @@ func NewContextCustomID(id string, appProtocolIds ...string) (*Context, error) { } func (context *Context) start(ws *websocket.Conn) *Sender { + if context.RequestHandler == nil { + // Compatibility mode: If the app hasn't set a RequestHandler, set one that uses the old + // handlerForProfile and defaultHandler. + context.RequestHandler = context.compatibilityHandler + } else if len(context.HandlerForProfile) > 0 || context.DefaultHandler != nil { + panic("blip.Context cannot have both a RequestHandler and legacy handlerForProfile or defaultHandler") + } r := newReceiver(context, ws) r.sender = newSender(context, ws, r) r.sender.start() @@ -166,6 +170,7 @@ func (context *Context) DialConfig(opts *DialOptions) (*Sender, error) { // If the receiveLoop terminates, stop the sender as well defer sender.Stop() + // defer context.dispatcher.stop() // Update Expvar stats for number of outstanding goroutines incrReceiverGoroutines() @@ -241,6 +246,7 @@ func (bwss *blipWebsocketServer) handle(ws *websocket.Conn) { sender := bwss.blipCtx.start(ws) err := sender.receiver.receiveLoop() sender.Stop() + // bwss.blipCtx.dispatcher.stop() if err != nil && !isCloseError(err) { bwss.blipCtx.log("BLIP/Websocket Handler exited with error: %v", err) if bwss.blipCtx.FatalErrorHandler != nil { @@ -250,52 +256,6 @@ func (bwss *blipWebsocketServer) handle(ws *websocket.Conn) { ws.Close(websocket.StatusNormalClosure, "") } -//////// DISPATCHING MESSAGES: - -func (context *Context) dispatchRequest(request *Message, sender *Sender) { - defer func() { - // On return/panic, send the response: - response := request.Response() - if panicked := recover(); panicked != nil { - if context.HandlerPanicHandler != nil { - context.HandlerPanicHandler(request, response, panicked) - } else { - stack := debug.Stack() - context.log("PANIC handling BLIP request %v: %v:\n%s", request, panicked, stack) - if response != nil { - response.SetError(BLIPErrorDomain, 500, fmt.Sprintf("Panic: %v", panicked)) - } - } - } - if response != nil { - sender.send(response) - } - }() - - context.logMessage("Incoming BLIP Request: %s", request) - handler := context.HandlerForProfile[request.Properties["Profile"]] - if handler == nil { - handler = context.DefaultHandler - if handler == nil { - handler = Unhandled - } - } - handler(request) -} - -func (context *Context) dispatchResponse(response *Message) { - defer func() { - // On return/panic, log a warning: - if panicked := recover(); panicked != nil { - stack := debug.Stack() - context.log("PANIC handling BLIP response %v: %v:\n%s", response, panicked, stack) - } - }() - - context.logMessage("Incoming BLIP Response: %s", response) - //panic("UNIMPLEMENTED") //TODO -} - //////// LOGGING: func (context *Context) log(format string, params ...interface{}) { @@ -314,6 +274,8 @@ func (context *Context) logFrame(format string, params ...interface{}) { } } +//////// UTILITIES: + func includesProtocol(header string, protocols []string) (string, bool) { for _, item := range strings.Split(header, ",") { for _, protocol := range protocols { diff --git a/context_test.go b/context_test.go index 4a56583..1cf131b 100644 --- a/context_test.go +++ b/context_test.go @@ -14,7 +14,6 @@ import ( "fmt" "log" "net" - "net/http" "sync" "testing" "time" @@ -78,17 +77,8 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - http.Handle("/TestServerAbruptlyCloseConnectionBehavior", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, nil)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -96,12 +86,8 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/TestServerAbruptlyCloseConnectionBehavior", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() // Create echo request echoRequest := NewRequest() @@ -122,15 +108,7 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { // Read the echo response response := echoRequest.Response() // <--- SG #3268 was causing this to block indefinitely - responseBody, err := response.Body() - - // Assertions about echo response (these might need to be altered, maybe what's expected in this scenario is actually an error) - assert.True(t, err == nil) - assert.True(t, len(responseBody) == 0) - - // TODO: add more assertions about the response. I'm not seeing any errors, or any - // TODO: way to differentiate this response with a normal response other than having an empty body - + assertBLIPError(t, response, BLIPErrorDomain, DisconnectedCode) } /* @@ -200,12 +178,7 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { sent := clientSender.Send(echoAmplifyRequest) assert.True(t, sent) echoAmplifyResponse := echoAmplifyRequest.Response() // <--- SG #3268 was causing this to block indefinitely - echoAmplifyResponseBody, _ := echoAmplifyResponse.Body() - assert.True(t, len(echoAmplifyResponseBody) == 0) - - // TODO: add more assertions about the response. I'm not seeing any errors, or any - // TODO: way to differentiate this response with a normal response other than having an empty body - + assertBLIPError(t, echoAmplifyResponse, BLIPErrorDomain, DisconnectedCode) } // Create a blip profile handler to respond to echo requests and then abruptly close the socket @@ -235,17 +208,8 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - http.Handle("/TestClientAbruptlyCloseConnectionBehavior", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, nil)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -253,13 +217,6 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/TestClientAbruptlyCloseConnectionBehavior", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } - // Handle EchoAmplifyData that should be initiated by server in response to getting incoming echo requests dispatchEchoAmplify := func(request *Message) { _, err := request.Body() @@ -273,6 +230,9 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { } blipContextEchoClient.HandlerForProfile["BLIPTest/EchoAmplifyData"] = dispatchEchoAmplify + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() + // Create echo request echoRequest := NewRequest() echoRequest.SetProfile("BLIPTest/EchoData") @@ -295,7 +255,7 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { responseBody, err := response.Body() // Assertions about echo response (these might need to be altered, maybe what's expected in this scenario is actually an error) - assert.True(t, err == nil) + assert.NoError(t, err) assert.Equal(t, "hello", string(responseBody)) // Wait until the amplify request was received by client (from server), and that the server read the response @@ -391,21 +351,8 @@ func TestUnsupportedSubProtocol(t *testing.T) { serverCtx.LogMessages = true serverCtx.LogFrames = true - server := serverCtx.WebSocketServer() - - mux := http.NewServeMux() - mux.Handle("/someBlip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - go func() { - err := http.Serve(listener, mux) - if err != nil { - panic(err) - } - }() + listener := startTestListener(t, serverCtx) + defer listener.Close() // Client client, err := NewContext(testCase.ClientProtocol...) @@ -413,13 +360,15 @@ func TestUnsupportedSubProtocol(t *testing.T) { t.Fatal(err) } port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/someBlip", port) + destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) s, err := client.Dial(destUrl) if testCase.ExpectError { assert.True(t, err != nil) } else { assert.Equal(t, nil, err) + } + if s != nil { s.Close() } @@ -430,24 +379,3 @@ func TestUnsupportedSubProtocol(t *testing.T) { }) } } - -// Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout -// TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo. -func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { - - // Create a channel so that a goroutine waiting on the waitgroup can send it's result (if any) - wgFinished := make(chan bool) - - go func() { - wg.Wait() - wgFinished <- true - }() - - select { - case <-wgFinished: - return nil - case <-time.After(timeout): - return fmt.Errorf("Timed out waiting after %v", timeout) - } - -} diff --git a/dispatcher.go b/dispatcher.go new file mode 100644 index 0000000..9915ad9 --- /dev/null +++ b/dispatcher.go @@ -0,0 +1,203 @@ +// Copyright 2023-Present Couchbase, Inc. +// +// Use of this software is governed by the Business Source License included +// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified +// in that file, in accordance with the Business Source License, use of this +// software will be governed by the Apache License, Version 2.0, included in +// the file licenses/APL2.txt. + +package blip + +import ( + "sync" +) + +//////// REQUEST HANDLER FUNCTIONS + +// A function that handles an incoming BLIP request and optionally sends a response. +// The request is not considered handled until this function returns. +// +// After the handler returns, any response you've added to the message (by calling +// `request.Response()“) will be sent, unless the message has the NoReply flag. +// If the message needs a response but none was created, a default empty response will be sent. +// If the handler panics, an error response is sent instead. +type SynchronousHandler = func(request *Message) + +// A function that asynchronously handles an incoming BLIP request and optionally sends a response. +// The handler function may return immediately without handling the request. +// The request is not considered handled until the `onComplete` callback is called, +// from any goroutine. +// +// The callback MUST be called eventually, even if the request has the NoReply flag; +// the caller of this handler may be using the callback to track and limit the number of concurrent +// handlers, so failing to call the callback could eventually block delivery of requests. +// +// The function is allowed to call `onComplete` before it returns, i.e. run synchronously, +// but it should still try to return as quickly as possible to avoid blocking upstream code. +type AsyncHandler = func(request *Message, onComplete RequestCompletedCallback) + +type Handler = SynchronousHandler + +type RequestCompletedCallback = func() + +// Utility SynchronousHandler function that responds to a Message with a 404 error. +func Unhandled(request *Message) { + //log.Printf("Warning: Unhandled BLIP message: %v (Profile=%q)", request, request.Profile()) + request.Response().SetError(BLIPErrorDomain, NotFoundCode, "No handler for BLIP request") +} + +// Utility AsyncHandler function that responds to a Message with a 404 error. +func UnhandledAsync(request *Message, onComplete RequestCompletedCallback) { + defer onComplete() + Unhandled(request) +} + +// A utility that wraps a SynchronousHandler function in an AsyncHandler function. +// When the returned AsyncHandler is called, it will asynchronously run the wrapped handler +// *on a new goroutine* and then call the completion routine. +func AsAsyncHandler(handler SynchronousHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + go func() { + defer onComplete() + handler(request) + }() + } +} + +// AsyncHandler function that uses the Context's old handlerForProfile and defaultHandler fields. +// Used as the RequestHandler when the client hasn't set one. +// For compatibility reasons it does not copy the HandlerForProfile or DefaultHandler, +// rather it accesses the ones in the Context on every call (without any locking!) +// This is because there are tests that change those properties while the connection is open. +func (context *Context) compatibilityHandler(request *Message, onComplete RequestCompletedCallback) { + profile := request.Properties[ProfileProperty] + handler := context.HandlerForProfile[profile] + if handler == nil { + handler = context.DefaultHandler + if handler == nil { + handler = Unhandled + } + } + // Old handlers ran on individual goroutines: + go func() { + defer onComplete() + handler(request) + }() +} + +//////// DISPATCHER INTERFACE: + +// An interface that provides an AsyncHandler function for incoming requests. +// For any Dispatcher instance `d`, `d.Dispatch` without parentheses is an AsyncHandler. +type Dispatcher interface { + Dispatch(request *Message, onComplete RequestCompletedCallback) +} + +//////// BY-PROFILE DISPATCHER + +// A Dispatcher implementation that routes messages to handlers based on the Profile property. +type ByProfileDispatcher struct { + mutex sync.Mutex + profiles map[string]AsyncHandler // Handler for each Profile + defaultHandler AsyncHandler // Fallback handler +} + +// Sets the AsyncHandler function for a specific Profile property value. +// A nil value removes the current handler. +// This method is thread-safe and may be called while the connection is active. +func (d *ByProfileDispatcher) SetHandler(profile string, handler AsyncHandler) { + d.mutex.Lock() + defer d.mutex.Unlock() + if handler != nil { + if d.profiles == nil { + d.profiles = map[string]AsyncHandler{} + } + d.profiles[profile] = handler + } else { + delete(d.profiles, profile) + } +} + +// Sets the default handler, when no Profile matches. +// If this is not set, the default default handler responds synchronously with a BLIP/404 error. +// This method is thread-safe and may be called while the connection is active. +func (d *ByProfileDispatcher) SetDefaultHandler(handler AsyncHandler) { + d.mutex.Lock() + d.defaultHandler = handler + d.mutex.Unlock() +} + +func (d *ByProfileDispatcher) Dispatch(request *Message, onComplete RequestCompletedCallback) { + profile := request.Properties[ProfileProperty] + d.mutex.Lock() + handler := d.profiles[profile] + if handler == nil { + handler = d.defaultHandler + if handler == nil { + handler = UnhandledAsync + } + } + d.mutex.Unlock() + + handler(request, onComplete) +} + +//////// THROTTLING DISPATCHER + +// A Dispatcher implementation that forwards to a given AsyncHandler, but only allows a certain +// number of concurrent calls. +// Excess requests will be queued and later dispatched in the order received. +type ThrottlingDispatcher struct { + Handler AsyncHandler + MaxConcurrency int + mutex sync.Mutex + concurrency int + pending []savedDispatch +} + +type savedDispatch struct { + request *Message + onComplete RequestCompletedCallback +} + +// Initializes a ThrottlingDispatcher. +func (d *ThrottlingDispatcher) Init(maxConcurrency int, handler AsyncHandler) { + d.Handler = handler + d.MaxConcurrency = maxConcurrency +} + +func (d *ThrottlingDispatcher) Dispatch(request *Message, onComplete RequestCompletedCallback) { + d.mutex.Lock() + if d.concurrency >= d.MaxConcurrency { + // Too many running; add this one to the queue: + d.pending = append(d.pending, savedDispatch{request: request, onComplete: onComplete}) + d.mutex.Unlock() + } else { + // Dispatch it now: + d.concurrency++ + d.mutex.Unlock() + d.dispatchNow(request, onComplete) + } +} + +func (d *ThrottlingDispatcher) dispatchNow(request *Message, onComplete RequestCompletedCallback) { + d.Handler(request, func() { + // When the handler finishes, decrement the concurrency count and start a queued request: + var next savedDispatch + d.mutex.Lock() + if len(d.pending) > 0 { + next = d.pending[0] + d.pending = d.pending[1:] + } else { + d.concurrency-- + } + + d.mutex.Unlock() + + onComplete() + + if next.request != nil { + d.dispatchNow(next.request, next.onComplete) + } + }) +} diff --git a/expvar.go b/expvar.go index 0b5ea60..2dfde4a 100644 --- a/expvar.go +++ b/expvar.go @@ -25,22 +25,6 @@ func decrReceiverGoroutines() { goblipExpvar.Add("goroutines_receiver", -1) } -func incrAsyncReadGoroutines() { - goblipExpvar.Add("goroutines_async_read", 1) -} - -func decrAsyncReadGoroutines() { - goblipExpvar.Add("goroutines_async_read", -1) -} - -func incrNextFrameToSendGoroutines() { - goblipExpvar.Add("goroutines_next_frame_to_send", 1) -} - -func decrNextFrameToSendGoroutines() { - goblipExpvar.Add("goroutines_next_frame_to_send", -1) -} - func incrParseLoopGoroutines() { goblipExpvar.Add("goroutines_parse_loop", 1) } diff --git a/functional_test.go b/functional_test.go index d968b3a..3015bcc 100644 --- a/functional_test.go +++ b/functional_test.go @@ -12,10 +12,7 @@ package blip import ( "expvar" - "fmt" "log" - "net" - "net/http" "strconv" "sync" "testing" @@ -40,9 +37,8 @@ func TestEchoRoundTrip(t *testing.T) { // ----------------- Setup Echo Server ------------------------- - // Create a blip profile handler to respond to echo requests and then abruptly close the socket + // Create a blip profile handler to respond to echo requests dispatchEcho := func(request *Message) { - defer receivedRequests.Done() body, err := request.Body() if err != nil { log.Printf("ERROR reading body of %s: %s", request, err) @@ -60,22 +56,12 @@ func TestEchoRoundTrip(t *testing.T) { // Blip setup blipContextEchoServer.HandlerForProfile["BLIPTest/EchoData"] = dispatchEcho - blipContextEchoServer.LogMessages = false - blipContextEchoServer.LogFrames = false + blipContextEchoServer.LogMessages = true + blipContextEchoServer.LogFrames = true // Websocket Server - server := blipContextEchoServer.WebSocketServer() - - // HTTP Handler wrapping websocket server - mux := http.NewServeMux() - mux.Handle("/blip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, mux)) - }() + listener := startTestListener(t, blipContextEchoServer) + defer listener.Close() // ----------------- Setup Echo Client ---------------------------------------- @@ -83,15 +69,12 @@ func TestEchoRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) - sender, err := blipContextEchoClient.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + + sender := startTestClient(t, blipContextEchoClient, listener) + defer sender.Close() numRequests := 100 - receivedRequests.Add(numRequests) + receivedRequests.Add(numRequests * 50) for i := 0; i < numRequests; i++ { @@ -112,12 +95,14 @@ func TestEchoRoundTrip(t *testing.T) { for j := 0; j < 10; j++ { go func(m *Message) { response := m.Response() + defer receivedRequests.Done() if response == nil { t.Errorf("unexpected nil message response") return } responseBody, err := response.Body() - assert.True(t, err == nil) + assert.NoError(t, err) + //log.Printf("Got response: %s", responseBody) assert.Equal(t, "hello", string(responseBody)) }(echoRequest) } @@ -128,7 +113,7 @@ func TestEchoRoundTrip(t *testing.T) { // Wait until all requests were sent and responded to receivedRequests.Wait() - + log.Printf("*** Closing connections ***") } // TestSenderPing ensures a client configured with a WebsocketPingInterval sends ping frames on an otherwise idle connection. @@ -139,17 +124,8 @@ func TestSenderPing(t *testing.T) { if err != nil { t.Fatal(err) } - server := serverCtx.WebSocketServer() - - mux := http.NewServeMux() - mux.Handle("/blip", server) - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - go func() { - t.Error(http.Serve(listener, mux)) - }() + listener := startTestListener(t, serverCtx) + defer listener.Close() // client clientCtx, err := NewContext(BlipTestAppProtocolId) @@ -160,17 +136,11 @@ func TestSenderPing(t *testing.T) { clientCtx.LogFrames = true clientCtx.WebsocketPingInterval = time.Millisecond * 10 - port := listener.Addr().(*net.TCPAddr).Port - destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) - // client hasn't connected yet, stats are uninitialized assert.Equal(t, int64(0), expvarToInt64(goblipExpvar.Get("sender_ping_count"))) assert.Equal(t, int64(0), expvarToInt64(goblipExpvar.Get("goroutines_sender_ping"))) - sender, err := clientCtx.Dial(destUrl) - if err != nil { - t.Fatalf("Error opening WebSocket: %v", err) - } + sender := startTestClient(t, clientCtx, listener) time.Sleep(time.Millisecond * 50) diff --git a/go.mod b/go.mod index 2e49a27..87798b6 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/klauspost/compress v1.15.11 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.8.2 nhooyr.io/websocket v1.8.7 ) @@ -12,4 +12,5 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 612a6ec..f1c6d72 100644 --- a/go.sum +++ b/go.sum @@ -42,9 +42,14 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= @@ -62,5 +67,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= diff --git a/message.go b/message.go index 0972c16..2850d4b 100644 --- a/message.go +++ b/message.go @@ -11,51 +11,51 @@ licenses/APL2.txt. package blip import ( - "bytes" "encoding/json" - "errors" "fmt" "io" "io/ioutil" - "log" - "runtime/debug" + "strconv" "sync" ) +// A Message's serial number. Outgoing requests are numbered consecutively by the sending peer. +// A response has the same number as its request (i.e. the number was created by the other peer.) type MessageNumber uint32 // A BLIP message. It could be a request or response or error, and it could be from me or the peer. type Message struct { - Outgoing bool // Is this a message created locally? - Sender *Sender // The connection that sent this message. - Properties Properties // The message's metadata, similar to HTTP headers. - body []byte // The message body. MIME type is defined by "Content-Type" property - number MessageNumber // The sequence number of the message in the connection. - flags frameFlags // Message flags as seen on the first frame. - bytesSent uint64 - bytesAcked uint64 - - reader io.ReadCloser // Stream that an incoming message is being read from - encoder io.ReadCloser // Stream that an outgoing message is being written to - readingBody bool // True if reader stream has been accessed by client already - complete bool // Has this message been completely received? - response *Message // Response to this message, if it's a request - inResponseTo *Message // Message this is a response to - cond *sync.Cond // Used to make Response() method block until response arrives -} - -// Closes all resources for the message. -func (message *Message) Close() (err error) { - if message.reader != nil { - err = message.reader.Close() - } - if message.encoder != nil { - err = message.encoder.Close() - } - return err + Outgoing bool // Is this a message created locally? + Sender *Sender // The connection that sent this message. + Properties Properties // The message's metadata, similar to HTTP headers. + body []byte // The message body. MIME type is defined by "Content-Type" property + bodyReader *PipeReader // Non-nil if incoming body is being streamed. + bodyWriter *PipeWriter // Non-nil if incoming body is being streamed. + bodyMutex sync.Mutex // Synchronizes access to body properties + response *Message // The response to this outgoing request [must lock 'cond'] + inResponseTo *Message // Outgoing request that this is a response to + cond *sync.Cond // Used to make Response() method block until response arrives + onResponse func(*Message) // Application's callback to receive response + onSent func() // Application's callback when message has been sent + number MessageNumber // The sequence number of the message in the connection. + flags frameFlags // Message flags as seen on the first frame. + inProgress bool // True while message is being sent or received } -// Returns a string describing the message for debugging purposes +// The error info from a BLIP response, as a Go Error value. +type ErrorResponse struct { + Domain string + Code int + Message string +} + +// A callback function that takes a message and returns nothing +type MessageCallback func(*Message) + +var ErrConnectionClosed = fmt.Errorf("BLIP connection closed") + +// Returns a string describing the message for debugging purposes. +// It contains the message number and type. A "!" indicates urgent, and "~" indicates compressed. func (message *Message) String() string { return frameString(message.number, message.flags) } @@ -84,9 +84,7 @@ func NewRequest() *Message { // The order in which a request message was sent. // A response has the same serial number as its request even though it goes the other direction. func (message *Message) SerialNumber() MessageNumber { - if message.number == 0 { - panic("Unsent message has no serial number yet") - } + message.assertSent() return message.number } @@ -117,11 +115,9 @@ func (message *Message) SetCompressed(compressed bool) { } // Marks an outgoing message as being one-way: no reply will be sent. -func (request *Message) SetNoReply(noReply bool) { - if request.Type() != RequestType { - panic("Can't call SetNoReply on a response") - } - request.setFlag(kNoReply, noReply) +func (message *Message) SetNoReply(noReply bool) { + message.assertIsRequest() + message.setFlag(kNoReply, noReply) } func (message *Message) setFlag(flag frameFlags, value bool) { @@ -133,55 +129,114 @@ func (message *Message) setFlag(flag frameFlags, value bool) { } } -func (message *Message) assertMutable() { - if !message.Outgoing || message.encoder != nil { - panic("Message can't be modified") +// Registers a callback that will be called when this message has been written to the socket. +// (This does not mean it's been delivered! If you need delivery confirmation, wait for a reply.) +func (message *Message) OnSent(callback func()) { + message.assertOutgoing() + message.onSent = callback +} + +// The value of the "Profile" property which is used to identify a request's purpose. +func (message *Message) Profile() string { + return message.Properties[ProfileProperty] +} + +// Sets the value of the "Profile" property which is used to identify a request's purpose. +func (message *Message) SetProfile(profile string) { + message.Properties[ProfileProperty] = profile +} + +// True if a message is in its completed form: an incoming message whose entire body has +// arrived, or an outgoing message that's been queued for delivery. +func (message *Message) Complete() bool { + return !message.inProgress +} + +// Writes the encoded form of a Message to a stream. +func (message *Message) WriteTo(writer io.Writer) error { + if err := message.Properties.WriteTo(writer); err != nil { + return err + } + var err error + if len(message.body) > 0 { + _, err = writer.Write(message.body) } + return err } -// Reads an incoming message's properties from the reader if necessary -func (m *Message) readProperties() error { - if m.Properties != nil { +//////// ERRORS: + +// The error information in a response whose Type is ErrorType. +// (The return type `*ErrorResponse` implements the `error` interface.) +// If the message is not an error, returns nil. +func (response *Message) Error() *ErrorResponse { + if response.Type() != ErrorType { return nil - } else if m.reader == nil { - panic("Message has no reader") } - m.Properties = Properties{} - return m.Properties.ReadFrom(m.reader) + domain := response.Properties[ErrorDomainProperty] + if domain == "" { + domain = BLIPErrorDomain + } + code, _ := strconv.ParseInt(response.Properties[ErrorCodeProperty], 10, 0) + message, _ := response.Body() + return &ErrorResponse{ + Domain: domain, + Code: int(code), + Message: string(message), + } } -// The value of the "Profile" property which is used to identify a request's purpose. -func (request *Message) Profile() string { - return request.Properties["Profile"] +// ErrorResponse implements the `error` interface. +func (err *ErrorResponse) Error() string { + if err.Message == "" { + return fmt.Sprintf("%s/%d", err.Domain, err.Code) + } else { + return fmt.Sprintf("%s (%s/%d)", err.Message, err.Domain, err.Code) + } } -// Sets the value of the "Profile" property which is used to identify a request's purpose. -func (request *Message) SetProfile(profile string) { - request.Properties["Profile"] = profile +// Changes a pending response into an error. +// It is safe (and a no-op) to call this on a nil Message. +func (response *Message) SetError(errDomain string, errCode int, message string) { + if response != nil { + response.assertMutable() + response.setError(errDomain, errCode, message) + } } +func (response *Message) setError(errDomain string, errCode int, message string) { + response.assertIsResponse() + response.flags = (response.flags &^ kTypeMask) | frameFlags(ErrorType) + response.Properties = Properties{ + ErrorDomainProperty: errDomain, + ErrorCodeProperty: fmt.Sprintf("%d", errCode), + } + response.body = []byte(message) +} + +//////// BODY: + // Returns a Reader object from which the message body can be read. // If this is an incoming message the body will be streamed as the message arrives over -// the network (and multiple calls to BodyReader() won't work.) -func (m *Message) BodyReader() (io.Reader, error) { +// the network. +func (m *Message) BodyReader() (*PipeReader, error) { if m.Outgoing || m.body != nil { - return bytes.NewReader(m.body), nil - } - if err := m.readProperties(); err != nil { - return nil, err + r, w := NewPipe() + _, _ = w.Write(m.body) + w.Close() + return r, nil + } else { + return m.bodyReader, nil } - m.readingBody = true - return m.reader, nil } // Returns the entire message body as a byte array. // If the message is incoming, blocks until the entire body is received. func (m *Message) Body() ([]byte, error) { + m.bodyMutex.Lock() + defer m.bodyMutex.Unlock() if m.body == nil && !m.Outgoing { - if m.readingBody { - panic("Already reading body as a stream") - } - body, err := ioutil.ReadAll(m.reader) + body, err := ioutil.ReadAll(m.bodyReader) if err != nil { return nil, err } @@ -190,12 +245,6 @@ func (m *Message) Body() ([]byte, error) { return m.body, nil } -// Sets the entire body of an outgoing message. -func (m *Message) SetBody(body []byte) { - m.assertMutable() - m.body = body -} - // Returns the message body parsed as JSON. func (m *Message) ReadJSONBody(value interface{}) error { if bodyReader, err := m.BodyReader(); err != nil { @@ -207,6 +256,12 @@ func (m *Message) ReadJSONBody(value interface{}) error { } } +// Sets the entire body of an outgoing message. +func (m *Message) SetBody(body []byte) { + m.assertMutable() + m.body = body +} + // Sets the message body to JSON generated from the given JSON-encodable value. // As a convenience this also sets the "Content-Type" property to "application/json". func (m *Message) SetJSONBody(value interface{}) error { @@ -225,22 +280,45 @@ func (m *Message) SetJSONBodyAsBytes(jsonBytes []byte) { m.SetCompressed(true) } -// Returns the response message to this request. Its properties and body are initially empty. -// Multiple calls return the same object. +func (m *Message) bodySize() int { + if m.body != nil { + return len(m.body) + } else { + return m.bodyWriter.bytesPending() + } +} + +// Makes a copy of a Message. Only for tests. +func (message *Message) Clone() *Message { + // Make sure the body is available. This may block. + _, _ = message.Body() + + message.bodyMutex.Lock() + defer message.bodyMutex.Unlock() + return &Message{ + Outgoing: message.Outgoing, + Properties: message.Properties, + body: message.body, + number: message.number, + flags: message.flags, + } +} + +//////// RESPONSE HANDLING: + +// Returns the response message to this request. Multiple calls return the same object. // If called on a NoReply request, this returns nil. +// - If this is an incoming request, the response is immediately available. Its properties and +// body are initially empty and ready to be filled in. +// - If this is an outgoing request, the function blocks until the response arrives over the +// network. func (request *Message) Response() *Message { + request.assertIsRequest() + request.assertSent() if request.flags&kNoReply != 0 { return nil - } - if request.Type() != RequestType { - panic("Can't respond to this message") - } - if request.number == 0 { - panic("Can't get response before message has been sent") - } - - // block until a response has been set by responseComplete - if request.Outgoing { + } else if request.Outgoing { + // Outgoing request: block until a response has been set by responseComplete request.cond.L.Lock() for request.response == nil { request.cond.Wait() @@ -248,177 +326,261 @@ func (request *Message) Response() *Message { response := request.response request.cond.L.Unlock() return response - } - - // request is incoming, so we need to build a response - request.cond.L.Lock() - defer request.cond.L.Unlock() - // if we already have a response, return it - if request.response != nil { - return request.response - } - response := request.createResponse() - response.flags |= request.flags & kUrgent - response.Properties = Properties{} - request.response = response - return response -} - -// Changes a pending response into an error. -// It is safe (and a no-op) to call this on a nil Message. -func (response *Message) SetError(errDomain string, errCode int, message string) { - if response != nil { - response.assertMutable() - if response.Type() == RequestType { - panic("Can't call SetError on a request") - } - response.flags = (response.flags &^ kTypeMask) | frameFlags(ErrorType) - response.Properties = Properties{ - "Error-Domain": errDomain, - "Error-Code": fmt.Sprintf("%d", errCode), - } - if message != "" { - response.body = []byte(message) + } else { + // Incoming request: create a response for the caller to fill in + request.cond.L.Lock() + defer request.cond.L.Unlock() + // if we already have a response, return it + if request.response != nil { + return request.response } + response := request.createResponse() + response.flags |= request.flags & kUrgent + response.Properties = Properties{} + request.response = response + return response } } -//////// INTERNALS: +// Registers a function to be called when a response to this outgoing Message arrives. +// This Message must be an outgoing request. +// Only one function can be registered; registering another causes a panic. +func (request *Message) OnResponse(callback func(*Message)) { + request.assertIsRequest() + request.assertOutgoing() + precondition(request.flags&kNoReply == 0, "OnResponse: Message %s was sent NoReply", request) -func newIncomingMessage(sender *Sender, number MessageNumber, flags frameFlags, reader io.ReadCloser) *Message { - return &Message{ - Sender: sender, - flags: flags | kMoreComing, - number: number, - reader: reader, - cond: sync.NewCond(&sync.Mutex{}), + request.cond.L.Lock() + response := request.response + if response == nil { + if request.onResponse != nil { + request.cond.L.Unlock() + panic("Message already has an OnResponse callback") + } + request.onResponse = callback } -} + request.cond.L.Unlock() -// Creates an incoming message given properties and body; exposed only for testing. -func NewParsedIncomingMessage(sender *Sender, msgType MessageType, properties Properties, body []byte) *Message { - if properties == nil { - properties = Properties{} - } - if body == nil { - body = []byte{} + if response != nil { + // Response has already arrived: + callback(response) } - msg := newIncomingMessage(sender, 1, frameFlags(msgType), nil) - msg.Properties = properties - msg.body = body - return msg } +// Creates a response Message for this Message. func (request *Message) createResponse() *Message { response := &Message{ flags: frameFlags(ResponseType) | (request.flags & kUrgent), number: request.number, Outgoing: !request.Outgoing, inResponseTo: request, + inProgress: request.Outgoing, cond: sync.NewCond(&sync.Mutex{}), } if !response.Outgoing { response.flags |= kMoreComing + response.bodyReader, response.bodyWriter = NewPipe() } return response } -func (request *Message) responseComplete(response *Message) { +// Notifies an outgoing Message that its response is available. +// - This unblocks any waiting `Response` methods, which will now return the response. +// - It calls any registered "OnResponse" callback. +// - It sets the `response` field so subsequent calls to `Response` will return it. +func (request *Message) responseAvailable(response *Message) { + var callback func(*Message) + request.cond.L.Lock() - defer request.cond.L.Unlock() - if request.response != nil { - panic(fmt.Sprintf("Multiple responses to %s", request)) + if existing := request.response; existing != nil { + request.cond.L.Unlock() + if existing != response { + panic(fmt.Sprintf("Multiple responses to %s", request)) + } + return } request.response = response + callback = request.onResponse + request.onResponse = nil request.cond.Broadcast() + request.cond.L.Unlock() + + if callback != nil { + callback(response) // Calling on receiver thread; callback should exit quickly + } } -//////// I/O: +//////// SENDING MESSAGES: -func (m *Message) WriteTo(writer io.Writer) error { - if err := m.Properties.WriteTo(writer); err != nil { - return err +// A wrapper around an outgoing Message that generates frame bodies to send. +type msgSender struct { + *Message + bytesSent uint64 + bytesAcked uint64 + remainingProperties []byte + remainingBody []byte +} + +func (m *msgSender) nextFrameToSend(maxSize int) ([]byte, frameFlags) { + m.assertOutgoing() + m.assertSent() + + if m.remainingProperties == nil { + // On the first call, encode the properties: + m.inProgress = true + m.remainingProperties = m.Properties.Encode() + m.remainingBody = m.body } - var err error - if len(m.body) > 0 { - _, err = writer.Write(m.body) + + var frame []byte + if plen := min(maxSize, len(m.remainingProperties)); plen > 0 { + // Send properties first: + frame = m.remainingProperties[0:plen] + m.remainingProperties = m.remainingProperties[plen:] + maxSize -= plen } - return err -} -func (m *Message) ReadFrom(reader io.Reader) error { - if err := m.Properties.ReadFrom(reader); err != nil { - return err + if blen := min(maxSize, len(m.remainingBody)); blen > 0 { + // Then the body: + frame = append(frame, m.remainingBody[0:blen]...) + m.remainingBody = m.remainingBody[blen:] } - var err error - m.body, err = ioutil.ReadAll(reader) - return err + + flags := m.flags + if len(m.remainingProperties) > 0 || len(m.remainingBody) > 0 { + flags |= kMoreComing + } + return frame, flags } -// Returns a write stream to write the incoming message's content into. When the stream is closed, -// the message will deliver itself. -func (m *Message) asyncRead(onComplete func(error)) io.WriteCloser { +func (m *msgSender) addBytesSent(bytesSent uint64) { + m.bytesSent += bytesSent +} - reader, writer := io.Pipe() - m.reader = reader +func (m *msgSender) needsAck() bool { + return m.bytesSent > m.bytesAcked+kMaxUnackedBytes +} - // Start a goroutine to read off the read-end of the io.Pipe until it's read everything, or the - // write end of the io.Pipe was closed, which can happen if the peer closes the connection. - go func() { - defer func() { - if p := recover(); p != nil { - err := fmt.Sprintf("PANIC in BLIP asyncRead: %v", p) - log.Printf(err+"\n%s", debug.Stack()) - reader.CloseWithError(errors.New(err)) - } - }() +func (m *msgSender) receivedAck(bytesAcked uint64) { + if bytesAcked > m.bytesAcked { + m.bytesAcked = bytesAcked + } +} - // Update Expvar stats for number of outstanding goroutines - incrAsyncReadGoroutines() - defer decrAsyncReadGoroutines() +func (m *msgSender) sent() { + if callback := m.onSent; callback != nil { + m.onSent = nil + callback() + } +} - err := m.ReadFrom(reader) - onComplete(err) - }() - return writer +// Informs an outgoing message that the connection has closed +func (m *msgSender) cancelOutgoing() { } -func (m *Message) nextFrameToSend(maxSize int) ([]byte, frameFlags) { - if m.number == 0 || !m.Outgoing { - panic("Can't send this message") - } +//////// RECEIVING MESSAGES: - if m.encoder == nil { - // Start the encoder goroutine: - var writer io.WriteCloser - m.encoder, writer = io.Pipe() - go func() { - defer func() { - if p := recover(); p != nil { - log.Printf("PANIC in BLIP nextFrameToSend: %v\n%s", p, debug.Stack()) - } - }() - defer writer.Close() +// A wrapper around an incoming Message while it's being received. +type msgReceiver struct { + *Message + bytesWritten uint64 + propertiesBuffer []byte +} - // Update Expvar stats for number of outstanding goroutines - incrNextFrameToSendGoroutines() - defer decrNextFrameToSendGoroutines() +func newIncomingMessage(sender *Sender, number MessageNumber, flags frameFlags) *msgReceiver { + m := &msgReceiver{ + Message: &Message{ + Sender: sender, + flags: flags | kMoreComing, + number: number, + inProgress: true, + cond: sync.NewCond(&sync.Mutex{}), + }, + } + m.bodyReader, m.bodyWriter = NewPipe() + return m +} - _ = m.WriteTo(writer) +type dispatchState struct { + atStart bool // The beginning of the message has arrived + atEnd bool // The end of the message has arrived +} - }() +// Appends a frame's data to an incoming message. +func (m *msgReceiver) addIncomingBytes(bytes []byte, complete bool) (dispatchState, error) { + state := dispatchState{atEnd: complete} + if m.Properties == nil { + // First read the message properties: + m.propertiesBuffer = append(m.propertiesBuffer, bytes...) + props, bytesRead, err := ReadProperties(m.propertiesBuffer) + if err != nil { + return state, err + } else if props == nil { + if complete { + return state, fmt.Errorf("incomplete properties in BLIP message") + } + return state, nil // incomplete properties; wait for more + } + // Got the complete properties: + m.Properties = props + bytes = m.propertiesBuffer[bytesRead:] + m.propertiesBuffer = nil + state.atStart = true } - frame := make([]byte, maxSize) - flags := m.flags - size, err := io.ReadFull(m.encoder, frame) - if err == nil { - flags |= kMoreComing + // Now add to the body: + if complete { + m.inProgress = false + m.flags = m.flags &^ kMoreComing + } + if m.body != nil { + m.body = append(m.body, bytes...) } else { - frame = frame[0:size] + _, _ = m.bodyWriter.Write(bytes) + if complete { + m.bodyWriter.Close() + } } - return frame, flags + return state, nil } -// A callback function that takes a message and returns nothing -type MessageCallback func(*Message) +// Add `frameSize` to my `bytesWritten` and send an ACK message every `kAckInterval` bytes +func (m *msgReceiver) maybeSendAck(frameSize int) { + oldWritten := m.bytesWritten + m.bytesWritten += uint64(frameSize) + if oldWritten > 0 && (oldWritten/kAckInterval) < (m.bytesWritten/kAckInterval) { + m.Sender.sendAck(m.number, m.Type(), m.bytesWritten) + } +} + +// Informs an incoming message that the connection has closed +func (m *msgReceiver) cancelIncoming() { + if m.inResponseTo != nil { + if m.bodyWriter != nil { + _ = m.bodyWriter.CloseWithError(ErrConnectionClosed) + } + m.setError(BLIPErrorDomain, DisconnectedCode, "Connection closed") + m.inResponseTo.responseAvailable(m.Message) + } +} + +//////// UTILITIES + +func (message *Message) assertMutable() { + precondition(message.Outgoing && !message.inProgress, "Message %s is not modifiable", message) +} +func (message *Message) assertIncoming() { + precondition(!message.Outgoing, "Message %s is not incoming", message) +} +func (message *Message) assertOutgoing() { + precondition(message.Outgoing, "Message %s is not outgoing", message) +} +func (message *Message) assertIsRequest() { + precondition(message.Type() == RequestType, "Message %s is not a request", message) +} +func (message *Message) assertIsResponse() { + precondition(message.Type() != RequestType, "Message %s is not a response", message) +} +func (message *Message) assertSent() { + precondition(message.number != 0, "Message %s has not been sent", message) +} diff --git a/message_test.go b/message_test.go index 3b1c431..a1a3c72 100644 --- a/message_test.go +++ b/message_test.go @@ -11,12 +11,11 @@ licenses/APL2.txt. package blip import ( - "bytes" + "fmt" "io" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func init() { @@ -27,49 +26,82 @@ func makeTestRequest() *Message { m := NewRequest() m.Properties["Content-Type"] = "ham/rye" m.Properties["X-Weather"] = "rainy" + m.SetProfile("test") m.SetBody([]byte("The white knight is sliding down the poker. He balances very badly.")) return m } func TestMessageEncoding(t *testing.T) { m := makeTestRequest() - var writer bytes.Buffer - err := m.WriteTo(&writer) - assert.Equal(t, nil, err) - serialized := writer.Bytes() - assert.Equal(t, "\x25Content-Type\x00ham/rye\x00X-Weather\x00rainy\x00The white knight is sliding down the poker. He balances very badly.", string(serialized)) - t.Logf("Encoded as %d bytes", len(serialized)) - - m2 := newIncomingMessage(nil, 1, m.flags, nil) - reader := bytes.NewReader(serialized) - err = m2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, m.Properties, m2.Properties) - mbody, _ := m.Body() - m2body, _ := m2.Body() - assert.Equal(t, mbody, m2body) + assert.Equal(t, "test", m.Profile()) + serialized := serializeMessage(t, m) + assert.Equal(t, "\x32Content-Type\x00ham/rye\x00Profile\x00test\x00X-Weather\x00rainy\x00The white knight is sliding down the poker. He balances very badly.", string(serialized)) +} + +func TestMessageStreaming(t *testing.T) { + m := makeTestRequest() + serialized := serializeMessage(t, m) + body, _ := m.Body() + + for breakAt := 1; breakAt <= len(serialized); breakAt++ { + t.Run(fmt.Sprintf("Frame size %d", breakAt), func(t *testing.T) { + // "Receive" message m2 in two pieces, separated at offset `breakAt`. + frame1 := serialized[0:breakAt] + frame2 := serialized[breakAt:] + complete := len(frame2) == 0 + expectedState := dispatchState{ + atStart: len(frame1) >= len(serialized)-len(body), + atEnd: complete, + } + + m2 := newIncomingMessage(nil, 1, m.flags) + state, err := m2.addIncomingBytes(frame1, complete) + assert.NoError(t, err) + assert.Equal(t, state, expectedState) + if state.atStart { + // Frame 1 completes the properties, so the reader should be available: + assert.Equal(t, m.Properties, m2.Properties) + reader, err := m2.BodyReader() + assert.NoError(t, err) + buf := make([]byte, 500) + n, err := reader.TryRead(buf) + assert.NoError(t, err) + assert.Equal(t, n, breakAt-51) + assert.Equal(t, body[0:n], buf[0:n]) + + if !expectedState.atEnd { + state, err = m2.addIncomingBytes(frame2, true) + assert.NoError(t, err) + assert.Equal(t, state, dispatchState{false, true}) + n2, err := reader.TryRead(buf[n:]) + assert.NoError(t, err) + assert.Equal(t, len(body)-n, n2) + assert.Equal(t, body, buf[0:(n+n2)]) + } + + n, err = reader.TryRead(buf) + assert.ErrorIs(t, err, io.EOF) + assert.Equal(t, n, 0) + } + }) + } } func TestMessageEncodingCompressed(t *testing.T) { m := makeTestRequest() m.SetCompressed(true) - var writer bytes.Buffer - err := m.WriteTo(&writer) - assert.Equal(t, nil, err) - serialized := writer.Bytes() + serialized := serializeMessage(t, m) // Commented due to test failure: // http://drone.couchbase.io/couchbase/go-blip/4 (test logs: https://gist.github.com/tleyden/ae2aa71978cd11ca5d9d1f6878593cdb) // goassert.Equals(t, string(serialized), "\x1a\x04\x00ham/rye\x00X-Weather\x00rainy\x00\x1f\x8b\b\x00\x00\tn\x88\x00\xff\f\xca\xd1\t\xc5 \f\x05\xd0U\xee\x04\xce\xf1\x06x\v\xd8z\xd1`\x88ń\x8a\xdb\xd7\xcf\x03\xe7߈\xd5$\x88nR[@\x1c\xaeR\xc4*\xcaX\x868\xe1\x19\x9d3\xe1G\\Y\xb3\xddt\xbc\x9c\xfb\xa8\xe8N_\x00\x00\x00\xff\xffs*\xa1\xa6C\x00\x00\x00") // log.Printf("Encoded compressed as %d bytes", len(serialized)) - m2 := newIncomingMessage(nil, 1, m.flags, nil) - reader := bytes.NewReader(serialized) - err = m2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, m.Properties, m2.Properties) - assert.Equal(t, m.body, m2.body) - + m2 := newIncomingMessage(nil, 1, m.flags) + state, err := m2.addIncomingBytes(serialized, true) + assert.NoError(t, err) + assert.Equal(t, state, dispatchState{true, true}) + assertEqualMessages(t, m, m2.Message) } func BenchmarkMessageEncoding(b *testing.B) { @@ -85,33 +117,13 @@ func BenchmarkMessageEncoding(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - msg.encoder = nil + msg.inProgress = false + sender := &msgSender{Message: msg} for { - _, flags := msg.nextFrameToSend(4090) + _, flags := sender.nextFrameToSend(4090) if flags&kMoreComing == 0 { break } } } } - -func TestMessageDecoding(t *testing.T) { - original := makeTestRequest() - reader, writer := io.Pipe() - go func() { - require.NoError(t, original.WriteTo(writer)) - writer.Close() - }() - - incoming := newIncomingMessage(nil, original.number, original.flags, reader) - err := incoming.readProperties() - assert.Equal(t, nil, err) - assert.Equal(t, original.Properties, incoming.Properties) - err = incoming.readProperties() - assert.Equal(t, nil, err) - - body, err := incoming.Body() - assert.Equal(t, nil, err) - assert.Equal(t, original.body, body) - assert.Equal(t, incoming.body, body) -} diff --git a/messagequeue.go b/messagequeue.go index d2ace1b..9d1c2c8 100644 --- a/messagequeue.go +++ b/messagequeue.go @@ -20,21 +20,26 @@ const kInitialQueueCapacity = 10 type messageQueue struct { logContext LogContext maxCount int - queue []*Message + queue []*msgSender numRequestsSent MessageNumber cond *sync.Cond + /* (instrumentation for perf testing) + totalSize int + highestCount int + highestTotalSize int + */ } func newMessageQueue(logContext LogContext, maxCount int) *messageQueue { return &messageQueue{ logContext: logContext, - queue: make([]*Message, 0, kInitialQueueCapacity), + queue: make([]*msgSender, 0, kInitialQueueCapacity), cond: sync.NewCond(&sync.Mutex{}), maxCount: maxCount, } } -func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock +func (q *messageQueue) _push(msg *msgSender, new bool) bool { // requires lock if !msg.Outgoing { panic("Not an outgoing message") } @@ -53,7 +58,7 @@ func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock if q.queue[index].Urgent() { index += 2 break - } else if new && q.queue[index].encoder == nil { + } else if new && !q.queue[index].inProgress { // But have to keep message starts in order index += 1 break @@ -74,20 +79,32 @@ func (q *messageQueue) _push(msg *Message, new bool) bool { // requires lock copy(q.queue[index+1:n+1], q.queue[index:n]) q.queue[index] = msg + /* (instrumentation for perf testing) + q.totalSize += len(msg.body) + if n+1 > q.highestCount { + q.highestCount = n + 1 + q.logContext.log("messageQueue total size = %d (%d messages)", q.totalSize, n+1) + } + if q.totalSize > q.highestTotalSize { + q.highestTotalSize = q.totalSize + q.logContext.log("messageQueue total size = %d (%d messages)", q.totalSize, n+1) + } + */ if len(q.queue) == 1 { q.cond.Signal() // It's non-empty now, so unblock a waiting pop() } + return true } // Push an item into the queue -func (q *messageQueue) push(msg *Message) bool { +func (q *messageQueue) push(msg *msgSender) bool { return q.pushWithCallback(msg, nil) } // Push an item into the queue, also providing a callback function that will be invoked // after the number is assigned to the message, but before pushing into the queue. -func (q *messageQueue) pushWithCallback(msg *Message, prepushCallback MessageCallback) bool { +func (q *messageQueue) pushWithCallback(msg *msgSender, prepushCallback MessageCallback) bool { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -112,13 +129,13 @@ func (q *messageQueue) pushWithCallback(msg *Message, prepushCallback MessageCal } if prepushCallback != nil { - prepushCallback(msg) + prepushCallback(msg.Message) } return q._push(msg, isNew) } -func (q *messageQueue) _maybePop(actuallyPop bool) *Message { +func (q *messageQueue) _maybePop(actuallyPop bool) *msgSender { q.cond.L.Lock() defer q.cond.L.Unlock() for len(q.queue) == 0 && q.queue != nil { @@ -132,7 +149,9 @@ func (q *messageQueue) _maybePop(actuallyPop bool) *Message { msg := q.queue[0] if actuallyPop { q.queue = q.queue[1:] - + /* (instrumentation for perf testing) + q.totalSize -= len(msg.body) + */ if len(q.queue) == q.maxCount-1 { q.cond.Signal() } @@ -140,10 +159,10 @@ func (q *messageQueue) _maybePop(actuallyPop bool) *Message { return msg } -func (q *messageQueue) first() *Message { return q._maybePop(false) } -func (q *messageQueue) pop() *Message { return q._maybePop(true) } +func (q *messageQueue) first() *msgSender { return q._maybePop(false) } +func (q *messageQueue) pop() *msgSender { return q._maybePop(true) } -func (q *messageQueue) find(msgNo MessageNumber, msgType MessageType) *Message { +func (q *messageQueue) find(msgNo MessageNumber, msgType MessageType) *msgSender { q.cond.L.Lock() defer q.cond.L.Unlock() for _, message := range q.queue { @@ -159,14 +178,8 @@ func (q *messageQueue) stop() { q.cond.L.Lock() defer q.cond.L.Unlock() - // Iterate over messages and call close on every message's readcloser, since it's possible that - // a goroutine may be blocked on the reader, thus causing a resource leak. Added during SG #3268 - // diagnosis, but does not fix any reproducible issues. for _, message := range q.queue { - err := message.Close() - if err != nil { - q.logContext.logMessage("Warning: messageQueue encountered error closing message while stopping. Error: %v", err) - } + message.cancelOutgoing() } q.queue = nil diff --git a/messagequeue_test.go b/messagequeue_test.go index 43484be..c3a74b8 100644 --- a/messagequeue_test.go +++ b/messagequeue_test.go @@ -11,8 +11,6 @@ licenses/APL2.txt. package blip import ( - "bytes" - "io/ioutil" "log" "sync" "testing" @@ -28,7 +26,7 @@ func TestMessagePushPop(t *testing.T) { // Push a non-urgent message into the queue for i := 0; i < 2; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) @@ -38,7 +36,7 @@ func TestMessagePushPop(t *testing.T) { } // Push an urgent message into the queue - urgentMsg := NewRequest() + urgentMsg := &msgSender{Message: NewRequest()} urgentMsg.SetUrgent(true) pushed := mq.push(urgentMsg) assert.True(t, pushed) @@ -78,7 +76,7 @@ func TestConcurrentAccess(t *testing.T) { // Fill it up to capacity w/ normal messages for i := 0; i < maxSendQueueCount; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) @@ -88,7 +86,7 @@ func TestConcurrentAccess(t *testing.T) { doneWg.Add(2) pusher := func() { for i := 0; i < 100; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) } @@ -149,15 +147,15 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio maxSendQueueCount := 25 mq := newMessageQueue(&TestLogContext{silent: true}, maxSendQueueCount) - // Add normal messages that are "in-progress" since they have a non-nil msg.encoder + // Add normal messages that are "in-progress" for i := 0; i < 5; i++ { - msg := NewRequest() + msg := &msgSender{Message: NewRequest()} pushed := mq.push(msg) assert.True(t, pushed) assert.False(t, mq.nextMessageIsUrgent()) - // set the msg.encoder to something so that the next urgent message will go to the head of the line - msg.encoder = ioutil.NopCloser(&bytes.Buffer{}) + // set inProgress so that the next urgent message will go to the head of the line + msg.inProgress = true } @@ -165,7 +163,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio assert.True(t, mq.first().SerialNumber() == MessageNumber(1)) // Push an urgent message into the queue - urgentMsg := NewRequest() + urgentMsg := &msgSender{Message: NewRequest()} urgentMsg.SetUrgent(true) pushed := mq.push(urgentMsg) assert.True(t, pushed) @@ -181,7 +179,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio mq.pop() // T6: [n5] [n4] [n3] [n2] [u6] - // Since all the normal messages have had frames sent (faked, via non-nil msg.encoder), then the + // Since all the normal messages have had frames sent (faked, via setting inProgress), then the // urgent message should have skipped to the head of the line // assert.True(t, mq.nextMessageIsUrgent()) headOfLine := mq.first() @@ -189,7 +187,7 @@ func TestUrgentMessageOrdering(t *testing.T) { // Test passes, but some assertio // Push another urgent message // T7: [n5] [n4] [n3] [u7] [n2] [u6] - anotherUrgentMsg := NewRequest() + anotherUrgentMsg := &msgSender{Message: NewRequest()} anotherUrgentMsg.SetUrgent(true) pushed = mq.push(anotherUrgentMsg) assert.True(t, pushed) diff --git a/pipe.go b/pipe.go new file mode 100644 index 0000000..8435add --- /dev/null +++ b/pipe.go @@ -0,0 +1,180 @@ +/* +Copyright 2023-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package blip + +import ( + "io" + "sync" +) + +// Creates a new pipe, a pair of bound streams. +// Similar to io.Pipe, except the stream is buffered so writes don't block. +func NewPipe() (*PipeReader, *PipeWriter) { + p := &pipe_shared{ + chunks: make([][]byte, 0, 50), + cond: sync.NewCond(&sync.Mutex{}), + } + return &PipeReader{shared: p}, &PipeWriter{shared: p} +} + +// The private state shared between a PipeReader and PipeWriter. +type pipe_shared struct { + chunks [][]byte // ordered list of unread byte-arrays written to the Pipe + err error // Set when closed, to io.EOF or some other error. + cond *sync.Cond // Synchronizes writer & reader +} + +// -------- PIPEWRITER + +// The write end of a pipe. Implements io.WriteCloser. +// Unlike io.Pipe, writes do not block; instead the unread data is buffered in memory. +type PipeWriter struct { + shared *pipe_shared +} + +// Standard Writer method. Does not block. +func (w *PipeWriter) Write(chunk []byte) (n int, err error) { + if len(chunk) == 0 { + // The Writer interface forbids retaining the input, so we must copy it: + copied := make([]byte, len(chunk)) + copy(copied, chunk) + chunk = copied + } + if err = w._add(chunk, nil); err == nil { + n = len(chunk) + } + return +} + +// Closes the pipe. +// The associated PipeReader can still read any remaining data, then it will get an EOF error. +func (w *PipeWriter) Close() error { + return w.CloseWithError(io.EOF) +} + +// Closes the pipe with a custom error. +func (w *PipeWriter) CloseWithError(err error) error { + if err == nil { + err = io.EOF + } + _ = w._add(nil, err) + return nil +} + +func (w *PipeWriter) _add(chunk []byte, err error) error { + // adds a chunk or sets an error; or if there's already an error, returns it. + w.shared.cond.L.Lock() + defer w.shared.cond.L.Unlock() + + if w.shared.err != nil { + return w.shared.err + } + + if err == nil || err == io.EOF { + if len(chunk) > 0 { + w.shared.chunks = append(w.shared.chunks, chunk) + } + } else { + w.shared.chunks = nil // make sure reader sees the custom error ASAP + } + w.shared.err = err + w.shared.cond.Signal() + return nil +} + +// The number of bytes written but not yet read +func (w *PipeWriter) bytesPending() int { + total := 0 + w.shared.cond.L.Lock() + for _, chunk := range w.shared.chunks { + total += len(chunk) + } + w.shared.cond.L.Unlock() + return total +} + +// -------- PIPEREADER + +// The read end of a pipe. Implements io.ReadCloser. +type PipeReader struct { + shared *pipe_shared // Shared state + chunk []byte // The data chunk currently being read from +} + +// Standard Reader method. +func (r *PipeReader) Read(p []byte) (n int, err error) { return r._read(p, true) } + +// Non-blocking read: similar to Read, but if no data is available returns (0, nil). +func (r *PipeReader) TryRead(p []byte) (n int, err error) { return r._read(p, false) } + +func (r *PipeReader) _read(p []byte, wait bool) (n int, err error) { + if len(r.chunk) == 0 { + // Current chunk is exhausted; wait for a new one: + r.chunk, err = r._nextChunk(wait) + } + if len(r.chunk) > 0 { + // Read bytes out of the current chunk: + n = copy(p, r.chunk) + r.chunk = r.chunk[n:] + } + return +} + +// Returns true if a Read call will not block, whether because there's data or EOF or an error. +func (r *PipeReader) CanRead() (ok bool) { + if len(r.chunk) > 0 { + ok = true + } else { + r.shared.cond.L.Lock() + ok = len(r.shared.chunks) > 0 || r.shared.err != nil + r.shared.cond.L.Unlock() + } + return +} + +// Closes the reader. Subsequent PipeReader.Write calls will return io.ErrClosedPipe. +func (r *PipeReader) Close() error { + return r.CloseWithError(io.ErrClosedPipe) +} + +// Closes the reader with a custom error. Subsequent PipeReader.Write calls will return this error. +func (r *PipeReader) CloseWithError(err error) error { + r.shared.cond.L.Lock() + if r.shared.err == nil { + r.shared.err = err + r.shared.chunks = nil + } + r.shared.cond.L.Unlock() + return nil +} + +func (r *PipeReader) _nextChunk(wait bool) ([]byte, error) { + // Returns the next chunk added by the writer, or the error if any. + // If neither is available and `wait` is true, it blocks. + r.shared.cond.L.Lock() + defer r.shared.cond.L.Unlock() + for { + if len(r.shared.chunks) > 0 { + // If there's a chunk in the queue, return it: + chunk := r.shared.chunks[0] + r.shared.chunks = r.shared.chunks[1:] + return chunk, nil + } else if r.shared.err != nil { + // Else if an error is set, return that: + return nil, r.shared.err + } else if !wait { + // Non-blocking call with nothing to read: + return nil, nil + } + // None of the above -- wait for something to happen: + r.shared.cond.Wait() + } +} diff --git a/properties.go b/properties.go index 21b85d9..fde648e 100644 --- a/properties.go +++ b/properties.go @@ -24,49 +24,54 @@ type Properties map[string]string // For testing purposes, clients can set this to true to write properties sorted by key var SortProperties = false -// Adapter for io.Reader to io.ByteReader -type byteReader struct { - reader io.Reader -} - -func (br byteReader) ReadByte() (byte, error) { - var p [1]byte - _, err := br.reader.Read(p[:]) - return p[0], err -} +// Implementation-imposed max encoded size of message properties (not part of protocol) +const maxPropertiesLength = 100 * 1024 -// Reads encoded Properties from a stream. -func (properties *Properties) ReadFrom(reader io.Reader) error { - length, err := binary.ReadUvarint(byteReader{reader}) - if err != nil { - return err +// Reads encoded Properties from a byte array. +// On success, returns the Properties map and the number of bytes read. +// If the array doesn't contain the complete properties, returns (nil, 0, nil). +// On failure, returns nil Properties and an error. +func ReadProperties(body []byte) (properties Properties, bytesRead int, err error) { + length, bytesRead := binary.Uvarint(body) + if bytesRead == 0 { + // Not enough bytes to read the varint + return + } else if bytesRead < 0 || length > maxPropertiesLength { + err = fmt.Errorf("invalid properties length in BLIP message") + return + } else if bytesRead+int(length) > len(body) { + // Incomplete + return nil, 0, nil } + if length == 0 { - return nil - } - body := make([]byte, length) - if _, err := io.ReadFull(reader, body); err != nil { - return err + // Empty properties + return Properties{}, bytesRead, nil } + body = body[bytesRead:] + bytesRead += int(length) + if body[length-1] != 0 { - return fmt.Errorf("Invalid properties (not NUL-terminated)") + err = fmt.Errorf("invalid properties (not NUL-terminated)") + return } eachProp := bytes.Split(body[0:length-1], []byte{0}) nProps := len(eachProp) / 2 if nProps*2 != len(eachProp) { - return fmt.Errorf("Odd number of strings in properties") + err = fmt.Errorf("odd number of strings in properties") + return } - *properties = Properties{} + properties = Properties{} for i := 0; i < len(eachProp); i += 2 { key := string(eachProp[i]) value := string(eachProp[i+1]) - if _, exists := (*properties)[key]; exists { - return fmt.Errorf("Duplicate property name %q", key) + if _, exists := (properties)[key]; exists { + return nil, bytesRead, fmt.Errorf("duplicate property name %q", key) } - (*properties)[key] = value + properties[key] = value } - return nil + return } // Writes Properties to a stream. @@ -103,6 +108,13 @@ func (properties Properties) WriteTo(writer io.Writer) error { return nil } +// Writes Properties to a byte array. +func (properties Properties) Encode() []byte { + var out bytes.Buffer + _ = properties.WriteTo(&out) + return out.Bytes() +} + // Properties stored as alternating keys / values type propertyList [][]byte diff --git a/properties_test.go b/properties_test.go index 5751ca6..0e562a6 100644 --- a/properties_test.go +++ b/properties_test.go @@ -12,6 +12,7 @@ package blip import ( "bytes" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -26,14 +27,21 @@ func TestReadWriteProperties(t *testing.T) { var writer bytes.Buffer err := p.WriteTo(&writer) assert.Equal(t, nil, err) + writer.Write([]byte("FOOBAR")) serialized := writer.Bytes() - assert.Equal(t, "\x2EContent-Type\x00application/octet-stream\x00Foo\x00Bar\x00", string(serialized)) + propertiesLength := len(serialized) - len("FOOBAR") + assert.Equal(t, "\x2EContent-Type\x00application/octet-stream\x00Foo\x00Bar\x00FOOBAR", string(serialized)) - var p2 Properties - reader := bytes.NewReader(serialized) - err = p2.ReadFrom(reader) - assert.Equal(t, nil, err) - assert.Equal(t, p, p2) + for dataLen := 0; dataLen <= len(serialized); dataLen++ { + p2, bytesRead, err := ReadProperties(serialized[0:dataLen]) + assert.NoError(t, err) + if dataLen < propertiesLength { + assert.Nil(t, p2) + } else { + assert.Equal(t, p, p2) + assert.Equal(t, propertiesLength, bytesRead) + } + } } func TestReadWriteEmptyProperties(t *testing.T) { @@ -44,11 +52,10 @@ func TestReadWriteEmptyProperties(t *testing.T) { serialized := writer.Bytes() assert.Equal(t, "\x00", string(serialized)) - var p2 Properties - reader := bytes.NewReader(serialized) - err = p2.ReadFrom(reader) + p2, bytesRead, err := ReadProperties(serialized) assert.Equal(t, nil, err) - assert.Equal(t, p, p2) + assert.Equal(t, Properties{}, p2) + assert.Equal(t, len(serialized), bytesRead) } func TestReadBadProperties(t *testing.T) { @@ -56,27 +63,32 @@ func TestReadBadProperties(t *testing.T) { {"", "EOF"}, {"\x00", ""}, {"\x0C", "EOF"}, - {"\x0CX\x00Y\x00Foo\x00Ba", "unexpected EOF"}, + {"\x0CX\x00Y\x00Foo\x00Ba", "EOF"}, {"\x0CX\x00Y\x00Foo\x00Bar\x00", ""}, - {"\x14X\x00Y\x00Foo\x00Bar\x00Foo\x00Zog\x00", "Duplicate property name \"Foo\""}, + {"\x14X\x00Y\x00Foo\x00Bar\x00Foo\x00Zog\x00", "duplicate property name \"Foo\""}, - {"\x02hi", "Invalid properties (not NUL-terminated)"}, - {"\x02h\x00", "Odd number of strings in properties"}, + {"\x02hi", "invalid properties (not NUL-terminated)"}, + {"\x02h\x00", "odd number of strings in properties"}, } - var p2 Properties for i, pair := range bad { - reader := bytes.NewReader([]byte(pair[0])) - err := p2.ReadFrom(reader) - var errStr string - if err == nil { - if reader.Len() != 0 { - t.Errorf("Error decoding #%d %q: No error, but left %d bytes unread", i, pair[0], reader.Len()) + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + serialized := []byte(pair[0]) + p2, bytesRead, err := ReadProperties(serialized) + var errStr string + if err == nil { + if bytesRead == 0 { + errStr = "EOF" + } else if bytesRead != len(serialized) { + t.Errorf("Error decoding #%d %q: No error, but left %d bytes unread", i, pair[0], len(serialized)-bytesRead) + } else { + assert.NotNil(t, p2) + } + } else { + errStr = err.Error() } - } else { - errStr = err.Error() - } - if errStr != pair[1] { - t.Errorf("Error decoding #%d %q: %q (expected %q)", i, pair[0], errStr, pair[1]) - } + if errStr != pair[1] { + t.Errorf("Error decoding #%d %q: %q (expected %q)", i, pair[0], errStr, pair[1]) + } + }) } } diff --git a/protocol.go b/protocol.go index 435927b..0e67a55 100644 --- a/protocol.go +++ b/protocol.go @@ -22,9 +22,6 @@ import ( // Every sub-protocol used by a caller should begin with this string. const WebSocketSubProtocolPrefix = "BLIP_3" -// Domain used in errors returned by BLIP itself. -const BLIPErrorDomain = "BLIP" - //////// MESSAGE TYPE: // Enumeration of the different types of messages in the BLIP protocol. @@ -66,6 +63,31 @@ func (t MessageType) ackSourceType() MessageType { return t - 4 } +//////// PROPERTIES & ERRORS: + +// Message property that indicates what type of message it is +const ProfileProperty = "Profile" + +// Property of an error response that indicates the type of error; +// application-defined, except for the value "BLIP" (BLIPErrorDomain). +const ErrorDomainProperty = "Error-Domain" + +// Property of an error response containing a numeric error code, +// to be interpreted in the context of the error domain; application-defined. +const ErrorCodeProperty = "Error-Code" + +// Domain used in errors returned by BLIP itself. +const BLIPErrorDomain = "BLIP" + +// Standard error codes in the BLIP domain: +const ( + BadRequestCode = 400 // Something's invalid with the request properties or body + ForbiddenCode = 403 // You're not allowed to make this request + NotFoundCode = 404 // No handler for this Profile + HandlerFailedCode = 501 // A handler failed unexpectedly (panic, exception...) + DisconnectedCode = 503 // Fake response delivered if connection closes unexpectedly +) + //////// FRAME FLAGS: type frameFlags uint8 diff --git a/receiver.go b/receiver.go index d93c1c8..ca2caf6 100644 --- a/receiver.go +++ b/receiver.go @@ -15,24 +15,17 @@ import ( "context" "encoding/binary" "fmt" - "io" - "log" "runtime/debug" "sync" "sync/atomic" + "time" "nhooyr.io/websocket" ) const checksumLength = 4 -type msgStreamer struct { - message *Message - writer io.WriteCloser - bytesWritten uint64 -} - -type msgStreamerMap map[MessageNumber]*msgStreamer +type msgStreamerMap map[MessageNumber]*msgReceiver // The receiving side of a BLIP connection. // Handles receiving WebSocket messages as frames and assembling them into BLIP messages. @@ -51,18 +44,27 @@ type receiver struct { pendingRequests msgStreamerMap // Unfinished REQ messages being assembled pendingResponses msgStreamerMap // Unfinished RES messages being assembled maxPendingResponseNumber MessageNumber // Largest RES # I've seen + + dispatchMutex sync.Mutex // For thread-safe access to the fields below + dispatchCond sync.Cond // Used when receiver stops reading + maxDispatchedBytes int // above this value, receiver stops reading the WebSocket + dispatchedBytes int // Size of dispatched but unhandled incoming requests + dispatchedMessageCount int // Number of dispatched but unhandled incoming requests } func newReceiver(context *Context, conn *websocket.Conn) *receiver { - return &receiver{ - conn: conn, - context: context, - channel: make(chan []byte, 10), - parseError: make(chan error, 1), - frameDecoder: getDecompressor(context), - pendingRequests: msgStreamerMap{}, - pendingResponses: msgStreamerMap{}, + rcvr := &receiver{ + conn: conn, + context: context, + channel: make(chan []byte, 10), + parseError: make(chan error, 1), + frameDecoder: getDecompressor(context), + pendingRequests: msgStreamerMap{}, + pendingResponses: msgStreamerMap{}, + maxDispatchedBytes: context.MaxDispatchedBytes, } + rcvr.dispatchCond = sync.Cond{L: &rcvr.dispatchMutex} + return rcvr } func (r *receiver) receiveLoop() error { @@ -74,7 +76,7 @@ func (r *receiver) receiveLoop() error { for { // Receive the next raw WebSocket frame: - _, frame, err := r.conn.Read(context.TODO()) + msgType, frame, err := r.conn.Read(context.TODO()) if err != nil { if isCloseError(err) { // lower log level for close @@ -88,7 +90,12 @@ func (r *receiver) receiveLoop() error { return err } - r.channel <- frame + switch msgType { + case websocket.MessageBinary: + r.channel <- frame + default: + r.context.log("Warning: received WebSocket message of type %v", msgType) + } } } @@ -96,7 +103,7 @@ func (r *receiver) parseLoop() { defer func() { // Panic handler: atomic.AddInt32(&r.activeGoroutines, -1) if p := recover(); p != nil { - log.Printf("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack()) + r.context.log("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack()) err, _ := p.(error) if err == nil { err = fmt.Errorf("Panic: %v", p) @@ -131,7 +138,6 @@ func (r *receiver) fatalError(err error) { } func (r *receiver) stop() { - r.closePendingResponses() r.conn.Close(websocket.StatusNormalClosure, "") @@ -140,20 +146,11 @@ func (r *receiver) stop() { } func (r *receiver) closePendingResponses() { - r.pendingMutex.Lock() defer r.pendingMutex.Unlock() - // There can be goroutines spawned by message.asyncRead() that are blocked waiting to - // read off their end of an io.Pipe, and if the peer abruptly closes a connection which causes - // the sender to stop(), the other side of that io.Pipe must be closed to avoid the goroutine's - // call to unblock on the read() call. This loops through any io.Pipewriters in pendingResponses and - // close them, unblocking the readers and letting the message.asyncRead() goroutines proceed. for _, msgStreamer := range r.pendingResponses { - err := msgStreamer.writer.Close() - if err != nil { - r.context.logMessage("Warning: error closing msgStreamer writer in pending responses while stopping receiver: %v", err) - } + msgStreamer.cancelIncoming() } } @@ -216,14 +213,14 @@ func (r *receiver) handleIncomingFrame(frame []byte) error { return err } - return r.processFrame(requestNumber, flags, body, frameSize) + return r.processIncomingFrame(requestNumber, flags, body, frameSize) } } -func (r *receiver) processFrame(requestNumber MessageNumber, flags frameFlags, frame []byte, frameSize int) error { +func (r *receiver) processIncomingFrame(requestNumber MessageNumber, flags frameFlags, frame []byte, frameSize int) error { // Look up or create the writer stream for this message: complete := (flags & kMoreComing) == 0 - var msgStream *msgStreamer + var msgStream *msgReceiver var err error switch flags.messageType() { case RequestType: @@ -235,30 +232,40 @@ func (r *receiver) processFrame(requestNumber MessageNumber, flags frameFlags, f default: r.context.log("Warning: Ignoring incoming message type, with flags 0x%x", flags) } + if msgStream == nil { + return err + } // Write the decoded frame body to the stream: - if msgStream != nil { - if _, err := writeFull(frame, msgStream.writer); err != nil { - return err - } else if complete { - if err = msgStream.writer.Close(); err != nil { - r.context.log("Warning: message writer closed with error %v", err) - } - } else { - //FIX: This isn't the right place to do this, because this goroutine doesn't block even - // if the client can't read the message fast enough. The right place to send the ACK is - // in the goroutine that's running msgStream.writer. (Somehow...) - oldWritten := msgStream.bytesWritten - msgStream.bytesWritten += uint64(frameSize) - if oldWritten > 0 && (oldWritten/kAckInterval) < (msgStream.bytesWritten/kAckInterval) { - r.sender.sendAck(requestNumber, flags.messageType(), msgStream.bytesWritten) - } + state, err := msgStream.addIncomingBytes(frame, complete) + if err != nil { + return err + } + + if !complete { + // Not complete yet; send an ACK message every `kAckInterval` bytes: + //FIX: This isn't the right place to do this, because this goroutine doesn't block even + // if the client can't read the message fast enough. + msgStream.maybeSendAck(frameSize) + } + + // Dispatch at first or last frame: + if request := msgStream.inResponseTo; request != nil { + if state.atStart { + // Dispatch response to its request message as soon as properties are available: + r.context.logMessage("Received response %s", msgStream.Message) + request.responseAvailable(msgStream.Message) // Response to outgoing request + } + } else { + if /*state.atStart ||*/ state.atEnd { + // Dispatch request to the dispatcher: + r.dispatch(msgStream.Message) } } return err } -func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) { +func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgReceiver, err error) { r.pendingMutex.Lock() defer r.pendingMutex.Unlock() msgStream = r.pendingRequests[requestNumber] @@ -268,15 +275,7 @@ func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFla } } else if requestNumber == r.numRequestsReceived+1 { r.numRequestsReceived++ - request := newIncomingMessage(r.sender, requestNumber, flags, nil) - atomic.AddInt32(&r.activeGoroutines, 1) - msgStream = &msgStreamer{ - message: request, - writer: request.asyncRead(func(err error) { - r.context.dispatchRequest(request, r.sender) - atomic.AddInt32(&r.activeGoroutines, -1) - }), - } + msgStream = newIncomingMessage(r.sender, requestNumber, flags) if !complete { r.pendingRequests[requestNumber] = msgStream } @@ -286,13 +285,13 @@ func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFla return msgStream, nil } -func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) { +func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgReceiver, err error) { r.pendingMutex.Lock() defer r.pendingMutex.Unlock() msgStream = r.pendingResponses[requestNumber] if msgStream != nil { if msgStream.bytesWritten == 0 { - msgStream.message.flags = flags // set flags based on 1st frame of response + msgStream.flags = flags // set flags based on 1st frame of response } if complete { delete(r.pendingResponses, requestNumber) @@ -307,15 +306,14 @@ func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFl return } -// pendingResponses is accessed from both the receiveLoop goroutine and the sender's goroutine, -// so it needs synchronization. -func (r *receiver) awaitResponse(request *Message, writer io.WriteCloser) { +func (r *receiver) awaitResponse(response *Message) { + // pendingResponses is accessed from both the receiveLoop goroutine and the sender's goroutine, + // so it needs synchronization. r.pendingMutex.Lock() defer r.pendingMutex.Unlock() - number := request.number - r.pendingResponses[number] = &msgStreamer{ - message: request, - writer: writer, + number := response.number + r.pendingResponses[number] = &msgReceiver{ + Message: response, } if number > r.maxPendingResponseNumber { r.maxPendingResponseNumber = number @@ -328,16 +326,72 @@ func (r *receiver) backlog() (pendingRequest, pendingResponses int) { return len(r.pendingRequests), len(r.pendingResponses) } -// Why isn't this in the io package already, when ReadFull is? -func writeFull(buf []byte, writer io.Writer) (nWritten int, err error) { - for len(buf) > 0 { - var n int - n, err = writer.Write(buf) - if err != nil { - break +//////// REQUEST DISPATCHING + +func (r *receiver) dispatch(request *Message) { + sender := r.sender + requestSize := request.bodySize() + onComplete := func() { + response := request.Response() + if panicked := recover(); panicked != nil { + if handler := sender.context.HandlerPanicHandler; handler != nil { + handler(request, response, panicked) + } else { + stack := debug.Stack() + sender.context.log("PANIC handling BLIP request %v: %v:\n%s", request, panicked, stack) + if response != nil { + // (It is generally considered bad security to reveal internal state like error + // messages or stack dumps in a network response.) + response.SetError(BLIPErrorDomain, HandlerFailedCode, "Internal Error") + } + } + } + + r.subDispatchedBytes(requestSize) + + if response != nil { + sender.send(response) } - nWritten += n - buf = buf[n:] } - return + + r.addDispatchedBytes(requestSize) + r.context.RequestHandler(request, onComplete) + r.waitOnDispatchedBytes() +} + +func (r *receiver) addDispatchedBytes(n int) { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + r.dispatchedBytes += n + r.dispatchedMessageCount++ + r.dispatchMutex.Unlock() + } +} + +func (r *receiver) subDispatchedBytes(n int) { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + prevBytes := r.dispatchedBytes + r.dispatchedBytes = prevBytes - n + r.dispatchedMessageCount-- + if prevBytes > r.maxDispatchedBytes && r.dispatchedBytes <= r.maxDispatchedBytes { + r.dispatchCond.Signal() + } + r.dispatchMutex.Unlock() + } +} + +func (r *receiver) waitOnDispatchedBytes() { + if r.maxDispatchedBytes > 0 { + r.dispatchMutex.Lock() + if r.dispatchedBytes > r.maxDispatchedBytes { + start := time.Now() + r.context.log("WebSocket receiver paused (%d requests being handled, %d bytes)", r.dispatchedMessageCount, r.dispatchedBytes) + for r.dispatchedBytes > r.maxDispatchedBytes { + r.dispatchCond.Wait() + } + r.context.log("...WebSocket receiver resuming after %v (now %d requests being handled, %d bytes)", time.Since(start), r.dispatchedMessageCount, r.dispatchedBytes) + } + r.dispatchMutex.Unlock() + } } diff --git a/sender.go b/sender.go index e13e333..845889d 100644 --- a/sender.go +++ b/sender.go @@ -14,11 +14,9 @@ import ( "bytes" "context" "encoding/binary" - "log" "runtime/debug" "strings" "sync" - "sync/atomic" "time" "nhooyr.io/websocket" @@ -42,9 +40,8 @@ type Sender struct { conn *websocket.Conn receiver *receiver queue *messageQueue - icebox map[msgKey]*Message - curMsg *Message - numRequestsSent MessageNumber + icebox map[msgKey]*msgSender + curMsg *msgSender requeueLock sync.Mutex activeGoroutines int32 websocketPingInterval time.Duration @@ -59,7 +56,7 @@ func newSender(c *Context, conn *websocket.Conn, receiver *receiver) *Sender { conn: conn, receiver: receiver, queue: newMessageQueue(c, c.MaxSendQueueCount), - icebox: map[msgKey]*Message{}, + icebox: map[msgKey]*msgSender{}, websocketPingInterval: c.WebsocketPingInterval, ctx: ctx, ctxCancel: ctxCancel, @@ -80,7 +77,7 @@ func (sender *Sender) Send(msg *Message) bool { // Posts a request or response to be delivered asynchronously. // Returns false if the message can't be queued because the Sender has stopped. func (sender *Sender) send(msg *Message) bool { - if msg.Sender != nil || msg.encoder != nil { + if msg.Sender != nil { panic("Message is already enqueued") } msg.Sender = sender @@ -93,17 +90,12 @@ func (sender *Sender) send(msg *Message) bool { prePushCallback := func(prePushMsg *Message) { if prePushMsg.Type() == RequestType && !prePushMsg.NoReply() { response := prePushMsg.createResponse() - atomic.AddInt32(&sender.activeGoroutines, 1) - writer := response.asyncRead(func(err error) { - // TODO: the error passed into this callback is currently being ignored. Calling response.SetError() causes: "panic: Message can't be modified" - prePushMsg.responseComplete(response) - atomic.AddInt32(&sender.activeGoroutines, -1) - }) - sender.receiver.awaitResponse(response, writer) + sender.receiver.awaitResponse(response) } } - return sender.queue.pushWithCallback(msg, prePushCallback) + msgSender := &msgSender{Message: msg} + return sender.queue.pushWithCallback(msgSender, prePushCallback) } @@ -135,7 +127,7 @@ func (sender *Sender) start() { go func() { defer func() { if panicked := recover(); panicked != nil { - log.Printf("PANIC in BLIP sender: %v\n%s", panicked, debug.Stack()) + sender.context.log("PANIC in BLIP sender: %v\n%s", panicked, debug.Stack()) } }() @@ -186,9 +178,6 @@ func (sender *Sender) start() { err := sender.conn.Write(sender.ctx, websocket.MessageBinary, frameBuffer.Bytes()) if err != nil { sender.context.logFrame("Sender error writing framebuffer (len=%d). Error: %v", len(frameBuffer.Bytes()), err) - if err := msg.Close(); err != nil { - sender.context.logFrame("Sender error closing message. Error: %v", err) - } } frameBuffer.Reset() @@ -198,6 +187,8 @@ func (sender *Sender) start() { panic("empty frame should not have moreComing") } sender.requeue(msg, uint64(bytesSent)) + } else { + msg.sent() } } returnCompressor(frameEncoder) @@ -239,7 +230,7 @@ func (sender *Sender) start() { //////// FLOW CONTROL: -func (sender *Sender) popNextMessage() *Message { +func (sender *Sender) popNextMessage() *msgSender { sender.requeueLock.Lock() sender.curMsg = nil sender.requeueLock.Unlock() @@ -256,11 +247,11 @@ func (sender *Sender) popNextMessage() *Message { return msg } -func (sender *Sender) requeue(msg *Message, bytesSent uint64) { +func (sender *Sender) requeue(msg *msgSender, bytesSent uint64) { sender.requeueLock.Lock() defer sender.requeueLock.Unlock() - msg.bytesSent += bytesSent - if msg.bytesSent <= msg.bytesAcked+kMaxUnackedBytes { + msg.addBytesSent(bytesSent) + if !msg.needsAck() { // requeue it so it can send its next frame later sender.queue.push(msg) } else { @@ -275,14 +266,14 @@ func (sender *Sender) receivedAck(requestNumber MessageNumber, msgType MessageTy sender.requeueLock.Lock() defer sender.requeueLock.Unlock() if msg := sender.queue.find(requestNumber, msgType); msg != nil { - msg.bytesAcked = bytesReceived + msg.receivedAck(bytesReceived) } else if msg := sender.curMsg; msg != nil && msg.number == requestNumber && msg.Type() == msgType { - msg.bytesAcked = bytesReceived + msg.receivedAck(bytesReceived) } else { key := msgKey{msgNo: requestNumber, msgType: msgType} if msg := sender.icebox[key]; msg != nil { - msg.bytesAcked = bytesReceived - if msg.bytesSent <= msg.bytesAcked+kMaxUnackedBytes { + msg.receivedAck(bytesReceived) + if !msg.needsAck() { sender.context.logFrame("Resuming %v", msg) delete(sender.icebox, key) sender.queue.push(msg) @@ -302,5 +293,4 @@ func (sender *Sender) sendAck(msgNo MessageNumber, msgType MessageType, bytesRec if err != nil { sender.context.logFrame("Sender error writing ack. Error: %v", err) } - } diff --git a/thread_pool.go b/thread_pool.go new file mode 100644 index 0000000..0a1aeec --- /dev/null +++ b/thread_pool.go @@ -0,0 +1,121 @@ +/* +Copyright 2023-Present Couchbase, Inc. + +Use of this software is governed by the Business Source License included in +the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that +file, in accordance with the Business Source License, use of this software will +be governed by the Apache License, Version 2.0, included in the file +licenses/APL2.txt. +*/ + +package blip + +import ( + "log" + "runtime" + "runtime/debug" +) + +// Runs functions asynchronously using a fixed-size pool of goroutines. +// Intended for use by Dispatchers and Handlers, though it's not limited to that. +type ThreadPool struct { + Concurrency int // The number of goroutines. 0 means GOMAXPROCS + PanicHandler func(err interface{}) // Called if a function panics + + concurrency int // Actual concurrency + channel chan<- func() // Queue of calls to be dispatched + terminator chan bool // Goroutines stop when this is closed +} + +// Starts a ThreadPool. +func (p *ThreadPool) Start() { + p.concurrency = p.Concurrency + if p.concurrency <= 0 { + p.concurrency = runtime.GOMAXPROCS(0) + } + ch := make(chan func(), 1000) //TODO: What real limit? + p.channel = ch + p.terminator = make(chan bool) + + handlerLoop := func(i int) { + // Each goroutine in the pool runs this loop, calling functions from the channel + // until it either reads a nil function, or something happens to the terminator. + for { + select { + case fn := <-ch: + if fn != nil { + p.perform(fn) + } else { + return + } + case <-p.terminator: + return + } + } + } + + for i := 0; i < p.concurrency; i++ { + go handlerLoop(i + 1) + } +} + +// Stops a ThreadPool after it completes all currently running or queued tasks. +func (p *ThreadPool) Stop() { + if p.channel != nil { + for i := 0; i < p.concurrency; i++ { + p.channel <- nil // each nil stops one goroutine + } + } +} + +// Stops a ThreadPool ASAP. Currently running tasks will complete; queued tasks will be dropped. +func (p *ThreadPool) StopImmediately() { + if p.terminator != nil { + close(p.terminator) + } +} + +// Schedules the function to be called on ThreadPool's next available goroutine. +func (p *ThreadPool) Go(fn func()) { + if p.channel == nil { + panic("ThreadPool has not been started") + } else if fn == nil { + panic("Invalid nil function") + } + p.channel <- fn +} + +// Given a SynchronousHandler function, returns an AsyncHandler function that will call the +// wrapped handler on one of the ThreadPool's goroutines. +func (p *ThreadPool) WrapSynchronousHandler(handler SynchronousHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + p.Go(func() { + defer onComplete() + handler(request) + }) + } +} + +// Given an AsyncHandler function, returns an AsyncHandler function that will call the +// wrapped handler on one of the ThreadPool's goroutines. +func (p *ThreadPool) WrapAsyncHandler(handler AsyncHandler) AsyncHandler { + return func(request *Message, onComplete RequestCompletedCallback) { + p.Go(func() { + handler(request, onComplete) + }) + } +} + +func (p *ThreadPool) perform(fn func()) { + defer func() { + if panicked := recover(); panicked != nil { + if p.PanicHandler != nil { + p.PanicHandler(panicked) + } else { + log.Printf("PANIC in ThreadPool function: %v\n%s", panicked, debug.Stack()) + } + } + }() + + fn() +} diff --git a/util.go b/util.go index 164ec95..d9ded3a 100644 --- a/util.go +++ b/util.go @@ -11,6 +11,7 @@ licenses/APL2.txt. package blip import ( + "fmt" "sync/atomic" "time" ) @@ -46,3 +47,18 @@ func errorFromChannel(c chan error) error { } return nil } + +func min(a, b int) int { + if a < b { + return a + } else { + return b + } +} + +// Simple assertion that panics if the condition isn't met. +func precondition(condition bool, panicMessage string, args ...interface{}) { + if !condition { + panic(fmt.Sprintf("Precondition failed! "+panicMessage, args...)) + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..a597f7e --- /dev/null +++ b/util_test.go @@ -0,0 +1,101 @@ +// Copyright 2023-Present Couchbase, Inc. +// +// Use of this software is governed by the Business Source License included +// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified +// in that file, in accordance with the Business Source License, use of this +// software will be governed by the Apache License, Version 2.0, included in +// the file licenses/APL2.txt. + +package blip + +import ( + "bytes" + "fmt" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// Starts a WebSocket server on the Context, returning its net.Listener. +// The server runs on a background goroutine. +// Close the Listener when the test finishes. +func startTestListener(t *testing.T, serverContext *Context) net.Listener { + mux := http.NewServeMux() + mux.Handle("/blip", serverContext.WebSocketServer()) + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Error opening WebSocket listener: %v", err) + } + go func() { + _ = http.Serve(listener, mux) + }() + return listener +} + +// Connects the Context to a Listener created by `startTestListener`. +// Close the Sender when the test finishes. +func startTestClient(t *testing.T, clientContext *Context, listener net.Listener) *Sender { + port := listener.Addr().(*net.TCPAddr).Port + destUrl := fmt.Sprintf("ws://localhost:%d/blip", port) + sender, err := clientContext.Dial(destUrl) + if err != nil { + t.Fatalf("Error opening WebSocket client: %v", err) + } + return sender +} + +// Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout +// TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo. +func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { + + // Create a channel so that a goroutine waiting on the waitgroup can send it's result (if any) + wgFinished := make(chan bool) + + go func() { + wg.Wait() + wgFinished <- true + }() + + select { + case <-wgFinished: + return nil + case <-time.After(timeout): + return fmt.Errorf("Timed out waiting after %v", timeout) + } +} + +// Returns the serialized form (properties + body) of a Message. +func serializeMessage(t *testing.T, m *Message) []byte { + var writer bytes.Buffer + err := m.WriteTo(&writer) + assert.NoError(t, err) + return writer.Bytes() +} + +// Asserts that two Messages are identical +func assertEqualMessages(t *testing.T, m, m2 *Message) bool { + if !assert.Equal(t, m.flags, m2.flags) || !assert.Equal(t, m.Properties, m2.Properties) { + return false + } + mb, err := m.Body() + m2b, err2 := m2.Body() + return assert.NoError(t, err) && assert.NoError(t, err2) && assert.Equal(t, mb, m2b) + +} + +// Asserts that `response` contains a BLIP error with the given domain and code. +func assertBLIPError(t *testing.T, response *Message, domain string, code int) bool { + responseError := response.Error() + if assert.NotNil(t, responseError, "Expected a BLIP error") { + if responseError.Domain == domain && responseError.Code == code { + return true + } + assert.Fail(t, "Unexpected BLIP error", "Got %v: expected %v", + responseError, &ErrorResponse{domain, code, ""}) + } + return false +} From 011516c682c2e36970a21c8a00f1fec2355c31a0 Mon Sep 17 00:00:00 2001 From: Jens Alfke Date: Tue, 9 May 2023 13:47:44 -0700 Subject: [PATCH 2/6] Fixed all compiler warnings Most of them were about capitalized error messages... --- codec.go | 4 ++-- codec_test.go | 2 +- context.go | 6 +++--- message.go | 10 +++------- properties.go | 4 ++-- properties_test.go | 4 ++-- receiver.go | 14 +++++++------- util_test.go | 4 ++-- 8 files changed, 22 insertions(+), 26 deletions(-) diff --git a/codec.go b/codec.go index 517effa..8d962b7 100644 --- a/codec.go +++ b/codec.go @@ -128,7 +128,7 @@ func (d *decompressor) passthrough(input []byte, checksum *uint32) ([]byte, erro _, _ = d.checksum.Write(input) // Update checksum (no error possible) if checksum != nil { if curChecksum := d.getChecksum(); curChecksum != *checksum { - return nil, fmt.Errorf("Invalid checksum %x; should be %x", curChecksum, *checksum) + return nil, fmt.Errorf("invalid checksum %x; should be %x", curChecksum, *checksum) } } return input, nil @@ -168,7 +168,7 @@ func (d *decompressor) decompress(input []byte, checksum uint32) ([]byte, error) return nil, err } else if n == 0 { // Nothing more to read; since checksum didn't match (above), fail: - return nil, fmt.Errorf("Invalid checksum %x; should be %x", d.getChecksum(), checksum) + return nil, fmt.Errorf("invalid checksum %x; should be %x", d.getChecksum(), checksum) } _, _ = d.checksum.Write(d.buffer[0:n]) // Update checksum (no error possible) diff --git a/codec_test.go b/codec_test.go index b87a718..e0a86ca 100644 --- a/codec_test.go +++ b/codec_test.go @@ -29,7 +29,7 @@ func init() { randomData = make([]byte, 65536) var b byte var step byte = 1 - for i, _ := range randomData { + for i := range randomData { if rando.Intn(10) == 0 { b = byte(rando.Intn(256)) step = byte(rando.Intn(4)) diff --git a/context.go b/context.go index f97cbd3..28dc5cf 100644 --- a/context.go +++ b/context.go @@ -222,9 +222,9 @@ func (bwss *blipWebsocketServer) handshake(w http.ResponseWriter, r *http.Reques protocolHeader := r.Header.Get("Sec-WebSocket-Protocol") protocol, found := includesProtocol(protocolHeader, bwss.blipCtx.SupportedSubProtocols) if !found { - stringSeperatedProtocols := strings.Join(bwss.blipCtx.SupportedSubProtocols, ",") - bwss.blipCtx.log("Error: Client doesn't support any of WS protocols: %s only %s", stringSeperatedProtocols, protocolHeader) - return nil, fmt.Errorf("I only speak %s protocols", stringSeperatedProtocols) + stringSeparatedProtocols := strings.Join(bwss.blipCtx.SupportedSubProtocols, ",") + bwss.blipCtx.log("Error: Client doesn't support any of WS protocols: %s only %s", stringSeparatedProtocols, protocolHeader) + return nil, fmt.Errorf("I only speak %s protocols", stringSeparatedProtocols) } ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ diff --git a/message.go b/message.go index 2850d4b..98fce4b 100644 --- a/message.go +++ b/message.go @@ -14,7 +14,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "strconv" "sync" ) @@ -153,8 +152,8 @@ func (message *Message) Complete() bool { } // Writes the encoded form of a Message to a stream. -func (message *Message) WriteTo(writer io.Writer) error { - if err := message.Properties.WriteTo(writer); err != nil { +func (message *Message) WriteEncodedTo(writer io.Writer) error { + if err := message.Properties.WriteEncodedTo(writer); err != nil { return err } var err error @@ -236,7 +235,7 @@ func (m *Message) Body() ([]byte, error) { m.bodyMutex.Lock() defer m.bodyMutex.Unlock() if m.body == nil && !m.Outgoing { - body, err := ioutil.ReadAll(m.bodyReader) + body, err := io.ReadAll(m.bodyReader) if err != nil { return nil, err } @@ -569,9 +568,6 @@ func (m *msgReceiver) cancelIncoming() { func (message *Message) assertMutable() { precondition(message.Outgoing && !message.inProgress, "Message %s is not modifiable", message) } -func (message *Message) assertIncoming() { - precondition(!message.Outgoing, "Message %s is not incoming", message) -} func (message *Message) assertOutgoing() { precondition(message.Outgoing, "Message %s is not outgoing", message) } diff --git a/properties.go b/properties.go index fde648e..74df0b0 100644 --- a/properties.go +++ b/properties.go @@ -75,7 +75,7 @@ func ReadProperties(body []byte) (properties Properties, bytesRead int, err erro } // Writes Properties to a stream. -func (properties Properties) WriteTo(writer io.Writer) error { +func (properties Properties) WriteEncodedTo(writer io.Writer) error { // First convert the property strings into byte arrays, and add up their sizes: var strings propertyList = make(propertyList, 2*len(properties)) i := 0 @@ -111,7 +111,7 @@ func (properties Properties) WriteTo(writer io.Writer) error { // Writes Properties to a byte array. func (properties Properties) Encode() []byte { var out bytes.Buffer - _ = properties.WriteTo(&out) + _ = properties.WriteEncodedTo(&out) return out.Bytes() } diff --git a/properties_test.go b/properties_test.go index 0e562a6..675740f 100644 --- a/properties_test.go +++ b/properties_test.go @@ -25,7 +25,7 @@ func init() { func TestReadWriteProperties(t *testing.T) { p := Properties{"Content-Type": "application/octet-stream", "Foo": "Bar"} var writer bytes.Buffer - err := p.WriteTo(&writer) + err := p.WriteEncodedTo(&writer) assert.Equal(t, nil, err) writer.Write([]byte("FOOBAR")) serialized := writer.Bytes() @@ -47,7 +47,7 @@ func TestReadWriteProperties(t *testing.T) { func TestReadWriteEmptyProperties(t *testing.T) { var p Properties var writer bytes.Buffer - err := p.WriteTo(&writer) + err := p.WriteEncodedTo(&writer) assert.Equal(t, nil, err) serialized := writer.Bytes() assert.Equal(t, "\x00", string(serialized)) diff --git a/receiver.go b/receiver.go index ca2caf6..3637506 100644 --- a/receiver.go +++ b/receiver.go @@ -106,7 +106,7 @@ func (r *receiver) parseLoop() { r.context.log("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack()) err, _ := p.(error) if err == nil { - err = fmt.Errorf("Panic: %v", p) + err = fmt.Errorf("panic: %v", p) } r.fatalError(err) } @@ -157,7 +157,7 @@ func (r *receiver) closePendingResponses() { func (r *receiver) handleIncomingFrame(frame []byte) error { // Parse BLIP header: if len(frame) < 2 { - return fmt.Errorf("Illegally short frame") + return fmt.Errorf("illegally short frame") } r.frameBuffer.Reset() r.frameBuffer.Write(frame) @@ -189,14 +189,14 @@ func (r *receiver) handleIncomingFrame(frame []byte) error { bufferedFrame := r.frameBuffer.Bytes() frameSize := len(bufferedFrame) if len(frame) < checksumLength { - return fmt.Errorf("Illegally short frame") + return fmt.Errorf("illegally short frame") } - checksumSlice := bufferedFrame[len(bufferedFrame)-checksumLength : len(bufferedFrame)] + checksumSlice := bufferedFrame[len(bufferedFrame)-checksumLength:] checksum := binary.BigEndian.Uint32(checksumSlice) r.frameBuffer.Truncate(r.frameBuffer.Len() - checksumLength) if r.context.LogFrames { - r.context.logFrame("Received frame: %s (flags=%8b, length=%d)", + r.context.logFrame("received frame: %s (flags=%8b, length=%d)", frameString(requestNumber, flags), flags, r.frameBuffer.Len()) } @@ -280,7 +280,7 @@ func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFla r.pendingRequests[requestNumber] = msgStream } } else { - return nil, fmt.Errorf("Bad incoming request number %d", requestNumber) + return nil, fmt.Errorf("bad incoming request number %d", requestNumber) } return msgStream, nil } @@ -301,7 +301,7 @@ func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFl r.context.log("Warning: Unexpected response frame to my msg #%d", requestNumber) // benign } else { // processing a response frame with a message number higher than any requests I've sent - err = fmt.Errorf("Bogus message number %d in response. Expected to be less than max pending response number (%d)", requestNumber, r.maxPendingResponseNumber) + err = fmt.Errorf("bogus message number %d in response. Expected to be less than max pending response number (%d)", requestNumber, r.maxPendingResponseNumber) } return } diff --git a/util_test.go b/util_test.go index a597f7e..25bfd1c 100644 --- a/util_test.go +++ b/util_test.go @@ -64,14 +64,14 @@ func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { case <-wgFinished: return nil case <-time.After(timeout): - return fmt.Errorf("Timed out waiting after %v", timeout) + return fmt.Errorf("timed out waiting after %v", timeout) } } // Returns the serialized form (properties + body) of a Message. func serializeMessage(t *testing.T, m *Message) []byte { var writer bytes.Buffer - err := m.WriteTo(&writer) + err := m.WriteEncodedTo(&writer) assert.NoError(t, err) return writer.Bytes() } From a7b936f01f65e11f1af5dd739d6d1457352cd93a Mon Sep 17 00:00:00 2001 From: Jens Alfke Date: Mon, 15 May 2023 12:52:38 -0700 Subject: [PATCH 3/6] Small cleanup --- context.go | 2 -- message.go | 2 +- messagequeue.go | 2 +- receiver.go | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/context.go b/context.go index 28dc5cf..71807e6 100644 --- a/context.go +++ b/context.go @@ -170,7 +170,6 @@ func (context *Context) DialConfig(opts *DialOptions) (*Sender, error) { // If the receiveLoop terminates, stop the sender as well defer sender.Stop() - // defer context.dispatcher.stop() // Update Expvar stats for number of outstanding goroutines incrReceiverGoroutines() @@ -246,7 +245,6 @@ func (bwss *blipWebsocketServer) handle(ws *websocket.Conn) { sender := bwss.blipCtx.start(ws) err := sender.receiver.receiveLoop() sender.Stop() - // bwss.blipCtx.dispatcher.stop() if err != nil && !isCloseError(err) { bwss.blipCtx.log("BLIP/Websocket Handler exited with error: %v", err) if bwss.blipCtx.FatalErrorHandler != nil { diff --git a/message.go b/message.go index 98fce4b..62d2ffd 100644 --- a/message.go +++ b/message.go @@ -49,7 +49,7 @@ type ErrorResponse struct { } // A callback function that takes a message and returns nothing -type MessageCallback func(*Message) +type messageCallback func(*Message) var ErrConnectionClosed = fmt.Errorf("BLIP connection closed") diff --git a/messagequeue.go b/messagequeue.go index 9d1c2c8..da5ebb8 100644 --- a/messagequeue.go +++ b/messagequeue.go @@ -104,7 +104,7 @@ func (q *messageQueue) push(msg *msgSender) bool { // Push an item into the queue, also providing a callback function that will be invoked // after the number is assigned to the message, but before pushing into the queue. -func (q *messageQueue) pushWithCallback(msg *msgSender, prepushCallback MessageCallback) bool { +func (q *messageQueue) pushWithCallback(msg *msgSender, prepushCallback messageCallback) bool { q.cond.L.Lock() defer q.cond.L.Unlock() diff --git a/receiver.go b/receiver.go index 3637506..fcc97f5 100644 --- a/receiver.go +++ b/receiver.go @@ -326,7 +326,7 @@ func (r *receiver) backlog() (pendingRequest, pendingResponses int) { return len(r.pendingRequests), len(r.pendingResponses) } -//////// REQUEST DISPATCHING +//////// REQUEST DISPATCHING & FLOW CONTROL func (r *receiver) dispatch(request *Message) { sender := r.sender From 3f1855f3bfe9fcdac379590847ffd83a3e2deeb9 Mon Sep 17 00:00:00 2001 From: Jens Alfke Date: Tue, 6 Jun 2023 11:24:23 -0700 Subject: [PATCH 4/6] Fixed race conditions handling a closed WebSocket If an incomplete incoming reply has already been dispatched to the app, don't change its body to an error when the connection closes. Also made receiver.closePendingResponses() be called on the receiver's parseLoop's goroutine, not the sender's. --- message.go | 11 +++++++---- receiver.go | 14 ++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/message.go b/message.go index 62d2ffd..f265556 100644 --- a/message.go +++ b/message.go @@ -480,10 +480,12 @@ func (m *msgSender) cancelOutgoing() { //////// RECEIVING MESSAGES: // A wrapper around an incoming Message while it's being received. +// Called only on the receiver's parseLoop goroutine. type msgReceiver struct { *Message bytesWritten uint64 propertiesBuffer []byte + dispatched bool } func newIncomingMessage(sender *Sender, number MessageNumber, flags frameFlags) *msgReceiver { @@ -554,11 +556,12 @@ func (m *msgReceiver) maybeSendAck(frameSize int) { // Informs an incoming message that the connection has closed func (m *msgReceiver) cancelIncoming() { - if m.inResponseTo != nil { - if m.bodyWriter != nil { - _ = m.bodyWriter.CloseWithError(ErrConnectionClosed) - } + if m.bodyWriter != nil { + _ = m.bodyWriter.CloseWithError(ErrConnectionClosed) + } + if !m.dispatched && m.inResponseTo != nil { m.setError(BLIPErrorDomain, DisconnectedCode, "Connection closed") + m.dispatched = true m.inResponseTo.responseAvailable(m.Message) } } diff --git a/receiver.go b/receiver.go index fcc97f5..a3d9d8a 100644 --- a/receiver.go +++ b/receiver.go @@ -67,6 +67,8 @@ func newReceiver(context *Context, conn *websocket.Conn) *receiver { return rcvr } +// Reads WebSocket messages, not returning until the socket is closed. +// Spawns parseLoop in a goroutine and pushes the messages to it. func (r *receiver) receiveLoop() error { defer atomic.AddInt32(&r.activeGoroutines, -1) atomic.AddInt32(&r.activeGoroutines, 1) @@ -99,6 +101,7 @@ func (r *receiver) receiveLoop() error { } } +// Goroutine created by receiveLoop that parses BLIP frames and dispatches messages. func (r *receiver) parseLoop() { defer func() { // Panic handler: atomic.AddInt32(&r.activeGoroutines, -1) @@ -127,21 +130,22 @@ func (r *receiver) parseLoop() { } r.context.logFrame("parseLoop stopped") + r.closePendingResponses() + returnDecompressor(r.frameDecoder) r.frameDecoder = nil } +// called on the parseLoop goroutine func (r *receiver) fatalError(err error) { r.context.log("Error: parseLoop closing socket due to error: %v", err) r.parseError <- err - r.stop() + r.conn.Close(websocket.StatusAbnormalClosure, "") } +// called by the sender func (r *receiver) stop() { - r.closePendingResponses() - r.conn.Close(websocket.StatusNormalClosure, "") - waitForZeroActiveGoroutines(r.context, &r.activeGoroutines) } @@ -254,11 +258,13 @@ func (r *receiver) processIncomingFrame(requestNumber MessageNumber, flags frame if state.atStart { // Dispatch response to its request message as soon as properties are available: r.context.logMessage("Received response %s", msgStream.Message) + msgStream.dispatched = true request.responseAvailable(msgStream.Message) // Response to outgoing request } } else { if /*state.atStart ||*/ state.atEnd { // Dispatch request to the dispatcher: + msgStream.dispatched = true r.dispatch(msgStream.Message) } } From 178e422fd034755113d80ecb694b7a95e101c879 Mon Sep 17 00:00:00 2001 From: Ben Brooks Date: Tue, 25 Mar 2025 12:53:29 +0000 Subject: [PATCH 5/6] Bump testify lib --- go.mod | 4 ++-- go.sum | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 42b2b8b..566a63f 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.23 require ( github.com/coder/websocket v1.8.12 github.com/klauspost/compress v1.17.11 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.10.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ae9b23c..5060e89 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,14 @@ github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 438f75c1dd2ea5b2552757add71b268b6791bd5b Mon Sep 17 00:00:00 2001 From: Ben Brooks Date: Tue, 25 Mar 2025 13:02:54 +0000 Subject: [PATCH 6/6] lint fixes --- context_test.go | 4 ++-- message.go | 5 +++-- properties.go | 14 ++++++-------- thread_pool.go | 6 +++--- util_test.go | 16 ++++++---------- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/context_test.go b/context_test.go index 7262a59..ad3d413 100644 --- a/context_test.go +++ b/context_test.go @@ -113,7 +113,7 @@ func TestServerAbruptlyCloseConnectionBehavior(t *testing.T) { // Read the echo response response := echoRequest.Response() // <--- SG #3268 was causing this to block indefinitely - assertBLIPError(t, response, BLIPErrorDomain, DisconnectedCode) + requireBLIPError(t, response, BLIPErrorDomain, DisconnectedCode) } /* @@ -183,7 +183,7 @@ func TestClientAbruptlyCloseConnectionBehavior(t *testing.T) { sent := clientSender.Send(echoAmplifyRequest) assert.True(t, sent) echoAmplifyResponse := echoAmplifyRequest.Response() // <--- SG #3268 was causing this to block indefinitely - assertBLIPError(t, echoAmplifyResponse, BLIPErrorDomain, DisconnectedCode) + requireBLIPError(t, echoAmplifyResponse, BLIPErrorDomain, DisconnectedCode) } // Create a blip profile handler to respond to echo requests and then abruptly close the socket diff --git a/message.go b/message.go index 84715b1..bb3bd52 100644 --- a/message.go +++ b/message.go @@ -309,13 +309,14 @@ func (message *Message) Clone() *Message { message.bodyMutex.Lock() defer message.bodyMutex.Unlock() - return &Message{ + m := &Message{ Outgoing: message.Outgoing, Properties: message.Properties, body: message.body, number: message.number, - flags: message.flags, } + m.flags.Store(message.flags.Load()) + return m } //////// RESPONSE HANDLING: diff --git a/properties.go b/properties.go index 74df0b0..638cc3f 100644 --- a/properties.go +++ b/properties.go @@ -13,6 +13,7 @@ package blip import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "sort" @@ -35,10 +36,9 @@ func ReadProperties(body []byte) (properties Properties, bytesRead int, err erro length, bytesRead := binary.Uvarint(body) if bytesRead == 0 { // Not enough bytes to read the varint - return + return nil, 0, nil } else if bytesRead < 0 || length > maxPropertiesLength { - err = fmt.Errorf("invalid properties length in BLIP message") - return + return nil, bytesRead, errors.New("invalid properties length in BLIP message") } else if bytesRead+int(length) > len(body) { // Incomplete return nil, 0, nil @@ -53,14 +53,12 @@ func ReadProperties(body []byte) (properties Properties, bytesRead int, err erro bytesRead += int(length) if body[length-1] != 0 { - err = fmt.Errorf("invalid properties (not NUL-terminated)") - return + return nil, 0, errors.New("invalid properties (not NUL-terminated)") } eachProp := bytes.Split(body[0:length-1], []byte{0}) nProps := len(eachProp) / 2 if nProps*2 != len(eachProp) { - err = fmt.Errorf("odd number of strings in properties") - return + return nil, bytesRead, errors.New("odd number of strings in properties") } properties = Properties{} for i := 0; i < len(eachProp); i += 2 { @@ -71,7 +69,7 @@ func ReadProperties(body []byte) (properties Properties, bytesRead int, err erro } properties[key] = value } - return + return properties, bytesRead, nil } // Writes Properties to a stream. diff --git a/thread_pool.go b/thread_pool.go index 0a1aeec..0db311c 100644 --- a/thread_pool.go +++ b/thread_pool.go @@ -44,7 +44,7 @@ func (p *ThreadPool) Start() { select { case fn := <-ch: if fn != nil { - p.perform(fn) + p.perform(i, fn) } else { return } @@ -106,13 +106,13 @@ func (p *ThreadPool) WrapAsyncHandler(handler AsyncHandler) AsyncHandler { } } -func (p *ThreadPool) perform(fn func()) { +func (p *ThreadPool) perform(i int, fn func()) { defer func() { if panicked := recover(); panicked != nil { if p.PanicHandler != nil { p.PanicHandler(panicked) } else { - log.Printf("PANIC in ThreadPool function: %v\n%s", panicked, debug.Stack()) + log.Printf("PANIC in ThreadPool[%d] function: %v\n%s", i, panicked, debug.Stack()) } } }() diff --git a/util_test.go b/util_test.go index 2589440..5611230 100644 --- a/util_test.go +++ b/util_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Starts a WebSocket server on the Context, returning its net.Listener. @@ -87,15 +88,10 @@ func assertEqualMessages(t *testing.T, m, m2 *Message) bool { } -// Asserts that `response` contains a BLIP error with the given domain and code. -func assertBLIPError(t *testing.T, response *Message, domain string, code int) bool { +// requireBLIPError requires that `response` contains a BLIP error with the given domain and code. +func requireBLIPError(t *testing.T, response *Message, domain string, code int) { responseError := response.Error() - if assert.NotNil(t, responseError, "Expected a BLIP error") { - if responseError.Domain == domain && responseError.Code == code { - return true - } - assert.Fail(t, "Unexpected BLIP error", "Got %v: expected %v", - responseError, &ErrorResponse{domain, code, ""}) - } - return false + require.Errorf(t, responseError, "Expected a BLIP error") + require.Equal(t, domain, responseError.Domain) + require.Equal(t, code, responseError.Code) }