...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/connstring

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package connstring // import "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/randutil"
    19  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
    21  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    22  )
    23  
    24  const (
    25  	// ServerMonitoringModeAuto indicates that the client will behave like "poll"
    26  	// mode when running on a FaaS (Function as a Service) platform, or like
    27  	// "stream" mode otherwise. The client detects its execution environment by
    28  	// following the rules for generating the "client.env" handshake metadata field
    29  	// as specified in the MongoDB Handshake specification. This is the default
    30  	// mode.
    31  	ServerMonitoringModeAuto = "auto"
    32  
    33  	// ServerMonitoringModePoll indicates that the client will periodically check
    34  	// the server using a hello or legacy hello command and then sleep for
    35  	// heartbeatFrequencyMS milliseconds before running another check.
    36  	ServerMonitoringModePoll = "poll"
    37  
    38  	// ServerMonitoringModeStream indicates that the client will use a streaming
    39  	// protocol when the server supports it. The streaming protocol optimally
    40  	// reduces the time it takes for a client to discover server state changes.
    41  	ServerMonitoringModeStream = "stream"
    42  )
    43  
    44  var (
    45  	// ErrLoadBalancedWithMultipleHosts is returned when loadBalanced=true is
    46  	// specified in a URI with multiple hosts.
    47  	ErrLoadBalancedWithMultipleHosts = errors.New(
    48  		"loadBalanced cannot be set to true if multiple hosts are specified")
    49  
    50  	// ErrLoadBalancedWithReplicaSet is returned when loadBalanced=true is
    51  	// specified in a URI with the replicaSet option.
    52  	ErrLoadBalancedWithReplicaSet = errors.New(
    53  		"loadBalanced cannot be set to true if a replica set name is specified")
    54  
    55  	// ErrLoadBalancedWithDirectConnection is returned when loadBalanced=true is
    56  	// specified in a URI with the directConnection option.
    57  	ErrLoadBalancedWithDirectConnection = errors.New(
    58  		"loadBalanced cannot be set to true if the direct connection option is specified")
    59  
    60  	// ErrSRVMaxHostsWithReplicaSet is returned when srvMaxHosts > 0 is
    61  	// specified in a URI with the replicaSet option.
    62  	ErrSRVMaxHostsWithReplicaSet = errors.New(
    63  		"srvMaxHosts cannot be a positive value if a replica set name is specified")
    64  
    65  	// ErrSRVMaxHostsWithLoadBalanced is returned when srvMaxHosts > 0 is
    66  	// specified in a URI with loadBalanced=true.
    67  	ErrSRVMaxHostsWithLoadBalanced = errors.New(
    68  		"srvMaxHosts cannot be a positive value if loadBalanced is set to true")
    69  )
    70  
    71  // random is a package-global pseudo-random number generator.
    72  var random = randutil.NewLockedRand()
    73  
    74  // ParseAndValidate parses the provided URI into a ConnString object.
    75  // It check that all values are valid.
    76  func ParseAndValidate(s string) (*ConnString, error) {
    77  	connStr, err := Parse(s)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	err = connStr.Validate()
    82  	if err != nil {
    83  		return nil, fmt.Errorf("error validating uri: %w", err)
    84  	}
    85  	return connStr, nil
    86  }
    87  
    88  // Parse parses the provided URI into a ConnString object
    89  // but does not check that all values are valid. Use `ConnString.Validate()`
    90  // to run the validation checks separately.
    91  func Parse(s string) (*ConnString, error) {
    92  	p := parser{dnsResolver: dns.DefaultResolver}
    93  	connStr, err := p.parse(s)
    94  	if err != nil {
    95  		return nil, fmt.Errorf("error parsing uri: %w", err)
    96  	}
    97  	return connStr, err
    98  }
    99  
   100  // ConnString represents a connection string to mongodb.
   101  type ConnString struct {
   102  	Original                           string
   103  	AppName                            string
   104  	AuthMechanism                      string
   105  	AuthMechanismProperties            map[string]string
   106  	AuthMechanismPropertiesSet         bool
   107  	AuthSource                         string
   108  	AuthSourceSet                      bool
   109  	Compressors                        []string
   110  	Connect                            ConnectMode
   111  	ConnectSet                         bool
   112  	DirectConnection                   bool
   113  	DirectConnectionSet                bool
   114  	ConnectTimeout                     time.Duration
   115  	ConnectTimeoutSet                  bool
   116  	Database                           string
   117  	HeartbeatInterval                  time.Duration
   118  	HeartbeatIntervalSet               bool
   119  	Hosts                              []string
   120  	J                                  bool
   121  	JSet                               bool
   122  	LoadBalanced                       bool
   123  	LoadBalancedSet                    bool
   124  	LocalThreshold                     time.Duration
   125  	LocalThresholdSet                  bool
   126  	MaxConnIdleTime                    time.Duration
   127  	MaxConnIdleTimeSet                 bool
   128  	MaxPoolSize                        uint64
   129  	MaxPoolSizeSet                     bool
   130  	MinPoolSize                        uint64
   131  	MinPoolSizeSet                     bool
   132  	MaxConnecting                      uint64
   133  	MaxConnectingSet                   bool
   134  	Password                           string
   135  	PasswordSet                        bool
   136  	RawHosts                           []string
   137  	ReadConcernLevel                   string
   138  	ReadPreference                     string
   139  	ReadPreferenceTagSets              []map[string]string
   140  	RetryWrites                        bool
   141  	RetryWritesSet                     bool
   142  	RetryReads                         bool
   143  	RetryReadsSet                      bool
   144  	MaxStaleness                       time.Duration
   145  	MaxStalenessSet                    bool
   146  	ReplicaSet                         string
   147  	Scheme                             string
   148  	ServerMonitoringMode               string
   149  	ServerSelectionTimeout             time.Duration
   150  	ServerSelectionTimeoutSet          bool
   151  	SocketTimeout                      time.Duration
   152  	SocketTimeoutSet                   bool
   153  	SRVMaxHosts                        int
   154  	SRVServiceName                     string
   155  	SSL                                bool
   156  	SSLSet                             bool
   157  	SSLClientCertificateKeyFile        string
   158  	SSLClientCertificateKeyFileSet     bool
   159  	SSLClientCertificateKeyPassword    func() string
   160  	SSLClientCertificateKeyPasswordSet bool
   161  	SSLCertificateFile                 string
   162  	SSLCertificateFileSet              bool
   163  	SSLPrivateKeyFile                  string
   164  	SSLPrivateKeyFileSet               bool
   165  	SSLInsecure                        bool
   166  	SSLInsecureSet                     bool
   167  	SSLCaFile                          string
   168  	SSLCaFileSet                       bool
   169  	SSLDisableOCSPEndpointCheck        bool
   170  	SSLDisableOCSPEndpointCheckSet     bool
   171  	Timeout                            time.Duration
   172  	TimeoutSet                         bool
   173  	WString                            string
   174  	WNumber                            int
   175  	WNumberSet                         bool
   176  	Username                           string
   177  	UsernameSet                        bool
   178  	ZlibLevel                          int
   179  	ZlibLevelSet                       bool
   180  	ZstdLevel                          int
   181  	ZstdLevelSet                       bool
   182  
   183  	WTimeout              time.Duration
   184  	WTimeoutSet           bool
   185  	WTimeoutSetFromOption bool
   186  
   187  	Options        map[string][]string
   188  	UnknownOptions map[string][]string
   189  }
   190  
   191  func (u *ConnString) String() string {
   192  	return u.Original
   193  }
   194  
   195  // HasAuthParameters returns true if this ConnString has any authentication parameters set and therefore represents
   196  // a request for authentication.
   197  func (u *ConnString) HasAuthParameters() bool {
   198  	// Check all auth parameters except for AuthSource because an auth source without other credentials is semantically
   199  	// valid and must not be interpreted as a request for authentication.
   200  	return u.AuthMechanism != "" || u.AuthMechanismProperties != nil || u.UsernameSet || u.PasswordSet
   201  }
   202  
   203  // Validate checks that the Auth and SSL parameters are valid values.
   204  func (u *ConnString) Validate() error {
   205  	var err error
   206  
   207  	if err = u.validateAuth(); err != nil {
   208  		return err
   209  	}
   210  
   211  	if err = u.validateSSL(); err != nil {
   212  		return err
   213  	}
   214  
   215  	// Check for invalid write concern (i.e. w=0 and j=true)
   216  	if u.WNumberSet && u.WNumber == 0 && u.JSet && u.J {
   217  		return writeconcern.ErrInconsistent
   218  	}
   219  
   220  	// Check for invalid use of direct connections.
   221  	if (u.ConnectSet && u.Connect == SingleConnect) ||
   222  		(u.DirectConnectionSet && u.DirectConnection) {
   223  		if len(u.Hosts) > 1 {
   224  			return errors.New("a direct connection cannot be made if multiple hosts are specified")
   225  		}
   226  		if u.Scheme == SchemeMongoDBSRV {
   227  			return errors.New("a direct connection cannot be made if an SRV URI is used")
   228  		}
   229  		if u.LoadBalancedSet && u.LoadBalanced {
   230  			return ErrLoadBalancedWithDirectConnection
   231  		}
   232  	}
   233  
   234  	// Validation for load-balanced mode.
   235  	if u.LoadBalancedSet && u.LoadBalanced {
   236  		if len(u.Hosts) > 1 {
   237  			return ErrLoadBalancedWithMultipleHosts
   238  		}
   239  		if u.ReplicaSet != "" {
   240  			return ErrLoadBalancedWithReplicaSet
   241  		}
   242  	}
   243  
   244  	// Check for invalid use of SRVMaxHosts.
   245  	if u.SRVMaxHosts > 0 {
   246  		if u.ReplicaSet != "" {
   247  			return ErrSRVMaxHostsWithReplicaSet
   248  		}
   249  		if u.LoadBalanced {
   250  			return ErrSRVMaxHostsWithLoadBalanced
   251  		}
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  func (u *ConnString) setDefaultAuthParams(dbName string) error {
   258  	// We do this check here rather than in validateAuth because this function is called as part of parsing and sets
   259  	// the value of AuthSource if authentication is enabled.
   260  	if u.AuthSourceSet && u.AuthSource == "" {
   261  		return errors.New("authSource must be non-empty when supplied in a URI")
   262  	}
   263  
   264  	switch strings.ToLower(u.AuthMechanism) {
   265  	case "plain":
   266  		if u.AuthSource == "" {
   267  			u.AuthSource = dbName
   268  			if u.AuthSource == "" {
   269  				u.AuthSource = "$external"
   270  			}
   271  		}
   272  	case "gssapi":
   273  		if u.AuthMechanismProperties == nil {
   274  			u.AuthMechanismProperties = map[string]string{
   275  				"SERVICE_NAME": "mongodb",
   276  			}
   277  		} else if v, ok := u.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
   278  			u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
   279  		}
   280  		fallthrough
   281  	case "mongodb-aws", "mongodb-x509":
   282  		if u.AuthSource == "" {
   283  			u.AuthSource = "$external"
   284  		} else if u.AuthSource != "$external" {
   285  			return fmt.Errorf("auth source must be $external")
   286  		}
   287  	case "mongodb-cr":
   288  		fallthrough
   289  	case "scram-sha-1":
   290  		fallthrough
   291  	case "scram-sha-256":
   292  		if u.AuthSource == "" {
   293  			u.AuthSource = dbName
   294  			if u.AuthSource == "" {
   295  				u.AuthSource = "admin"
   296  			}
   297  		}
   298  	case "":
   299  		// Only set auth source if there is a request for authentication via non-empty credentials.
   300  		if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) {
   301  			u.AuthSource = dbName
   302  			if u.AuthSource == "" {
   303  				u.AuthSource = "admin"
   304  			}
   305  		}
   306  	default:
   307  		return fmt.Errorf("invalid auth mechanism")
   308  	}
   309  	return nil
   310  }
   311  
   312  func (u *ConnString) addOptions(connectionArgPairs []string) error {
   313  	var tlsssl *bool // used to determine if tls and ssl options are both specified and set differently.
   314  	for _, pair := range connectionArgPairs {
   315  		kv := strings.SplitN(pair, "=", 2)
   316  		if len(kv) != 2 || kv[0] == "" {
   317  			return fmt.Errorf("invalid option")
   318  		}
   319  
   320  		key, err := url.QueryUnescape(kv[0])
   321  		if err != nil {
   322  			return fmt.Errorf("invalid option key %q: %w", kv[0], err)
   323  		}
   324  
   325  		value, err := url.QueryUnescape(kv[1])
   326  		if err != nil {
   327  			return fmt.Errorf("invalid option value %q: %w", kv[1], err)
   328  		}
   329  
   330  		lowerKey := strings.ToLower(key)
   331  		switch lowerKey {
   332  		case "appname":
   333  			u.AppName = value
   334  		case "authmechanism":
   335  			u.AuthMechanism = value
   336  		case "authmechanismproperties":
   337  			u.AuthMechanismProperties = make(map[string]string)
   338  			pairs := strings.Split(value, ",")
   339  			for _, pair := range pairs {
   340  				kv := strings.SplitN(pair, ":", 2)
   341  				if len(kv) != 2 || kv[0] == "" {
   342  					return fmt.Errorf("invalid authMechanism property")
   343  				}
   344  				u.AuthMechanismProperties[kv[0]] = kv[1]
   345  			}
   346  			u.AuthMechanismPropertiesSet = true
   347  		case "authsource":
   348  			u.AuthSource = value
   349  			u.AuthSourceSet = true
   350  		case "compressors":
   351  			compressors := strings.Split(value, ",")
   352  			if len(compressors) < 1 {
   353  				return fmt.Errorf("must have at least 1 compressor")
   354  			}
   355  			u.Compressors = compressors
   356  		case "connect":
   357  			switch strings.ToLower(value) {
   358  			case "automatic":
   359  			case "direct":
   360  				u.Connect = SingleConnect
   361  			default:
   362  				return fmt.Errorf("invalid 'connect' value: %q", value)
   363  			}
   364  			if u.DirectConnectionSet {
   365  				expectedValue := u.Connect == SingleConnect // directConnection should be true if connect=direct
   366  				if u.DirectConnection != expectedValue {
   367  					return fmt.Errorf("options connect=%q and directConnection=%v conflict", value, u.DirectConnection)
   368  				}
   369  			}
   370  
   371  			u.ConnectSet = true
   372  		case "directconnection":
   373  			switch strings.ToLower(value) {
   374  			case "true":
   375  				u.DirectConnection = true
   376  			case "false":
   377  			default:
   378  				return fmt.Errorf("invalid 'directConnection' value: %q", value)
   379  			}
   380  
   381  			if u.ConnectSet {
   382  				expectedValue := AutoConnect
   383  				if u.DirectConnection {
   384  					expectedValue = SingleConnect
   385  				}
   386  
   387  				if u.Connect != expectedValue {
   388  					return fmt.Errorf("options connect=%q and directConnection=%q conflict", u.Connect, value)
   389  				}
   390  			}
   391  			u.DirectConnectionSet = true
   392  		case "connecttimeoutms":
   393  			n, err := strconv.Atoi(value)
   394  			if err != nil || n < 0 {
   395  				return fmt.Errorf("invalid value for %q: %q", key, value)
   396  			}
   397  			u.ConnectTimeout = time.Duration(n) * time.Millisecond
   398  			u.ConnectTimeoutSet = true
   399  		case "heartbeatintervalms", "heartbeatfrequencyms":
   400  			n, err := strconv.Atoi(value)
   401  			if err != nil || n < 0 {
   402  				return fmt.Errorf("invalid value for %q: %q", key, value)
   403  			}
   404  			u.HeartbeatInterval = time.Duration(n) * time.Millisecond
   405  			u.HeartbeatIntervalSet = true
   406  		case "journal":
   407  			switch value {
   408  			case "true":
   409  				u.J = true
   410  			case "false":
   411  				u.J = false
   412  			default:
   413  				return fmt.Errorf("invalid value for %q: %q", key, value)
   414  			}
   415  
   416  			u.JSet = true
   417  		case "loadbalanced":
   418  			switch value {
   419  			case "true":
   420  				u.LoadBalanced = true
   421  			case "false":
   422  				u.LoadBalanced = false
   423  			default:
   424  				return fmt.Errorf("invalid value for %q: %q", key, value)
   425  			}
   426  
   427  			u.LoadBalancedSet = true
   428  		case "localthresholdms":
   429  			n, err := strconv.Atoi(value)
   430  			if err != nil || n < 0 {
   431  				return fmt.Errorf("invalid value for %q: %q", key, value)
   432  			}
   433  			u.LocalThreshold = time.Duration(n) * time.Millisecond
   434  			u.LocalThresholdSet = true
   435  		case "maxidletimems":
   436  			n, err := strconv.Atoi(value)
   437  			if err != nil || n < 0 {
   438  				return fmt.Errorf("invalid value for %q: %q", key, value)
   439  			}
   440  			u.MaxConnIdleTime = time.Duration(n) * time.Millisecond
   441  			u.MaxConnIdleTimeSet = true
   442  		case "maxpoolsize":
   443  			n, err := strconv.Atoi(value)
   444  			if err != nil || n < 0 {
   445  				return fmt.Errorf("invalid value for %q: %q", key, value)
   446  			}
   447  			u.MaxPoolSize = uint64(n)
   448  			u.MaxPoolSizeSet = true
   449  		case "minpoolsize":
   450  			n, err := strconv.Atoi(value)
   451  			if err != nil || n < 0 {
   452  				return fmt.Errorf("invalid value for %q: %q", key, value)
   453  			}
   454  			u.MinPoolSize = uint64(n)
   455  			u.MinPoolSizeSet = true
   456  		case "maxconnecting":
   457  			n, err := strconv.Atoi(value)
   458  			if err != nil || n < 0 {
   459  				return fmt.Errorf("invalid value for %q: %q", key, value)
   460  			}
   461  			u.MaxConnecting = uint64(n)
   462  			u.MaxConnectingSet = true
   463  		case "readconcernlevel":
   464  			u.ReadConcernLevel = value
   465  		case "readpreference":
   466  			u.ReadPreference = value
   467  		case "readpreferencetags":
   468  			if value == "" {
   469  				// If "readPreferenceTags=" is supplied, append an empty map to tag sets to
   470  				// represent a wild-card.
   471  				u.ReadPreferenceTagSets = append(u.ReadPreferenceTagSets, map[string]string{})
   472  				break
   473  			}
   474  
   475  			tags := make(map[string]string)
   476  			items := strings.Split(value, ",")
   477  			for _, item := range items {
   478  				parts := strings.Split(item, ":")
   479  				if len(parts) != 2 {
   480  					return fmt.Errorf("invalid value for %q: %q", key, value)
   481  				}
   482  				tags[parts[0]] = parts[1]
   483  			}
   484  			u.ReadPreferenceTagSets = append(u.ReadPreferenceTagSets, tags)
   485  		case "maxstaleness", "maxstalenessseconds":
   486  			n, err := strconv.Atoi(value)
   487  			if err != nil || n < 0 {
   488  				return fmt.Errorf("invalid value for %q: %q", key, value)
   489  			}
   490  			u.MaxStaleness = time.Duration(n) * time.Second
   491  			u.MaxStalenessSet = true
   492  		case "replicaset":
   493  			u.ReplicaSet = value
   494  		case "retrywrites":
   495  			switch value {
   496  			case "true":
   497  				u.RetryWrites = true
   498  			case "false":
   499  				u.RetryWrites = false
   500  			default:
   501  				return fmt.Errorf("invalid value for %q: %q", key, value)
   502  			}
   503  
   504  			u.RetryWritesSet = true
   505  		case "retryreads":
   506  			switch value {
   507  			case "true":
   508  				u.RetryReads = true
   509  			case "false":
   510  				u.RetryReads = false
   511  			default:
   512  				return fmt.Errorf("invalid value for %q: %q", key, value)
   513  			}
   514  
   515  			u.RetryReadsSet = true
   516  		case "servermonitoringmode":
   517  			if !IsValidServerMonitoringMode(value) {
   518  				return fmt.Errorf("invalid value for %q: %q", key, value)
   519  			}
   520  
   521  			u.ServerMonitoringMode = value
   522  		case "serverselectiontimeoutms":
   523  			n, err := strconv.Atoi(value)
   524  			if err != nil || n < 0 {
   525  				return fmt.Errorf("invalid value for %q: %q", key, value)
   526  			}
   527  			u.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
   528  			u.ServerSelectionTimeoutSet = true
   529  		case "sockettimeoutms":
   530  			n, err := strconv.Atoi(value)
   531  			if err != nil || n < 0 {
   532  				return fmt.Errorf("invalid value for %q: %q", key, value)
   533  			}
   534  			u.SocketTimeout = time.Duration(n) * time.Millisecond
   535  			u.SocketTimeoutSet = true
   536  		case "srvmaxhosts":
   537  			// srvMaxHosts can only be set on URIs with the "mongodb+srv" scheme
   538  			if u.Scheme != SchemeMongoDBSRV {
   539  				return fmt.Errorf("cannot specify srvMaxHosts on non-SRV URI")
   540  			}
   541  
   542  			n, err := strconv.Atoi(value)
   543  			if err != nil || n < 0 {
   544  				return fmt.Errorf("invalid value for %q: %q", key, value)
   545  			}
   546  			u.SRVMaxHosts = n
   547  		case "srvservicename":
   548  			// srvServiceName can only be set on URIs with the "mongodb+srv" scheme
   549  			if u.Scheme != SchemeMongoDBSRV {
   550  				return fmt.Errorf("cannot specify srvServiceName on non-SRV URI")
   551  			}
   552  
   553  			// srvServiceName must be between 1 and 62 characters according to
   554  			// our specification. Empty service names are not valid, and the service
   555  			// name (including prepended underscore) should not exceed the 63 character
   556  			// limit for DNS query subdomains.
   557  			if len(value) < 1 || len(value) > 62 {
   558  				return fmt.Errorf("srvServiceName value must be between 1 and 62 characters")
   559  			}
   560  			u.SRVServiceName = value
   561  		case "ssl", "tls":
   562  			switch value {
   563  			case "true":
   564  				u.SSL = true
   565  			case "false":
   566  				u.SSL = false
   567  			default:
   568  				return fmt.Errorf("invalid value for %q: %q", key, value)
   569  			}
   570  			if tlsssl == nil {
   571  				tlsssl = new(bool)
   572  				*tlsssl = u.SSL
   573  			} else if *tlsssl != u.SSL {
   574  				return errors.New("tls and ssl options, when both specified, must be equivalent")
   575  			}
   576  
   577  			u.SSLSet = true
   578  		case "sslclientcertificatekeyfile", "tlscertificatekeyfile":
   579  			u.SSL = true
   580  			u.SSLSet = true
   581  			u.SSLClientCertificateKeyFile = value
   582  			u.SSLClientCertificateKeyFileSet = true
   583  		case "sslclientcertificatekeypassword", "tlscertificatekeyfilepassword":
   584  			u.SSLClientCertificateKeyPassword = func() string { return value }
   585  			u.SSLClientCertificateKeyPasswordSet = true
   586  		case "tlscertificatefile":
   587  			u.SSL = true
   588  			u.SSLSet = true
   589  			u.SSLCertificateFile = value
   590  			u.SSLCertificateFileSet = true
   591  		case "tlsprivatekeyfile":
   592  			u.SSL = true
   593  			u.SSLSet = true
   594  			u.SSLPrivateKeyFile = value
   595  			u.SSLPrivateKeyFileSet = true
   596  		case "sslinsecure", "tlsinsecure":
   597  			switch value {
   598  			case "true":
   599  				u.SSLInsecure = true
   600  			case "false":
   601  				u.SSLInsecure = false
   602  			default:
   603  				return fmt.Errorf("invalid value for %q: %q", key, value)
   604  			}
   605  
   606  			u.SSLInsecureSet = true
   607  		case "sslcertificateauthorityfile", "tlscafile":
   608  			u.SSL = true
   609  			u.SSLSet = true
   610  			u.SSLCaFile = value
   611  			u.SSLCaFileSet = true
   612  		case "timeoutms":
   613  			n, err := strconv.Atoi(value)
   614  			if err != nil || n < 0 {
   615  				return fmt.Errorf("invalid value for %q: %q", key, value)
   616  			}
   617  			u.Timeout = time.Duration(n) * time.Millisecond
   618  			u.TimeoutSet = true
   619  		case "tlsdisableocspendpointcheck":
   620  			u.SSL = true
   621  			u.SSLSet = true
   622  
   623  			switch value {
   624  			case "true":
   625  				u.SSLDisableOCSPEndpointCheck = true
   626  			case "false":
   627  				u.SSLDisableOCSPEndpointCheck = false
   628  			default:
   629  				return fmt.Errorf("invalid value for %q: %q", key, value)
   630  			}
   631  			u.SSLDisableOCSPEndpointCheckSet = true
   632  		case "w":
   633  			if w, err := strconv.Atoi(value); err == nil {
   634  				if w < 0 {
   635  					return fmt.Errorf("invalid value for %q: %q", key, value)
   636  				}
   637  
   638  				u.WNumber = w
   639  				u.WNumberSet = true
   640  				u.WString = ""
   641  				break
   642  			}
   643  
   644  			u.WString = value
   645  			u.WNumberSet = false
   646  
   647  		case "wtimeoutms":
   648  			n, err := strconv.Atoi(value)
   649  			if err != nil || n < 0 {
   650  				return fmt.Errorf("invalid value for %q: %q", key, value)
   651  			}
   652  			u.WTimeout = time.Duration(n) * time.Millisecond
   653  			u.WTimeoutSet = true
   654  		case "wtimeout":
   655  			// Defer to wtimeoutms, but not to a manually-set option.
   656  			if u.WTimeoutSet {
   657  				break
   658  			}
   659  			n, err := strconv.Atoi(value)
   660  			if err != nil || n < 0 {
   661  				return fmt.Errorf("invalid value for %q: %q", key, value)
   662  			}
   663  			u.WTimeout = time.Duration(n) * time.Millisecond
   664  		case "zlibcompressionlevel":
   665  			level, err := strconv.Atoi(value)
   666  			if err != nil || (level < -1 || level > 9) {
   667  				return fmt.Errorf("invalid value for %q: %q", key, value)
   668  			}
   669  
   670  			if level == -1 {
   671  				level = wiremessage.DefaultZlibLevel
   672  			}
   673  			u.ZlibLevel = level
   674  			u.ZlibLevelSet = true
   675  		case "zstdcompressionlevel":
   676  			const maxZstdLevel = 22 // https://github.com/facebook/zstd/blob/a880ca239b447968493dd2fed3850e766d6305cc/contrib/linux-kernel/lib/zstd/compress.c#L3291
   677  			level, err := strconv.Atoi(value)
   678  			if err != nil || (level < -1 || level > maxZstdLevel) {
   679  				return fmt.Errorf("invalid value for %q: %q", key, value)
   680  			}
   681  
   682  			if level == -1 {
   683  				level = wiremessage.DefaultZstdLevel
   684  			}
   685  			u.ZstdLevel = level
   686  			u.ZstdLevelSet = true
   687  		default:
   688  			if u.UnknownOptions == nil {
   689  				u.UnknownOptions = make(map[string][]string)
   690  			}
   691  			u.UnknownOptions[lowerKey] = append(u.UnknownOptions[lowerKey], value)
   692  		}
   693  
   694  		if u.Options == nil {
   695  			u.Options = make(map[string][]string)
   696  		}
   697  		u.Options[lowerKey] = append(u.Options[lowerKey], value)
   698  	}
   699  	return nil
   700  }
   701  
   702  func (u *ConnString) validateAuth() error {
   703  	switch strings.ToLower(u.AuthMechanism) {
   704  	case "mongodb-cr":
   705  		if u.Username == "" {
   706  			return fmt.Errorf("username required for MONGO-CR")
   707  		}
   708  		if u.Password == "" {
   709  			return fmt.Errorf("password required for MONGO-CR")
   710  		}
   711  		if u.AuthMechanismProperties != nil {
   712  			return fmt.Errorf("MONGO-CR cannot have mechanism properties")
   713  		}
   714  	case "mongodb-x509":
   715  		if u.Password != "" {
   716  			return fmt.Errorf("password cannot be specified for MONGO-X509")
   717  		}
   718  		if u.AuthMechanismProperties != nil {
   719  			return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
   720  		}
   721  	case "mongodb-aws":
   722  		if u.Username != "" && u.Password == "" {
   723  			return fmt.Errorf("username without password is invalid for MONGODB-AWS")
   724  		}
   725  		if u.Username == "" && u.Password != "" {
   726  			return fmt.Errorf("password without username is invalid for MONGODB-AWS")
   727  		}
   728  		var token bool
   729  		for k := range u.AuthMechanismProperties {
   730  			if k != "AWS_SESSION_TOKEN" {
   731  				return fmt.Errorf("invalid auth property for MONGODB-AWS")
   732  			}
   733  			token = true
   734  		}
   735  		if token && u.Username == "" && u.Password == "" {
   736  			return fmt.Errorf("token without username and password is invalid for MONGODB-AWS")
   737  		}
   738  	case "gssapi":
   739  		if u.Username == "" {
   740  			return fmt.Errorf("username required for GSSAPI")
   741  		}
   742  		for k := range u.AuthMechanismProperties {
   743  			if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" && k != "SERVICE_HOST" {
   744  				return fmt.Errorf("invalid auth property for GSSAPI")
   745  			}
   746  		}
   747  	case "plain":
   748  		if u.Username == "" {
   749  			return fmt.Errorf("username required for PLAIN")
   750  		}
   751  		if u.Password == "" {
   752  			return fmt.Errorf("password required for PLAIN")
   753  		}
   754  		if u.AuthMechanismProperties != nil {
   755  			return fmt.Errorf("PLAIN cannot have mechanism properties")
   756  		}
   757  	case "scram-sha-1":
   758  		if u.Username == "" {
   759  			return fmt.Errorf("username required for SCRAM-SHA-1")
   760  		}
   761  		if u.Password == "" {
   762  			return fmt.Errorf("password required for SCRAM-SHA-1")
   763  		}
   764  		if u.AuthMechanismProperties != nil {
   765  			return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
   766  		}
   767  	case "scram-sha-256":
   768  		if u.Username == "" {
   769  			return fmt.Errorf("username required for SCRAM-SHA-256")
   770  		}
   771  		if u.Password == "" {
   772  			return fmt.Errorf("password required for SCRAM-SHA-256")
   773  		}
   774  		if u.AuthMechanismProperties != nil {
   775  			return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
   776  		}
   777  	case "":
   778  		if u.UsernameSet && u.Username == "" {
   779  			return fmt.Errorf("username required if URI contains user info")
   780  		}
   781  	default:
   782  		return fmt.Errorf("invalid auth mechanism")
   783  	}
   784  	return nil
   785  }
   786  
   787  func (u *ConnString) validateSSL() error {
   788  	if !u.SSL {
   789  		return nil
   790  	}
   791  
   792  	if u.SSLClientCertificateKeyFileSet {
   793  		if u.SSLCertificateFileSet || u.SSLPrivateKeyFileSet {
   794  			return errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided " +
   795  				"along with tlsCertificateFile or tlsPrivateKeyFile")
   796  		}
   797  		return nil
   798  	}
   799  	if u.SSLCertificateFileSet && !u.SSLPrivateKeyFileSet {
   800  		return errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")
   801  	}
   802  	if u.SSLPrivateKeyFileSet && !u.SSLCertificateFileSet {
   803  		return errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")
   804  	}
   805  
   806  	if u.SSLInsecureSet && u.SSLDisableOCSPEndpointCheckSet {
   807  		return errors.New("the sslInsecure/tlsInsecure URI option cannot be provided along with " +
   808  			"tlsDisableOCSPEndpointCheck ")
   809  	}
   810  	return nil
   811  }
   812  
   813  func sanitizeHost(host string) (string, error) {
   814  	if host == "" {
   815  		return host, nil
   816  	}
   817  	unescaped, err := url.QueryUnescape(host)
   818  	if err != nil {
   819  		return "", fmt.Errorf("invalid host %q: %w", host, err)
   820  	}
   821  
   822  	_, port, err := net.SplitHostPort(unescaped)
   823  	// this is unfortunate that SplitHostPort actually requires
   824  	// a port to exist.
   825  	if err != nil {
   826  		if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
   827  			return "", err
   828  		}
   829  	}
   830  
   831  	if port != "" {
   832  		d, err := strconv.Atoi(port)
   833  		if err != nil {
   834  			return "", fmt.Errorf("port must be an integer: %w", err)
   835  		}
   836  		if d <= 0 || d >= 65536 {
   837  			return "", fmt.Errorf("port must be in the range [1, 65535]")
   838  		}
   839  	}
   840  	return unescaped, nil
   841  }
   842  
   843  // ConnectMode informs the driver on how to connect
   844  // to the server.
   845  type ConnectMode uint8
   846  
   847  var _ fmt.Stringer = ConnectMode(0)
   848  
   849  // ConnectMode constants.
   850  const (
   851  	AutoConnect ConnectMode = iota
   852  	SingleConnect
   853  )
   854  
   855  // String implements the fmt.Stringer interface.
   856  func (c ConnectMode) String() string {
   857  	switch c {
   858  	case AutoConnect:
   859  		return "automatic"
   860  	case SingleConnect:
   861  		return "direct"
   862  	default:
   863  		return "unknown"
   864  	}
   865  }
   866  
   867  // Scheme constants
   868  const (
   869  	SchemeMongoDB    = "mongodb"
   870  	SchemeMongoDBSRV = "mongodb+srv"
   871  )
   872  
   873  type parser struct {
   874  	dnsResolver *dns.Resolver
   875  }
   876  
   877  func (p *parser) parse(original string) (*ConnString, error) {
   878  	connStr := &ConnString{}
   879  	connStr.Original = original
   880  	uri := original
   881  
   882  	var err error
   883  	if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") {
   884  		connStr.Scheme = SchemeMongoDBSRV
   885  		// remove the scheme
   886  		uri = uri[len(SchemeMongoDBSRV)+3:]
   887  	} else if strings.HasPrefix(uri, SchemeMongoDB+"://") {
   888  		connStr.Scheme = SchemeMongoDB
   889  		// remove the scheme
   890  		uri = uri[len(SchemeMongoDB)+3:]
   891  	} else {
   892  		return nil, errors.New(`scheme must be "mongodb" or "mongodb+srv"`)
   893  	}
   894  
   895  	if idx := strings.Index(uri, "@"); idx != -1 {
   896  		userInfo := uri[:idx]
   897  		uri = uri[idx+1:]
   898  
   899  		username := userInfo
   900  		var password string
   901  
   902  		if idx := strings.Index(userInfo, ":"); idx != -1 {
   903  			username = userInfo[:idx]
   904  			password = userInfo[idx+1:]
   905  			connStr.PasswordSet = true
   906  		}
   907  
   908  		// Validate and process the username.
   909  		if strings.Contains(username, "/") {
   910  			return nil, fmt.Errorf("unescaped slash in username")
   911  		}
   912  		connStr.Username, err = url.PathUnescape(username)
   913  		if err != nil {
   914  			return nil, fmt.Errorf("invalid username: %w", err)
   915  		}
   916  		connStr.UsernameSet = true
   917  
   918  		// Validate and process the password.
   919  		if strings.Contains(password, ":") {
   920  			return nil, fmt.Errorf("unescaped colon in password")
   921  		}
   922  		if strings.Contains(password, "/") {
   923  			return nil, fmt.Errorf("unescaped slash in password")
   924  		}
   925  		connStr.Password, err = url.PathUnescape(password)
   926  		if err != nil {
   927  			return nil, fmt.Errorf("invalid password: %w", err)
   928  		}
   929  	}
   930  
   931  	// fetch the hosts field
   932  	hosts := uri
   933  	if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
   934  		if uri[idx] == '@' {
   935  			return nil, fmt.Errorf("unescaped @ sign in user info")
   936  		}
   937  		if uri[idx] == '?' {
   938  			return nil, fmt.Errorf("must have a / before the query ?")
   939  		}
   940  		hosts = uri[:idx]
   941  	}
   942  
   943  	for _, host := range strings.Split(hosts, ",") {
   944  		host, err = sanitizeHost(host)
   945  		if err != nil {
   946  			return nil, fmt.Errorf("invalid host %q: %w", host, err)
   947  		}
   948  		if host != "" {
   949  			connStr.RawHosts = append(connStr.RawHosts, host)
   950  		}
   951  	}
   952  	connStr.Hosts = connStr.RawHosts
   953  	uri = uri[len(hosts):]
   954  	extractedDatabase, err := extractDatabaseFromURI(uri)
   955  	if err != nil {
   956  		return nil, err
   957  	}
   958  
   959  	uri = extractedDatabase.uri
   960  	connStr.Database = extractedDatabase.db
   961  
   962  	// grab connection arguments from URI
   963  	connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
   964  	if err != nil {
   965  		return nil, err
   966  	}
   967  
   968  	// grab connection arguments from TXT record and enable SSL if "mongodb+srv://"
   969  	var connectionArgsFromTXT []string
   970  	if connStr.Scheme == SchemeMongoDBSRV && p.dnsResolver != nil {
   971  		connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
   972  		if err != nil {
   973  			return nil, err
   974  		}
   975  
   976  		// SSL is enabled by default for SRV, but can be manually disabled with "ssl=false".
   977  		connStr.SSL = true
   978  		connStr.SSLSet = true
   979  	}
   980  
   981  	// add connection arguments from URI and TXT records to connstring
   982  	connectionArgPairs := make([]string, 0, len(connectionArgsFromTXT)+len(connectionArgsFromQueryString))
   983  	connectionArgPairs = append(connectionArgPairs, connectionArgsFromTXT...)
   984  	connectionArgPairs = append(connectionArgPairs, connectionArgsFromQueryString...)
   985  
   986  	err = connStr.addOptions(connectionArgPairs)
   987  	if err != nil {
   988  		return nil, err
   989  	}
   990  
   991  	// do SRV lookup if "mongodb+srv://"
   992  	if connStr.Scheme == SchemeMongoDBSRV && p.dnsResolver != nil {
   993  		parsedHosts, err := p.dnsResolver.ParseHosts(hosts, connStr.SRVServiceName, true)
   994  		if err != nil {
   995  			return connStr, err
   996  		}
   997  
   998  		// If p.SRVMaxHosts is non-zero and is less than the number of hosts, randomly
   999  		// select SRVMaxHosts hosts from parsedHosts.
  1000  		if connStr.SRVMaxHosts > 0 && connStr.SRVMaxHosts < len(parsedHosts) {
  1001  			random.Shuffle(len(parsedHosts), func(i, j int) {
  1002  				parsedHosts[i], parsedHosts[j] = parsedHosts[j], parsedHosts[i]
  1003  			})
  1004  			parsedHosts = parsedHosts[:connStr.SRVMaxHosts]
  1005  		}
  1006  
  1007  		var hosts []string
  1008  		for _, host := range parsedHosts {
  1009  			host, err = sanitizeHost(host)
  1010  			if err != nil {
  1011  				return connStr, fmt.Errorf("invalid host %q: %w", host, err)
  1012  			}
  1013  			if host != "" {
  1014  				hosts = append(hosts, host)
  1015  			}
  1016  		}
  1017  		connStr.Hosts = hosts
  1018  	}
  1019  	if len(connStr.Hosts) == 0 {
  1020  		return nil, fmt.Errorf("must have at least 1 host")
  1021  	}
  1022  
  1023  	err = connStr.setDefaultAuthParams(extractedDatabase.db)
  1024  	if err != nil {
  1025  		return nil, err
  1026  	}
  1027  
  1028  	// If WTimeout was set from manual options passed in, set WTImeoutSet to true.
  1029  	if connStr.WTimeoutSetFromOption {
  1030  		connStr.WTimeoutSet = true
  1031  	}
  1032  
  1033  	return connStr, nil
  1034  }
  1035  
  1036  // IsValidServerMonitoringMode will return true if the given string matches a
  1037  // valid server monitoring mode.
  1038  func IsValidServerMonitoringMode(mode string) bool {
  1039  	return mode == ServerMonitoringModeAuto ||
  1040  		mode == ServerMonitoringModeStream ||
  1041  		mode == ServerMonitoringModePoll
  1042  }
  1043  
  1044  func extractQueryArgsFromURI(uri string) ([]string, error) {
  1045  	if len(uri) == 0 {
  1046  		return nil, nil
  1047  	}
  1048  
  1049  	if uri[0] != '?' {
  1050  		return nil, errors.New("must have a ? separator between path and query")
  1051  	}
  1052  
  1053  	uri = uri[1:]
  1054  	if len(uri) == 0 {
  1055  		return nil, nil
  1056  	}
  1057  	return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
  1058  
  1059  }
  1060  
  1061  type extractedDatabase struct {
  1062  	uri string
  1063  	db  string
  1064  }
  1065  
  1066  // extractDatabaseFromURI is a helper function to retrieve information about
  1067  // the database from the passed in URI. It accepts as an argument the currently
  1068  // parsed URI and returns the remainder of the uri, the database it found,
  1069  // and any error it encounters while parsing.
  1070  func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
  1071  	if len(uri) == 0 {
  1072  		return extractedDatabase{}, nil
  1073  	}
  1074  
  1075  	if uri[0] != '/' {
  1076  		return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
  1077  	}
  1078  
  1079  	uri = uri[1:]
  1080  	if len(uri) == 0 {
  1081  		return extractedDatabase{}, nil
  1082  	}
  1083  
  1084  	database := uri
  1085  	if idx := strings.IndexRune(uri, '?'); idx != -1 {
  1086  		database = uri[:idx]
  1087  	}
  1088  
  1089  	escapedDatabase, err := url.QueryUnescape(database)
  1090  	if err != nil {
  1091  		return extractedDatabase{}, fmt.Errorf("invalid database %q: %w", database, err)
  1092  	}
  1093  
  1094  	uri = uri[len(database):]
  1095  
  1096  	return extractedDatabase{
  1097  		uri: uri,
  1098  		db:  escapedDatabase,
  1099  	}, nil
  1100  }
  1101  

View as plain text