diff --git a/internal/driverutil/operation.go b/internal/driverutil/operation.go index 32704312ff..e37cba5903 100644 --- a/internal/driverutil/operation.go +++ b/internal/driverutil/operation.go @@ -28,4 +28,5 @@ const ( ListIndexesOp = "listIndexes" // ListIndexesOp is the name for listing indexes ListDatabasesOp = "listDatabases" // ListDatabasesOp is the name for listing databases UpdateOp = "update" // UpdateOp is the name for updating + BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write ) diff --git a/mongo/client.go b/mongo/client.go index cebd06559c..25b0ab1379 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -851,9 +851,10 @@ func (c *Client) createBaseCursorOptions() driver.CursorOptions { } } -// BulkWrite performs a client-levelbulk write operation. -func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, +// BulkWrite performs a client-level bulk write operation. +func (c *Client) BulkWrite(ctx context.Context, models ClientWriteModels, opts ...*options.ClientBulkWriteOptions) (*ClientBulkWriteResult, error) { + // TODO: Remove once DRIVERS-2888 is implemented. if c.isAutoEncryptionSet { return nil, errors.New("bulkWrite does not currently support automatic encryption") } @@ -886,6 +887,9 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, wc = bwo.WriteConcern } if !writeconcern.AckWrite(wc) { + if bwo.Ordered == nil || *bwo.Ordered { + return nil, errors.New("cannot request unacknowledged write concern and ordered writes") + } sess = nil } @@ -896,7 +900,7 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, selector := makePinnedSelector(sess, writeSelector) op := clientBulkWrite{ - models: models.models, + models: models, ordered: bwo.Ordered, bypassDocumentValidation: bwo.BypassDocumentValidation, comment: bwo.Comment, @@ -908,6 +912,8 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, } if bwo.VerboseResults == nil || !(*bwo.VerboseResults) { op.errorsOnly = true + } else if !writeconcern.AckWrite(wc) { + return nil, errors.New("cannot request unacknowledged write concern and verbose results") } if err = op.execute(ctx); err != nil { return nil, replaceErrors(err) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index cccc49e226..ea48963d87 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -24,6 +25,10 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) +const ( + database = "admin" +) + // bulkWrite performs a bulkwrite operation type clientBulkWrite struct { models []clientWriteModel @@ -42,12 +47,17 @@ type clientBulkWrite struct { func (bw *clientBulkWrite) execute(ctx context.Context) error { if len(bw.models) == 0 { - return errors.New("empty write models") + return ErrEmptySlice + } + for _, m := range bw.models { + if m.model == nil { + return ErrNilDocument + } } batches := &modelBatches{ session: bw.session, client: bw.client, - ordered: bw.ordered, + ordered: bw.ordered == nil || *bw.ordered, models: bw.models, result: &bw.result, retryMode: driver.RetryOnce, @@ -61,7 +71,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { Type: driver.Write, Batches: batches, CommandMonitor: bw.client.monitor, - Database: "admin", + Database: database, Deployment: bw.client.deployment, Selector: bw.selector, WriteConcern: bw.writeConcern, @@ -70,7 +80,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { Timeout: bw.client.timeout, Logger: bw.client.logger, Authenticator: bw.client.authenticator, - Name: "bulkWrite", + Name: driverutil.BulkWriteOp, }.Execute(ctx) var exception *ClientBulkWriteException switch tt := err.(type) { @@ -96,7 +106,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { } if exception != nil { var hasSuccess bool - if bw.ordered == nil || *bw.ordered { + if batches.ordered { _, ok := batches.writeErrors[0] hasSuccess = !ok } else { @@ -125,9 +135,7 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) } dst = bsoncore.AppendValueElement(dst, "comment", comment) } - if bw.ordered != nil { - dst = bsoncore.AppendBooleanElement(dst, "ordered", *bw.ordered) - } + dst = bsoncore.AppendBooleanElement(dst, "ordered", bw.ordered == nil || *bw.ordered) if bw.let != nil { let, err := marshal(bw.let, bw.client.bsonOpts, bw.client.registry) if err != nil { @@ -173,7 +181,7 @@ type modelBatches struct { session *session.Client client *Client - ordered *bool + ordered bool models []clientWriteModel offset int @@ -188,7 +196,7 @@ type modelBatches struct { } func (mb *modelBatches) IsOrdered() *bool { - return mb.ordered + return &mb.ordered } func (mb *modelBatches) AdvanceBatches(n int) { @@ -272,7 +280,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD nsIdx, nsDst := fn.appendStart(nil, "nsInfo") totalSize -= 1000 - size := (len(dst) - l) * 2 + size := len(dst) + len(nsDst) var n int for i := mb.offset; i < len(mb.models); i++ { if n == maxCount { @@ -362,9 +370,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD return 0, nil, err } length := len(doc) - if length > maxDocSize { - break - } if !exists { length += len(ns) } @@ -372,6 +377,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD if size >= totalSize { break } + if maxDocSize > 0 && length > maxDocSize+16*1024 { + return 0, nil, driver.ErrDocumentTooLarge + } dst = fn.appendDocument(dst, strconv.Itoa(n), doc) if !exists { @@ -389,6 +397,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD dst = fn.updateLength(dst, opsIdx, int32(len(dst[opsIdx:]))) nsDst = fn.updateLength(nsDst, nsIdx, int32(len(nsDst[nsIdx:]))) dst = append(dst, nsDst...) + if maxDocSize > 0 && len(dst) > maxDocSize+16*1024 { + return 0, nil, driver.ErrDocumentTooLarge + } mb.retryMode = driver.RetryNone if mb.client.retryWrites && canRetry { @@ -424,6 +435,19 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum if err != nil { return err } + if !res.Ok { + return ClientBulkWriteException{ + TopLevelError: &WriteError{ + Code: int(res.Code), + Message: res.Errmsg, + Raw: bson.Raw(resp), + }, + WriteConcernErrors: mb.writeConcernErrors, + WriteErrors: mb.writeErrors, + PartialResult: mb.result, + } + } + mb.result.DeletedCount += int64(res.NDeleted) mb.result.InsertedCount += int64(res.NInserted) mb.result.MatchedCount += int64(res.NMatched) @@ -470,21 +494,12 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum if err != nil { return err } - isOrdered := mb.ordered == nil || *mb.ordered - if isOrdered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) { - exception := ClientBulkWriteException{ + if mb.ordered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) { + return ClientBulkWriteException{ WriteConcernErrors: mb.writeConcernErrors, WriteErrors: mb.writeErrors, PartialResult: mb.result, } - if !res.Ok { - exception.TopLevelError = &WriteError{ - Code: int(res.Code), - Message: res.Errmsg, - Raw: bson.Raw(resp), - } - } - return exception } return nil } @@ -558,6 +573,8 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool { type clientInsertDoc struct { namespace int document interface{} + + sizeLimit int } func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (interface{}, bsoncore.Document, error) { @@ -568,6 +585,9 @@ func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bsonc if err != nil { return nil, nil, err } + if d.sizeLimit > 0 && len(f) > d.sizeLimit { + return nil, nil, driver.ErrDocumentTooLarge + } var id interface{} f, id, err = ensureID(f, primitive.NilObjectID, bsonOpts, registry) if err != nil { @@ -588,6 +608,8 @@ type clientUpdateDoc struct { upsert *bool multi bool checkDollarKey bool + + sizeLimit int } func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) { @@ -605,6 +627,9 @@ func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsonc if err != nil { return nil, err } + if d.sizeLimit > 0 && len(u.Data) > d.sizeLimit { + return nil, driver.ErrDocumentTooLarge + } doc = bsoncore.AppendValueElement(doc, "updateMods", u) doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi) diff --git a/mongo/client_bulk_write_models.go b/mongo/client_bulk_write_models.go index 4a2259a5c9..feabfd30f7 100644 --- a/mongo/client_bulk_write_models.go +++ b/mongo/client_bulk_write_models.go @@ -13,9 +13,7 @@ import ( ) // ClientWriteModels is a struct that can be used in a client-level BulkWrite operation. -type ClientWriteModels struct { - models []clientWriteModel -} +type ClientWriteModels []clientWriteModel type clientWriteModel struct { namespace string @@ -23,12 +21,9 @@ type clientWriteModel struct { } // AppendInsertOne appends ClientInsertOneModels. -func (m *ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -37,12 +32,9 @@ func (m *ClientWriteModels) AppendInsertOne(database, collection string, models } // AppendUpdateOne appends ClientUpdateOneModels. -func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -51,12 +43,9 @@ func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models } // AppendUpdateMany appends ClientUpdateManyModels. -func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -65,12 +54,9 @@ func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models } // AppendReplaceOne appends ClientReplaceOneModels. -func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -79,12 +65,9 @@ func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models } // AppendDeleteOne appends ClientDeleteOneModels. -func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -93,12 +76,9 @@ func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models } // AppendDeleteMany appends ClientDeleteManyModels. -func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } +func (m ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) ClientWriteModels { for _, model := range models { - m.models = append(m.models, clientWriteModel{ + m = append(m, clientWriteModel{ namespace: fmt.Sprintf("%s.%s", database, collection), model: model, }) @@ -106,7 +86,7 @@ func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models return m } -// ClientInsertOneModel is used to insert a single document in a BulkWrite operation. +// ClientInsertOneModel is used to insert a single document in a client-level BulkWrite operation. type ClientInsertOneModel struct { Document interface{} } @@ -166,7 +146,7 @@ func (uom *ClientUpdateOneModel) SetCollation(collation *options.Collation) *Cli } // SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If -// an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the +// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the // ClientBulkWriteResult. func (uom *ClientUpdateOneModel) SetUpsert(upsert bool) *ClientUpdateOneModel { uom.Upsert = &upsert @@ -219,7 +199,7 @@ func (umm *ClientUpdateManyModel) SetCollation(collation *options.Collation) *Cl } // SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If -// an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the +// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the // ClientBulkWriteResult. func (umm *ClientUpdateManyModel) SetUpsert(upsert bool) *ClientUpdateManyModel { umm.Upsert = &upsert @@ -265,7 +245,7 @@ func (rom *ClientReplaceOneModel) SetCollation(collation *options.Collation) *Cl } // SetUpsert specifies whether or not the replacement document should be inserted if no document matching the filter is -// found. If an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the +// found. If an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the // BulkWriteResult. func (rom *ClientReplaceOneModel) SetUpsert(upsert bool) *ClientReplaceOneModel { rom.Upsert = &upsert diff --git a/mongo/errors.go b/mongo/errors.go index 547d27827a..5340d632cc 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -611,14 +611,20 @@ func (bwe BulkWriteException) serverError() {} // ClientBulkWriteException is the error type returned by ClientBulkWrite operations. type ClientBulkWriteException struct { + // A top-level error that occurred when attempting to communicate with the server + // or execute the bulk write. This value may not be populated if the exception was + // thrown due to errors occurring on individual writes. TopLevelError *WriteError // The write concern errors that occurred. WriteConcernErrors []WriteConcernError // The write errors that occurred during individual operation execution. + // This map will contain at most one entry if the bulk write was ordered. WriteErrors map[int]WriteError + // The results of any successful operations that were performed before the error + // was encountered. PartialResult *ClientBulkWriteResult } diff --git a/mongo/integration/crud_prose_test.go b/mongo/integration/crud_prose_test.go index 89aa7f1296..1ec7ddad75 100644 --- a/mongo/integration/crud_prose_test.go +++ b/mongo/integration/crud_prose_test.go @@ -438,7 +438,7 @@ func TestClientBulkWrite(t *testing.T) { MaxWriteBatchSize int } require.NoError(mt, mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello), "Hello error") - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ @@ -473,7 +473,7 @@ func TestClientBulkWrite(t *testing.T) { MaxMessageSizeBytes int } require.NoError(mt, mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello), "Hello error") - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 for i := 0; i < numModels; i++ { models. @@ -511,7 +511,7 @@ func TestClientBulkWrite(t *testing.T) { }, }) - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ @@ -549,7 +549,7 @@ func TestClientBulkWrite(t *testing.T) { } err = mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error") - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels for i := 0; i < hello.MaxWriteBatchSize+1; i++ { models. AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ @@ -773,12 +773,12 @@ func TestClientBulkWrite(t *testing.T) { err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error") - newModels := func() (int, *mongo.ClientWriteModels) { + newModels := func() (int, mongo.ClientWriteModels) { maxBsonObjectSize := hello.MaxBsonObjectSize opsBytes := hello.MaxMessageSizeBytes - 1122 numModels := opsBytes / maxBsonObjectSize - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels n := numModels for i := 0; i < n; i++ { models. @@ -915,7 +915,7 @@ func TestClientBulkWrite(t *testing.T) { require.NoError(mt, err, "Drop error") numModels := hello.MaxMessageSizeBytes / hello.MaxBsonObjectSize - models := &mongo.ClientWriteModels{} + var models mongo.ClientWriteModels for i := 0; i < numModels+1; i++ { models. AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ diff --git a/mongo/integration/unified/client_operation_execution.go b/mongo/integration/unified/client_operation_execution.go index 6e18af44c9..b199670603 100644 --- a/mongo/integration/unified/client_operation_execution.go +++ b/mongo/integration/unified/client_operation_execution.go @@ -178,7 +178,7 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return nil, err } - wirteModels := &mongo.ClientWriteModels{} + var wirteModels mongo.ClientWriteModels opts := options.ClientBulkWrite() elems, err := operation.Arguments.Elements() @@ -288,7 +288,7 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return newDocumentResult(rawBuilder.Build(), err), nil } -func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error { +func appendClientBulkWriteModel(key string, value bson.Raw, model mongo.ClientWriteModels) error { switch key { case "insertOne": namespace, m, err := createClientInsertOneModel(value) diff --git a/mongo/options/clientbulkwriteoptions.go b/mongo/options/clientbulkwriteoptions.go index ad91f37488..7c460e47ce 100644 --- a/mongo/options/clientbulkwriteoptions.go +++ b/mongo/options/clientbulkwriteoptions.go @@ -12,21 +12,19 @@ import ( // ClientBulkWriteOptions represents options that can be used to configure a client-level BulkWrite operation. type ClientBulkWriteOptions struct { - // If true, writes executed as part of the operation will opt out of document-level validation on the server. This - // option is valid for MongoDB versions >= 3.2 and is ignored for previous server versions. The default value is - // false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information about document - // validation. + // If true, writes executed as part of the operation will opt out of document-level validation on the server. The + // default value is false. See https://www.mongodb.com/docs/manual/core/schema-validation/ for more information + // about document validation. BypassDocumentValidation *bool // A string or document that will be included in server logs, profiling logs, and currentOp queries to help trace - // the operation. The default value is nil, which means that no comment will be included in the logs. + // the operation. The default value is nil, which means that no comment will be included in the logs. Comment interface{} // If true, no writes will be executed after one fails. The default value is true. Ordered *bool - // Specifies parameters for all update and delete commands in the BulkWrite. This option is only valid for MongoDB - // versions >= 5.0. Older servers will report an error for using this option. This must be a document mapping + // Specifies parameters for all update and delete commands in the BulkWrite. This must be a document mapping // parameter names to values. Values must be constant or closed expressions that do not reference document fields. // Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). Let interface{} @@ -63,8 +61,7 @@ func (b *ClientBulkWriteOptions) SetBypassDocumentValidation(bypass bool) *Clien return b } -// SetLet sets the value for the Let field. Let specifies parameters for all update and delete commands in the BulkWrite. -// This option is only valid for MongoDB versions >= 5.0. Older servers will report an error for using this option. +// SetLet sets the value for the Let field. Let specifies parameters for all update and delete commands in the ClientBulkWrite. // This must be a document mapping parameter names to values. Values must be constant or closed expressions that do not // reference document fields. Parameters can then be accessed as variables in an aggregate expression context (e.g. "$$var"). func (b *ClientBulkWriteOptions) SetLet(let interface{}) *ClientBulkWriteOptions { @@ -84,8 +81,8 @@ func (b *ClientBulkWriteOptions) SetVerboseResults(verboseResults bool) *ClientB return b } -// MergeClientBulkWriteOptions combines the given BulkWriteOptions instances into a single BulkWriteOptions in a last-one-wins -// fashion. +// MergeClientBulkWriteOptions combines the given ClientBulkWriteOptions instances into a single +// ClientBulkWriteOptions in a last-one-wins fashion. // // Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a // single options struct instead. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index c2f4601947..21573488c5 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1349,6 +1349,8 @@ func (op Operation) createWireMessage( var wmindex int32 var err error + unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) + fIdx := -1 isLegacy := isLegacyHandshake(op, desc) switch { @@ -1374,10 +1376,10 @@ func (op Operation) createWireMessage( default: wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg) fIdx = len(dst) - appendBatches := func(dst []byte) ([]byte, error) { + appendBatches := func(dst []byte, maxCount, maxDocSize, totalSize int) ([]byte, error) { var processedBatches int dsOffset := len(dst) - processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxMessageSize)) + processedBatches, dst, err = op.Batches.AppendBatchSequence(dst, maxCount, maxDocSize, totalSize) if err != nil { return nil, err } @@ -1401,12 +1403,16 @@ func (op Operation) createWireMessage( case *Batches: dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) if err == nil && op.Batches != nil { - dst, err = appendBatches(dst) + dst, err = appendBatches(dst, int(desc.MaxBatchCount), int(desc.MaxDocumentSize), int(desc.MaxDocumentSize)) } default: var batches []byte if op.Batches != nil { - batches, err = appendBatches(batches) + maxDocSize := -1 + if unacknowledged { + maxDocSize = int(desc.MaxDocumentSize) + } + batches, err = appendBatches(batches, int(desc.MaxBatchCount), maxDocSize, int(desc.MaxMessageSize)) } if err == nil { dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, op.CommandFn) @@ -1423,7 +1429,6 @@ func (op Operation) createWireMessage( var moreToCome bool // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either // aren't batching or we are encoding the last batch. - unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) batching := op.Batches != nil && op.Batches.Size() > info.processedBatches if fIdx > 0 && unacknowledged && !batching { dst[fIdx] |= byte(wiremessage.MoreToCome)