Skip to content

Commit 52442f5

Browse files
committed
Add custom options to client bulkWrite.
1 parent 54bab6d commit 52442f5

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

internal/integration/client_test.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ import (
3131
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
3232
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
3333
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
34+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
3435
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage"
36+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions"
3537
"golang.org/x/sync/errgroup"
3638
)
3739

@@ -838,8 +840,40 @@ func TestClient_BulkWrite(t *testing.T) {
838840
}
839841

840842
_, err := mt.Client.BulkWrite(context.Background(), writes)
841-
require.NoError(t, err)
842-
assert.Equal(t, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites)
843+
require.NoError(mt, err)
844+
assert.Equal(mt, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites)
845+
})
846+
mt.Run("test options callback", func(mt *mtest.T) {
847+
mt.Parallel()
848+
849+
insertOneModel := mongo.NewClientInsertOneModel().SetDocument(bson.D{{"x", 1}})
850+
writes := []mongo.ClientBulkWrite{{
851+
Database: "foo",
852+
Collection: "bar",
853+
Model: insertOneModel,
854+
}}
855+
856+
marshalValue := func(val interface{}) bson.RawValue {
857+
t.Helper()
858+
859+
valType, data, err := bson.MarshalValue(val)
860+
require.Nil(t, err, "MarshalValue error: %v", err)
861+
return bson.RawValue{
862+
Type: valType,
863+
Value: data,
864+
}
865+
}
866+
867+
opts := options.ClientBulkWrite()
868+
xoptions.SetInternalClientBulkWriteOptions(opts, "commandCallback", func(dst []byte, _ description.SelectedServer) ([]byte, error) {
869+
dst = bsoncore.AppendStringElement(dst, "foo", "bar")
870+
return dst, nil
871+
})
872+
_, _ = mt.Client.BulkWrite(context.Background(), writes, opts)
873+
evt := mt.GetStartedEvent()
874+
val := evt.Command.Lookup("foo")
875+
expected := marshalValue("bar")
876+
assert.Equal(mt, expected, val, "expected value to be %s", expected.String())
843877
})
844878
}
845879

mongo/client.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,16 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite,
963963
op.rawData = &rawData
964964
}
965965
}
966+
if bypassEmptyTsReplacementOpt := optionsutil.Value(bwo.Internal, "bypassEmptyTsReplacement"); bypassEmptyTsReplacementOpt != nil {
967+
if bypassEmptyTsReplacement, ok := bypassEmptyTsReplacementOpt.(bool); ok {
968+
op.bypassEmptyTsReplacement = &bypassEmptyTsReplacement
969+
}
970+
}
971+
if commandCallbackOpt := optionsutil.Value(bwo.Internal, "commandCallback"); commandCallbackOpt != nil {
972+
if commandCallback, ok := commandCallbackOpt.(func([]byte, description.SelectedServer) ([]byte, error)); ok {
973+
op.commandCallback = commandCallback
974+
}
975+
}
966976
if bwo.VerboseResults == nil || !(*bwo.VerboseResults) {
967977
op.errorsOnly = true
968978
} else if !acknowledged {

mongo/client_bulk_write.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ type clientBulkWrite struct {
4545
selector description.ServerSelector
4646
writeConcern *writeconcern.WriteConcern
4747
rawData *bool
48+
bypassEmptyTsReplacement *bool
49+
commandCallback func([]byte, description.SelectedServer) ([]byte, error)
4850

4951
result ClientBulkWriteResult
5052
}
@@ -122,7 +124,8 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
122124
}
123125

124126
func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) ([]byte, error) {
125-
return func(dst []byte, desc description.SelectedServer) ([]byte, error) {
127+
return func(cmd []byte, desc description.SelectedServer) ([]byte, error) {
128+
var dst []byte
126129
dst = bsoncore.AppendInt32Element(dst, "bulkWrite", 1)
127130

128131
dst = bsoncore.AppendBooleanElement(dst, "errorsOnly", bw.errorsOnly)
@@ -148,7 +151,19 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer)
148151
if bw.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) {
149152
dst = bsoncore.AppendBooleanElement(dst, "rawData", *bw.rawData)
150153
}
151-
return dst, nil
154+
if bw.bypassEmptyTsReplacement != nil {
155+
dst = bsoncore.AppendBooleanElement(dst, "bypassEmptyTsReplacement", *bw.bypassEmptyTsReplacement)
156+
}
157+
if bw.commandCallback != nil {
158+
var err error
159+
dst, err = bw.commandCallback(dst, desc)
160+
if err != nil {
161+
return nil, err
162+
}
163+
}
164+
165+
cmd = append(cmd, dst...)
166+
return cmd, nil
152167
}
153168
}
154169

x/mongo/driver/xoptions/options.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"go.mongodb.org/mongo-driver/v2/internal/optionsutil"
1313
"go.mongodb.org/mongo-driver/v2/mongo/options"
1414
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
15+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
1516
)
1617

1718
// SetInternalClientOptions sets internal options for ClientOptions.
@@ -101,6 +102,24 @@ func SetInternalClientBulkWriteOptions(a *options.ClientBulkWriteOptionsBuilder,
101102
opts.Internal = optionsutil.WithValue(opts.Internal, key, b)
102103
return nil
103104
})
105+
case "bypassEmptyTsReplacement":
106+
b, ok := option.(bool)
107+
if !ok {
108+
return typeErrFunc("bool")
109+
}
110+
a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error {
111+
opts.Internal = optionsutil.WithValue(opts.Internal, key, b)
112+
return nil
113+
})
114+
case "commandCallback":
115+
cb, ok := option.(func([]byte, description.SelectedServer) ([]byte, error))
116+
if !ok {
117+
return typeErrFunc("func([]byte, description.SelectedServer) ([]byte, error)")
118+
}
119+
a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error {
120+
opts.Internal = optionsutil.WithValue(opts.Internal, key, cb)
121+
return nil
122+
})
104123
default:
105124
return fmt.Errorf("unsupported option: %q", key)
106125
}

0 commit comments

Comments
 (0)