...

Source file src/go.mongodb.org/mongo-driver/mongo/database_test.go

Documentation: go.mongodb.org/mongo-driver/mongo

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"testing"
    13  	"time"
    14  
    15  	"go.mongodb.org/mongo-driver/bson"
    16  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    17  	"go.mongodb.org/mongo-driver/internal/assert"
    18  	"go.mongodb.org/mongo-driver/mongo/options"
    19  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    20  	"go.mongodb.org/mongo-driver/mongo/readpref"
    21  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    22  	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    23  )
    24  
    25  func setupDb(name string, opts ...*options.DatabaseOptions) *Database {
    26  	client := setupClient()
    27  	return client.Database(name, opts...)
    28  }
    29  
    30  func compareDbs(t *testing.T, expected, got *Database) {
    31  	t.Helper()
    32  	assert.Equal(t, expected.readPreference, got.readPreference,
    33  		"expected read preference %v, got %v", expected.readPreference, got.readPreference)
    34  	assert.Equal(t, expected.readConcern, got.readConcern,
    35  		"expected read concern %v, got %v", expected.readConcern, got.readConcern)
    36  	assert.Equal(t, expected.writeConcern, got.writeConcern,
    37  		"expected write concern %v, got %v", expected.writeConcern, got.writeConcern)
    38  	assert.Equal(t, expected.registry, got.registry,
    39  		"expected write concern %v, got %v", expected.registry, got.registry)
    40  }
    41  
    42  func TestDatabase(t *testing.T) {
    43  	t.Run("initialize", func(t *testing.T) {
    44  		name := "foo"
    45  		db := setupDb(name)
    46  		assert.Equal(t, name, db.Name(), "expected db name %v, got %v", name, db.Name())
    47  		assert.NotNil(t, db.Client(), "expected valid client, got nil")
    48  	})
    49  	t.Run("options", func(t *testing.T) {
    50  		t.Run("custom", func(t *testing.T) {
    51  			rpPrimary := readpref.Primary()
    52  			rpSecondary := readpref.Secondary()
    53  			wc1 := writeconcern.New(writeconcern.W(5))
    54  			wc2 := writeconcern.New(writeconcern.W(10))
    55  			rcLocal := readconcern.Local()
    56  			rcMajority := readconcern.Majority()
    57  			reg := bsoncodec.NewRegistryBuilder().Build()
    58  
    59  			opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1).
    60  				SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg)
    61  			expected := &Database{
    62  				readPreference: rpSecondary,
    63  				readConcern:    rcMajority,
    64  				writeConcern:   wc2,
    65  				registry:       reg,
    66  			}
    67  			got := setupDb("foo", opts)
    68  			compareDbs(t, expected, got)
    69  		})
    70  		t.Run("inherit", func(t *testing.T) {
    71  			rpPrimary := readpref.Primary()
    72  			rcLocal := readconcern.Local()
    73  			wc1 := writeconcern.New(writeconcern.W(10))
    74  			reg := bsoncodec.NewRegistryBuilder().Build()
    75  
    76  			client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg))
    77  			got := client.Database("foo", options.Database().SetWriteConcern(wc1))
    78  			expected := &Database{
    79  				readPreference: rpPrimary,
    80  				readConcern:    rcLocal,
    81  				writeConcern:   wc1,
    82  				registry:       reg,
    83  			}
    84  			compareDbs(t, expected, got)
    85  		})
    86  	})
    87  	t.Run("replace topology error", func(t *testing.T) {
    88  		db := setupDb("foo")
    89  		err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
    90  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    91  
    92  		err = db.Drop(bgCtx)
    93  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    94  
    95  		_, err = db.ListCollections(bgCtx, bson.D{})
    96  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    97  	})
    98  	t.Run("TransientTransactionError label", func(t *testing.T) {
    99  		client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
   100  		err := client.Connect(bgCtx)
   101  		defer func() { _ = client.Disconnect(bgCtx) }()
   102  		assert.Nil(t, err, "expected nil, got %v", err)
   103  
   104  		t.Run("negative case of non-transaction", func(t *testing.T) {
   105  			var sse topology.ServerSelectionError
   106  			var le LabeledError
   107  
   108  			err := client.Ping(bgCtx, nil)
   109  			assert.NotNil(t, err, "expected error, got nil")
   110  			assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
   111  			if errors.As(err, &le) {
   112  				assert.False(t, le.HasErrorLabel("TransientTransactionError"), `expected error not to include the "TransientTransactionError" label`)
   113  			}
   114  		})
   115  
   116  		t.Run("positive case of transaction", func(t *testing.T) {
   117  			var sse topology.ServerSelectionError
   118  			var le LabeledError
   119  
   120  			sess, err := client.StartSession()
   121  			assert.Nil(t, err, "expected nil, got %v", err)
   122  			defer sess.EndSession(bgCtx)
   123  
   124  			sessCtx := NewSessionContext(bgCtx, sess)
   125  			err = sess.StartTransaction()
   126  			assert.Nil(t, err, "expected nil, got %v", err)
   127  
   128  			err = client.Ping(sessCtx, nil)
   129  			assert.NotNil(t, err, "expected error, got nil")
   130  			assert.True(t, errors.As(err, &sse), `expected error to be a "topology.ServerSelectionError"`)
   131  			assert.True(t, errors.As(err, &le), `expected error to implement the "LabeledError" interface`)
   132  			assert.True(t, le.HasErrorLabel("TransientTransactionError"), `expected error to include the "TransientTransactionError" label`)
   133  		})
   134  	})
   135  	t.Run("nil document error", func(t *testing.T) {
   136  		db := setupDb("foo")
   137  
   138  		err := db.RunCommand(bgCtx, nil).Err()
   139  		assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
   140  
   141  		_, err = db.Watch(context.Background(), nil)
   142  		watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
   143  		assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err)
   144  
   145  		_, err = db.ListCollections(context.Background(), nil)
   146  		assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
   147  
   148  		_, err = db.ListCollectionNames(context.Background(), nil)
   149  		assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
   150  	})
   151  }
   152  

View as plain text