From 4fd629ff29fc4898f3946f5c6dac6899e378451d Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 19 Aug 2025 17:15:30 -0400 Subject: [PATCH 1/2] Add custom options to client bulkWrite. --- internal/integration/client_test.go | 39 +++++++++++++++++++++++++++-- mongo/client.go | 10 ++++++++ mongo/client_bulk_write.go | 19 ++++++++++++-- x/mongo/driver/xoptions/options.go | 19 ++++++++++++++ 4 files changed, 83 insertions(+), 4 deletions(-) diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 8b37f12b47..0d4ca92e67 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -31,7 +31,9 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions" "golang.org/x/sync/errgroup" ) @@ -838,8 +840,41 @@ func TestClient_BulkWrite(t *testing.T) { } _, err := mt.Client.BulkWrite(context.Background(), writes) - require.NoError(t, err) - assert.Equal(t, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites) + require.NoError(mt, err) + assert.Equal(mt, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites) + }) + mt.Run("test options callback", func(mt *mtest.T) { + mt.Parallel() + + insertOneModel := mongo.NewClientInsertOneModel().SetDocument(bson.D{{"x", 1}}) + writes := []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: insertOneModel, + }} + + marshalValue := func(val interface{}) bson.RawValue { + t.Helper() + + valType, data, err := bson.MarshalValue(val) + require.Nil(t, err, "MarshalValue error: %v", err) + return bson.RawValue{ + Type: valType, + Value: data, + } + } + + opts := options.ClientBulkWrite() + err := xoptions.SetInternalClientBulkWriteOptions(opts, "commandCallback", func(dst []byte, _ description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendStringElement(dst, "foo", "bar") + return dst, nil + }) + require.NoError(mt, err) + _, _ = mt.Client.BulkWrite(context.Background(), writes, opts) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("foo") + expected := marshalValue("bar") + assert.Equal(mt, expected, val, "expected value to be %s", expected.String()) }) } diff --git a/mongo/client.go b/mongo/client.go index cb2e54944e..803ea6e9af 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -963,6 +963,16 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, op.rawData = &rawData } } + if bypassEmptyTsReplacementOpt := optionsutil.Value(bwo.Internal, "bypassEmptyTsReplacement"); bypassEmptyTsReplacementOpt != nil { + if bypassEmptyTsReplacement, ok := bypassEmptyTsReplacementOpt.(bool); ok { + op.bypassEmptyTsReplacement = &bypassEmptyTsReplacement + } + } + if commandCallbackOpt := optionsutil.Value(bwo.Internal, "commandCallback"); commandCallbackOpt != nil { + if commandCallback, ok := commandCallbackOpt.(func([]byte, description.SelectedServer) ([]byte, error)); ok { + op.commandCallback = commandCallback + } + } if bwo.VerboseResults == nil || !(*bwo.VerboseResults) { op.errorsOnly = true } else if !acknowledged { diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 310fdbc301..a0f394f6fd 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -45,6 +45,8 @@ type clientBulkWrite struct { selector description.ServerSelector writeConcern *writeconcern.WriteConcern rawData *bool + bypassEmptyTsReplacement *bool + commandCallback func([]byte, description.SelectedServer) ([]byte, error) result ClientBulkWriteResult } @@ -122,7 +124,8 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { } func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) { - return func(dst []byte, desc description.SelectedServer) ([]byte, error) { + return func(cmd []byte, desc description.SelectedServer) ([]byte, error) { + var dst []byte dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1) dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly) @@ -148,7 +151,19 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) if bw.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) { dst = bsoncore.AppendBooleanElement(dst, "rawData", *bw.rawData) } - return dst, nil + if bw.bypassEmptyTsReplacement != nil { + dst = bsoncore.AppendBooleanElement(dst, "bypassEmptyTsReplacement", *bw.bypassEmptyTsReplacement) + } + if bw.commandCallback != nil { + var err error + dst, err = bw.commandCallback(dst, desc) + if err != nil { + return nil, err + } + } + + cmd = append(cmd, dst...) + return cmd, nil } } diff --git a/x/mongo/driver/xoptions/options.go b/x/mongo/driver/xoptions/options.go index fa11dd60b8..8aa059555f 100644 --- a/x/mongo/driver/xoptions/options.go +++ b/x/mongo/driver/xoptions/options.go @@ -12,6 +12,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" ) // SetInternalClientOptions sets internal options for ClientOptions. @@ -101,6 +102,24 @@ func SetInternalClientBulkWriteOptions(a *options.ClientBulkWriteOptionsBuilder, opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "bypassEmptyTsReplacement": + b, ok := option.(bool) + if !ok { + return typeErrFunc("bool") + } + a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, b) + return nil + }) + case "commandCallback": + cb, ok := option.(func([]byte, description.SelectedServer) ([]byte, error)) + if !ok { + return typeErrFunc("func([]byte, description.SelectedServer) ([]byte, error)") + } + a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, cb) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } From 4dd4232e535abca636cd56963c533145dd27ed84 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 21 Aug 2025 13:49:05 -0400 Subject: [PATCH 2/2] ensure callback a final mutator --- mongo/client_bulk_write.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index a0f394f6fd..78542fbe3e 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -60,6 +60,18 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { return fmt.Errorf("error from model at index %d: %w", i, ErrNilDocument) } } + newCommand := func(dst []byte, desc description.SelectedServer) ([]byte, error) { + var cmd []byte + cmd, err := bw.newCommand()(cmd, desc) + if err == nil && bw.commandCallback != nil { + cmd, err = bw.commandCallback(cmd, desc) + } + if err != nil { + return nil, err + } + dst = append(dst, cmd...) + return dst, nil + } batches := &modelBatches{ session: bw.session, client: bw.client, @@ -69,7 +81,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { retryMode: driver.RetryOnce, } err := driver.Operation{ - CommandFn: bw.newCommand(), + CommandFn: newCommand, ProcessResponseFn: batches.processResponse, Client: bw.session, Clock: bw.client.clock, @@ -124,8 +136,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { } func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) { - return func(cmd []byte, desc description.SelectedServer) ([]byte, error) { - var dst []byte + return func(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1) dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly) @@ -154,16 +165,7 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) if bw.bypassEmptyTsReplacement != nil { dst = bsoncore.AppendBooleanElement(dst, "bypassEmptyTsReplacement", *bw.bypassEmptyTsReplacement) } - if bw.commandCallback != nil { - var err error - dst, err = bw.commandCallback(dst, desc) - if err != nil { - return nil, err - } - } - - cmd = append(cmd, dst...) - return cmd, nil + return dst, nil } }