...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring_spec_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/connstring

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package connstring_test
     8  
     9  import (
    10  	"encoding/json"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"math"
    14  	"path"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	"go.mongodb.org/mongo-driver/internal/require"
    20  	"go.mongodb.org/mongo-driver/internal/spectest"
    21  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    22  )
    23  
    24  type host struct {
    25  	Type string
    26  	Host string
    27  	Port json.Number
    28  }
    29  
    30  type auth struct {
    31  	Username string
    32  	Password *string
    33  	DB       string
    34  }
    35  
    36  type testCase struct {
    37  	Description string
    38  	URI         string
    39  	Valid       bool
    40  	Warning     bool
    41  	Hosts       []host
    42  	Auth        *auth
    43  	Options     map[string]interface{}
    44  }
    45  
    46  type testContainer struct {
    47  	Tests []testCase
    48  }
    49  
    50  const connstringTestsDir = "../../../../testdata/connection-string/"
    51  const urioptionsTestDir = "../../../../testdata/uri-options/"
    52  
    53  func (h *host) toString() string {
    54  	switch h.Type {
    55  	case "unix":
    56  		return h.Host
    57  	case "ip_literal":
    58  		if len(h.Port) == 0 {
    59  			return "[" + h.Host + "]"
    60  		}
    61  		return "[" + h.Host + "]" + ":" + string(h.Port)
    62  	case "ipv4":
    63  		fallthrough
    64  	case "hostname":
    65  		if len(h.Port) == 0 {
    66  			return h.Host
    67  		}
    68  		return h.Host + ":" + string(h.Port)
    69  	}
    70  
    71  	return ""
    72  }
    73  
    74  func hostsToStrings(hosts []host) []string {
    75  	out := make([]string, len(hosts))
    76  
    77  	for i, host := range hosts {
    78  		out[i] = host.toString()
    79  	}
    80  
    81  	return out
    82  }
    83  
    84  func runTestsInFile(t *testing.T, dirname string, filename string, warningsError bool) {
    85  	filepath := path.Join(dirname, filename)
    86  	content, err := ioutil.ReadFile(filepath)
    87  	require.NoError(t, err)
    88  
    89  	var container testContainer
    90  	require.NoError(t, json.Unmarshal(content, &container))
    91  
    92  	// Remove ".json" from filename.
    93  	filename = filename[:len(filename)-5]
    94  
    95  	for _, testCase := range container.Tests {
    96  		runTest(t, filename, testCase, warningsError)
    97  	}
    98  }
    99  
   100  var skipDescriptions = map[string]struct{}{
   101  	"Valid options specific to single-threaded drivers are parsed correctly": {},
   102  }
   103  
   104  var skipKeywords = []string{
   105  	"tlsAllowInvalidHostnames",
   106  	"tlsAllowInvalidCertificates",
   107  	"tlsDisableCertificateRevocationCheck",
   108  	"serverSelectionTryOnce",
   109  }
   110  
   111  func runTest(t *testing.T, filename string, test testCase, warningsError bool) {
   112  	t.Run(filename+"/"+test.Description, func(t *testing.T) {
   113  		if _, skip := skipDescriptions[test.Description]; skip {
   114  			t.Skip()
   115  		}
   116  		for _, keyword := range skipKeywords {
   117  			if strings.Contains(test.Description, keyword) {
   118  				t.Skipf("skipping because keyword %s", keyword)
   119  			}
   120  		}
   121  
   122  		cs, err := connstring.ParseAndValidate(test.URI)
   123  		// Since we don't have warnings in Go, we return warnings as errors.
   124  		//
   125  		// This is a bit unfortunate, but since we do raise warnings as errors with the newer
   126  		// URI options, but don't with some of the older things, we do a switch on the filename
   127  		// here. We are trying to not break existing user applications that have unrecognized
   128  		// options.
   129  		if test.Valid && !(test.Warning && warningsError) {
   130  			require.NoError(t, err)
   131  		} else {
   132  			require.Error(t, err)
   133  			return
   134  		}
   135  
   136  		require.Equal(t, test.URI, cs.Original)
   137  
   138  		if test.Hosts != nil {
   139  			require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)
   140  		}
   141  
   142  		if test.Auth != nil {
   143  			require.Equal(t, test.Auth.Username, cs.Username)
   144  
   145  			if test.Auth.Password == nil {
   146  				require.False(t, cs.PasswordSet)
   147  			} else {
   148  				require.True(t, cs.PasswordSet)
   149  				require.Equal(t, *test.Auth.Password, cs.Password)
   150  			}
   151  
   152  			if test.Auth.DB != cs.Database {
   153  				require.Equal(t, test.Auth.DB, cs.AuthSource)
   154  			} else {
   155  				require.Equal(t, test.Auth.DB, cs.Database)
   156  			}
   157  		}
   158  
   159  		// Check that all options are present.
   160  		verifyConnStringOptions(t, cs, test.Options)
   161  
   162  		// Check that non-present options are unset. This will be redundant with the above checks
   163  		// for options that are present.
   164  		var ok bool
   165  
   166  		_, ok = test.Options["maxpoolsize"]
   167  		require.Equal(t, ok, cs.MaxPoolSizeSet)
   168  	})
   169  }
   170  
   171  // Test case for all connection string spec tests.
   172  func TestConnStringSpec(t *testing.T) {
   173  	for _, file := range spectest.FindJSONFilesInDir(t, connstringTestsDir) {
   174  		runTestsInFile(t, connstringTestsDir, file, false)
   175  	}
   176  }
   177  
   178  func TestURIOptionsSpec(t *testing.T) {
   179  	for _, file := range spectest.FindJSONFilesInDir(t, urioptionsTestDir) {
   180  		runTestsInFile(t, urioptionsTestDir, file, true)
   181  	}
   182  }
   183  
   184  // verifyConnStringOptions verifies the options on the connection string.
   185  func verifyConnStringOptions(t *testing.T, cs *connstring.ConnString, options map[string]interface{}) {
   186  	// Check that all options are present.
   187  	for key, value := range options {
   188  
   189  		key = strings.ToLower(key)
   190  		switch key {
   191  		case "appname":
   192  			require.Equal(t, value, cs.AppName)
   193  		case "authsource":
   194  			require.Equal(t, value, cs.AuthSource)
   195  		case "authmechanism":
   196  			require.Equal(t, value, cs.AuthMechanism)
   197  		case "authmechanismproperties":
   198  			convertedMap := value.(map[string]interface{})
   199  			require.Equal(t,
   200  				mapInterfaceToString(convertedMap),
   201  				cs.AuthMechanismProperties)
   202  		case "compressors":
   203  			require.Equal(t, convertToStringSlice(value), cs.Compressors)
   204  		case "connecttimeoutms":
   205  			require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
   206  		case "directconnection":
   207  			require.True(t, cs.DirectConnectionSet)
   208  			require.Equal(t, value, cs.DirectConnection)
   209  		case "heartbeatfrequencyms":
   210  			require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
   211  		case "journal":
   212  			require.True(t, cs.JSet)
   213  			require.Equal(t, value, cs.J)
   214  		case "loadbalanced":
   215  			require.True(t, cs.LoadBalancedSet)
   216  			require.Equal(t, value, cs.LoadBalanced)
   217  		case "localthresholdms":
   218  			require.True(t, cs.LocalThresholdSet)
   219  			require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
   220  		case "maxidletimems":
   221  			require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
   222  		case "maxpoolsize":
   223  			require.True(t, cs.MaxPoolSizeSet)
   224  			require.Equal(t, value, cs.MaxPoolSize)
   225  		case "maxstalenessseconds":
   226  			require.True(t, cs.MaxStalenessSet)
   227  			require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
   228  		case "minpoolsize":
   229  			require.True(t, cs.MinPoolSizeSet)
   230  			require.Equal(t, value, int64(cs.MinPoolSize))
   231  		case "readpreference":
   232  			require.Equal(t, value, cs.ReadPreference)
   233  		case "readpreferencetags":
   234  			sm, ok := value.([]interface{})
   235  			require.True(t, ok)
   236  			tags := make([]map[string]string, 0, len(sm))
   237  			for _, i := range sm {
   238  				m, ok := i.(map[string]interface{})
   239  				require.True(t, ok)
   240  				tags = append(tags, mapInterfaceToString(m))
   241  			}
   242  			require.Equal(t, tags, cs.ReadPreferenceTagSets)
   243  		case "readconcernlevel":
   244  			require.Equal(t, value, cs.ReadConcernLevel)
   245  		case "replicaset":
   246  			require.Equal(t, value, cs.ReplicaSet)
   247  		case "retrywrites":
   248  			require.True(t, cs.RetryWritesSet)
   249  			require.Equal(t, value, cs.RetryWrites)
   250  		case "serverselectiontimeoutms":
   251  			require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
   252  		case "srvmaxhosts":
   253  			require.Equal(t, value, float64(cs.SRVMaxHosts))
   254  		case "srvservicename":
   255  			require.Equal(t, value, cs.SRVServiceName)
   256  		case "ssl", "tls":
   257  			require.Equal(t, value, cs.SSL)
   258  		case "sockettimeoutms":
   259  			require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
   260  		case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
   261  			require.True(t, cs.SSLInsecureSet)
   262  			require.Equal(t, value, cs.SSLInsecure)
   263  		case "tlscafile":
   264  			require.True(t, cs.SSLCaFileSet)
   265  			require.Equal(t, value, cs.SSLCaFile)
   266  		case "tlscertificatekeyfile":
   267  			require.True(t, cs.SSLClientCertificateKeyFileSet)
   268  			require.Equal(t, value, cs.SSLClientCertificateKeyFile)
   269  		case "tlscertificatekeyfilepassword":
   270  			require.True(t, cs.SSLClientCertificateKeyPasswordSet)
   271  			require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
   272  		case "w":
   273  			if cs.WNumberSet {
   274  				valueInt := getIntFromInterface(value)
   275  				require.NotNil(t, valueInt)
   276  				require.Equal(t, *valueInt, int64(cs.WNumber))
   277  			} else {
   278  				require.Equal(t, value, cs.WString)
   279  			}
   280  		case "wtimeoutms":
   281  			require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
   282  		case "waitqueuetimeoutms":
   283  		case "zlibcompressionlevel":
   284  			require.Equal(t, value, float64(cs.ZlibLevel))
   285  		case "zstdcompressionlevel":
   286  			require.Equal(t, value, float64(cs.ZstdLevel))
   287  		case "tlsdisableocspendpointcheck":
   288  			require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
   289  		case "servermonitoringmode":
   290  			require.Equal(t, value, cs.ServerMonitoringMode)
   291  		default:
   292  			opt, ok := cs.UnknownOptions[key]
   293  			require.True(t, ok)
   294  			require.Contains(t, opt, fmt.Sprint(value))
   295  		}
   296  	}
   297  }
   298  
   299  // Convert each interface{} value in the map to a string.
   300  func mapInterfaceToString(m map[string]interface{}) map[string]string {
   301  	out := make(map[string]string)
   302  
   303  	for key, value := range m {
   304  		out[key] = fmt.Sprint(value)
   305  	}
   306  
   307  	return out
   308  }
   309  
   310  // getIntFromInterface attempts to convert an empty interface value to an integer.
   311  //
   312  // Returns nil if it is not possible.
   313  func getIntFromInterface(i interface{}) *int64 {
   314  	var out int64
   315  
   316  	switch v := i.(type) {
   317  	case int:
   318  		out = int64(v)
   319  	case int32:
   320  		out = int64(v)
   321  	case int64:
   322  		out = v
   323  	case float32:
   324  		f := float64(v)
   325  		if math.Floor(f) != f || f > float64(math.MaxInt64) {
   326  			break
   327  		}
   328  
   329  		out = int64(f)
   330  
   331  	case float64:
   332  		if math.Floor(v) != v || v > float64(math.MaxInt64) {
   333  			break
   334  		}
   335  
   336  		out = int64(v)
   337  	default:
   338  		return nil
   339  	}
   340  
   341  	return &out
   342  }
   343  
   344  func convertToStringSlice(i interface{}) []string {
   345  	s, ok := i.([]interface{})
   346  	if !ok {
   347  		return nil
   348  	}
   349  	ret := make([]string, 0, len(s))
   350  	for _, v := range s {
   351  		str, ok := v.(string)
   352  		if !ok {
   353  			continue
   354  		}
   355  		ret = append(ret, str)
   356  	}
   357  	return ret
   358  }
   359  

View as plain text