Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Oct 21, 2024
1 parent b287ccf commit fe83c0b
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 90 deletions.
1 change: 1 addition & 0 deletions internal/driverutil/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ const (
ListIndexesOp = "listIndexes" // ListIndexesOp is the name for listing indexes
ListDatabasesOp = "listDatabases" // ListDatabasesOp is the name for listing databases
UpdateOp = "update" // UpdateOp is the name for updating
BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write
)
12 changes: 9 additions & 3 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,10 @@ func (c *Client) createBaseCursorOptions() driver.CursorOptions {
}
}

// BulkWrite performs a client-levelbulk write operation.
func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
// BulkWrite performs a client-level bulk write operation.
func (c *Client) BulkWrite(ctx context.Context, models ClientWriteModels,
opts ...*options.ClientBulkWriteOptions) (*ClientBulkWriteResult, error) {
// TODO: Remove once DRIVERS-2888 is implemented.
if c.isAutoEncryptionSet {
return nil, errors.New("bulkWrite does not currently support automatic encryption")
}
Expand Down Expand Up @@ -886,6 +887,9 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
wc = bwo.WriteConcern
}
if !writeconcern.AckWrite(wc) {
if bwo.Ordered == nil || *bwo.Ordered {
return nil, errors.New("cannot request unacknowledged write concern and ordered writes")
}
sess = nil
}

Expand All @@ -896,7 +900,7 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
selector := makePinnedSelector(sess, writeSelector)

op := clientBulkWrite{
models: models.models,
models: models,
ordered: bwo.Ordered,
bypassDocumentValidation: bwo.BypassDocumentValidation,
comment: bwo.Comment,
Expand All @@ -908,6 +912,8 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
}
if bwo.VerboseResults == nil || !(*bwo.VerboseResults) {
op.errorsOnly = true
} else if !writeconcern.AckWrite(wc) {
return nil, errors.New("cannot request unacknowledged write concern and verbose results")
}
if err = op.execute(ctx); err != nil {
return nil, replaceErrors(err)
Expand Down
75 changes: 50 additions & 25 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/driverutil"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
Expand All @@ -24,6 +25,10 @@ import (
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

const (
database = "admin"
)

// bulkWrite performs a bulkwrite operation
type clientBulkWrite struct {
models []clientWriteModel
Expand All @@ -42,12 +47,17 @@ type clientBulkWrite struct {

func (bw *clientBulkWrite) execute(ctx context.Context) error {
if len(bw.models) == 0 {
return errors.New("empty write models")
return ErrEmptySlice
}
for _, m := range bw.models {
if m.model == nil {
return ErrNilDocument
}
}
batches := &modelBatches{
session: bw.session,
client: bw.client,
ordered: bw.ordered,
ordered: bw.ordered == nil || *bw.ordered,
models: bw.models,
result: &bw.result,
retryMode: driver.RetryOnce,
Expand All @@ -61,7 +71,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
Type: driver.Write,
Batches: batches,
CommandMonitor: bw.client.monitor,
Database: "admin",
Database: database,
Deployment: bw.client.deployment,
Selector: bw.selector,
WriteConcern: bw.writeConcern,
Expand All @@ -70,7 +80,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
Timeout: bw.client.timeout,
Logger: bw.client.logger,
Authenticator: bw.client.authenticator,
Name: "bulkWrite",
Name: driverutil.BulkWriteOp,
}.Execute(ctx)
var exception *ClientBulkWriteException
switch tt := err.(type) {
Expand All @@ -96,7 +106,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
}
if exception != nil {
var hasSuccess bool
if bw.ordered == nil || *bw.ordered {
if batches.ordered {
_, ok := batches.writeErrors[0]
hasSuccess = !ok
} else {
Expand Down Expand Up @@ -125,9 +135,7 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer)
}
dst = bsoncore.AppendValueElement(dst, "comment", comment)
}
if bw.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *bw.ordered)
}
dst = bsoncore.AppendBooleanElement(dst, "ordered", bw.ordered == nil || *bw.ordered)
if bw.let != nil {
let, err := marshal(bw.let, bw.client.bsonOpts, bw.client.registry)
if err != nil {
Expand Down Expand Up @@ -173,7 +181,7 @@ type modelBatches struct {
session *session.Client
client *Client

ordered *bool
ordered bool
models []clientWriteModel

offset int
Expand All @@ -188,7 +196,7 @@ type modelBatches struct {
}

func (mb *modelBatches) IsOrdered() *bool {
return mb.ordered
return &mb.ordered
}

func (mb *modelBatches) AdvanceBatches(n int) {
Expand Down Expand Up @@ -272,7 +280,7 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
nsIdx, nsDst := fn.appendStart(nil, "nsInfo")

totalSize -= 1000
size := (len(dst) - l) * 2
size := len(dst) + len(nsDst)
var n int
for i := mb.offset; i < len(mb.models); i++ {
if n == maxCount {
Expand Down Expand Up @@ -362,16 +370,16 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
return 0, nil, err
}
length := len(doc)
if length > maxDocSize {
break
}
if !exists {
length += len(ns)
}
size += length
if size >= totalSize {
break
}
if maxDocSize > 0 && length > maxDocSize+16*1024 {
return 0, nil, driver.ErrDocumentTooLarge
}

dst = fn.appendDocument(dst, strconv.Itoa(n), doc)
if !exists {
Expand All @@ -389,6 +397,9 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
dst = fn.updateLength(dst, opsIdx, int32(len(dst[opsIdx:])))
nsDst = fn.updateLength(nsDst, nsIdx, int32(len(nsDst[nsIdx:])))
dst = append(dst, nsDst...)
if maxDocSize > 0 && len(dst) > maxDocSize+16*1024 {
return 0, nil, driver.ErrDocumentTooLarge
}

mb.retryMode = driver.RetryNone
if mb.client.retryWrites && canRetry {
Expand Down Expand Up @@ -424,6 +435,19 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
if err != nil {
return err
}
if !res.Ok {
return ClientBulkWriteException{
TopLevelError: &WriteError{
Code: int(res.Code),
Message: res.Errmsg,
Raw: bson.Raw(resp),
},
WriteConcernErrors: mb.writeConcernErrors,
WriteErrors: mb.writeErrors,
PartialResult: mb.result,
}
}

mb.result.DeletedCount += int64(res.NDeleted)
mb.result.InsertedCount += int64(res.NInserted)
mb.result.MatchedCount += int64(res.NMatched)
Expand Down Expand Up @@ -470,21 +494,12 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
if err != nil {
return err
}
isOrdered := mb.ordered == nil || *mb.ordered
if isOrdered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) {
exception := ClientBulkWriteException{
if mb.ordered && (writeCmdErr.WriteConcernError != nil || !ok || !res.Ok || res.NErrors > 0) {
return ClientBulkWriteException{
WriteConcernErrors: mb.writeConcernErrors,
WriteErrors: mb.writeErrors,
PartialResult: mb.result,
}
if !res.Ok {
exception.TopLevelError = &WriteError{
Code: int(res.Code),
Message: res.Errmsg,
Raw: bson.Raw(resp),
}
}
return exception
}
return nil
}
Expand Down Expand Up @@ -558,6 +573,8 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
type clientInsertDoc struct {
namespace int
document interface{}

sizeLimit int
}

func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (interface{}, bsoncore.Document, error) {
Expand All @@ -568,6 +585,9 @@ func (d *clientInsertDoc) marshal(bsonOpts *options.BSONOptions, registry *bsonc
if err != nil {
return nil, nil, err
}
if d.sizeLimit > 0 && len(f) > d.sizeLimit {
return nil, nil, driver.ErrDocumentTooLarge
}
var id interface{}
f, id, err = ensureID(f, primitive.NilObjectID, bsonOpts, registry)
if err != nil {
Expand All @@ -588,6 +608,8 @@ type clientUpdateDoc struct {
upsert *bool
multi bool
checkDollarKey bool

sizeLimit int
}

func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) {
Expand All @@ -605,6 +627,9 @@ func (d *clientUpdateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsonc
if err != nil {
return nil, err
}
if d.sizeLimit > 0 && len(u.Data) > d.sizeLimit {
return nil, driver.ErrDocumentTooLarge
}
doc = bsoncore.AppendValueElement(doc, "updateMods", u)
doc = bsoncore.AppendBooleanElement(doc, "multi", d.multi)

Expand Down
54 changes: 17 additions & 37 deletions mongo/client_bulk_write_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@ import (
)

// ClientWriteModels is a struct that can be used in a client-level BulkWrite operation.
type ClientWriteModels struct {
models []clientWriteModel
}
type ClientWriteModels []clientWriteModel

type clientWriteModel struct {
namespace string
model interface{}
}

// AppendInsertOne appends ClientInsertOneModels.
func (m *ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
Expand All @@ -37,12 +32,9 @@ func (m *ClientWriteModels) AppendInsertOne(database, collection string, models
}

// AppendUpdateOne appends ClientUpdateOneModels.
func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
Expand All @@ -51,12 +43,9 @@ func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models
}

// AppendUpdateMany appends ClientUpdateManyModels.
func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
Expand All @@ -65,12 +54,9 @@ func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models
}

// AppendReplaceOne appends ClientReplaceOneModels.
func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
Expand All @@ -79,12 +65,9 @@ func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models
}

// AppendDeleteOne appends ClientDeleteOneModels.
func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
Expand All @@ -93,20 +76,17 @@ func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models
}

// AppendDeleteMany appends ClientDeleteManyModels.
func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) *ClientWriteModels {
if m == nil {
m = &ClientWriteModels{}
}
func (m ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) ClientWriteModels {
for _, model := range models {
m.models = append(m.models, clientWriteModel{
m = append(m, clientWriteModel{
namespace: fmt.Sprintf("%s.%s", database, collection),
model: model,
})
}
return m
}

// ClientInsertOneModel is used to insert a single document in a BulkWrite operation.
// ClientInsertOneModel is used to insert a single document in a client-level BulkWrite operation.
type ClientInsertOneModel struct {
Document interface{}
}
Expand Down Expand Up @@ -166,7 +146,7 @@ func (uom *ClientUpdateOneModel) SetCollation(collation *options.Collation) *Cli
}

// SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If
// an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the
// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the
// ClientBulkWriteResult.
func (uom *ClientUpdateOneModel) SetUpsert(upsert bool) *ClientUpdateOneModel {
uom.Upsert = &upsert
Expand Down Expand Up @@ -219,7 +199,7 @@ func (umm *ClientUpdateManyModel) SetCollation(collation *options.Collation) *Cl
}

// SetUpsert specifies whether or not a new document should be inserted if no document matching the filter is found. If
// an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the
// an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the
// ClientBulkWriteResult.
func (umm *ClientUpdateManyModel) SetUpsert(upsert bool) *ClientUpdateManyModel {
umm.Upsert = &upsert
Expand Down Expand Up @@ -265,7 +245,7 @@ func (rom *ClientReplaceOneModel) SetCollation(collation *options.Collation) *Cl
}

// SetUpsert specifies whether or not the replacement document should be inserted if no document matching the filter is
// found. If an upsert is performed, the _id of the upserted document can be retrieved from the UpsertedIDs field of the
// found. If an upsert is performed, the _id of the upserted document can be retrieved from the UpdateResults field of the
// BulkWriteResult.
func (rom *ClientReplaceOneModel) SetUpsert(upsert bool) *ClientReplaceOneModel {
rom.Upsert = &upsert
Expand Down
Loading

0 comments on commit fe83c0b

Please sign in to comment.