1
2
3
4
5
6
7 package connstring_test
8
9 import (
10 "encoding/json"
11 "fmt"
12 "io/ioutil"
13 "math"
14 "path"
15 "strings"
16 "testing"
17 "time"
18
19 "go.mongodb.org/mongo-driver/internal/require"
20 "go.mongodb.org/mongo-driver/internal/spectest"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
22 )
23
24 type host struct {
25 Type string
26 Host string
27 Port json.Number
28 }
29
30 type auth struct {
31 Username string
32 Password *string
33 DB string
34 }
35
36 type testCase struct {
37 Description string
38 URI string
39 Valid bool
40 Warning bool
41 Hosts []host
42 Auth *auth
43 Options map[string]interface{}
44 }
45
46 type testContainer struct {
47 Tests []testCase
48 }
49
50 const connstringTestsDir = "../../../../testdata/connection-string/"
51 const urioptionsTestDir = "../../../../testdata/uri-options/"
52
53 func (h *host) toString() string {
54 switch h.Type {
55 case "unix":
56 return h.Host
57 case "ip_literal":
58 if len(h.Port) == 0 {
59 return "[" + h.Host + "]"
60 }
61 return "[" + h.Host + "]" + ":" + string(h.Port)
62 case "ipv4":
63 fallthrough
64 case "hostname":
65 if len(h.Port) == 0 {
66 return h.Host
67 }
68 return h.Host + ":" + string(h.Port)
69 }
70
71 return ""
72 }
73
74 func hostsToStrings(hosts []host) []string {
75 out := make([]string, len(hosts))
76
77 for i, host := range hosts {
78 out[i] = host.toString()
79 }
80
81 return out
82 }
83
84 func runTestsInFile(t *testing.T, dirname string, filename string, warningsError bool) {
85 filepath := path.Join(dirname, filename)
86 content, err := ioutil.ReadFile(filepath)
87 require.NoError(t, err)
88
89 var container testContainer
90 require.NoError(t, json.Unmarshal(content, &container))
91
92
93 filename = filename[:len(filename)-5]
94
95 for _, testCase := range container.Tests {
96 runTest(t, filename, testCase, warningsError)
97 }
98 }
99
100 var skipDescriptions = map[string]struct{}{
101 "Valid options specific to single-threaded drivers are parsed correctly": {},
102 }
103
104 var skipKeywords = []string{
105 "tlsAllowInvalidHostnames",
106 "tlsAllowInvalidCertificates",
107 "tlsDisableCertificateRevocationCheck",
108 "serverSelectionTryOnce",
109 }
110
111 func runTest(t *testing.T, filename string, test testCase, warningsError bool) {
112 t.Run(filename+"/"+test.Description, func(t *testing.T) {
113 if _, skip := skipDescriptions[test.Description]; skip {
114 t.Skip()
115 }
116 for _, keyword := range skipKeywords {
117 if strings.Contains(test.Description, keyword) {
118 t.Skipf("skipping because keyword %s", keyword)
119 }
120 }
121
122 cs, err := connstring.ParseAndValidate(test.URI)
123
124
125
126
127
128
129 if test.Valid && !(test.Warning && warningsError) {
130 require.NoError(t, err)
131 } else {
132 require.Error(t, err)
133 return
134 }
135
136 require.Equal(t, test.URI, cs.Original)
137
138 if test.Hosts != nil {
139 require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)
140 }
141
142 if test.Auth != nil {
143 require.Equal(t, test.Auth.Username, cs.Username)
144
145 if test.Auth.Password == nil {
146 require.False(t, cs.PasswordSet)
147 } else {
148 require.True(t, cs.PasswordSet)
149 require.Equal(t, *test.Auth.Password, cs.Password)
150 }
151
152 if test.Auth.DB != cs.Database {
153 require.Equal(t, test.Auth.DB, cs.AuthSource)
154 } else {
155 require.Equal(t, test.Auth.DB, cs.Database)
156 }
157 }
158
159
160 verifyConnStringOptions(t, cs, test.Options)
161
162
163
164 var ok bool
165
166 _, ok = test.Options["maxpoolsize"]
167 require.Equal(t, ok, cs.MaxPoolSizeSet)
168 })
169 }
170
171
172 func TestConnStringSpec(t *testing.T) {
173 for _, file := range spectest.FindJSONFilesInDir(t, connstringTestsDir) {
174 runTestsInFile(t, connstringTestsDir, file, false)
175 }
176 }
177
178 func TestURIOptionsSpec(t *testing.T) {
179 for _, file := range spectest.FindJSONFilesInDir(t, urioptionsTestDir) {
180 runTestsInFile(t, urioptionsTestDir, file, true)
181 }
182 }
183
184
185 func verifyConnStringOptions(t *testing.T, cs *connstring.ConnString, options map[string]interface{}) {
186
187 for key, value := range options {
188
189 key = strings.ToLower(key)
190 switch key {
191 case "appname":
192 require.Equal(t, value, cs.AppName)
193 case "authsource":
194 require.Equal(t, value, cs.AuthSource)
195 case "authmechanism":
196 require.Equal(t, value, cs.AuthMechanism)
197 case "authmechanismproperties":
198 convertedMap := value.(map[string]interface{})
199 require.Equal(t,
200 mapInterfaceToString(convertedMap),
201 cs.AuthMechanismProperties)
202 case "compressors":
203 require.Equal(t, convertToStringSlice(value), cs.Compressors)
204 case "connecttimeoutms":
205 require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
206 case "directconnection":
207 require.True(t, cs.DirectConnectionSet)
208 require.Equal(t, value, cs.DirectConnection)
209 case "heartbeatfrequencyms":
210 require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
211 case "journal":
212 require.True(t, cs.JSet)
213 require.Equal(t, value, cs.J)
214 case "loadbalanced":
215 require.True(t, cs.LoadBalancedSet)
216 require.Equal(t, value, cs.LoadBalanced)
217 case "localthresholdms":
218 require.True(t, cs.LocalThresholdSet)
219 require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
220 case "maxidletimems":
221 require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
222 case "maxpoolsize":
223 require.True(t, cs.MaxPoolSizeSet)
224 require.Equal(t, value, cs.MaxPoolSize)
225 case "maxstalenessseconds":
226 require.True(t, cs.MaxStalenessSet)
227 require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
228 case "minpoolsize":
229 require.True(t, cs.MinPoolSizeSet)
230 require.Equal(t, value, int64(cs.MinPoolSize))
231 case "readpreference":
232 require.Equal(t, value, cs.ReadPreference)
233 case "readpreferencetags":
234 sm, ok := value.([]interface{})
235 require.True(t, ok)
236 tags := make([]map[string]string, 0, len(sm))
237 for _, i := range sm {
238 m, ok := i.(map[string]interface{})
239 require.True(t, ok)
240 tags = append(tags, mapInterfaceToString(m))
241 }
242 require.Equal(t, tags, cs.ReadPreferenceTagSets)
243 case "readconcernlevel":
244 require.Equal(t, value, cs.ReadConcernLevel)
245 case "replicaset":
246 require.Equal(t, value, cs.ReplicaSet)
247 case "retrywrites":
248 require.True(t, cs.RetryWritesSet)
249 require.Equal(t, value, cs.RetryWrites)
250 case "serverselectiontimeoutms":
251 require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
252 case "srvmaxhosts":
253 require.Equal(t, value, float64(cs.SRVMaxHosts))
254 case "srvservicename":
255 require.Equal(t, value, cs.SRVServiceName)
256 case "ssl", "tls":
257 require.Equal(t, value, cs.SSL)
258 case "sockettimeoutms":
259 require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
260 case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
261 require.True(t, cs.SSLInsecureSet)
262 require.Equal(t, value, cs.SSLInsecure)
263 case "tlscafile":
264 require.True(t, cs.SSLCaFileSet)
265 require.Equal(t, value, cs.SSLCaFile)
266 case "tlscertificatekeyfile":
267 require.True(t, cs.SSLClientCertificateKeyFileSet)
268 require.Equal(t, value, cs.SSLClientCertificateKeyFile)
269 case "tlscertificatekeyfilepassword":
270 require.True(t, cs.SSLClientCertificateKeyPasswordSet)
271 require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
272 case "w":
273 if cs.WNumberSet {
274 valueInt := getIntFromInterface(value)
275 require.NotNil(t, valueInt)
276 require.Equal(t, *valueInt, int64(cs.WNumber))
277 } else {
278 require.Equal(t, value, cs.WString)
279 }
280 case "wtimeoutms":
281 require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
282 case "waitqueuetimeoutms":
283 case "zlibcompressionlevel":
284 require.Equal(t, value, float64(cs.ZlibLevel))
285 case "zstdcompressionlevel":
286 require.Equal(t, value, float64(cs.ZstdLevel))
287 case "tlsdisableocspendpointcheck":
288 require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
289 case "servermonitoringmode":
290 require.Equal(t, value, cs.ServerMonitoringMode)
291 default:
292 opt, ok := cs.UnknownOptions[key]
293 require.True(t, ok)
294 require.Contains(t, opt, fmt.Sprint(value))
295 }
296 }
297 }
298
299
300 func mapInterfaceToString(m map[string]interface{}) map[string]string {
301 out := make(map[string]string)
302
303 for key, value := range m {
304 out[key] = fmt.Sprint(value)
305 }
306
307 return out
308 }
309
310
311
312
313 func getIntFromInterface(i interface{}) *int64 {
314 var out int64
315
316 switch v := i.(type) {
317 case int:
318 out = int64(v)
319 case int32:
320 out = int64(v)
321 case int64:
322 out = v
323 case float32:
324 f := float64(v)
325 if math.Floor(f) != f || f > float64(math.MaxInt64) {
326 break
327 }
328
329 out = int64(f)
330
331 case float64:
332 if math.Floor(v) != v || v > float64(math.MaxInt64) {
333 break
334 }
335
336 out = int64(v)
337 default:
338 return nil
339 }
340
341 return &out
342 }
343
344 func convertToStringSlice(i interface{}) []string {
345 s, ok := i.([]interface{})
346 if !ok {
347 return nil
348 }
349 ret := make([]string, 0, len(s))
350 for _, v := range s {
351 str, ok := v.(string)
352 if !ok {
353 continue
354 }
355 ret = append(ret, str)
356 }
357 return ret
358 }
359
View as plain text