diff --git a/internal/ptrutil/ptr.go b/internal/ptrutil/ptr.go new file mode 100644 index 0000000000..bf64aad178 --- /dev/null +++ b/internal/ptrutil/ptr.go @@ -0,0 +1,12 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package ptrutil + +// Ptr will return the memory location of the given value. +func Ptr[T any](val T) *T { + return &val +} diff --git a/mongo/collection.go b/mongo/collection.go index dbe238a9e3..523b12d23a 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1215,13 +1215,14 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, // // See DRIVERS-2722 for more detail. _, deadlineSet := ctx.Deadline() - return coll.find(ctx, filter, deadlineSet, opts...) + return coll.find(ctx, filter, deadlineSet, false, opts...) } func (coll *Collection) find( ctx context.Context, filter interface{}, omitCSOTMaxTimeMS bool, + unsafeAllowSeperateMaxTimeMS bool, opts ...*options.FindOptions, ) (cur *Cursor, err error) { @@ -1260,7 +1261,8 @@ func (coll *Collection) find( ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger). - OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator) + OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator). + UnsafeAllowSeperateMaxTimeMS(unsafeAllowSeperateMaxTimeMS) cursorOpts := coll.client.createBaseCursorOptions() @@ -1408,6 +1410,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, ctx = context.Background() } + var unsafeAllowSeperateMaxTimeMS bool findOpts := make([]*options.FindOptions, 0, len(opts)) for _, opt := range opts { if opt == nil { @@ -1433,12 +1436,16 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, Snapshot: opt.Snapshot, Sort: opt.Sort, }) + + if opt.UnsafeAllowSeperateMaxTimeMS { + unsafeAllowSeperateMaxTimeMS = opt.UnsafeAllowSeperateMaxTimeMS + } } // Unconditionally send a limit to make sure only one document is returned and the cursor is not kept open // by the server. findOpts = append(findOpts, options.Find().SetLimit(-1)) - cursor, err := coll.find(ctx, filter, false, findOpts...) + cursor, err := coll.find(ctx, filter, false, unsafeAllowSeperateMaxTimeMS, findOpts...) return &SingleResult{ ctx: ctx, cur: cursor, diff --git a/mongo/integration/csot_test.go b/mongo/integration/csot_test.go index da622cb94e..5b5fc06ac4 100644 --- a/mongo/integration/csot_test.go +++ b/mongo/integration/csot_test.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/eventtest" + "go.mongodb.org/mongo-driver/internal/ptrutil" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/integration/mtest" @@ -279,6 +280,9 @@ func TestCSOT_maxTimeMS(t *testing.T) { evt := getStartedEvent(mt, command) maxTimeVal := evt.Command.Lookup("maxTimeMS") + if len(maxTimeVal.Value) == 0 { + return -1 + } require.Greater(mt, len(maxTimeVal.Value), @@ -591,6 +595,150 @@ func TestCSOT_maxTimeMS(t *testing.T) { maxTimeMS, "expected maxTimeMS to be equal to the MaxTime value") }) + + mt.Run("UnsafeAllowSeperateMaxTimeMSWithCSOT", func(mt *mtest.T) { + ops := []struct { + name string + commandName string + fn func(ctx context.Context, coll *mongo.Collection, maxTime *time.Duration) error + cursorOp bool + }{ + { + name: "FindOne", + commandName: "find", + fn: func(ctx context.Context, coll *mongo.Collection, maxTime *time.Duration) error { + opts := options.FindOne() + opts.UnsafeAllowSeperateMaxTimeMS = true + if maxTime != nil { + opts.SetMaxTime(*maxTime) + } + res := coll.FindOne(ctx, bson.D{}, opts) + return res.Err() + }, + cursorOp: false, + }, + //{ + // name: "Find", + // commandName: "find", + // fn: func(ctx context.Context, coll *mongo.Collection, maxTime *time.Duration) error { + // opts := options.Find() + // if maxTime != nil { + // opts.SetMaxTime(*maxTime) + // } + // _, err := coll.Find(ctx, bson.D{}, opts) + // return err + // }, + // cursorOp: true, + // }, + //{ + // name: "FindOneAndUpdate", + // commandName: "findAndModify", + // fn: func(ctx context.Context, coll *mongo.Collection, maxTime *time.Duration) error { + // opts := options.FindOneAndUpdate() + // if maxTime != nil { + // opts.SetMaxTime(*maxTime) + // } + // res := coll.FindOneAndUpdate(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}, opts) + // return res.Err() + // }, + // cursorOp: false, + // }, + //{ + // name: "Aggregate", + // commandName: "aggregate", + // fn: func(ctx context.Context, coll *mongo.Collection, maxTime *time.Duration) error { + // opts := options.Aggregate() + // if maxTime != nil { + // opts.SetMaxTime(*maxTime) + // } + // _, err := coll.Aggregate(ctx, bson.D{}, opts) + // return err + // }, + // cursorOp: true, + // }, + } + + for _, op := range ops { + mt.Run(op.name, func(mt *mtest.T) { + testCases := []struct { + name string + ctxTimeout *time.Duration + maxTime *time.Duration + wantMS int + wantDelta float64 + }{ + { + name: "CSOT with context deadline with maxTime", + ctxTimeout: ptrutil.Ptr(10 * time.Second), + maxTime: ptrutil.Ptr(5 * time.Second), + wantMS: 5_000, + wantDelta: 0, + }, + { + name: "CSOT with context deadline without maxTime", + ctxTimeout: ptrutil.Ptr(10 * time.Second), + maxTime: nil, + wantMS: 10_000, + wantDelta: 500, + }, + { + name: "CSOT without context deadline with maxTime", + ctxTimeout: nil, + maxTime: ptrutil.Ptr(5 * time.Second), + wantMS: 5_000, + wantDelta: 0, + }, + { + name: "CSOT without context deadline with maxTime", + ctxTimeout: nil, + maxTime: nil, + wantMS: 15_000, + wantDelta: 500, + }, + } + + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + // driver.UnsafeAllowSeperateMaxTimeMSWithCSOT = true + // defer func() { driver.UnsafeAllowSeperateMaxTimeMSWithCSOT = false }() + + // Enable CSOT + mt.ResetClient(options.Client().SetTimeout(15 * time.Second)) + + var hasDeadline bool + ctx := context.Background() + if tc.ctxTimeout != nil { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, *tc.ctxTimeout) + defer cancel() + + hasDeadline = true + } + + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + err := op.fn(ctx, mt.Coll, tc.maxTime) + require.NoError(mt, err) + + // Assert that maxTimeMS is set and that it's equal to the MaxTime + // value. + maxTimeMS := getMaxTimeMS(mt, op.commandName) + if op.cursorOp && tc.maxTime == nil && hasDeadline { + assert.Equal(mt, int64(-1), maxTimeMS) + } else { + assert.InDelta(mt, + tc.wantMS, + maxTimeMS, + tc.wantDelta, + "expected maxTimeMS to be equal to the MaxTime value") + } + }) + } + }) + } + }) } func TestCSOT_errors(t *testing.T) { diff --git a/mongo/options/findoptions.go b/mongo/options/findoptions.go index fa3bf1197a..44791d3b18 100644 --- a/mongo/options/findoptions.go +++ b/mongo/options/findoptions.go @@ -418,6 +418,19 @@ type FindOneOptions struct { // A document specifying the sort order to apply to the query. The first document in the sorted order will be // returned. The driver will return an error if the sort parameter is a multi-key map. Sort interface{} + + // UnsafeAllowSeperateMaxTimeMS is allows setting maxTimeMS independently of + // the context deadline when CSOT is enabled (client.Timeout >=0). If a user + // provides a context deadline it will be used for all blocking client-side + // logic (e.g. socket timeouts, checking out connections, etc). + // + // This switch is untested and experimental. + // + // ⚠️ **USE WITH CAUTION** ⚠️ + // + // Deprecated: This option is for internal use only and should not be set. It + // may be changed or removed in any release. + UnsafeAllowSeperateMaxTimeMS bool } // FindOne creates a new FindOneOptions instance. @@ -615,6 +628,9 @@ func MergeFindOneOptions(opts ...*FindOneOptions) *FindOneOptions { if opt.Sort != nil { fo.Sort = opt.Sort } + if opt.UnsafeAllowSeperateMaxTimeMS { + fo.UnsafeAllowSeperateMaxTimeMS = opt.UnsafeAllowSeperateMaxTimeMS + } } return fo diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index ec6f69eca0..cb0c8ce286 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -324,6 +324,19 @@ type Operation struct { // where a default read preference is used when the operation // ReadPreference is not specified. omitReadPreference bool + + // UnsafeAllowSeperateMaxTimeMS is allows setting maxTimeMS independently of + // the context deadline when CSOT is enabled (client.Timeout >=0). If a user + // provides a context deadline it will be used for all blocking client-side + // logic (e.g. socket timeouts, checking out connections, etc). + // + // This switch is untested and experimental. + // + // ⚠️ **USE WITH CAUTION** ⚠️ + // + // Deprecated: This option is for internal use only and should not be set. It + // may be changed or removed in any release. + UnsafeAllowSeperateMaxTimeMS bool } // shouldEncrypt returns true if this operation should automatically be encrypted. @@ -1593,6 +1606,8 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. func (op Operation) calculateMaxTimeMS(ctx context.Context, mon RTTMonitor) (uint64, error) { + unsafelyOverrideCSOT := op.UnsafeAllowSeperateMaxTimeMS && op.MaxTime != nil + // If CSOT is enabled and we're not omitting the CSOT-calculated maxTimeMS // value, then calculate maxTimeMS. // @@ -1603,7 +1618,7 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, mon RTTMonitor) (uin // TODO(GODRIVER-2944): Remove or refactor this logic when we add the // "timeoutMode" option, which will allow users to opt-in to the // CSOT-calculated maxTimeMS values if that's the behavior they want. - if csot.IsTimeoutContext(ctx) && !op.OmitCSOTMaxTimeMS { + if csot.IsTimeoutContext(ctx) && !op.OmitCSOTMaxTimeMS && !unsafelyOverrideCSOT { if deadline, ok := ctx.Deadline(); ok { remainingTimeout := time.Until(deadline) rtt90 := mon.P90() diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index c71b7d755e..02e109530a 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -25,46 +25,47 @@ import ( // Find performs a find operation. type Find struct { - authenticator driver.Authenticator - allowDiskUse *bool - allowPartialResults *bool - awaitData *bool - batchSize *int32 - collation bsoncore.Document - comment *string - filter bsoncore.Document - hint bsoncore.Value - let bsoncore.Document - limit *int64 - max bsoncore.Document - maxTime *time.Duration - min bsoncore.Document - noCursorTimeout *bool - oplogReplay *bool - projection bsoncore.Document - returnKey *bool - showRecordID *bool - singleBatch *bool - skip *int64 - snapshot *bool - sort bsoncore.Document - tailable *bool - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - readConcern *readconcern.ReadConcern - readPreference *readpref.ReadPref - selector description.ServerSelector - retry *driver.RetryMode - result driver.CursorResponse - serverAPI *driver.ServerAPIOptions - timeout *time.Duration - omitCSOTMaxTimeMS bool - logger *logger.Logger + authenticator driver.Authenticator + allowDiskUse *bool + allowPartialResults *bool + awaitData *bool + batchSize *int32 + collation bsoncore.Document + comment *string + filter bsoncore.Document + hint bsoncore.Value + let bsoncore.Document + limit *int64 + max bsoncore.Document + maxTime *time.Duration + min bsoncore.Document + noCursorTimeout *bool + oplogReplay *bool + projection bsoncore.Document + returnKey *bool + showRecordID *bool + singleBatch *bool + skip *int64 + snapshot *bool + sort bsoncore.Document + tailable *bool + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + readConcern *readconcern.ReadConcern + readPreference *readpref.ReadPref + selector description.ServerSelector + retry *driver.RetryMode + result driver.CursorResponse + serverAPI *driver.ServerAPIOptions + timeout *time.Duration + omitCSOTMaxTimeMS bool + logger *logger.Logger + unsafeAllowSeperateMaxTimeMS bool } // NewFind constructs and returns a new Find. @@ -93,27 +94,28 @@ func (f *Find) Execute(ctx context.Context) error { } return driver.Operation{ - CommandFn: f.command, - ProcessResponseFn: f.processResponse, - RetryMode: f.retry, - Type: driver.Read, - Client: f.session, - Clock: f.clock, - CommandMonitor: f.monitor, - Crypt: f.crypt, - Database: f.database, - Deployment: f.deployment, - MaxTime: f.maxTime, - ReadConcern: f.readConcern, - ReadPreference: f.readPreference, - Selector: f.selector, - Legacy: driver.LegacyFind, - ServerAPI: f.serverAPI, - Timeout: f.timeout, - Logger: f.logger, - Name: driverutil.FindOp, - OmitCSOTMaxTimeMS: f.omitCSOTMaxTimeMS, - Authenticator: f.authenticator, + CommandFn: f.command, + ProcessResponseFn: f.processResponse, + RetryMode: f.retry, + Type: driver.Read, + Client: f.session, + Clock: f.clock, + CommandMonitor: f.monitor, + Crypt: f.crypt, + Database: f.database, + Deployment: f.deployment, + MaxTime: f.maxTime, + ReadConcern: f.readConcern, + ReadPreference: f.readPreference, + Selector: f.selector, + Legacy: driver.LegacyFind, + ServerAPI: f.serverAPI, + Timeout: f.timeout, + Logger: f.logger, + Name: driverutil.FindOp, + OmitCSOTMaxTimeMS: f.omitCSOTMaxTimeMS, + Authenticator: f.authenticator, + UnsafeAllowSeperateMaxTimeMS: f.unsafeAllowSeperateMaxTimeMS, }.Execute(ctx) } @@ -587,3 +589,13 @@ func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { f.authenticator = authenticator return f } + +// UnsafeAllowSeperateMaxTimeMS allows CSOT with independent maxTimeMS. +func (f *Find) UnsafeAllowSeperateMaxTimeMS(val bool) *Find { + if f == nil { + f = new(Find) + } + + f.unsafeAllowSeperateMaxTimeMS = val + return f +}