1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "testing"
13 "time"
14
15 "go.mongodb.org/mongo-driver/bson"
16 "go.mongodb.org/mongo-driver/bson/bsoncodec"
17 "go.mongodb.org/mongo-driver/internal/assert"
18 "go.mongodb.org/mongo-driver/mongo/options"
19 "go.mongodb.org/mongo-driver/mongo/readconcern"
20 "go.mongodb.org/mongo-driver/mongo/readpref"
21 "go.mongodb.org/mongo-driver/mongo/writeconcern"
22 "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
23 )
24
25 func setupDb(name string, opts ...*options.DatabaseOptions) *Database {
26 client := setupClient()
27 return client.Database(name, opts...)
28 }
29
30 func compareDbs(t *testing.T, expected, got *Database) {
31 t.Helper()
32 assert.Equal(t, expected.readPreference, got.readPreference,
33 "expected read preference %v, got %v", expected.readPreference, got.readPreference)
34 assert.Equal(t, expected.readConcern, got.readConcern,
35 "expected read concern %v, got %v", expected.readConcern, got.readConcern)
36 assert.Equal(t, expected.writeConcern, got.writeConcern,
37 "expected write concern %v, got %v", expected.writeConcern, got.writeConcern)
38 assert.Equal(t, expected.registry, got.registry,
39 "expected write concern %v, got %v", expected.registry, got.registry)
40 }
41
42 func TestDatabase(t *testing.T) {
43 t.Run("initialize", func(t *testing.T) {
44 name := "foo"
45 db := setupDb(name)
46 assert.Equal(t, name, db.Name(), "expected db name %v, got %v", name, db.Name())
47 assert.NotNil(t, db.Client(), "expected valid client, got nil")
48 })
49 t.Run("options", func(t *testing.T) {
50 t.Run("custom", func(t *testing.T) {
51 rpPrimary := readpref.Primary()
52 rpSecondary := readpref.Secondary()
53 wc1 := writeconcern.New(writeconcern.W(5))
54 wc2 := writeconcern.New(writeconcern.W(10))
55 rcLocal := readconcern.Local()
56 rcMajority := readconcern.Majority()
57 reg := bsoncodec.NewRegistryBuilder().Build()
58
59 opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1).
60 SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg)
61 expected := &Database{
62 readPreference: rpSecondary,
63 readConcern: rcMajority,
64 writeConcern: wc2,
65 registry: reg,
66 }
67 got := setupDb("foo", opts)
68 compareDbs(t, expected, got)
69 })
70 t.Run("inherit", func(t *testing.T) {
71 rpPrimary := readpref.Primary()
72 rcLocal := readconcern.Local()
73 wc1 := writeconcern.New(writeconcern.W(10))
74 reg := bsoncodec.NewRegistryBuilder().Build()
75
76 client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg))
77 got := client.Database("foo", options.Database().SetWriteConcern(wc1))
78 expected := &Database{
79 readPreference: rpPrimary,
80 readConcern: rcLocal,
81 writeConcern: wc1,
82 registry: reg,
83 }
84 compareDbs(t, expected, got)
85 })
86 })
87 t.Run("replace topology error", func(t *testing.T) {
88 db := setupDb("foo")
89 err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
90 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
91
92 err = db.Drop(bgCtx)
93 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
94
95 _, err = db.ListCollections(bgCtx, bson.D{})
96 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
97 })
98 t.Run("TransientTransactionError label", func(t *testing.T) {
99 client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
100 err := client.Connect(bgCtx)
101 defer func() { _ = client.Disconnect(bgCtx) }()
102 assert.Nil(t, err, "expected nil, got %v", err)
103
104 t.Run("negative case of non-transaction", func(t *testing.T) {
105 var sse topology.ServerSelectionError
106 var le LabeledError
107
108 err := client.Ping(bgCtx, nil)
109 assert.NotNil(t, err, "expected error, got nil")
110 assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
111 if errors.As(err, &le) {
112 assert.False(t, le.HasErrorLabel("TransientTransactionError"), `expected error not to include the "TransientTransactionError" label`)
113 }
114 })
115
116 t.Run("positive case of transaction", func(t *testing.T) {
117 var sse topology.ServerSelectionError
118 var le LabeledError
119
120 sess, err := client.StartSession()
121 assert.Nil(t, err, "expected nil, got %v", err)
122 defer sess.EndSession(bgCtx)
123
124 sessCtx := NewSessionContext(bgCtx, sess)
125 err = sess.StartTransaction()
126 assert.Nil(t, err, "expected nil, got %v", err)
127
128 err = client.Ping(sessCtx, nil)
129 assert.NotNil(t, err, "expected error, got nil")
130 assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
131 assert.True(t, errors.As(err, &le), `expected error to implement the "LabeledError" interface`)
132 assert.True(t, le.HasErrorLabel("TransientTransactionError"), `expected error to include the "TransientTransactionError" label`)
133 })
134 })
135 t.Run("nil document error", func(t *testing.T) {
136 db := setupDb("foo")
137
138 err := db.RunCommand(bgCtx, nil).Err()
139 assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
140
141 _, err = db.Watch(context.Background(), nil)
142 watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
143 assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err)
144
145 _, err = db.ListCollections(context.Background(), nil)
146 assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
147
148 _, err = db.ListCollectionNames(context.Background(), nil)
149 assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
150 })
151 }
152
View as plain text