Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(POC) GODRIVER-3414 Complete pending reads on conn checkout #1937

Draft
wants to merge 5 commits into
base: v1
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions event/monitoring.go
Original file line number Diff line number Diff line change
@@ -91,17 +91,20 @@ const (

// strings for pool command monitoring types
const (
PoolCreated = "ConnectionPoolCreated"
PoolReady = "ConnectionPoolReady"
PoolCleared = "ConnectionPoolCleared"
PoolClosedEvent = "ConnectionPoolClosed"
ConnectionCreated = "ConnectionCreated"
ConnectionReady = "ConnectionReady"
ConnectionClosed = "ConnectionClosed"
GetStarted = "ConnectionCheckOutStarted"
GetFailed = "ConnectionCheckOutFailed"
GetSucceeded = "ConnectionCheckedOut"
ConnectionReturned = "ConnectionCheckedIn"
PoolCreated = "ConnectionPoolCreated"
PoolReady = "ConnectionPoolReady"
PoolCleared = "ConnectionPoolCleared"
PoolClosedEvent = "ConnectionPoolClosed"
ConnectionCreated = "ConnectionCreated"
ConnectionReady = "ConnectionReady"
ConnectionClosed = "ConnectionClosed"
GetStarted = "ConnectionCheckOutStarted"
GetFailed = "ConnectionCheckOutFailed"
GetSucceeded = "ConnectionCheckedOut"
ConnectionReturned = "ConnectionCheckedIn"
ConnectionPendingReadStarted = "ConnectionPendingReadStarted"
ConnectionPendingReadSucceeded = "ConnectionPendingReadSucceeded"
ConnectionPendingReadFailed = "ConnectionPendingReadFailed"
)

// MonitorPoolOptions contains pool options as formatted in pool events
@@ -121,9 +124,11 @@ type PoolEvent struct {
Reason string `json:"reason"`
// ServiceID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field
// can be used to distinguish between individual servers in a load balanced deployment.
ServiceID *primitive.ObjectID `json:"serviceId"`
Interruption bool `json:"interruptInUseConnections"`
Error error `json:"error"`
ServiceID *primitive.ObjectID `json:"serviceId"`
Interruption bool `json:"interruptInUseConnections"`
Error error `json:"error"`
RequestID int32 `json:"requestId"`
RemainingTime time.Duration `json:"remainingTime"`
}

// PoolMonitor is a function that allows the user to gain access to events occurring in the pool
34 changes: 34 additions & 0 deletions internal/driverutil/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) MongoDB, Inc. 2025-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package driverutil

import "context"

type ContextKey string

const (
ContextKeyHasMaxTimeMS ContextKey = "hasMaxTimeMS"
ContextKeyRequestID ContextKey = "requestID"
)

func WithValueHasMaxTimeMS(parentCtx context.Context, val bool) context.Context {
return context.WithValue(parentCtx, ContextKeyHasMaxTimeMS, val)
}

func WithRequestID(parentCtx context.Context, requestID int32) context.Context {
return context.WithValue(parentCtx, ContextKeyRequestID, requestID)
}

func HasMaxTimeMS(ctx context.Context) bool {
return ctx.Value(ContextKeyHasMaxTimeMS) != nil
}

func GetRequestID(ctx context.Context) (int32, bool) {
val, ok := ctx.Value(ContextKeyRequestID).(int32)

return val, ok
}
3 changes: 3 additions & 0 deletions internal/logger/component.go
Original file line number Diff line number Diff line change
@@ -28,6 +28,9 @@ const (
ConnectionCheckoutFailed = "Connection checkout failed"
ConnectionCheckedOut = "Connection checked out"
ConnectionCheckedIn = "Connection checked in"
ConnectionPendingReadStarted = "Pending read started"
ConnectionPendingReadSucceeded = "Pending read succeeded"
ConnectionPendingReadFailed = "Pending read failed"
ServerSelectionFailed = "Server selection failed"
ServerSelectionStarted = "Server selection started"
ServerSelectionSucceeded = "Server selection succeeded"
12 changes: 12 additions & 0 deletions internal/ptrutil/ptr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package ptrutil

// Ptr will return the memory location of the given value.
func Ptr[T any](val T) *T {
return &val
}
6 changes: 3 additions & 3 deletions mongo/integration/csot_test.go
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ func TestCSOT_maxTimeMS(t *testing.T) {
},
sendsMaxTimeMSWithTimeoutMS: true,
sendsMaxTimeMSWithContextDeadline: false,
preventsConnClosureWithTimeoutMS: true,
preventsConnClosureWithTimeoutMS: false,
},
{
desc: "FindOneAndDelete",
@@ -206,7 +206,7 @@ func TestCSOT_maxTimeMS(t *testing.T) {
},
sendsMaxTimeMSWithTimeoutMS: true,
sendsMaxTimeMSWithContextDeadline: false,
preventsConnClosureWithTimeoutMS: true,
preventsConnClosureWithTimeoutMS: false,
},
{
desc: "Watch",
@@ -220,7 +220,7 @@ func TestCSOT_maxTimeMS(t *testing.T) {
},
sendsMaxTimeMSWithTimeoutMS: true,
sendsMaxTimeMSWithContextDeadline: true,
preventsConnClosureWithTimeoutMS: true,
preventsConnClosureWithTimeoutMS: false,
// Change Streams aren't supported on standalone topologies.
topologies: []mtest.TopologyKind{
mtest.ReplicaSet,
519 changes: 519 additions & 0 deletions testdata/client-side-operations-timeout/pending-reads.json

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions testdata/client-side-operations-timeout/pending-reads.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
description: "Operation timeouts do not cause connection churn"

schemaVersion: "1.9"

runOnRequirements:
- minServerVersion: "4.4"
# TODO(SERVER-96344): When using failpoints, mongos returns MaxTimeMSExpired
# after maxTimeMS, whereas mongod returns it after
# max(blockTimeMS, maxTimeMS). Until this ticket is resolved, these tests
# will not pass on sharded clusters.
topologies: ["standalone", "replicaset"]

createEntities:
- client:
id: &failPointClient failPointClient
useMultipleMongoses: false
- client:
id: &client client
uriOptions:
maxPoolSize: 1
useMultipleMongoses: false
observeEvents:
- commandFailedEvent
- commandSucceededEvent
- connectionCheckedOutEvent
- connectionCheckedInEvent
- connectionClosedEvent
- database:
id: &database test
client: *client
databaseName: *database
- collection:
id: &collection coll
database: *database
collectionName: *collection

initialData:
- collectionName: *collection
databaseName: *database
documents: []

tests:
- description: "Write operation with successful pending read"
operations:
# Create a failpoint to block the first operation
- name: failPoint
object: testRunner
arguments:
client: *failPointClient
failPoint:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: ["insert"]
blockConnection: true
blockTimeMS: 750

# Execute operation with timeout less than block time
- name: insertOne
object: *collection
arguments:
timeoutMS: 50
document: { _id: 3, x: 1 }
expectError:
isTimeoutError: true

# Execute a subsequent operation to complete the read
- name: findOne
object: *collection
arguments:
filter: { _id: 1 }

expectEvents:
- client: *client
events:
- commandFailedEvent:
commandName: insert
- commandSucceededEvent:
commandName: find
- client: *client
eventType: cmap
events:
- connectionCheckedOutEvent: {} # insert
- connectionCheckedInEvent: {} # insert fails
- connectionCheckedOutEvent: {} # find
- connectionCheckedInEvent: {} # find succeeds

- description: "Concurrent write operation with successful pending read"
operations:
# Create a failpoint to block the first operation
- name: failPoint
object: testRunner
arguments:
client: *failPointClient
failPoint:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: ["insert"]
blockConnection: true
blockTimeMS: 750

# Start threads.
- name: createEntities
object: testRunner
arguments:
entities:
- thread:
id: &thread0 thread0
- thread:
id: &thread1 thread1

# Run an insert in two threads. We expect the first to time out and the
# second to finish the pending read from the first and complete
# successfully.
- name: runOnThread
object: testRunner
arguments:
thread: *thread0
operation:
name: insertOne
object: *collection
arguments:
timeoutMS: 500
document:
_id: 2
expectError:
isTimeoutError: true

# Ensure the first thread checks out a connection before executing the
# operation in the second thread. This maintains concurrent behavior but
# presents the worst case scenario.
- name: waitForEvent
object: testRunner
arguments:
client: *client
event:
connectionCheckedOutEvent: {}
count: 1

- name: runOnThread
object: testRunner
arguments:
thread: *thread1
operation:
name: insertOne
object: *collection
arguments:
document:
_id: 3

# Stop threads.
- name: waitForThread
object: testRunner
arguments:
thread: *thread1

expectEvents:
- client: *client
events:
- commandFailedEvent:
commandName: insert
- commandSucceededEvent:
commandName: insert
- client: *client
eventType: cmap
events:
- connectionCheckedOutEvent: {} # insert
- connectionCheckedInEvent: {} # insert fails
- connectionCheckedOutEvent: {} # find
- connectionCheckedInEvent: {} # find succeeds

- description: "Write operation with unsuccessful pending read"
operations:
# Create a failpoint to block the first operation
- name: failPoint
object: testRunner
arguments:
client: *failPointClient
failPoint:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: ["insert"]
blockConnection: true
blockTimeMS: 1100

# Execute operation with timeout less than block time
- name: insertOne
object: *collection
arguments:
timeoutMS: 50
document: { _id: 3, x: 1 }
expectError:
isTimeoutError: true

# The pending read should fail
- name: insertOne
object: *collection
arguments:
timeoutMS: 1000
document: { _id: 3, x: 1 }
expectError:
isTimeoutError: true

expectEvents:
- client: *client
events:
- commandFailedEvent:
commandName: insert
# No second failed event since we timed out attempting to check out
# the connection for the second operation
- client: *client
eventType: cmap
events:
- connectionCheckedOutEvent: {} # first insert
- connectionCheckedInEvent: {} # first insert fails
- connectionClosedEvent: # second insert times out pending read in checkout, closes
reason: error

- description: "Read operation with successful pending read"
operations:
# Create a failpoint to block the first operation
- name: failPoint
object: testRunner
arguments:
client: *failPointClient
failPoint:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: ["find"]
blockConnection: true
blockTimeMS: 750

# Execute operation with timeout less than block time
- name: findOne
object: *collection
arguments:
timeoutMS: 50
filter: { _id: 1 }
expectError:
isTimeoutError: true

# Execute a subsequent operation to complete the read
- name: findOne
object: *collection
arguments:
filter: { _id: 1 }

expectEvents:
- client: *client
events:
- commandFailedEvent:
commandName: find
- commandSucceededEvent:
commandName: find
- client: *client
eventType: cmap
events:
- connectionCheckedOutEvent: {} # first find
- connectionCheckedInEvent: {} # first find fails
- connectionCheckedOutEvent: {} # second find
- connectionCheckedInEvent: {} # second find succeeds

- description: "Read operation with unsuccessful pending read"
operations:
# Create a failpoint to block the first operation
- name: failPoint
object: testRunner
arguments:
client: *failPointClient
failPoint:
configureFailPoint: failCommand
mode: { times: 1 }
data:
failCommands: ["find"]
blockConnection: true
blockTimeMS: 1100

# Execute operation with timeout less than block time
- name: findOne
object: *collection
arguments:
timeoutMS: 50
filter: { _id: 1 }
expectError:
isTimeoutError: true

# The pending read should fail
- name: findOne
object: *collection
arguments:
timeoutMS: 1000
filter: { _id: 1 }
expectError:
isTimeoutError: true

expectEvents:
- client: *client
events:
- commandFailedEvent:
commandName: find
# No second failed event since we timed out attempting to check out
# the connection for the second operation
- client: *client
eventType: cmap
events:
- connectionCheckedOutEvent: {} # first find
- connectionCheckedInEvent: {} # first find fails
- connectionClosedEvent: # second find times out pending read in checkout, closes
reason: error
6 changes: 6 additions & 0 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
@@ -807,6 +807,12 @@ func (op Operation) Execute(ctx context.Context) error {
if moreToCome {
roundTrip = op.moreToComeRoundTrip
}

if maxTimeMS != 0 {
ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, startedInfo.requestID)
}

res, err = roundTrip(ctx, conn, *wm)

if ep, ok := srvr.(ErrorProcessor); ok {
30 changes: 24 additions & 6 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ import (
"time"

"go.mongodb.org/mongo-driver/internal/csot"
"go.mongodb.org/mongo-driver/internal/driverutil"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
@@ -46,6 +47,12 @@ var (

func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }

type pendingReadState struct {
remainingBytes int32
requestID int32
remainingTime *time.Duration
}

type connection struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
@@ -83,7 +90,8 @@ type connection struct {

// awaitRemainingBytes indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitRemainingBytes *int32
// awaitRemainingBytes *int32
pendingReadState *pendingReadState

// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
// accessTokens in the OIDC authenticator cache.
@@ -452,7 +460,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if c.awaitRemainingBytes == nil {
if c.pendingReadState == nil {
// If the connection was not marked as awaiting response, use the
// pre-CSOT behavior and close the connection because we don't know
// if there are other bytes left to read.
@@ -523,8 +531,13 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
// reading messages from an exhaust cursor.
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &l
if l := int32(n); l == 0 && isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) {
requestID, _ := driverutil.GetRequestID(ctx)

c.pendingReadState = &pendingReadState{
remainingBytes: l,
requestID: requestID,
}
}
return nil, "incomplete read of message header", err
}
@@ -539,8 +552,13 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &remainingBytes
if remainingBytes > 0 && isCSOTTimeout(err) && driverutil.HasMaxTimeMS(ctx) {
requestID, _ := driverutil.GetRequestID(ctx)

c.pendingReadState = &pendingReadState{
remainingBytes: remainingBytes,
requestID: requestID,
}
}
return dst, "incomplete read of full message", err
}
233 changes: 166 additions & 67 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ package topology

import (
"context"
"errors"
"fmt"
"io"
"net"
@@ -18,6 +19,7 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/logger"
"go.mongodb.org/mongo-driver/internal/ptrutil"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/x/mongo/driver"
)
@@ -573,6 +575,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) {
return nil, w.err
}

if err := awaitPendingRead(ctx, p, w.conn); err != nil {
return nil, err
}

duration = time.Since(start)
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
@@ -629,6 +635,10 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) {
return nil, w.err
}

if err := awaitPendingRead(ctx, p, w.conn); err != nil {
return nil, err
}

duration := time.Since(start)
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
@@ -768,83 +778,187 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro
return nil
}

var (
// BGReadTimeout is the maximum amount of the to wait when trying to read
// the server reply on a connection after an operation timed out. The
// default is 1 second.
//
// Deprecated: BGReadTimeout is intended for internal use only and may be
// removed or modified at any time.
BGReadTimeout = 1 * time.Second
// PendingReadTimeout is the maximum amount of the to wait when trying to read
// the server reply on a connection after an operation timed out. The
// default is 1 second.
//
// Deprecated: PendingReadTimeout is intended for internal use only and may be
// removed or modified at any time.
var PendingReadTimeout = 1000 * time.Millisecond

// awaitPendingRead sets a new read deadline on the provided connection and
// tries to read any bytes returned by the server. If there are any errors, the
// connection will be checked back into the pool to be retried.
func awaitPendingRead(ctx context.Context, pool *pool, conn *connection) error {
// If there are no bytes pending read, do nothing.
if conn.pendingReadState == nil {
return nil
}

// BGReadCallback is a callback for monitoring the behavior of the
// background-read-on-timeout connection preserving mechanism.
//
// Deprecated: BGReadCallback is intended for internal use only and may be
// removed or modified at any time.
BGReadCallback func(addr string, start, read time.Time, errs []error, connClosed bool)
)
prs := conn.pendingReadState
if prs.remainingTime == nil {
prs.remainingTime = ptrutil.Ptr(PendingReadTimeout)
}

// bgRead sets a new read deadline on the provided connection (1 second in the
// future) and tries to read any bytes returned by the server. If successful, it
// checks the connection into the provided pool. If there are any errors, it
// closes the connection.
//
// It calls the package-global BGReadCallback function, if set, with the
// address, timings, and any errors that occurred.
func bgRead(pool *pool, conn *connection, size int32) {
var err error
start := time.Now()
if mustLogPoolMessage(pool) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
logger.KeyRequestID, prs.requestID,
}

logPoolMessage(pool, logger.ConnectionPendingReadStarted, keysAndValues...)
}

if pool.monitor != nil {
event := &event.PoolEvent{
Type: event.ConnectionPendingReadStarted,
ConnectionID: conn.driverConnectionID,
RequestID: prs.requestID,
}

pool.monitor.Event(event)
}

size := prs.remainingBytes

checkIn := false
var someErr error

defer func() {
read := time.Now()
errs := make([]error, 0)
connClosed := false
if err != nil {
errs = append(errs, err)
connClosed = true
err = conn.close()
if err != nil {
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err))
if mustLogPoolMessage(pool) && someErr != nil {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
logger.KeyRequestID, prs.requestID,
logger.KeyReason, someErr.Error(),
logger.KeyRemainingTimeMS, *prs.remainingTime,
}

logPoolMessage(pool, logger.ConnectionPendingReadFailed, keysAndValues...)
}

if pool.monitor != nil && someErr != nil {
event := &event.PoolEvent{
Type: event.ConnectionPendingReadFailed,
Address: pool.address.String(),
ConnectionID: conn.driverConnectionID,
RequestID: prs.requestID,
RemainingTime: *prs.remainingTime,
Reason: someErr.Error(),
Error: someErr,
}

pool.monitor.Event(event)
}

// If we have exceeded the time limit, then close the connection.
if prs.remainingTime != nil && *prs.remainingTime < 0 {
if err := conn.close(); err != nil {
panic(err)
}

return
}

if !checkIn {
return
}

// No matter what happens, always check the connection back into the
// pool, which will either make it available for other operations or
// remove it from the pool if it was closed.
err = pool.checkInNoEvent(conn)
if err != nil {
errs = append(errs, fmt.Errorf("error checking in: %w", err))
}
//
// TODO(GODRIVER-3385): Figure out how to handle this error. It's possible
// that a single connection can be checked out to handle multiple concurrent
// operations. This is likely a bug in the Go Driver. So it's possible that
// the connection is idle at the point of check-in.
_ = pool.checkInNoEvent(conn)
}()

if BGReadCallback != nil {
BGReadCallback(conn.addr.String(), start, read, errs, connClosed)
dl, contextDeadlineUsed := ctx.Deadline()
if !contextDeadlineUsed {
// If there is a remainingTime, use that. If not, use the static
// PendingReadTimeout. This is required since a user could provide a timeout
// for the first try that does not exceed the pending read timeout, fail,
// and then not use a timeout for a subsequent try.
if prs.remainingTime != nil {
dl = time.Now().Add(*prs.remainingTime)
} else {
dl = time.Now().Add(PendingReadTimeout)
}
}()
}

err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout))
err := conn.nc.SetReadDeadline(dl)
if err != nil {
err = fmt.Errorf("error setting a read deadline: %w", err)
return
checkIn = true

someErr = fmt.Errorf("error setting a read deadline: %w", err)

return someErr
}

if size == 0 {
st := time.Now()

if size == 0 { // Question: Would this alawys equal to zero?
var sizeBuf [4]byte
_, err = io.ReadFull(conn.nc, sizeBuf[:])
if err != nil {
err = fmt.Errorf("error reading the message size: %w", err)
return
if _, err := io.ReadFull(conn.nc, sizeBuf[:]); err != nil {
prs.remainingTime = ptrutil.Ptr(*prs.remainingTime - time.Since(st))
checkIn = true

err = transformNetworkError(ctx, err, contextDeadlineUsed)
someErr = fmt.Errorf("error reading the message size: %w", err)

return someErr
}
size, err = conn.parseWmSizeBytes(sizeBuf)
if err != nil {
return
checkIn = true
someErr = transformNetworkError(ctx, err, contextDeadlineUsed)

return someErr
}
size -= 4
}
_, err = io.CopyN(io.Discard, conn.nc, int64(size))

n, err := io.CopyN(io.Discard, conn.nc, int64(size))
if err != nil {
err = fmt.Errorf("error discarding %d byte message: %w", size, err)
// If the read times out, record the bytes left to read before exiting.
nerr := net.Error(nil)
if l := int32(n); l == 0 && errors.As(err, &nerr) && nerr.Timeout() {
prs.remainingBytes = l + prs.remainingBytes
prs.remainingTime = ptrutil.Ptr(*prs.remainingTime - time.Since(st))
}

checkIn = true

err = transformNetworkError(ctx, err, contextDeadlineUsed)
someErr = fmt.Errorf("error discarding %d byte message: %w", size, err)

return someErr
}

if mustLogPoolMessage(pool) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
logger.KeyRequestID, prs.requestID,
}

logPoolMessage(pool, logger.ConnectionPendingReadSucceeded, keysAndValues...)
}

if pool.monitor != nil {
event := &event.PoolEvent{
Type: event.ConnectionPendingReadSucceeded,
Address: pool.address.String(),
ConnectionID: conn.driverConnectionID,
Duration: time.Since(st),
}

pool.monitor.Event(event)
}

conn.pendingReadState = nil

return nil
}

// checkIn returns an idle connection to the pool. If the connection is perished or the pool is
@@ -886,21 +1000,6 @@ func (p *pool) checkInNoEvent(conn *connection) error {
return ErrWrongPool
}

// If the connection has an awaiting server response, try to read the
// response in another goroutine before checking it back into the pool.
//
// Do this here because we want to publish checkIn events when the operation
// is done with the connection, not when it's ready to be used again. That
// means that connections in "awaiting response" state are checked in but
// not usable, which is not covered by the current pool events. We may need
// to add pool event information in the future to communicate that.
if conn.awaitRemainingBytes != nil {
size := *conn.awaitRemainingBytes
conn.awaitRemainingBytes = nil
go bgRead(p, conn, size)
return nil
}

// Bump the connection idle start time here because we're about to make the
// connection "available". The idle start time is used to determine how long
// a connection has been idle and when it has reached its max idle time and
218 changes: 123 additions & 95 deletions x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ import (
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/csot"
"go.mongodb.org/mongo-driver/internal/driverutil"
"go.mongodb.org/mongo-driver/internal/eventtest"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/mongo/address"
@@ -1198,24 +1199,10 @@ func TestPool(t *testing.T) {
})
}

func TestBackgroundRead(t *testing.T) {
func TestAwaitPendingRead(t *testing.T) {
t.Parallel()

newBGReadCallback := func(errsCh chan []error) func(string, time.Time, time.Time, []error, bool) {
return func(_ string, _, _ time.Time, errs []error, _ bool) {
errsCh <- errs
close(errsCh)
}
}

t.Run("incomplete read of message header", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
@@ -1241,22 +1228,18 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil")
close(errsCh) // this line causes a double close if BGReadCallback is ever called.
assert.Nil(t, conn.pendingReadState, "conn.awaitRemainingBytes should be nil")
})
t.Run("timeout reading message header, successful background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
@@ -1270,40 +1253,48 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
})

var pendingReadError error
monitor := &event.PoolMonitor{
Event: func(pe *event.PoolEvent) {
if pe.Type == event.ConnectionPendingReadFailed {
pendingReadError = pe.Error
}
},
}

p := newPool(
poolConfig{Address: address.Address(addr.String())},
poolConfig{
Address: address.Address(addr.String()),
PoolMonitor: monitor,
},
)
defer p.close(context.Background())
err := p.ready()
require.NoError(t, err)

conn, err := p.checkOut(context.Background())
require.NoError(t, err)

ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
require.NoError(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 0, "expected no error from bgRead()")

_, err = p.checkOut(context.Background())
require.NoError(t, err)

require.NoError(t, pendingReadError)
})
t.Run("timeout reading message header, incomplete head during background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
@@ -1317,8 +1308,20 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
})

var pendingReadError error
monitor := &event.PoolMonitor{
Event: func(pe *event.PoolEvent) {
if pe.Type == event.ConnectionPendingReadFailed {
pendingReadError = pe.Error
}
},
}

p := newPool(
poolConfig{Address: address.Address(addr.String())},
poolConfig{
Address: address.Address(addr.String()),
PoolMonitor: monitor,
},
)
defer p.close(context.Background())
err := p.ready()
@@ -1328,30 +1331,24 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
require.NoError(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")
assert.EqualError(t, bgErrs[0], "error reading the message size: unexpected EOF")

_, err = p.checkOut(context.Background())
require.Error(t, err)

assert.EqualError(t, pendingReadError, "error reading the message size: unexpected EOF")
})
t.Run("timeout reading message header, background read timeout", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

cleanup := make(chan struct{})
@@ -1369,44 +1366,52 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
})

var pendingReadError error
monitor := &event.PoolMonitor{
Event: func(pe *event.PoolEvent) {
if pe.Type == event.ConnectionPendingReadFailed {
pendingReadError = pe.Error
}
},
}

p := newPool(
poolConfig{Address: address.Address(addr.String())},
poolConfig{
Address: address.Address(addr.String()),
PoolMonitor: monitor,
},
)

defer p.close(context.Background())
err := p.ready()
require.NoError(t, err)

conn, err := p.checkOut(context.Background())
require.NoError(t, err)

ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
require.NoError(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")

_, err = p.checkOut(context.Background())
require.Error(t, err)

wantErr := regexp.MustCompile(
`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, wantErr.MatchString(bgErrs[0].Error()), "error %q does not match pattern %q", bgErrs[0], wantErr)
assert.True(t, wantErr.MatchString(pendingReadError.Error()), "error %q does not match pattern %q", pendingReadError, wantErr)
})
t.Run("timeout reading full message, successful background read", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
@@ -1423,9 +1428,22 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
})

var pendingReadError error
monitor := &event.PoolMonitor{
Event: func(pe *event.PoolEvent) {
if pe.Type == event.ConnectionPendingReadFailed {
pendingReadError = pe.Error
}
},
}

p := newPool(
poolConfig{Address: address.Address(addr.String())},
poolConfig{
Address: address.Address(addr.String()),
PoolMonitor: monitor,
},
)

defer p.close(context.Background())
err := p.ready()
require.NoError(t, err)
@@ -1434,29 +1452,24 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
require.NoError(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 0, "expected no error from bgRead()")

_, err = p.checkOut(context.Background())
require.NoError(t, err)

require.NoError(t, pendingReadError)
})
t.Run("timeout reading full message, background read EOF", func(t *testing.T) {
errsCh := make(chan []error)
var originalCallback func(string, time.Time, time.Time, []error, bool)
originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh)
t.Cleanup(func() {
BGReadCallback = originalCallback
})

const timeout = 10 * time.Millisecond

addr := bootstrapConnections(t, 1, func(nc net.Conn) {
@@ -1473,32 +1486,47 @@ func TestBackgroundRead(t *testing.T) {
require.NoError(t, err)
})

var pendingReadError error
monitor := &event.PoolMonitor{
Event: func(pe *event.PoolEvent) {
if pe.Type == event.ConnectionPendingReadFailed {
pendingReadError = pe.Error
}
},
}

p := newPool(
poolConfig{Address: address.Address(addr.String())},
poolConfig{
Address: address.Address(addr.String()),
PoolMonitor: monitor,
},
)

defer p.close(context.Background())
err := p.ready()
require.NoError(t, err)

conn, err := p.checkOut(context.Background())
require.NoError(t, err)

ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
defer cancel()

ctx = driverutil.WithValueHasMaxTimeMS(ctx, true)
ctx = driverutil.WithRequestID(ctx, -1)

_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
require.NoError(t, err)
var bgErrs []error
select {
case bgErrs = <-errsCh:
case <-time.After(3 * time.Second):
assert.Fail(t, "did not receive expected error after waiting for 3 seconds")
}
require.Len(t, bgErrs, 1, "expected 1 error from bgRead()")
assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF")

_, err = p.checkOut(context.Background())
require.Error(t, err)

assert.EqualError(t, pendingReadError, "error discarding 3 byte message: EOF")
})
}