1
2
3
4
5
6
7 package dns
8
9 import (
10 "errors"
11 "fmt"
12 "net"
13 "runtime"
14 "strings"
15 )
16
17
18 type Resolver struct {
19
20 LookupSRV func(string, string, string) (string, []*net.SRV, error)
21 LookupTXT func(string) ([]string, error)
22 }
23
24
25 var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
26
27
28 func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
29 parsedHosts := strings.Split(host, ",")
30
31 if len(parsedHosts) != 1 {
32 return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
33 }
34 return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
35 }
36
37
38 func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
39 var connectionArgsFromTXT []string
40
41
42
43 recordsFromTXT, _ := r.LookupTXT(host)
44
45
46
47
48 if runtime.GOOS == "windows" {
49 recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
50 }
51
52 if len(recordsFromTXT) > 1 {
53 return nil, errors.New("multiple records from TXT not supported")
54 }
55 if len(recordsFromTXT) > 0 {
56 connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
57
58 err := validateTXTResult(connectionArgsFromTXT)
59 if err != nil {
60 return nil, err
61 }
62 }
63
64 return connectionArgsFromTXT, nil
65 }
66
67 func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
68 var err error
69
70 _, _, err = net.SplitHostPort(host)
71
72 if err == nil {
73
74
75 return nil, fmt.Errorf("URI with srv must not include a port number")
76 }
77
78
79 if srvName == "" {
80 srvName = "mongodb"
81 }
82 _, addresses, err := r.LookupSRV(srvName, "tcp", host)
83 if err != nil && strings.Contains(err.Error(), "cannot unmarshal DNS message") {
84 return nil, fmt.Errorf("see https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#hdr-Potential_DNS_Issues: %w", err)
85 } else if err != nil {
86 return nil, err
87 }
88
89 trimmedHost := strings.TrimSuffix(host, ".")
90
91 parsedHosts := make([]string, 0, len(addresses))
92 for _, address := range addresses {
93 trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
94 err := validateSRVResult(trimmedAddressTarget, trimmedHost)
95 if err != nil {
96 if stopOnErr {
97 return nil, err
98 }
99 continue
100 }
101 parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
102 }
103 return parsedHosts, nil
104 }
105
106 func validateSRVResult(recordFromSRV, inputHostName string) error {
107 separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
108 separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
109 if len(separatedRecord) < 2 {
110 return errors.New("DNS name must contain at least 2 labels")
111 }
112 if len(separatedRecord) < len(separatedInputDomain) {
113 return errors.New("Domain suffix from SRV record not matched input domain")
114 }
115
116 inputDomainSuffix := separatedInputDomain[1:]
117 domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
118
119 recordDomainSuffix := separatedRecord[domainSuffixOffset:]
120 for ix, label := range inputDomainSuffix {
121 if label != recordDomainSuffix[ix] {
122 return errors.New("Domain suffix from SRV record not matched input domain")
123 }
124 }
125 return nil
126 }
127
128 var allowedTXTOptions = map[string]struct{}{
129 "authsource": {},
130 "replicaset": {},
131 "loadbalanced": {},
132 }
133
134 func validateTXTResult(paramsFromTXT []string) error {
135 for _, param := range paramsFromTXT {
136 kv := strings.SplitN(param, "=", 2)
137 if len(kv) != 2 {
138 return errors.New("Invalid TXT record")
139 }
140 key := strings.ToLower(kv[0])
141 if _, ok := allowedTXTOptions[key]; !ok {
142 return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
143 }
144 }
145 return nil
146 }
147
View as plain text