1
2
3
4
5
6
7 package integration
8
9 import (
10 "context"
11 "os"
12 "reflect"
13 "runtime"
14 "testing"
15
16 "go.mongodb.org/mongo-driver/bson"
17 "go.mongodb.org/mongo-driver/internal/assert"
18 "go.mongodb.org/mongo-driver/internal/handshake"
19 "go.mongodb.org/mongo-driver/internal/require"
20 "go.mongodb.org/mongo-driver/mongo/integration/mtest"
21 "go.mongodb.org/mongo-driver/version"
22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
24 )
25
26 func TestHandshakeProse(t *testing.T) {
27 mt := mtest.New(t)
28
29 if len(os.Getenv("DOCKER_RUNNING")) > 0 {
30 t.Skip("These tests gives different results when run in Docker due to extra environment data.")
31 }
32
33 opts := mtest.NewOptions().
34 CreateCollection(false).
35 ClientType(mtest.Proxy)
36
37 clientMetadata := func(env bson.D) bson.D {
38 elems := bson.D{
39 {Key: "driver", Value: bson.D{
40 {Key: "name", Value: "mongo-go-driver"},
41 {Key: "version", Value: version.Driver},
42 }},
43 {Key: "os", Value: bson.D{
44 {Key: "type", Value: runtime.GOOS},
45 {Key: "architecture", Value: runtime.GOARCH},
46 }},
47 }
48
49 elems = append(elems, bson.E{Key: "platform", Value: runtime.Version()})
50
51
52 if env != nil && !reflect.DeepEqual(env, bson.D{}) {
53 elems = append(elems, bson.E{Key: "env", Value: env})
54 }
55
56 return elems
57 }
58
59
60
61 t.Setenv("AWS_EXECUTION_ENV", "")
62 t.Setenv("FUNCTIONS_WORKER_RUNTIME", "")
63 t.Setenv("K_SERVICE", "")
64 t.Setenv("VERCEL", "")
65 t.Setenv("AWS_REGION", "")
66 t.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "")
67 t.Setenv("FUNCTION_MEMORY_MB", "")
68 t.Setenv("FUNCTION_TIMEOUT_SEC", "")
69 t.Setenv("FUNCTION_REGION", "")
70 t.Setenv("VERCEL_REGION", "")
71
72 for _, test := range []struct {
73 name string
74 env map[string]string
75 want bson.D
76 }{
77 {
78 name: "1. valid AWS",
79 env: map[string]string{
80 "AWS_EXECUTION_ENV": "AWS_Lambda_java8",
81 "AWS_REGION": "us-east-2",
82 "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024",
83 },
84 want: clientMetadata(bson.D{
85 {Key: "name", Value: "aws.lambda"},
86 {Key: "memory_mb", Value: 1024},
87 {Key: "region", Value: "us-east-2"},
88 }),
89 },
90 {
91 name: "2. valid Azure",
92 env: map[string]string{
93 "FUNCTIONS_WORKER_RUNTIME": "node",
94 },
95 want: clientMetadata(bson.D{
96 {Key: "name", Value: "azure.func"},
97 }),
98 },
99 {
100 name: "3. valid GCP",
101 env: map[string]string{
102 "K_SERVICE": "servicename",
103 "FUNCTION_MEMORY_MB": "1024",
104 "FUNCTION_TIMEOUT_SEC": "60",
105 "FUNCTION_REGION": "us-central1",
106 },
107 want: clientMetadata(bson.D{
108 {Key: "name", Value: "gcp.func"},
109 {Key: "memory_mb", Value: 1024},
110 {Key: "region", Value: "us-central1"},
111 {Key: "timeout_sec", Value: 60},
112 }),
113 },
114 {
115 name: "4. valid Vercel",
116 env: map[string]string{
117 "VERCEL": "1",
118 "VERCEL_REGION": "cdg1",
119 },
120 want: clientMetadata(bson.D{
121 {Key: "name", Value: "vercel"},
122 {Key: "region", Value: "cdg1"},
123 }),
124 },
125 {
126 name: "5. invalid multiple providers",
127 env: map[string]string{
128 "AWS_EXECUTION_ENV": "AWS_Lambda_java8",
129 "FUNCTIONS_WORKER_RUNTIME": "node",
130 },
131 want: clientMetadata(nil),
132 },
133 {
134 name: "6. invalid long string",
135 env: map[string]string{
136 "AWS_EXECUTION_ENV": "AWS_Lambda_java8",
137 "AWS_REGION": func() string {
138 var s string
139 for i := 0; i < 512; i++ {
140 s += "a"
141 }
142 return s
143 }(),
144 },
145 want: clientMetadata(bson.D{
146 {Key: "name", Value: "aws.lambda"},
147 }),
148 },
149 {
150 name: "7. invalid wrong types",
151 env: map[string]string{
152 "AWS_EXECUTION_ENV": "AWS_Lambda_java8",
153 "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big",
154 },
155 want: clientMetadata(bson.D{
156 {Key: "name", Value: "aws.lambda"},
157 }),
158 },
159 {
160 name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"",
161 env: map[string]string{
162 "AWS_EXECUTION_ENV": "EC2",
163 },
164 want: clientMetadata(nil),
165 },
166 } {
167 test := test
168
169 mt.RunOpts(test.name, opts, func(mt *mtest.T) {
170 for k, v := range test.env {
171 mt.Setenv(k, v)
172 }
173
174
175 err := mt.Client.Ping(context.Background(), nil)
176 require.NoError(mt, err, "Ping error: %v", err)
177
178 messages := mt.GetProxiedMessages()
179 handshakeMessage := messages[:1][0]
180
181 hello := handshake.LegacyHello
182 if os.Getenv("REQUIRE_API_VERSION") == "true" {
183 hello = "hello"
184 }
185
186 assert.Equal(mt, hello, handshakeMessage.CommandName)
187
188
189 clientVal, err := handshakeMessage.Sent.Command.LookupErr("client")
190 require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command)
191
192 got, ok := clientVal.DocumentOK()
193 require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type)
194
195 wantBytes, err := bson.Marshal(test.want)
196 require.NoError(mt, err, "error marshaling want document: %v", err)
197
198 want := bsoncore.Document(wantBytes)
199 assert.Equal(mt, want, got, "want: %v, got: %v", want, got)
200 })
201 }
202 }
203
204 func TestLoadBalancedConnectionHandshake(t *testing.T) {
205 mt := mtest.New(t)
206
207 lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
208 mtest.LoadBalanced)
209
210 mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) {
211
212 err := mt.Client.Ping(context.Background(), nil)
213 require.NoError(mt, err, "Ping error: %v", err)
214
215 messages := mt.GetProxiedMessages()
216 handshakeMessage := messages[:1][0]
217
218
219
220 assert.Equal(mt, "hello", handshakeMessage.CommandName)
221 assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode)
222 })
223
224 opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
225 mtest.ReplicaSet,
226 mtest.Sharded,
227 mtest.Single,
228 mtest.ShardedReplicaSet)
229
230 mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) {
231
232 err := mt.Client.Ping(context.Background(), nil)
233 require.NoError(mt, err, "Ping error: %v", err)
234
235 messages := mt.GetProxiedMessages()
236 handshakeMessage := messages[:1][0]
237
238 want := wiremessage.OpQuery
239
240 hello := handshake.LegacyHello
241 if os.Getenv("REQUIRE_API_VERSION") == "true" {
242 hello = "hello"
243
244
245
246 want = wiremessage.OpMsg
247 }
248
249 assert.Equal(mt, hello, handshakeMessage.CommandName)
250 assert.Equal(mt, want, handshakeMessage.Sent.OpCode)
251 })
252 }
253
View as plain text