diff --git a/internal/assert/assertions.go b/internal/assert/assertions.go index 0754a411a1..29f0cf7e7c 100644 --- a/internal/assert/assertions.go +++ b/internal/assert/assertions.go @@ -1,5 +1,4 @@ // Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertions.go - // Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. // Use of this source code is governed by an MIT-style license that can be found in // the THIRD-PARTY-NOTICES file. @@ -79,7 +78,6 @@ the problem actually occurred in calling code.*/ // of each stack frame leading from the current test to the assert call that // failed. func CallerInfo() []string { - var pc uintptr var ok bool var file string @@ -307,7 +305,6 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) } return true - } // validateEqualArgs checks whether provided arguments can be safely used in the @@ -372,7 +369,6 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa } return true - } // NotNil asserts that the specified object is not nil. @@ -411,7 +407,8 @@ func isNil(object interface{}) bool { []reflect.Kind{ reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, - reflect.Ptr, reflect.Slice}, + reflect.Ptr, reflect.Slice, + }, kind) if isNilableKind && value.IsNil() { @@ -477,7 +474,6 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { } return true - } // False asserts that the specified value is false. @@ -492,7 +488,6 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { } return true - } // NotEqual asserts that the specified values are NOT equal. @@ -515,7 +510,6 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{ } return true - } // NotEqualValues asserts that two objects are not equal even when converted to the same type @@ -538,7 +532,6 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte // return (true, false) if element was not found. // return (true, true) if element was found. func containsElement(list interface{}, element interface{}) (ok, found bool) { - listValue := reflect.ValueOf(list) listType := reflect.TypeOf(list) if listType == nil { @@ -573,7 +566,6 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) { } } return true, false - } // Contains asserts that the specified string, list(array, slice...) or map contains the @@ -596,7 +588,6 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo } return true - } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the @@ -619,12 +610,10 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) } return true - } // isEmpty gets whether the specified object is considered empty or not. func isEmpty(object interface{}) bool { - // get nil case out of the way if object == nil { return true @@ -1090,3 +1079,28 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { return pass } + +// Empty asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". +// +// assert.Empty(t, obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value +func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + pass := isEmpty(object) + if !pass { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, fmt.Sprintf("Should be empty, but was %v", object), msgAndArgs...) + } + + return pass +} diff --git a/internal/integration/mongointernal_test.go b/internal/integration/mongointernal_test.go index c6671b3cb2..19c29428da 100644 --- a/internal/integration/mongointernal_test.go +++ b/internal/integration/mongointernal_test.go @@ -71,28 +71,4 @@ func TestNewSessionWithLSID(t *testing.T) { // doesn't panic. t.Errorf("expected EndSession to panic") }) - - mt.Run("ClientSession.SetServer panics", func(mt *mtest.T) { - mt.Parallel() - - sessionID := bson.Raw(bsoncore.NewDocumentBuilder(). - AppendBinary("id", 4, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}). - Build()) - sess := mongo.NewSessionWithLSID(mt.Client, sessionID) - - // Use a defer-recover block to catch the expected panic and assert that - // the recovered error is not nil. - defer func() { - err := recover() - assert.NotNil(mt, err, "expected ClientSession.SetServer to panic") - }() - - // Expect this call to panic. - sess.ClientSession().SetServer() - - // We expect that calling ClientSession.SetServer on a Session returned - // by NewSessionWithLSID panics. This code will only be reached if - // ClientSession.SetServer doesn't panic. - t.Errorf("expected ClientSession.SetServer to panic") - }) } diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 949a86a3d4..a2f2463f95 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -169,51 +169,64 @@ func (t *T) Run(name string, callback func(mt *T)) { t.RunOpts(name, NewOptions(), callback) } -// RunOpts creates a new T instance for a sub-test with the given options. If the current environment does not satisfy -// constraints specified in the options, the new sub-test will be skipped automatically. If the test is not skipped, -// the callback will be run with the new T instance. RunOpts creates a new collection with the given name which is -// available to the callback through the T.Coll variable and is dropped after the callback returns. -func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) { - t.T.Run(name, func(wrapped *testing.T) { - sub := newT(wrapped, t.baseOpts, opts) +// Setup initializes the test client and collection for this T instance. This is +// automatically called by RunOpts but can be called manually when using New() +// directly. +func (t *T) Setup() { + // add any mock responses for this test + if t.clientType == Mock && len(t.mockResponses) > 0 { + t.AddMockResponses(t.mockResponses...) + } - // add any mock responses for this test - if sub.clientType == Mock && len(sub.mockResponses) > 0 { - sub.AddMockResponses(sub.mockResponses...) - } + if t.createClient == nil || *t.createClient { + t.createTestClient() + } - if sub.createClient == nil || *sub.createClient { - sub.createTestClient() - } + // create a collection for this test + if t.Client != nil { + t.createTestCollection() + } - // create a collection for this test - if sub.Client != nil { - sub.createTestCollection() - } + // clear any events that may have happened during setup + t.ClearEvents() +} - // defer dropping all collections if the test is using a client - defer func() { - if sub.Client == nil { - return - } +// Teardown cleans up test resources and asserts that all sessions and +// connections are closed. When using New() directly, this should be called via +// defer after Setup(). +func (t *T) Teardown() { + if t.Client == nil { + return + } - // store number of sessions and connections checked out here but assert that they're equal to 0 after - // cleaning up test resources to make sure resources are always cleared - sessions := sub.Client.NumberSessionsInProgress() - conns := sub.NumberConnectionsCheckedOut() + // store number of sessions and connections checked out here but assert that they're equal to 0 after + // cleaning up test resources to make sure resources are always cleared + sessions := t.Client.NumberSessionsInProgress() + conns := t.NumberConnectionsCheckedOut() - if sub.clientType != Mock { - sub.ClearFailPoints() - sub.ClearCollections() - } + if t.clientType != Mock { + t.ClearFailPoints() + t.ClearCollections() + } + + _ = t.Client.Disconnect(context.Background()) + assert.Equal(t, 0, sessions, "%v sessions checked out", sessions) + assert.Equal(t, 0, conns, "%v connections checked out", conns) +} + +// RunOpts creates a new T instance for a sub-test with the given options. If +// the current environment does not satisfy constraints specified in the +// options, the new sub-test will be skipped automatically. If the test is not +// skipped, the callback will be run with the new T instance. RunOpts creates a +// new collection with the given name which is available to the callback through +// the T.Coll variable and is dropped after the callback returns. +func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) { + t.T.Run(name, func(wrapped *testing.T) { + sub := newT(wrapped, t.baseOpts, opts) - _ = sub.Client.Disconnect(context.Background()) - assert.Equal(sub, 0, sessions, "%v sessions checked out", sessions) - assert.Equal(sub, 0, conns, "%v connections checked out", conns) - }() + sub.Setup() + t.Cleanup(sub.Teardown) - // clear any events that may have happened during setup and run the test - sub.ClearEvents() callback(sub) }) } diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index a6d8b4e976..afa1501d02 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -508,7 +508,6 @@ func TestSessionsProse(t *testing.T) { limitedSessMsg := "expected session count to be less than the number of operations: %v" assert.True(mt, limitedSessionUse, limitedSessMsg, len(ops)) - }) mt.ResetClient(options.Client()) @@ -584,6 +583,81 @@ func TestSessionsProse(t *testing.T) { }) } +func TestSessionsProse_21_SettingSnapshotTimeWithoutSnapshot(t *testing.T) { + // 21. Having snapshotTime set and snapshot set to false is not allowed. + mtOpts := mtest. + NewOptions(). + MinServerVersion("5.0"). + Topologies(mtest.ReplicaSet, mtest.Sharded) + + mt := mtest.New(t, mtOpts) + + mt.Setup() + mt.Cleanup(mt.Teardown) + + // Start a session by calling startSession with snapshot = false and + // snapshotTime = new Timestamp(1). + sessOpts := options.Session().SetSnapshot(false).SetSnapshotTime(bson.Timestamp{T: 1}) + + _, err := mt.Client.StartSession(sessOpts) + require.Error(t, err) + require.Contains(t, err.Error(), "snapshotTime cannot be set when snapshot is false") +} + +func TestSessionsProse_22_SnapshotTimeGetterReturnsErrorForNonSnapshotSessions(t *testing.T) { + // 22. Retrieving `snapshotTime` on a non-snapshot session raises an error + t.Skip("Skipping test for prose 22; Go driver does not have a getter that raises an error.") +} + +func TestSessionsProse_23_EnsureSnapshotTimeIsImmutable(t *testing.T) { + // 23. Ensure `snapshotTime` is Read-Only + + mtOpts := mtest. + NewOptions(). + MinServerVersion("5.0"). + Topologies(mtest.ReplicaSet, mtest.Sharded) + + mt := mtest.New(t, mtOpts) + + mt.Run("multiple ClientSession calls isolation", func(mt *mtest.T) { + sess, err := mt.Client.StartSession(options.Session().SetSnapshot(false)) + require.NoError(mt, err) + defer sess.EndSession(context.Background()) + + // Verify initial state + require.Empty(mt, sess.ClientSession().SnapshotTime) + + // Attempt mutation through one ClientSession() call + client1 := sess.ClientSession() + client1.SnapshotTime = bson.Timestamp{T: 1} + + // Second ClientSession() call should return independent copy + require.Empty(mt, sess.ClientSession().SnapshotTime) + }) + + mt.Run("snapshotTime copy is immutable", func(mt *mtest.T) { + originalTS := bson.Timestamp{T: 100, I: 5} + sess, err := mt.Client.StartSession( + options.Session().SetSnapshot(true).SetSnapshotTime(originalTS), + ) + require.NoError(mt, err) + defer sess.EndSession(context.Background()) + + // Verify initial state + cs := sess.ClientSession() + require.True(mt, cs.SnapshotTimeSet) + require.Equal(mt, originalTS, cs.SnapshotTime) + + // Mutate the copy and verify it doesn't affect the session. + cs.SnapshotTime = bson.Timestamp{T: 999, I: 888} + cs.SnapshotTimeSet = false + + cs2 := sess.ClientSession() + require.True(mt, cs2.SnapshotTimeSet) + require.Equal(mt, originalTS, cs2.SnapshotTime) + }) +} + type sessionFunction struct { name string target string diff --git a/internal/integration/unified/entity.go b/internal/integration/unified/entity.go index c709289b8d..d9219c6f07 100644 --- a/internal/integration/unified/entity.go +++ b/internal/integration/unified/entity.go @@ -254,13 +254,30 @@ func newEntityMap() *EntityMap { return em } -func (em *EntityMap) addBSONEntity(id string, val bson.RawValue) error { +func (em *EntityMap) addBSONEntity(id string, val any) error { if err := em.verifyEntityDoesNotExist(id); err != nil { return err } + typ, bytes, err := bson.MarshalValue(val) + if err != nil { + return fmt.Errorf("error marshaling BSON value for entity ID %q: %w", id, err) + } + em.allEntities[id] = struct{}{} - em.bsonValues[id] = val + + // If val is already a bson.RawValue, use it directly to preserve the original + // type and bytes. If not, construct a new bson.RawValue. + rv, ok := val.(bson.RawValue) + if !ok { + rv = bson.RawValue{ + Type: typ, + Value: bytes, + } + } + + em.bsonValues[id] = rv + return nil } @@ -790,6 +807,22 @@ func (em *EntityMap) addSessionEntity(entityOptions *entityOptions) error { sessionOpts := options.Session() if entityOptions.SessionOptions != nil { sessionOpts = entityOptions.SessionOptions.SessionOptionsBuilder + + // Resolve snapshot time from EntityMap if specified + if entityOptions.SessionOptions.snapshotTimeID != nil { + snapshotTimeID := *entityOptions.SessionOptions.snapshotTimeID + RawTS, err := em.BSONValue(snapshotTimeID) + if err != nil { + return fmt.Errorf("error retrieving snapshot time for entity %q: %w", snapshotTimeID, err) + } + + t, i, ok := RawTS.TimestampOK() + if !ok { + return fmt.Errorf("snapshot time entity %q is not a timestamp", snapshotTimeID) + } + + sessionOpts.SetSnapshotTime(bson.Timestamp{T: t, I: i}) + } } sess, err := client.StartSession(sessionOpts) diff --git a/internal/integration/unified/matches.go b/internal/integration/unified/matches.go index 40270eeb7e..08e0202ce2 100644 --- a/internal/integration/unified/matches.go +++ b/internal/integration/unified/matches.go @@ -23,8 +23,10 @@ type keyPathCtxKey struct{} // extraKeysAllowedCtxKey is used as a key for a Context object. The value conveys whether or not the document under // test can contain extra keys. For example, if the expected document is {x: 1}, the document {x: 1, y: 1} would match // if the value for this key is true. -type extraKeysAllowedCtxKey struct{} -type extraKeysAllowedRootMatchCtxKey struct{} +type ( + extraKeysAllowedCtxKey struct{} + extraKeysAllowedRootMatchCtxKey struct{} +) func makeMatchContext(ctx context.Context, keyPath string, extraKeysAllowed bool) context.Context { ctx = context.WithValue(ctx, keyPathCtxKey{}, keyPath) @@ -264,8 +266,8 @@ func evaluateSpecialComparison(ctx context.Context, assertionDoc bson.Raw, actua // Numeric values can be compared even if their types are different (e.g. if expected is an int32 and actual // is an int64). - var expectedF64 = assertionVal.AsFloat64() - var actualF64 = actual.AsFloat64() + expectedF64 := assertionVal.AsFloat64() + actualF64 := actual.AsFloat64() if actualF64 > expectedF64 { return fmt.Errorf("expected numeric value %f to be less than or equal %f", actualF64, expectedF64) diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 85fa98781c..52b55608cf 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -127,6 +127,10 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat case "withTransaction": // executeWithTransaction internally verifies results/errors for each operation, so it doesn't return a result. return newEmptyResult(), executeWithTransaction(ctx, op, loopDone) + case "getSnapshotTime": + // executeGetSnapshotTime stores the snapshot time of the session as on + // the entity map for subsequent use. + return executeGetSnapshotTime(ctx, op) // Client operations case "appendMetadata": diff --git a/internal/integration/unified/session_operation_execution.go b/internal/integration/unified/session_operation_execution.go index c9408a2af9..897024047c 100644 --- a/internal/integration/unified/session_operation_execution.go +++ b/internal/integration/unified/session_operation_execution.go @@ -112,3 +112,27 @@ func executeWithTransaction(ctx context.Context, op *operation, loopDone <-chan }, temp.TransactionOptionsBuilder) return err } + +func executeGetSnapshotTime(ctx context.Context, op *operation) (*operationResult, error) { + entityID := op.ResultEntityID + if entityID == nil { + return nil, fmt.Errorf("getSnapshotTime operation requires a result entity ID") + } + + sess, err := entities(ctx).session(op.Object) + if err != nil { + return nil, err + } + + clientSess := sess.ClientSession() + + if !clientSess.SnapshotTimeSet { + return nil, fmt.Errorf("session has no snapshot time to store in entity %q", *entityID) + } + + if err := entities(ctx).addBSONEntity(*entityID, clientSess.SnapshotTime); err != nil { + return nil, fmt.Errorf("error storing result as BSON entity: %w", err) + } + + return newEmptyResult(), nil +} diff --git a/internal/integration/unified/session_options.go b/internal/integration/unified/session_options.go index aa9dc88afa..deeeef2d72 100644 --- a/internal/integration/unified/session_options.go +++ b/internal/integration/unified/session_options.go @@ -61,16 +61,18 @@ func (to *transactionOptions) UnmarshalBSON(data []byte) error { // convert BSON documents to a sessionOptions instance. type sessionOptions struct { *options.SessionOptionsBuilder + snapshotTimeID *string // Store the ID for later lookup in EntityMap } var _ bson.Unmarshaler = (*sessionOptions)(nil) func (so *sessionOptions) UnmarshalBSON(data []byte) error { var temp struct { - Causal *bool `bson:"causalConsistency"` - TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` - Snapshot *bool `bson:"snapshot"` - Extra map[string]any `bson:",inline"` + Causal *bool `bson:"causalConsistency"` + TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` + Snapshot *bool `bson:"snapshot"` + SnapshotTime *string `bson:"snapshotTime"` + Extra map[string]any `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { return fmt.Errorf("error unmarshalling to temporary sessionOptions object: %v", err) @@ -105,5 +107,9 @@ func (so *sessionOptions) UnmarshalBSON(data []byte) error { if temp.Snapshot != nil { so.SetSnapshot(*temp.Snapshot) } + + // Store the snapshot time ID for later lookup + so.snapshotTimeID = temp.SnapshotTime + return nil } diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index 71bc76d6eb..9082ee46f4 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -443,7 +443,7 @@ func waitForEvent(ctx context.Context, args waitForEventArguments) error { } } -func extractClientSession(sess *mongo.Session) *session.Client { +func extractClientSession(sess *mongo.Session) session.Client { return sess.ClientSession() } diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index af4495cb76..2565a8093d 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -161,13 +161,15 @@ var directories = []string{ "read-write-concern/tests/operation", } -var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) -var specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)) - return reg -}() +var ( + checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) + specTestRegistry = func() *bson.Registry { + reg := bson.NewRegistry() + reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) + reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)) + return reg + }() +) func TestUnifiedSpecs(t *testing.T) { for _, specDir := range directories { @@ -425,13 +427,34 @@ func executeGridFSOperation(mt *mtest.T, bucket *mongo.GridFSBucket, op *operati } func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess *mongo.Session) error { - var clientSession *session.Client + var ( + clientSession session.Client + hasSession bool + ) + if sess != nil { clientSession = sess.ClientSession() + hasSession = true + } + + requireSession := func(opName string) error { + if !hasSession { + return fmt.Errorf("%s requires a session", opName) + } + + return nil } switch op.Name { case "targetedFailPoint": + if err := requireSession(op.Name); err != nil { + return err + } + + if clientSession.PinnedServerAddr == nil { + return fmt.Errorf("%s requires pinned session", op.Name) + } + fpDoc := op.Arguments.Lookup("failPoint") var fp failpoint.FailPoint @@ -439,9 +462,6 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, return fmt.Errorf("Unmarshal error: %w", err) } - if clientSession == nil { - return errors.New("expected valid session, got nil") - } targetHost := clientSession.PinnedServerAddr.String() opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost}) integtest.AddTestServerAPIVersion(opts) @@ -462,6 +482,10 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, } mt.SetFailPointFromDocument(fp.Document()) case "assertSessionTransactionState": + if err := requireSession(op.Name); err != nil { + return err + } + stateVal, err := op.Arguments.LookupErr("state") if err != nil { return fmt.Errorf("unable to find 'state' in arguments: %w", err) @@ -471,9 +495,6 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, return errors.New("expected 'state' argument to be string") } - if clientSession == nil { - return errors.New("expected valid session, got nil") - } actualState := clientSession.TransactionState.String() // actualState should match expectedState, but "in progress" is the same as @@ -484,20 +505,22 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState) } case "assertSessionPinned": - if clientSession == nil { - return errors.New("expected valid session, got nil") + if err := requireSession(op.Name); err != nil { + return err } + if clientSession.PinnedServerAddr == nil { return errors.New("expected pinned server, got nil") } case "assertSessionUnpinned": - if clientSession == nil { - return errors.New("expected valid session, got nil") + if err := requireSession(op.Name); err != nil { + return err } + // We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned // case provides the pinned server address in the error msg for debugging. if clientSession.PinnedServerAddr != nil { - return fmt.Errorf("expected pinned server to be nil but got %q", clientSession.PinnedServerAddr) + return fmt.Errorf("expected pinned server to be nil but got %q", clientSession.PinnedServerAddr.String()) } case "assertSameLsidOnLastTwoCommands": first, second := lastTwoIDs(mt) diff --git a/internal/require/require.go b/internal/require/require.go index 0b60613e3e..aa9190fd7f 100644 --- a/internal/require/require.go +++ b/internal/require/require.go @@ -832,3 +832,26 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { } t.FailNow() } + +// Empty asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". +// +// require.Empty(t, obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value +func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Empty(t, object, msgAndArgs...) { + return + } + t.FailNow() +} diff --git a/internal/spectest/skip.go b/internal/spectest/skip.go index d2bd7f1fb0..7bd0deeb83 100644 --- a/internal/spectest/skip.go +++ b/internal/spectest/skip.go @@ -11,7 +11,6 @@ import "testing" // skipTests is a map of "fully-qualified test name" to "the reason for skipping // the test". var skipTests = map[string][]string{ - // SPEC-1403: This test checks to see if the correct error is thrown when auto // encrypting with a server < 4.2. Currently, the test will fail because a // server < 4.2 wouldn't have mongocryptd, so Client construction would fail @@ -843,15 +842,6 @@ var skipTests = map[string][]string{ "TestSDAMSpec/errors/pre-42-ShutdownInProgress.json", }, - // TODO(GODRIVER-3663): Expose atClusterTime parameter in snapshot sessions - "Expose atClusterTime parameter in snapshot sessions (GODRIVER-3663)": { - "TestUnifiedSpec/sessions/tests/snapshot-sessions.json/Find_operation_with_snapshot_and_snapshot_time", - "TestUnifiedSpec/sessions/tests/snapshot-sessions.json/Distinct_operation_with_snapshot_and_snapshot_time", - "TestUnifiedSpec/sessions/tests/snapshot-sessions.json/Aggregate_operation_with_snapshot_and_snapshot_time", - "TestUnifiedSpec/sessions/tests/snapshot-sessions.json/countDocuments_operation_with_snapshot_and_snapshot_time", - "TestUnifiedSpec/sessions/tests/snapshot-sessions.json/Mixed_operation_with_snapshot_and_snapshotTime", - }, - // TODO(DRIVERS-3356): Unskip this test when the spec test bug is fixed. "Handshake spec test 'metadata-not-propagated.yml' fails on sharded clusters (DRIVERS-3356)": { "TestUnifiedSpec/mongodb-handshake/tests/unified/metadata-not-propagated.json/metadata_append_does_not_create_new_connections_or_close_existing_ones_and_no_hello_command_is_sent", diff --git a/mongo/client.go b/mongo/client.go index 10182276b2..06f0bdc2f9 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -234,7 +234,6 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { topology.WithAuthConfigClientOptions(clientOpts), topology.WithAuthConfigDriverInfo(client.currentDriverInfo), ) - if err != nil { return nil, err } @@ -477,10 +476,15 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (* coreOpts.DefaultReadPreference = rp } } + if sessArgs.Snapshot != nil { coreOpts.Snapshot = sessArgs.Snapshot } + if sessArgs.SnapshotTime != nil { + coreOpts.SnapshotTime = sessArgs.SnapshotTime + } + sess, err := session.NewClientSession(c.sessionPool, c.id, coreOpts) if err != nil { return nil, wrapErrors(err) @@ -891,7 +895,8 @@ func (c *Client) UseSessionWithOptions( // The opts parameter can be used to specify options for change stream creation (see the options.ChangeStreamOptions // documentation). func (c *Client) Watch(ctx context.Context, pipeline any, - opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { + opts ...options.Lister[options.ChangeStreamOptions], +) (*ChangeStream, error) { csConfig := changeStreamConfig{ readConcern: c.readConcern, readPreference: c.readPreference, @@ -931,7 +936,8 @@ type ClientBulkWrite struct { // BulkWrite performs a client-level bulk write operation. func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, - opts ...options.Lister[options.ClientBulkWriteOptions]) (*ClientBulkWriteResult, error) { + opts ...options.Lister[options.ClientBulkWriteOptions], +) (*ClientBulkWriteResult, error) { // TODO(GODRIVER-3403): Remove after support for QE with Client.bulkWrite. if c.isAutoEncryptionSet { return nil, errors.New("bulkWrite does not currently support automatic encryption") diff --git a/mongo/options/sessionoptions.go b/mongo/options/sessionoptions.go index 1f5edcf8b9..a017e8c06d 100644 --- a/mongo/options/sessionoptions.go +++ b/mongo/options/sessionoptions.go @@ -6,6 +6,8 @@ package options +import "go.mongodb.org/mongo-driver/v2/bson" + // DefaultCausalConsistency is the default value for the CausalConsistency option. var DefaultCausalConsistency = true @@ -16,6 +18,7 @@ type SessionOptions struct { CausalConsistency *bool DefaultTransactionOptions *TransactionOptionsBuilder Snapshot *bool + SnapshotTime *bson.Timestamp } // SessionOptionsBuilder represents functional options that configure a Sessionopts. @@ -69,3 +72,16 @@ func (s *SessionOptionsBuilder) SetSnapshot(b bool) *SessionOptionsBuilder { }) return s } + +// SetSnapshotTime sets the value for the SnapshotTime field. Specifies the +// timestamp to use for snapshot reads within the session. This option can only +// be set if Snapshot is set to true. If not provided, the snapshot time will be +// determined automatically from the atClusterTime of the first read operation +// performed in the session. The default value is nil. +func (s *SessionOptionsBuilder) SetSnapshotTime(t bson.Timestamp) *SessionOptionsBuilder { + s.Opts = append(s.Opts, func(opts *SessionOptions) error { + opts.SnapshotTime = &t + return nil + }) + return s +} diff --git a/mongo/session.go b/mongo/session.go index db08f12589..2214901534 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -80,8 +80,12 @@ func SessionFromContext(ctx context.Context) *Session { // // Deprecated: This method is for internal use only and should not be used (see // GODRIVER-2700). It may be changed or removed in any release. -func (s *Session) ClientSession() *session.Client { - return s.clientSession +func (s *Session) ClientSession() session.Client { + if s.clientSession == nil { + return session.Client{} + } + + return *s.clientSession // Return a copy to prevent mutation. } // ID returns the current ID document associated with the session. The ID diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 9906563100..a60a757586 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1589,7 +1589,7 @@ func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I) data, _ = bsoncore.AppendDocumentEnd(data, 0) } - if client.Snapshot && client.SnapshotTime != nil { + if client.Snapshot && client.SnapshotTimeSet { data = data[:len(data)-1] // remove the null byte data = bsoncore.AppendTimestampElement(data, "atClusterTime", client.SnapshotTime.T, client.SnapshotTime.I) data, _ = bsoncore.AppendDocumentEnd(data, 0) diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 616f71bc29..b06d7dce0c 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -102,6 +102,13 @@ type Client struct { Aborting bool Snapshot bool + // SnapshotTime is the atClusterTime value for snapshot reads. This field is + // left immutable once set for the lifetime of the session. This guards + // against users updating custom snapshot times during transactions which + // could lead to a write conflict. + SnapshotTime bson.Timestamp + SnapshotTimeSet bool + // options for the current transaction // most recently set by transactionopt CurrentRc *readconcern.ReadConcern @@ -119,7 +126,6 @@ type Client struct { PinnedServerAddr *address.Address RecoveryToken bson.Raw PinnedConnection LoadBalancedTransactionConnection - SnapshotTime *bson.Timestamp } func getClusterTime(clusterTime bson.Raw) (uint32, uint32) { @@ -192,6 +198,10 @@ func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (* if mergedOpts.Snapshot != nil { c.Snapshot = *mergedOpts.Snapshot } + if mergedOpts.SnapshotTime != nil { + c.SnapshotTime = *mergedOpts.SnapshotTime + c.SnapshotTimeSet = true + } // For explicit sessions, the default for causalConsistency is true, unless Snapshot is // enabled, then it's false. Set the default and then allow any explicit causalConsistency @@ -205,6 +215,10 @@ func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (* return nil, errors.New("causal consistency and snapshot cannot both be set for a session") } + if c.SnapshotTimeSet && !c.Snapshot { + return nil, errors.New("snapshotTime cannot be set when snapshot is false") + } + if err := c.SetServer(); err != nil { return nil, err } @@ -273,9 +287,13 @@ func (c *Client) UpdateRecoveryToken(response bson.Raw) { c.RecoveryToken = token.Document() } -// UpdateSnapshotTime updates the session's value for the atClusterTime field of ReadConcern. +// UpdateSnapshotTime updates the session's value for the atClusterTime field of +// ReadConcern. func (c *Client) UpdateSnapshotTime(response bsoncore.Document) { - if c == nil { + if c == nil || c.SnapshotTimeSet { + // Do nothing if session is nil or snapshot time is already set. The driver + // sends the same atClusterTime for all operations in a snapshot session so + // resetting is a potentially dangerous redundancy. return } @@ -291,10 +309,11 @@ func (c *Client) UpdateSnapshotTime(response bsoncore.Document) { } t, i := ssTimeElem.Timestamp() - c.SnapshotTime = &bson.Timestamp{ + c.SnapshotTime = bson.Timestamp{ T: t, I: i, } + c.SnapshotTimeSet = true } // ClearPinnedResources clears the pinned server and/or connection associated with the session. diff --git a/x/mongo/driver/session/options.go b/x/mongo/driver/session/options.go index 742b3738cf..3c4abfccea 100644 --- a/x/mongo/driver/session/options.go +++ b/x/mongo/driver/session/options.go @@ -7,6 +7,7 @@ package session import ( + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" @@ -19,6 +20,7 @@ type ClientOptions struct { DefaultWriteConcern *writeconcern.WriteConcern DefaultReadPreference *readpref.ReadPref Snapshot *bool + SnapshotTime *bson.Timestamp } // TransactionOptions represents all possible options for starting a transaction in a session. @@ -49,6 +51,9 @@ func mergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.Snapshot != nil { c.Snapshot = opt.Snapshot } + if opt.SnapshotTime != nil { + c.SnapshotTime = opt.SnapshotTime + } } return c