...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/dns/dns.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/dns

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package dns
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"runtime"
    14  	"strings"
    15  )
    16  
    17  // Resolver resolves DNS records.
    18  type Resolver struct {
    19  	// Holds the functions to use for DNS lookups
    20  	LookupSRV func(string, string, string) (string, []*net.SRV, error)
    21  	LookupTXT func(string) ([]string, error)
    22  }
    23  
    24  // DefaultResolver is a Resolver that uses the default Resolver from the net package.
    25  var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
    26  
    27  // ParseHosts uses the srv string and service name to get the hosts.
    28  func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
    29  	parsedHosts := strings.Split(host, ",")
    30  
    31  	if len(parsedHosts) != 1 {
    32  		return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
    33  	}
    34  	return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
    35  }
    36  
    37  // GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
    38  func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
    39  	var connectionArgsFromTXT []string
    40  
    41  	// error ignored because not finding a TXT record should not be
    42  	// considered an error.
    43  	recordsFromTXT, _ := r.LookupTXT(host)
    44  
    45  	// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
    46  	// It will currently incorrectly concatenate multiple TXT records to one
    47  	// on windows.
    48  	if runtime.GOOS == "windows" {
    49  		recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
    50  	}
    51  
    52  	if len(recordsFromTXT) > 1 {
    53  		return nil, errors.New("multiple records from TXT not supported")
    54  	}
    55  	if len(recordsFromTXT) > 0 {
    56  		connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
    57  
    58  		err := validateTXTResult(connectionArgsFromTXT)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  	}
    63  
    64  	return connectionArgsFromTXT, nil
    65  }
    66  
    67  func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
    68  	var err error
    69  
    70  	_, _, err = net.SplitHostPort(host)
    71  
    72  	if err == nil {
    73  		// we were able to successfully extract a port from the host,
    74  		// but should not be able to when using SRV
    75  		return nil, fmt.Errorf("URI with srv must not include a port number")
    76  	}
    77  
    78  	// default to "mongodb" as service name if not supplied
    79  	if srvName == "" {
    80  		srvName = "mongodb"
    81  	}
    82  	_, addresses, err := r.LookupSRV(srvName, "tcp", host)
    83  	if err != nil && strings.Contains(err.Error(), "cannot unmarshal DNS message") {
    84  		return nil, fmt.Errorf("see https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#hdr-Potential_DNS_Issues: %w", err)
    85  	} else if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	trimmedHost := strings.TrimSuffix(host, ".")
    90  
    91  	parsedHosts := make([]string, 0, len(addresses))
    92  	for _, address := range addresses {
    93  		trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
    94  		err := validateSRVResult(trimmedAddressTarget, trimmedHost)
    95  		if err != nil {
    96  			if stopOnErr {
    97  				return nil, err
    98  			}
    99  			continue
   100  		}
   101  		parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
   102  	}
   103  	return parsedHosts, nil
   104  }
   105  
   106  func validateSRVResult(recordFromSRV, inputHostName string) error {
   107  	separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
   108  	separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
   109  	if len(separatedRecord) < 2 {
   110  		return errors.New("DNS name must contain at least 2 labels")
   111  	}
   112  	if len(separatedRecord) < len(separatedInputDomain) {
   113  		return errors.New("Domain suffix from SRV record not matched input domain")
   114  	}
   115  
   116  	inputDomainSuffix := separatedInputDomain[1:]
   117  	domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
   118  
   119  	recordDomainSuffix := separatedRecord[domainSuffixOffset:]
   120  	for ix, label := range inputDomainSuffix {
   121  		if label != recordDomainSuffix[ix] {
   122  			return errors.New("Domain suffix from SRV record not matched input domain")
   123  		}
   124  	}
   125  	return nil
   126  }
   127  
   128  var allowedTXTOptions = map[string]struct{}{
   129  	"authsource":   {},
   130  	"replicaset":   {},
   131  	"loadbalanced": {},
   132  }
   133  
   134  func validateTXTResult(paramsFromTXT []string) error {
   135  	for _, param := range paramsFromTXT {
   136  		kv := strings.SplitN(param, "=", 2)
   137  		if len(kv) != 2 {
   138  			return errors.New("Invalid TXT record")
   139  		}
   140  		key := strings.ToLower(kv[0])
   141  		if _, ok := allowedTXTOptions[key]; !ok {
   142  			return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
   143  		}
   144  	}
   145  	return nil
   146  }
   147  

View as plain text