1
2
3
4
5
6
7 package connstring
8
9 import (
10 "errors"
11 "fmt"
12 "net"
13 "net/url"
14 "strconv"
15 "strings"
16 "time"
17
18 "go.mongodb.org/mongo-driver/internal/randutil"
19 "go.mongodb.org/mongo-driver/mongo/writeconcern"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/dns"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
22 )
23
24 const (
25
26
27
28
29
30
31 ServerMonitoringModeAuto = "auto"
32
33
34
35
36 ServerMonitoringModePoll = "poll"
37
38
39
40
41 ServerMonitoringModeStream = "stream"
42 )
43
44 var (
45
46
47 ErrLoadBalancedWithMultipleHosts = errors.New(
48 "loadBalanced cannot be set to true if multiple hosts are specified")
49
50
51
52 ErrLoadBalancedWithReplicaSet = errors.New(
53 "loadBalanced cannot be set to true if a replica set name is specified")
54
55
56
57 ErrLoadBalancedWithDirectConnection = errors.New(
58 "loadBalanced cannot be set to true if the direct connection option is specified")
59
60
61
62 ErrSRVMaxHostsWithReplicaSet = errors.New(
63 "srvMaxHosts cannot be a positive value if a replica set name is specified")
64
65
66
67 ErrSRVMaxHostsWithLoadBalanced = errors.New(
68 "srvMaxHosts cannot be a positive value if loadBalanced is set to true")
69 )
70
71
72 var random = randutil.NewLockedRand()
73
74
75
76 func ParseAndValidate(s string) (*ConnString, error) {
77 connStr, err := Parse(s)
78 if err != nil {
79 return nil, err
80 }
81 err = connStr.Validate()
82 if err != nil {
83 return nil, fmt.Errorf("error validating uri: %w", err)
84 }
85 return connStr, nil
86 }
87
88
89
90
91 func Parse(s string) (*ConnString, error) {
92 p := parser{dnsResolver: dns.DefaultResolver}
93 connStr, err := p.parse(s)
94 if err != nil {
95 return nil, fmt.Errorf("error parsing uri: %w", err)
96 }
97 return connStr, err
98 }
99
100
101 type ConnString struct {
102 Original string
103 AppName string
104 AuthMechanism string
105 AuthMechanismProperties map[string]string
106 AuthMechanismPropertiesSet bool
107 AuthSource string
108 AuthSourceSet bool
109 Compressors []string
110 Connect ConnectMode
111 ConnectSet bool
112 DirectConnection bool
113 DirectConnectionSet bool
114 ConnectTimeout time.Duration
115 ConnectTimeoutSet bool
116 Database string
117 HeartbeatInterval time.Duration
118 HeartbeatIntervalSet bool
119 Hosts []string
120 J bool
121 JSet bool
122 LoadBalanced bool
123 LoadBalancedSet bool
124 LocalThreshold time.Duration
125 LocalThresholdSet bool
126 MaxConnIdleTime time.Duration
127 MaxConnIdleTimeSet bool
128 MaxPoolSize uint64
129 MaxPoolSizeSet bool
130 MinPoolSize uint64
131 MinPoolSizeSet bool
132 MaxConnecting uint64
133 MaxConnectingSet bool
134 Password string
135 PasswordSet bool
136 RawHosts []string
137 ReadConcernLevel string
138 ReadPreference string
139 ReadPreferenceTagSets []map[string]string
140 RetryWrites bool
141 RetryWritesSet bool
142 RetryReads bool
143 RetryReadsSet bool
144 MaxStaleness time.Duration
145 MaxStalenessSet bool
146 ReplicaSet string
147 Scheme string
148 ServerMonitoringMode string
149 ServerSelectionTimeout time.Duration
150 ServerSelectionTimeoutSet bool
151 SocketTimeout time.Duration
152 SocketTimeoutSet bool
153 SRVMaxHosts int
154 SRVServiceName string
155 SSL bool
156 SSLSet bool
157 SSLClientCertificateKeyFile string
158 SSLClientCertificateKeyFileSet bool
159 SSLClientCertificateKeyPassword func() string
160 SSLClientCertificateKeyPasswordSet bool
161 SSLCertificateFile string
162 SSLCertificateFileSet bool
163 SSLPrivateKeyFile string
164 SSLPrivateKeyFileSet bool
165 SSLInsecure bool
166 SSLInsecureSet bool
167 SSLCaFile string
168 SSLCaFileSet bool
169 SSLDisableOCSPEndpointCheck bool
170 SSLDisableOCSPEndpointCheckSet bool
171 Timeout time.Duration
172 TimeoutSet bool
173 WString string
174 WNumber int
175 WNumberSet bool
176 Username string
177 UsernameSet bool
178 ZlibLevel int
179 ZlibLevelSet bool
180 ZstdLevel int
181 ZstdLevelSet bool
182
183 WTimeout time.Duration
184 WTimeoutSet bool
185 WTimeoutSetFromOption bool
186
187 Options map[string][]string
188 UnknownOptions map[string][]string
189 }
190
191 func (u *ConnString) String() string {
192 return u.Original
193 }
194
195
196
197 func (u *ConnString) HasAuthParameters() bool {
198
199
200 return u.AuthMechanism != "" || u.AuthMechanismProperties != nil || u.UsernameSet || u.PasswordSet
201 }
202
203
204 func (u *ConnString) Validate() error {
205 var err error
206
207 if err = u.validateAuth(); err != nil {
208 return err
209 }
210
211 if err = u.validateSSL(); err != nil {
212 return err
213 }
214
215
216 if u.WNumberSet && u.WNumber == 0 && u.JSet && u.J {
217 return writeconcern.ErrInconsistent
218 }
219
220
221 if (u.ConnectSet && u.Connect == SingleConnect) ||
222 (u.DirectConnectionSet && u.DirectConnection) {
223 if len(u.Hosts) > 1 {
224 return errors.New("a direct connection cannot be made if multiple hosts are specified")
225 }
226 if u.Scheme == SchemeMongoDBSRV {
227 return errors.New("a direct connection cannot be made if an SRV URI is used")
228 }
229 if u.LoadBalancedSet && u.LoadBalanced {
230 return ErrLoadBalancedWithDirectConnection
231 }
232 }
233
234
235 if u.LoadBalancedSet && u.LoadBalanced {
236 if len(u.Hosts) > 1 {
237 return ErrLoadBalancedWithMultipleHosts
238 }
239 if u.ReplicaSet != "" {
240 return ErrLoadBalancedWithReplicaSet
241 }
242 }
243
244
245 if u.SRVMaxHosts > 0 {
246 if u.ReplicaSet != "" {
247 return ErrSRVMaxHostsWithReplicaSet
248 }
249 if u.LoadBalanced {
250 return ErrSRVMaxHostsWithLoadBalanced
251 }
252 }
253
254 return nil
255 }
256
257 func (u *ConnString) setDefaultAuthParams(dbName string) error {
258
259
260 if u.AuthSourceSet && u.AuthSource == "" {
261 return errors.New("authSource must be non-empty when supplied in a URI")
262 }
263
264 switch strings.ToLower(u.AuthMechanism) {
265 case "plain":
266 if u.AuthSource == "" {
267 u.AuthSource = dbName
268 if u.AuthSource == "" {
269 u.AuthSource = "$external"
270 }
271 }
272 case "gssapi":
273 if u.AuthMechanismProperties == nil {
274 u.AuthMechanismProperties = map[string]string{
275 "SERVICE_NAME": "mongodb",
276 }
277 } else if v, ok := u.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
278 u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
279 }
280 fallthrough
281 case "mongodb-aws", "mongodb-x509":
282 if u.AuthSource == "" {
283 u.AuthSource = "$external"
284 } else if u.AuthSource != "$external" {
285 return fmt.Errorf("auth source must be $external")
286 }
287 case "mongodb-cr":
288 fallthrough
289 case "scram-sha-1":
290 fallthrough
291 case "scram-sha-256":
292 if u.AuthSource == "" {
293 u.AuthSource = dbName
294 if u.AuthSource == "" {
295 u.AuthSource = "admin"
296 }
297 }
298 case "":
299
300 if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) {
301 u.AuthSource = dbName
302 if u.AuthSource == "" {
303 u.AuthSource = "admin"
304 }
305 }
306 default:
307 return fmt.Errorf("invalid auth mechanism")
308 }
309 return nil
310 }
311
312 func (u *ConnString) addOptions(connectionArgPairs []string) error {
313 var tlsssl *bool
314 for _, pair := range connectionArgPairs {
315 kv := strings.SplitN(pair, "=", 2)
316 if len(kv) != 2 || kv[0] == "" {
317 return fmt.Errorf("invalid option")
318 }
319
320 key, err := url.QueryUnescape(kv[0])
321 if err != nil {
322 return fmt.Errorf("invalid option key %q: %w", kv[0], err)
323 }
324
325 value, err := url.QueryUnescape(kv[1])
326 if err != nil {
327 return fmt.Errorf("invalid option value %q: %w", kv[1], err)
328 }
329
330 lowerKey := strings.ToLower(key)
331 switch lowerKey {
332 case "appname":
333 u.AppName = value
334 case "authmechanism":
335 u.AuthMechanism = value
336 case "authmechanismproperties":
337 u.AuthMechanismProperties = make(map[string]string)
338 pairs := strings.Split(value, ",")
339 for _, pair := range pairs {
340 kv := strings.SplitN(pair, ":", 2)
341 if len(kv) != 2 || kv[0] == "" {
342 return fmt.Errorf("invalid authMechanism property")
343 }
344 u.AuthMechanismProperties[kv[0]] = kv[1]
345 }
346 u.AuthMechanismPropertiesSet = true
347 case "authsource":
348 u.AuthSource = value
349 u.AuthSourceSet = true
350 case "compressors":
351 compressors := strings.Split(value, ",")
352 if len(compressors) < 1 {
353 return fmt.Errorf("must have at least 1 compressor")
354 }
355 u.Compressors = compressors
356 case "connect":
357 switch strings.ToLower(value) {
358 case "automatic":
359 case "direct":
360 u.Connect = SingleConnect
361 default:
362 return fmt.Errorf("invalid 'connect' value: %q", value)
363 }
364 if u.DirectConnectionSet {
365 expectedValue := u.Connect == SingleConnect
366 if u.DirectConnection != expectedValue {
367 return fmt.Errorf("options connect=%q and directConnection=%v conflict", value, u.DirectConnection)
368 }
369 }
370
371 u.ConnectSet = true
372 case "directconnection":
373 switch strings.ToLower(value) {
374 case "true":
375 u.DirectConnection = true
376 case "false":
377 default:
378 return fmt.Errorf("invalid 'directConnection' value: %q", value)
379 }
380
381 if u.ConnectSet {
382 expectedValue := AutoConnect
383 if u.DirectConnection {
384 expectedValue = SingleConnect
385 }
386
387 if u.Connect != expectedValue {
388 return fmt.Errorf("options connect=%q and directConnection=%q conflict", u.Connect, value)
389 }
390 }
391 u.DirectConnectionSet = true
392 case "connecttimeoutms":
393 n, err := strconv.Atoi(value)
394 if err != nil || n < 0 {
395 return fmt.Errorf("invalid value for %q: %q", key, value)
396 }
397 u.ConnectTimeout = time.Duration(n) * time.Millisecond
398 u.ConnectTimeoutSet = true
399 case "heartbeatintervalms", "heartbeatfrequencyms":
400 n, err := strconv.Atoi(value)
401 if err != nil || n < 0 {
402 return fmt.Errorf("invalid value for %q: %q", key, value)
403 }
404 u.HeartbeatInterval = time.Duration(n) * time.Millisecond
405 u.HeartbeatIntervalSet = true
406 case "journal":
407 switch value {
408 case "true":
409 u.J = true
410 case "false":
411 u.J = false
412 default:
413 return fmt.Errorf("invalid value for %q: %q", key, value)
414 }
415
416 u.JSet = true
417 case "loadbalanced":
418 switch value {
419 case "true":
420 u.LoadBalanced = true
421 case "false":
422 u.LoadBalanced = false
423 default:
424 return fmt.Errorf("invalid value for %q: %q", key, value)
425 }
426
427 u.LoadBalancedSet = true
428 case "localthresholdms":
429 n, err := strconv.Atoi(value)
430 if err != nil || n < 0 {
431 return fmt.Errorf("invalid value for %q: %q", key, value)
432 }
433 u.LocalThreshold = time.Duration(n) * time.Millisecond
434 u.LocalThresholdSet = true
435 case "maxidletimems":
436 n, err := strconv.Atoi(value)
437 if err != nil || n < 0 {
438 return fmt.Errorf("invalid value for %q: %q", key, value)
439 }
440 u.MaxConnIdleTime = time.Duration(n) * time.Millisecond
441 u.MaxConnIdleTimeSet = true
442 case "maxpoolsize":
443 n, err := strconv.Atoi(value)
444 if err != nil || n < 0 {
445 return fmt.Errorf("invalid value for %q: %q", key, value)
446 }
447 u.MaxPoolSize = uint64(n)
448 u.MaxPoolSizeSet = true
449 case "minpoolsize":
450 n, err := strconv.Atoi(value)
451 if err != nil || n < 0 {
452 return fmt.Errorf("invalid value for %q: %q", key, value)
453 }
454 u.MinPoolSize = uint64(n)
455 u.MinPoolSizeSet = true
456 case "maxconnecting":
457 n, err := strconv.Atoi(value)
458 if err != nil || n < 0 {
459 return fmt.Errorf("invalid value for %q: %q", key, value)
460 }
461 u.MaxConnecting = uint64(n)
462 u.MaxConnectingSet = true
463 case "readconcernlevel":
464 u.ReadConcernLevel = value
465 case "readpreference":
466 u.ReadPreference = value
467 case "readpreferencetags":
468 if value == "" {
469
470
471 u.ReadPreferenceTagSets = append(u.ReadPreferenceTagSets, map[string]string{})
472 break
473 }
474
475 tags := make(map[string]string)
476 items := strings.Split(value, ",")
477 for _, item := range items {
478 parts := strings.Split(item, ":")
479 if len(parts) != 2 {
480 return fmt.Errorf("invalid value for %q: %q", key, value)
481 }
482 tags[parts[0]] = parts[1]
483 }
484 u.ReadPreferenceTagSets = append(u.ReadPreferenceTagSets, tags)
485 case "maxstaleness", "maxstalenessseconds":
486 n, err := strconv.Atoi(value)
487 if err != nil || n < 0 {
488 return fmt.Errorf("invalid value for %q: %q", key, value)
489 }
490 u.MaxStaleness = time.Duration(n) * time.Second
491 u.MaxStalenessSet = true
492 case "replicaset":
493 u.ReplicaSet = value
494 case "retrywrites":
495 switch value {
496 case "true":
497 u.RetryWrites = true
498 case "false":
499 u.RetryWrites = false
500 default:
501 return fmt.Errorf("invalid value for %q: %q", key, value)
502 }
503
504 u.RetryWritesSet = true
505 case "retryreads":
506 switch value {
507 case "true":
508 u.RetryReads = true
509 case "false":
510 u.RetryReads = false
511 default:
512 return fmt.Errorf("invalid value for %q: %q", key, value)
513 }
514
515 u.RetryReadsSet = true
516 case "servermonitoringmode":
517 if !IsValidServerMonitoringMode(value) {
518 return fmt.Errorf("invalid value for %q: %q", key, value)
519 }
520
521 u.ServerMonitoringMode = value
522 case "serverselectiontimeoutms":
523 n, err := strconv.Atoi(value)
524 if err != nil || n < 0 {
525 return fmt.Errorf("invalid value for %q: %q", key, value)
526 }
527 u.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
528 u.ServerSelectionTimeoutSet = true
529 case "sockettimeoutms":
530 n, err := strconv.Atoi(value)
531 if err != nil || n < 0 {
532 return fmt.Errorf("invalid value for %q: %q", key, value)
533 }
534 u.SocketTimeout = time.Duration(n) * time.Millisecond
535 u.SocketTimeoutSet = true
536 case "srvmaxhosts":
537
538 if u.Scheme != SchemeMongoDBSRV {
539 return fmt.Errorf("cannot specify srvMaxHosts on non-SRV URI")
540 }
541
542 n, err := strconv.Atoi(value)
543 if err != nil || n < 0 {
544 return fmt.Errorf("invalid value for %q: %q", key, value)
545 }
546 u.SRVMaxHosts = n
547 case "srvservicename":
548
549 if u.Scheme != SchemeMongoDBSRV {
550 return fmt.Errorf("cannot specify srvServiceName on non-SRV URI")
551 }
552
553
554
555
556
557 if len(value) < 1 || len(value) > 62 {
558 return fmt.Errorf("srvServiceName value must be between 1 and 62 characters")
559 }
560 u.SRVServiceName = value
561 case "ssl", "tls":
562 switch value {
563 case "true":
564 u.SSL = true
565 case "false":
566 u.SSL = false
567 default:
568 return fmt.Errorf("invalid value for %q: %q", key, value)
569 }
570 if tlsssl == nil {
571 tlsssl = new(bool)
572 *tlsssl = u.SSL
573 } else if *tlsssl != u.SSL {
574 return errors.New("tls and ssl options, when both specified, must be equivalent")
575 }
576
577 u.SSLSet = true
578 case "sslclientcertificatekeyfile", "tlscertificatekeyfile":
579 u.SSL = true
580 u.SSLSet = true
581 u.SSLClientCertificateKeyFile = value
582 u.SSLClientCertificateKeyFileSet = true
583 case "sslclientcertificatekeypassword", "tlscertificatekeyfilepassword":
584 u.SSLClientCertificateKeyPassword = func() string { return value }
585 u.SSLClientCertificateKeyPasswordSet = true
586 case "tlscertificatefile":
587 u.SSL = true
588 u.SSLSet = true
589 u.SSLCertificateFile = value
590 u.SSLCertificateFileSet = true
591 case "tlsprivatekeyfile":
592 u.SSL = true
593 u.SSLSet = true
594 u.SSLPrivateKeyFile = value
595 u.SSLPrivateKeyFileSet = true
596 case "sslinsecure", "tlsinsecure":
597 switch value {
598 case "true":
599 u.SSLInsecure = true
600 case "false":
601 u.SSLInsecure = false
602 default:
603 return fmt.Errorf("invalid value for %q: %q", key, value)
604 }
605
606 u.SSLInsecureSet = true
607 case "sslcertificateauthorityfile", "tlscafile":
608 u.SSL = true
609 u.SSLSet = true
610 u.SSLCaFile = value
611 u.SSLCaFileSet = true
612 case "timeoutms":
613 n, err := strconv.Atoi(value)
614 if err != nil || n < 0 {
615 return fmt.Errorf("invalid value for %q: %q", key, value)
616 }
617 u.Timeout = time.Duration(n) * time.Millisecond
618 u.TimeoutSet = true
619 case "tlsdisableocspendpointcheck":
620 u.SSL = true
621 u.SSLSet = true
622
623 switch value {
624 case "true":
625 u.SSLDisableOCSPEndpointCheck = true
626 case "false":
627 u.SSLDisableOCSPEndpointCheck = false
628 default:
629 return fmt.Errorf("invalid value for %q: %q", key, value)
630 }
631 u.SSLDisableOCSPEndpointCheckSet = true
632 case "w":
633 if w, err := strconv.Atoi(value); err == nil {
634 if w < 0 {
635 return fmt.Errorf("invalid value for %q: %q", key, value)
636 }
637
638 u.WNumber = w
639 u.WNumberSet = true
640 u.WString = ""
641 break
642 }
643
644 u.WString = value
645 u.WNumberSet = false
646
647 case "wtimeoutms":
648 n, err := strconv.Atoi(value)
649 if err != nil || n < 0 {
650 return fmt.Errorf("invalid value for %q: %q", key, value)
651 }
652 u.WTimeout = time.Duration(n) * time.Millisecond
653 u.WTimeoutSet = true
654 case "wtimeout":
655
656 if u.WTimeoutSet {
657 break
658 }
659 n, err := strconv.Atoi(value)
660 if err != nil || n < 0 {
661 return fmt.Errorf("invalid value for %q: %q", key, value)
662 }
663 u.WTimeout = time.Duration(n) * time.Millisecond
664 case "zlibcompressionlevel":
665 level, err := strconv.Atoi(value)
666 if err != nil || (level < -1 || level > 9) {
667 return fmt.Errorf("invalid value for %q: %q", key, value)
668 }
669
670 if level == -1 {
671 level = wiremessage.DefaultZlibLevel
672 }
673 u.ZlibLevel = level
674 u.ZlibLevelSet = true
675 case "zstdcompressionlevel":
676 const maxZstdLevel = 22
677 level, err := strconv.Atoi(value)
678 if err != nil || (level < -1 || level > maxZstdLevel) {
679 return fmt.Errorf("invalid value for %q: %q", key, value)
680 }
681
682 if level == -1 {
683 level = wiremessage.DefaultZstdLevel
684 }
685 u.ZstdLevel = level
686 u.ZstdLevelSet = true
687 default:
688 if u.UnknownOptions == nil {
689 u.UnknownOptions = make(map[string][]string)
690 }
691 u.UnknownOptions[lowerKey] = append(u.UnknownOptions[lowerKey], value)
692 }
693
694 if u.Options == nil {
695 u.Options = make(map[string][]string)
696 }
697 u.Options[lowerKey] = append(u.Options[lowerKey], value)
698 }
699 return nil
700 }
701
702 func (u *ConnString) validateAuth() error {
703 switch strings.ToLower(u.AuthMechanism) {
704 case "mongodb-cr":
705 if u.Username == "" {
706 return fmt.Errorf("username required for MONGO-CR")
707 }
708 if u.Password == "" {
709 return fmt.Errorf("password required for MONGO-CR")
710 }
711 if u.AuthMechanismProperties != nil {
712 return fmt.Errorf("MONGO-CR cannot have mechanism properties")
713 }
714 case "mongodb-x509":
715 if u.Password != "" {
716 return fmt.Errorf("password cannot be specified for MONGO-X509")
717 }
718 if u.AuthMechanismProperties != nil {
719 return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
720 }
721 case "mongodb-aws":
722 if u.Username != "" && u.Password == "" {
723 return fmt.Errorf("username without password is invalid for MONGODB-AWS")
724 }
725 if u.Username == "" && u.Password != "" {
726 return fmt.Errorf("password without username is invalid for MONGODB-AWS")
727 }
728 var token bool
729 for k := range u.AuthMechanismProperties {
730 if k != "AWS_SESSION_TOKEN" {
731 return fmt.Errorf("invalid auth property for MONGODB-AWS")
732 }
733 token = true
734 }
735 if token && u.Username == "" && u.Password == "" {
736 return fmt.Errorf("token without username and password is invalid for MONGODB-AWS")
737 }
738 case "gssapi":
739 if u.Username == "" {
740 return fmt.Errorf("username required for GSSAPI")
741 }
742 for k := range u.AuthMechanismProperties {
743 if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" && k != "SERVICE_HOST" {
744 return fmt.Errorf("invalid auth property for GSSAPI")
745 }
746 }
747 case "plain":
748 if u.Username == "" {
749 return fmt.Errorf("username required for PLAIN")
750 }
751 if u.Password == "" {
752 return fmt.Errorf("password required for PLAIN")
753 }
754 if u.AuthMechanismProperties != nil {
755 return fmt.Errorf("PLAIN cannot have mechanism properties")
756 }
757 case "scram-sha-1":
758 if u.Username == "" {
759 return fmt.Errorf("username required for SCRAM-SHA-1")
760 }
761 if u.Password == "" {
762 return fmt.Errorf("password required for SCRAM-SHA-1")
763 }
764 if u.AuthMechanismProperties != nil {
765 return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
766 }
767 case "scram-sha-256":
768 if u.Username == "" {
769 return fmt.Errorf("username required for SCRAM-SHA-256")
770 }
771 if u.Password == "" {
772 return fmt.Errorf("password required for SCRAM-SHA-256")
773 }
774 if u.AuthMechanismProperties != nil {
775 return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
776 }
777 case "":
778 if u.UsernameSet && u.Username == "" {
779 return fmt.Errorf("username required if URI contains user info")
780 }
781 default:
782 return fmt.Errorf("invalid auth mechanism")
783 }
784 return nil
785 }
786
787 func (u *ConnString) validateSSL() error {
788 if !u.SSL {
789 return nil
790 }
791
792 if u.SSLClientCertificateKeyFileSet {
793 if u.SSLCertificateFileSet || u.SSLPrivateKeyFileSet {
794 return errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided " +
795 "along with tlsCertificateFile or tlsPrivateKeyFile")
796 }
797 return nil
798 }
799 if u.SSLCertificateFileSet && !u.SSLPrivateKeyFileSet {
800 return errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")
801 }
802 if u.SSLPrivateKeyFileSet && !u.SSLCertificateFileSet {
803 return errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")
804 }
805
806 if u.SSLInsecureSet && u.SSLDisableOCSPEndpointCheckSet {
807 return errors.New("the sslInsecure/tlsInsecure URI option cannot be provided along with " +
808 "tlsDisableOCSPEndpointCheck ")
809 }
810 return nil
811 }
812
813 func sanitizeHost(host string) (string, error) {
814 if host == "" {
815 return host, nil
816 }
817 unescaped, err := url.QueryUnescape(host)
818 if err != nil {
819 return "", fmt.Errorf("invalid host %q: %w", host, err)
820 }
821
822 _, port, err := net.SplitHostPort(unescaped)
823
824
825 if err != nil {
826 if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
827 return "", err
828 }
829 }
830
831 if port != "" {
832 d, err := strconv.Atoi(port)
833 if err != nil {
834 return "", fmt.Errorf("port must be an integer: %w", err)
835 }
836 if d <= 0 || d >= 65536 {
837 return "", fmt.Errorf("port must be in the range [1, 65535]")
838 }
839 }
840 return unescaped, nil
841 }
842
843
844
845 type ConnectMode uint8
846
847 var _ fmt.Stringer = ConnectMode(0)
848
849
850 const (
851 AutoConnect ConnectMode = iota
852 SingleConnect
853 )
854
855
856 func (c ConnectMode) String() string {
857 switch c {
858 case AutoConnect:
859 return "automatic"
860 case SingleConnect:
861 return "direct"
862 default:
863 return "unknown"
864 }
865 }
866
867
868 const (
869 SchemeMongoDB = "mongodb"
870 SchemeMongoDBSRV = "mongodb+srv"
871 )
872
873 type parser struct {
874 dnsResolver *dns.Resolver
875 }
876
877 func (p *parser) parse(original string) (*ConnString, error) {
878 connStr := &ConnString{}
879 connStr.Original = original
880 uri := original
881
882 var err error
883 if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") {
884 connStr.Scheme = SchemeMongoDBSRV
885
886 uri = uri[len(SchemeMongoDBSRV)+3:]
887 } else if strings.HasPrefix(uri, SchemeMongoDB+"://") {
888 connStr.Scheme = SchemeMongoDB
889
890 uri = uri[len(SchemeMongoDB)+3:]
891 } else {
892 return nil, errors.New(`scheme must be "mongodb" or "mongodb+srv"`)
893 }
894
895 if idx := strings.Index(uri, "@"); idx != -1 {
896 userInfo := uri[:idx]
897 uri = uri[idx+1:]
898
899 username := userInfo
900 var password string
901
902 if idx := strings.Index(userInfo, ":"); idx != -1 {
903 username = userInfo[:idx]
904 password = userInfo[idx+1:]
905 connStr.PasswordSet = true
906 }
907
908
909 if strings.Contains(username, "/") {
910 return nil, fmt.Errorf("unescaped slash in username")
911 }
912 connStr.Username, err = url.PathUnescape(username)
913 if err != nil {
914 return nil, fmt.Errorf("invalid username: %w", err)
915 }
916 connStr.UsernameSet = true
917
918
919 if strings.Contains(password, ":") {
920 return nil, fmt.Errorf("unescaped colon in password")
921 }
922 if strings.Contains(password, "/") {
923 return nil, fmt.Errorf("unescaped slash in password")
924 }
925 connStr.Password, err = url.PathUnescape(password)
926 if err != nil {
927 return nil, fmt.Errorf("invalid password: %w", err)
928 }
929 }
930
931
932 hosts := uri
933 if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
934 if uri[idx] == '@' {
935 return nil, fmt.Errorf("unescaped @ sign in user info")
936 }
937 if uri[idx] == '?' {
938 return nil, fmt.Errorf("must have a / before the query ?")
939 }
940 hosts = uri[:idx]
941 }
942
943 for _, host := range strings.Split(hosts, ",") {
944 host, err = sanitizeHost(host)
945 if err != nil {
946 return nil, fmt.Errorf("invalid host %q: %w", host, err)
947 }
948 if host != "" {
949 connStr.RawHosts = append(connStr.RawHosts, host)
950 }
951 }
952 connStr.Hosts = connStr.RawHosts
953 uri = uri[len(hosts):]
954 extractedDatabase, err := extractDatabaseFromURI(uri)
955 if err != nil {
956 return nil, err
957 }
958
959 uri = extractedDatabase.uri
960 connStr.Database = extractedDatabase.db
961
962
963 connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
964 if err != nil {
965 return nil, err
966 }
967
968
969 var connectionArgsFromTXT []string
970 if connStr.Scheme == SchemeMongoDBSRV && p.dnsResolver != nil {
971 connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
972 if err != nil {
973 return nil, err
974 }
975
976
977 connStr.SSL = true
978 connStr.SSLSet = true
979 }
980
981
982 connectionArgPairs := make([]string, 0, len(connectionArgsFromTXT)+len(connectionArgsFromQueryString))
983 connectionArgPairs = append(connectionArgPairs, connectionArgsFromTXT...)
984 connectionArgPairs = append(connectionArgPairs, connectionArgsFromQueryString...)
985
986 err = connStr.addOptions(connectionArgPairs)
987 if err != nil {
988 return nil, err
989 }
990
991
992 if connStr.Scheme == SchemeMongoDBSRV && p.dnsResolver != nil {
993 parsedHosts, err := p.dnsResolver.ParseHosts(hosts, connStr.SRVServiceName, true)
994 if err != nil {
995 return connStr, err
996 }
997
998
999
1000 if connStr.SRVMaxHosts > 0 && connStr.SRVMaxHosts < len(parsedHosts) {
1001 random.Shuffle(len(parsedHosts), func(i, j int) {
1002 parsedHosts[i], parsedHosts[j] = parsedHosts[j], parsedHosts[i]
1003 })
1004 parsedHosts = parsedHosts[:connStr.SRVMaxHosts]
1005 }
1006
1007 var hosts []string
1008 for _, host := range parsedHosts {
1009 host, err = sanitizeHost(host)
1010 if err != nil {
1011 return connStr, fmt.Errorf("invalid host %q: %w", host, err)
1012 }
1013 if host != "" {
1014 hosts = append(hosts, host)
1015 }
1016 }
1017 connStr.Hosts = hosts
1018 }
1019 if len(connStr.Hosts) == 0 {
1020 return nil, fmt.Errorf("must have at least 1 host")
1021 }
1022
1023 err = connStr.setDefaultAuthParams(extractedDatabase.db)
1024 if err != nil {
1025 return nil, err
1026 }
1027
1028
1029 if connStr.WTimeoutSetFromOption {
1030 connStr.WTimeoutSet = true
1031 }
1032
1033 return connStr, nil
1034 }
1035
1036
1037
1038 func IsValidServerMonitoringMode(mode string) bool {
1039 return mode == ServerMonitoringModeAuto ||
1040 mode == ServerMonitoringModeStream ||
1041 mode == ServerMonitoringModePoll
1042 }
1043
1044 func extractQueryArgsFromURI(uri string) ([]string, error) {
1045 if len(uri) == 0 {
1046 return nil, nil
1047 }
1048
1049 if uri[0] != '?' {
1050 return nil, errors.New("must have a ? separator between path and query")
1051 }
1052
1053 uri = uri[1:]
1054 if len(uri) == 0 {
1055 return nil, nil
1056 }
1057 return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
1058
1059 }
1060
1061 type extractedDatabase struct {
1062 uri string
1063 db string
1064 }
1065
1066
1067
1068
1069
1070 func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
1071 if len(uri) == 0 {
1072 return extractedDatabase{}, nil
1073 }
1074
1075 if uri[0] != '/' {
1076 return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
1077 }
1078
1079 uri = uri[1:]
1080 if len(uri) == 0 {
1081 return extractedDatabase{}, nil
1082 }
1083
1084 database := uri
1085 if idx := strings.IndexRune(uri, '?'); idx != -1 {
1086 database = uri[:idx]
1087 }
1088
1089 escapedDatabase, err := url.QueryUnescape(database)
1090 if err != nil {
1091 return extractedDatabase{}, fmt.Errorf("invalid database %q: %w", database, err)
1092 }
1093
1094 uri = uri[len(database):]
1095
1096 return extractedDatabase{
1097 uri: uri,
1098 db: escapedDatabase,
1099 }, nil
1100 }
1101
View as plain text