diff --git a/internal/verifier/change_reader.go b/internal/verifier/change_reader.go index 056fdbc5..75bda5e9 100644 --- a/internal/verifier/change_reader.go +++ b/internal/verifier/change_reader.go @@ -156,6 +156,10 @@ func (rc *ChangeReaderCommon) getEventsPerSecond() option.Option[float64] { } func (rc *ChangeReaderCommon) persistResumeToken(ctx context.Context, token bson.Raw) error { + if len(token) == 0 { + panic("internal error: resume token is empty but should never be") + } + coll := rc.metaDB.Collection(changeReaderCollectionName) _, err := coll.ReplaceOne( ctx, diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index 81732825..5e5ef327 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -9,6 +9,7 @@ import ( "github.com/10gen/migration-verifier/internal/retry" "github.com/10gen/migration-verifier/internal/util" "github.com/10gen/migration-verifier/mbson" + "github.com/10gen/migration-verifier/mmongo" "github.com/10gen/migration-verifier/option" mapset "github.com/deckarep/golang-set/v2" clone "github.com/huandu/go-clone/generic" @@ -376,11 +377,13 @@ func (csr *ChangeStreamReader) createChangeStream( csStartLogEvent := csr.logger.Info() - if token, hasToken := savedResumeToken.Get(); hasToken { + resumetoken, hasSavedToken := savedResumeToken.Get() + + if hasSavedToken { logEvent := csStartLogEvent. - Stringer(csr.resumeTokenDocID(), token) + Stringer(csr.resumeTokenDocID(), resumetoken) - ts, err := csr.resumeTokenTSExtractor(token) + ts, err := csr.resumeTokenTSExtractor(resumetoken) if err == nil { logEvent = addTimestampToLogEvent(ts, logEvent) } else { @@ -392,9 +395,9 @@ func (csr *ChangeStreamReader) createChangeStream( logEvent.Msg("Starting change stream from persisted resume token.") if util.ClusterHasChangeStreamStartAfter([2]int(csr.clusterInfo.VersionArray)) { - opts = opts.SetStartAfter(token) + opts = opts.SetStartAfter(resumetoken) } else { - opts = opts.SetResumeAfter(token) + opts = opts.SetResumeAfter(resumetoken) } } else { csStartLogEvent.Msgf("Starting change stream from current %s cluster time.", csr.readerType) @@ -410,9 +413,22 @@ func (csr *ChangeStreamReader) createChangeStream( return nil, nil, bson.Timestamp{}, errors.Wrap(err, "opening change stream") } - err = csr.persistResumeToken(ctx, changeStream.ResumeToken()) - if err != nil { - return nil, nil, bson.Timestamp{}, err + if !hasSavedToken { + // Usually the change stream’s initial response is empty, but sometimes + // there are events right away. We can discard those events because + // they’ve already happened, and our initial scan is yet to come. + if len(changeStream.ResumeToken()) == 0 { + _, _, err := mmongo.GetBatch(ctx, changeStream, nil, nil) + + if err != nil { + return nil, nil, bson.Timestamp{}, errors.Wrap(err, "discarding change stream’s initial events") + } + } + + err = csr.persistResumeToken(ctx, changeStream.ResumeToken()) + if err != nil { + return nil, nil, bson.Timestamp{}, errors.Wrapf(err, "persisting initial resume token") + } } startTs, err := csr.resumeTokenTSExtractor(changeStream.ResumeToken()) @@ -428,14 +444,19 @@ func (csr *ChangeStreamReader) createChangeStream( return nil, nil, bson.Timestamp{}, errors.Wrap(err, "failed to read cluster time from session") } - csr.logger.Debug(). - Any("resumeTokenTimestamp", startTs). - Any("clusterTime", clusterTime). - Stringer("changeStreamReader", csr). - Msg("Using earlier time as start timestamp.") - if startTs.After(clusterTime) { + csr.logger.Debug(). + Any("resumeTokenTimestamp", startTs). + Any("clusterTime", clusterTime). + Stringer("changeStreamReader", csr). + Msg("Cluster time predates resume token; using it as start timestamp.") + startTs = clusterTime + } else { + csr.logger.Debug(). + Any("resumeTokenTimestamp", startTs). + Stringer("changeStreamReader", csr). + Msg("Got start timestamp from change stream.") } return changeStream, sess, startTs, nil @@ -532,6 +553,10 @@ func (csr *ChangeStreamReader) String() string { } func extractTSFromChangeStreamResumeToken(resumeToken bson.Raw) (bson.Timestamp, error) { + if len(resumeToken) == 0 { + panic("internal error: resume token is empty but should never be") + } + // Change stream token is always a V1 keystring in the _data field tokenDataRV, err := resumeToken.LookupErr("_data") diff --git a/internal/verifier/change_stream_test.go b/internal/verifier/change_stream_test.go index 62a39bca..5066e4f8 100644 --- a/internal/verifier/change_stream_test.go +++ b/internal/verifier/change_stream_test.go @@ -28,6 +28,49 @@ import ( "golang.org/x/sync/errgroup" ) +func (suite *IntegrationTestSuite) TestChangeStreamFilter_InitialNonempty() { + zerolog.SetGlobalLevel(zerolog.TraceLevel) // gets restored automatically + + ctx := suite.Context() + dbName := suite.DBNameForTest() + + go func() { + for ctx.Err() == nil { + coll := suite.srcMongoClient. + Database(dbName). + Collection("coll") + + _, _ = coll.InsertOne(ctx, bson.D{{"_id", 123}}) + _, _ = coll.DeleteOne(ctx, bson.D{{"_id", 123}}) + } + }() + + for i := range 100 { + suite.Run( + fmt.Sprint(i), + func() { + ctx, cancel := contextplus.WithCancelCause(ctx) + defer cancel(fmt.Errorf("subtest is done")) + + verifier := suite.BuildVerifier() + + rdr, ok := verifier.srcChangeReader.(*ChangeStreamReader) + if !ok { + suite.T().Skipf("source change reader is a %T; this test needs a %T", verifier.srcChangeReader, rdr) + } + + eg, egCtx := contextplus.ErrGroup(ctx) + suite.Require().NoError(rdr.start(egCtx, eg)) + + suite.Require().NoError( + verifier.metaClient.Database(verifier.metaDBName).Drop(ctx), + ) + }, + ) + + } +} + func (suite *IntegrationTestSuite) TestChangeStreamFilter_NoNamespaces() { ctx := suite.Context() diff --git a/mmongo/cursor.go b/mmongo/cursor.go new file mode 100644 index 00000000..0c1e4827 --- /dev/null +++ b/mmongo/cursor.go @@ -0,0 +1,69 @@ +package mmongo + +import ( + "context" + "fmt" + "slices" + + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +type cursorLike interface { + TryNext(context.Context) bool + RemainingBatchLength() int + Err() error +} + +// GetBatch returns a batch of documents from a cursor. It does so by appending +// to passed-in slices, which lets you optimize memory handling. +func GetBatch[T cursorLike]( + ctx context.Context, + cursor T, + docs []bson.Raw, + buffer []byte, +) ([]bson.Raw, []byte, error) { + var docsCount, expectedCount int + + var curDoc bson.Raw + + for hasDocs := true; hasDocs; hasDocs = cursor.RemainingBatchLength() > 0 { + got := cursor.TryNext(ctx) + + if cursor.Err() != nil { + return nil, nil, errors.Wrap(cursor.Err(), "cursor iteration failed") + } + + if !got { + if docsCount != 0 { + panic(fmt.Sprintf("Docs batch ended after %d but expected %d", docsCount, expectedCount)) + } + + break + } + + // This ensures we only reallocate once (if at all): + if docsCount == 0 { + expectedCount = 1 + cursor.RemainingBatchLength() + docs = slices.Grow(docs, expectedCount) + } + + docsCount++ + + switch typedCursor := any(cursor).(type) { + case *mongo.Cursor: + curDoc = typedCursor.Current + case *mongo.ChangeStream: + curDoc = typedCursor.Current + default: + panic(fmt.Sprintf("unknown cursor type: %T", cursor)) + } + + docPos := len(buffer) + buffer = append(buffer, curDoc...) + docs = append(docs, buffer[docPos:]) + } + + return docs, buffer, nil +} diff --git a/mmongo/cursor_all_test.go b/mmongo/cursor_all_test.go deleted file mode 100644 index aec794d1..00000000 --- a/mmongo/cursor_all_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package mmongo - -import ( - "os" - "testing" - - "github.com/samber/lo" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mongodb.org/mongo-driver/v2/bson" - "go.mongodb.org/mongo-driver/v2/mongo" - "go.mongodb.org/mongo-driver/v2/mongo/options" -) - -func TestUnmarshalCursor(t *testing.T) { - ctx := t.Context() - - connStr := os.Getenv("MVTEST_META") - if connStr == "" { - t.Skipf("No MVTEST_META found; skipping.") - } - - client, err := mongo.Connect( - options.Client().ApplyURI(connStr), - ) - require.NoError(t, err) - - cursor, err := client.Database("admin").Aggregate( - ctx, - mongo.Pipeline{ - {{"$documents", lo.RepeatBy( - 30, - func(index int) bson.D { - return bson.D{{"foo", index}} - }, - )}}, - }, - options.Aggregate().SetBatchSize(10), - ) - require.NoError(t, err) - - batch, err := UnmarshalCursor(ctx, cursor, []unmarshaler{}) - require.NoError(t, err) - - assert.Equal( - t, - lo.RepeatBy( - 30, - func(index int) unmarshaler { - return unmarshaler{int32(index)} - }, - ), - batch, - "should be as expected", - ) -} - -type unmarshaler struct { - Foo int32 -} - -func (u *unmarshaler) UnmarshalFromBSON(in []byte) error { - u.Foo = lo.Must(bson.Raw(in).LookupErr("foo")).Int32() - - return nil -} diff --git a/mmongo/cursor_test.go b/mmongo/cursor_test.go new file mode 100644 index 00000000..f7c2a253 --- /dev/null +++ b/mmongo/cursor_test.go @@ -0,0 +1,130 @@ +package mmongo + +import ( + "os" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readconcern" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" +) + +func TestGetBatch(t *testing.T) { + ctx := t.Context() + + client := getClientFromEnv(t) + + coll := client.Database(t.Name()).Collection( + "coll", + options.Collection(). + SetWriteConcern(writeconcern.Majority()). + SetReadConcern(readconcern.Majority()), + ) + + sess, err := client.StartSession(options.Session().SetCausalConsistency(true)) + require.NoError(t, err) + + sctx := mongo.NewSessionContext(ctx, sess) + + docsCount := 1_000 + const batchSize = 100 + + _, err = coll.InsertMany( + sctx, + lo.RepeatBy( + docsCount, + func(index int) any { + return bson.D{} + }, + ), + ) + require.NoError(t, err) + + cursor, err := coll.Find( + ctx, + bson.D{}, + options.Find().SetBatchSize(batchSize), + ) + require.NoError(t, err) + + cursor.SetBatchSize(batchSize) + + var docs []bson.Raw + var buf []byte + for range docsCount / batchSize { + docs = docs[:0] + buf = buf[:0] + + docs, buf, err = GetBatch(ctx, cursor, docs, buf) + require.NoError(t, err) + + assert.Len(t, docs, 100, "should get expected batch") + } + + assert.False(t, cursor.TryNext(ctx), "cursor should be done") + require.NoError(t, cursor.Err(), "should be no error") +} + +func TestUnmarshalCursor(t *testing.T) { + ctx := t.Context() + + client := getClientFromEnv(t) + + cursor, err := client.Database("admin").Aggregate( + ctx, + mongo.Pipeline{ + {{"$documents", lo.RepeatBy( + 30, + func(index int) bson.D { + return bson.D{{"foo", index}} + }, + )}}, + }, + options.Aggregate().SetBatchSize(10), + ) + require.NoError(t, err) + + batch, err := UnmarshalCursor(ctx, cursor, []unmarshaler{}) + require.NoError(t, err) + + assert.Equal( + t, + lo.RepeatBy( + 30, + func(index int) unmarshaler { + return unmarshaler{int32(index)} + }, + ), + batch, + "should be as expected", + ) +} + +func getClientFromEnv(t *testing.T) *mongo.Client { + connStr := os.Getenv("MVTEST_META") + if connStr == "" { + t.Skipf("No MVTEST_META found; skipping.") + } + + client, err := mongo.Connect( + options.Client().ApplyURI(connStr), + ) + require.NoError(t, err) + + return client +} + +type unmarshaler struct { + Foo int32 +} + +func (u *unmarshaler) UnmarshalFromBSON(in []byte) error { + u.Foo = lo.Must(bson.Raw(in).LookupErr("foo")).Int32() + + return nil +}