...
1
2
3
4
5
6
7 package topology
8
9 import (
10 "context"
11 "crypto/tls"
12 "net"
13 "net/http"
14 "time"
15
16 "go.mongodb.org/mongo-driver/bson/primitive"
17 "go.mongodb.org/mongo-driver/event"
18 "go.mongodb.org/mongo-driver/internal/httputil"
19 "go.mongodb.org/mongo-driver/x/mongo/driver"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
21 )
22
23
24 type Dialer interface {
25 DialContext(ctx context.Context, network, address string) (net.Conn, error)
26 }
27
28
29 type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
30
31
32 func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
33 return df(ctx, network, address)
34 }
35
36
37
38
39
40 var DefaultDialer Dialer = &net.Dialer{}
41
42
43
44
45 type Handshaker = driver.Handshaker
46
47
48 type generationNumberFn func(serviceID *primitive.ObjectID) uint64
49
50 type connectionConfig struct {
51 connectTimeout time.Duration
52 dialer Dialer
53 handshaker Handshaker
54 idleTimeout time.Duration
55 cmdMonitor *event.CommandMonitor
56 readTimeout time.Duration
57 writeTimeout time.Duration
58 tlsConfig *tls.Config
59 httpClient *http.Client
60 compressors []string
61 zlibLevel *int
62 zstdLevel *int
63 ocspCache ocsp.Cache
64 disableOCSPEndpointCheck bool
65 tlsConnectionSource tlsConnectionSource
66 loadBalanced bool
67 getGenerationFn generationNumberFn
68 }
69
70 func newConnectionConfig(opts ...ConnectionOption) *connectionConfig {
71 cfg := &connectionConfig{
72 connectTimeout: 30 * time.Second,
73 dialer: nil,
74 tlsConnectionSource: defaultTLSConnectionSource,
75 httpClient: httputil.DefaultHTTPClient,
76 }
77
78 for _, opt := range opts {
79 if opt == nil {
80 continue
81 }
82 opt(cfg)
83 }
84
85 if cfg.dialer == nil {
86
87
88 cfg.dialer = &net.Dialer{}
89 }
90
91 return cfg
92 }
93
94
95 type ConnectionOption func(*connectionConfig)
96
97 func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
98 return func(c *connectionConfig) {
99 c.tlsConnectionSource = fn(c.tlsConnectionSource)
100 }
101 }
102
103
104 func WithCompressors(fn func([]string) []string) ConnectionOption {
105 return func(c *connectionConfig) {
106 c.compressors = fn(c.compressors)
107 }
108 }
109
110
111
112 func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
113 return func(c *connectionConfig) {
114 c.connectTimeout = fn(c.connectTimeout)
115 }
116 }
117
118
119 func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
120 return func(c *connectionConfig) {
121 c.dialer = fn(c.dialer)
122 }
123 }
124
125
126
127 func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
128 return func(c *connectionConfig) {
129 c.handshaker = fn(c.handshaker)
130 }
131 }
132
133
134 func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
135 return func(c *connectionConfig) {
136 c.idleTimeout = fn(c.idleTimeout)
137 }
138 }
139
140
141 func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
142 return func(c *connectionConfig) {
143 c.readTimeout = fn(c.readTimeout)
144 }
145 }
146
147
148 func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
149 return func(c *connectionConfig) {
150 c.writeTimeout = fn(c.writeTimeout)
151 }
152 }
153
154
155 func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
156 return func(c *connectionConfig) {
157 c.tlsConfig = fn(c.tlsConfig)
158 }
159 }
160
161
162 func WithHTTPClient(fn func(*http.Client) *http.Client) ConnectionOption {
163 return func(c *connectionConfig) {
164 c.httpClient = fn(c.httpClient)
165 }
166 }
167
168
169 func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
170 return func(c *connectionConfig) {
171 c.cmdMonitor = fn(c.cmdMonitor)
172 }
173 }
174
175
176 func WithZlibLevel(fn func(*int) *int) ConnectionOption {
177 return func(c *connectionConfig) {
178 c.zlibLevel = fn(c.zlibLevel)
179 }
180 }
181
182
183 func WithZstdLevel(fn func(*int) *int) ConnectionOption {
184 return func(c *connectionConfig) {
185 c.zstdLevel = fn(c.zstdLevel)
186 }
187 }
188
189
190 func WithOCSPCache(fn func(ocsp.Cache) ocsp.Cache) ConnectionOption {
191 return func(c *connectionConfig) {
192 c.ocspCache = fn(c.ocspCache)
193 }
194 }
195
196
197
198
199 func WithDisableOCSPEndpointCheck(fn func(bool) bool) ConnectionOption {
200 return func(c *connectionConfig) {
201 c.disableOCSPEndpointCheck = fn(c.disableOCSPEndpointCheck)
202 }
203 }
204
205
206 func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
207 return func(c *connectionConfig) {
208 c.loadBalanced = fn(c.loadBalanced)
209 }
210 }
211
212 func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
213 return func(c *connectionConfig) {
214 c.getGenerationFn = fn(c.getGenerationFn)
215 }
216 }
217
View as plain text