1
2
3
4
5
6
7 package topology
8
9 import (
10 "crypto/tls"
11 "fmt"
12 "net/http"
13 "strings"
14 "time"
15
16 "go.mongodb.org/mongo-driver/event"
17 "go.mongodb.org/mongo-driver/internal/logger"
18 "go.mongodb.org/mongo-driver/mongo/description"
19 "go.mongodb.org/mongo-driver/mongo/options"
20 "go.mongodb.org/mongo-driver/x/mongo/driver"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
22 "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
23 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
24 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
25 )
26
27 const defaultServerSelectionTimeout = 30 * time.Second
28
29
30 type Config struct {
31 Mode MonitorMode
32 ReplicaSetName string
33 SeedList []string
34 ServerOpts []ServerOption
35 URI string
36 ServerSelectionTimeout time.Duration
37 ServerMonitor *event.ServerMonitor
38 SRVMaxHosts int
39 SRVServiceName string
40 LoadBalanced bool
41 logger *logger.Logger
42 }
43
44
45 func ConvertToDriverAPIOptions(s *options.ServerAPIOptions) *driver.ServerAPIOptions {
46 driverOpts := driver.NewServerAPIOptions(string(s.ServerAPIVersion))
47 if s.Strict != nil {
48 driverOpts.SetStrict(*s.Strict)
49 }
50 if s.DeprecationErrors != nil {
51 driverOpts.SetDeprecationErrors(*s.DeprecationErrors)
52 }
53 return driverOpts
54 }
55
56 func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) {
57 if opts == nil {
58 opts = options.Logger()
59 }
60
61 componentLevels := make(map[logger.Component]logger.Level)
62 for component, level := range opts.ComponentLevels {
63 componentLevels[logger.Component(component)] = logger.Level(level)
64 }
65
66 log, err := logger.New(opts.Sink, opts.MaxDocumentLength, componentLevels)
67 if err != nil {
68 return nil, fmt.Errorf("error creating logger: %w", err)
69 }
70
71 return log, nil
72 }
73
74
75
76 func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) {
77 var serverAPI *driver.ServerAPIOptions
78
79 if err := co.Validate(); err != nil {
80 return nil, err
81 }
82
83 var connOpts []ConnectionOption
84 var serverOpts []ServerOption
85
86 cfgp := &Config{}
87
88
89 cfgp.ServerSelectionTimeout = defaultServerSelectionTimeout
90
91
92 cfgp.SeedList = []string{"localhost:27017"}
93
94
95
96
97
98 if co.ServerAPIOptions != nil {
99 serverAPI = ConvertToDriverAPIOptions(co.ServerAPIOptions)
100 serverOpts = append(serverOpts, WithServerAPI(func(*driver.ServerAPIOptions) *driver.ServerAPIOptions {
101 return serverAPI
102 }))
103 }
104
105 cfgp.URI = co.GetURI()
106
107 if co.SRVServiceName != nil {
108 cfgp.SRVServiceName = *co.SRVServiceName
109 }
110
111 if co.SRVMaxHosts != nil {
112 cfgp.SRVMaxHosts = *co.SRVMaxHosts
113 }
114
115
116 var appName string
117 if co.AppName != nil {
118 appName = *co.AppName
119
120 serverOpts = append(serverOpts, WithServerAppName(func(string) string {
121 return appName
122 }))
123 }
124
125 var comps []string
126 if len(co.Compressors) > 0 {
127 comps = co.Compressors
128
129 connOpts = append(connOpts, WithCompressors(
130 func(compressors []string) []string {
131 return append(compressors, comps...)
132 },
133 ))
134
135 for _, comp := range comps {
136 switch comp {
137 case "zlib":
138 connOpts = append(connOpts, WithZlibLevel(func(level *int) *int {
139 return co.ZlibLevel
140 }))
141 case "zstd":
142 connOpts = append(connOpts, WithZstdLevel(func(level *int) *int {
143 return co.ZstdLevel
144 }))
145 }
146 }
147
148 serverOpts = append(serverOpts, WithCompressionOptions(
149 func(opts ...string) []string { return append(opts, comps...) },
150 ))
151 }
152
153 var loadBalanced bool
154 if co.LoadBalanced != nil {
155 loadBalanced = *co.LoadBalanced
156 }
157
158
159 var handshaker = func(driver.Handshaker) driver.Handshaker {
160 return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(clock).
161 ServerAPI(serverAPI).LoadBalanced(loadBalanced)
162 }
163
164 if co.Auth != nil {
165 cred := &auth.Cred{
166 Username: co.Auth.Username,
167 Password: co.Auth.Password,
168 PasswordSet: co.Auth.PasswordSet,
169 Props: co.Auth.AuthMechanismProperties,
170 Source: co.Auth.AuthSource,
171 }
172 mechanism := co.Auth.AuthMechanism
173
174 if len(cred.Source) == 0 {
175 switch strings.ToUpper(mechanism) {
176 case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN:
177 cred.Source = "$external"
178 default:
179 cred.Source = "admin"
180 }
181 }
182
183 authenticator, err := auth.CreateAuthenticator(mechanism, cred)
184 if err != nil {
185 return nil, err
186 }
187
188 handshakeOpts := &auth.HandshakeOptions{
189 AppName: appName,
190 Authenticator: authenticator,
191 Compressors: comps,
192 ServerAPI: serverAPI,
193 LoadBalanced: loadBalanced,
194 ClusterClock: clock,
195 HTTPClient: co.HTTPClient,
196 }
197
198 if mechanism == "" {
199
200 handshakeOpts.DBUser = cred.Source + "." + cred.Username
201 }
202 if co.AuthenticateToAnything != nil && *co.AuthenticateToAnything {
203
204 handshakeOpts.PerformAuthentication = func(serv description.Server) bool {
205 return true
206 }
207 }
208
209 handshaker = func(driver.Handshaker) driver.Handshaker {
210 return auth.Handshaker(nil, handshakeOpts)
211 }
212 }
213 connOpts = append(connOpts, WithHandshaker(handshaker))
214
215 if co.ConnectTimeout != nil {
216 serverOpts = append(serverOpts, WithHeartbeatTimeout(
217 func(time.Duration) time.Duration { return *co.ConnectTimeout },
218 ))
219 connOpts = append(connOpts, WithConnectTimeout(
220 func(time.Duration) time.Duration { return *co.ConnectTimeout },
221 ))
222 }
223
224 if co.Dialer != nil {
225 connOpts = append(connOpts, WithDialer(
226 func(Dialer) Dialer { return co.Dialer },
227 ))
228 }
229
230 if co.Direct != nil && *co.Direct {
231 cfgp.Mode = SingleMode
232 }
233
234
235 if co.HeartbeatInterval != nil {
236 serverOpts = append(serverOpts, WithHeartbeatInterval(
237 func(time.Duration) time.Duration { return *co.HeartbeatInterval },
238 ))
239 }
240
241 cfgp.SeedList = []string{"localhost:27017"}
242 if len(co.Hosts) > 0 {
243 cfgp.SeedList = co.Hosts
244 }
245
246
247 if co.MaxConnIdleTime != nil {
248 serverOpts = append(serverOpts, WithConnectionPoolMaxIdleTime(
249 func(time.Duration) time.Duration { return *co.MaxConnIdleTime },
250 ))
251 }
252
253 if co.MaxPoolSize != nil {
254 serverOpts = append(
255 serverOpts,
256 WithMaxConnections(func(uint64) uint64 { return *co.MaxPoolSize }),
257 )
258 }
259
260 if co.MinPoolSize != nil {
261 serverOpts = append(
262 serverOpts,
263 WithMinConnections(func(uint64) uint64 { return *co.MinPoolSize }),
264 )
265 }
266
267 if co.MaxConnecting != nil {
268 serverOpts = append(
269 serverOpts,
270 WithMaxConnecting(func(uint64) uint64 { return *co.MaxConnecting }),
271 )
272 }
273
274 if co.PoolMonitor != nil {
275 serverOpts = append(
276 serverOpts,
277 WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return co.PoolMonitor }),
278 )
279 }
280
281 if co.Monitor != nil {
282 connOpts = append(connOpts, WithMonitor(
283 func(*event.CommandMonitor) *event.CommandMonitor { return co.Monitor },
284 ))
285 }
286
287 if co.ServerMonitor != nil {
288 serverOpts = append(
289 serverOpts,
290 WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return co.ServerMonitor }),
291 )
292 cfgp.ServerMonitor = co.ServerMonitor
293 }
294
295 if co.ReplicaSet != nil {
296 cfgp.ReplicaSetName = *co.ReplicaSet
297 }
298
299 if co.ServerSelectionTimeout != nil {
300 cfgp.ServerSelectionTimeout = *co.ServerSelectionTimeout
301 }
302
303 if co.SocketTimeout != nil {
304 connOpts = append(
305 connOpts,
306 WithReadTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
307 WithWriteTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
308 )
309 }
310
311 if co.TLSConfig != nil {
312 connOpts = append(connOpts, WithTLSConfig(
313 func(*tls.Config) *tls.Config {
314 return co.TLSConfig
315 },
316 ))
317 }
318
319
320 if co.HTTPClient != nil {
321 connOpts = append(connOpts, WithHTTPClient(
322 func(*http.Client) *http.Client {
323 return co.HTTPClient
324 },
325 ))
326 }
327
328
329 ocspCache := ocsp.NewCache()
330 connOpts = append(
331 connOpts,
332 WithOCSPCache(func(ocsp.Cache) ocsp.Cache { return ocspCache }),
333 )
334
335
336 if co.DisableOCSPEndpointCheck != nil {
337 connOpts = append(
338 connOpts,
339 WithDisableOCSPEndpointCheck(func(bool) bool { return *co.DisableOCSPEndpointCheck }),
340 )
341 }
342
343
344 if co.LoadBalanced != nil {
345 cfgp.LoadBalanced = *co.LoadBalanced
346
347 serverOpts = append(
348 serverOpts,
349 WithServerLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
350 )
351 connOpts = append(
352 connOpts,
353 WithConnectionLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
354 )
355 }
356
357 lgr, err := newLogger(co.LoggerOptions)
358 if err != nil {
359 return nil, err
360 }
361
362 serverOpts = append(
363 serverOpts,
364 withLogger(func() *logger.Logger { return lgr }),
365 withServerMonitoringMode(co.ServerMonitoringMode),
366 )
367
368 cfgp.logger = lgr
369
370 serverOpts = append(
371 serverOpts,
372 WithClock(func(*session.ClusterClock) *session.ClusterClock { return clock }),
373 WithConnectionOptions(func(...ConnectionOption) []ConnectionOption { return connOpts }))
374
375 cfgp.ServerOpts = serverOpts
376
377 return cfgp, nil
378 }
379
View as plain text