...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/handshake_test.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package integration
     8  
     9  import (
    10  	"context"
    11  	"os"
    12  	"reflect"
    13  	"runtime"
    14  	"testing"
    15  
    16  	"go.mongodb.org/mongo-driver/bson"
    17  	"go.mongodb.org/mongo-driver/internal/assert"
    18  	"go.mongodb.org/mongo-driver/internal/handshake"
    19  	"go.mongodb.org/mongo-driver/internal/require"
    20  	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
    21  	"go.mongodb.org/mongo-driver/version"
    22  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    23  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    24  )
    25  
    26  func TestHandshakeProse(t *testing.T) {
    27  	mt := mtest.New(t)
    28  
    29  	if len(os.Getenv("DOCKER_RUNNING")) > 0 {
    30  		t.Skip("These tests gives different results when run in Docker due to extra environment data.")
    31  	}
    32  
    33  	opts := mtest.NewOptions().
    34  		CreateCollection(false).
    35  		ClientType(mtest.Proxy)
    36  
    37  	clientMetadata := func(env bson.D) bson.D {
    38  		elems := bson.D{
    39  			{Key: "driver", Value: bson.D{
    40  				{Key: "name", Value: "mongo-go-driver"},
    41  				{Key: "version", Value: version.Driver},
    42  			}},
    43  			{Key: "os", Value: bson.D{
    44  				{Key: "type", Value: runtime.GOOS},
    45  				{Key: "architecture", Value: runtime.GOARCH},
    46  			}},
    47  		}
    48  
    49  		elems = append(elems, bson.E{Key: "platform", Value: runtime.Version()})
    50  
    51  		// If env is empty, don't include it in the metadata.
    52  		if env != nil && !reflect.DeepEqual(env, bson.D{}) {
    53  			elems = append(elems, bson.E{Key: "env", Value: env})
    54  		}
    55  
    56  		return elems
    57  	}
    58  
    59  	// Reset the environment variables to avoid environment namespace
    60  	// collision.
    61  	t.Setenv("AWS_EXECUTION_ENV", "")
    62  	t.Setenv("FUNCTIONS_WORKER_RUNTIME", "")
    63  	t.Setenv("K_SERVICE", "")
    64  	t.Setenv("VERCEL", "")
    65  	t.Setenv("AWS_REGION", "")
    66  	t.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "")
    67  	t.Setenv("FUNCTION_MEMORY_MB", "")
    68  	t.Setenv("FUNCTION_TIMEOUT_SEC", "")
    69  	t.Setenv("FUNCTION_REGION", "")
    70  	t.Setenv("VERCEL_REGION", "")
    71  
    72  	for _, test := range []struct {
    73  		name string
    74  		env  map[string]string
    75  		want bson.D
    76  	}{
    77  		{
    78  			name: "1. valid AWS",
    79  			env: map[string]string{
    80  				"AWS_EXECUTION_ENV":               "AWS_Lambda_java8",
    81  				"AWS_REGION":                      "us-east-2",
    82  				"AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024",
    83  			},
    84  			want: clientMetadata(bson.D{
    85  				{Key: "name", Value: "aws.lambda"},
    86  				{Key: "memory_mb", Value: 1024},
    87  				{Key: "region", Value: "us-east-2"},
    88  			}),
    89  		},
    90  		{
    91  			name: "2. valid Azure",
    92  			env: map[string]string{
    93  				"FUNCTIONS_WORKER_RUNTIME": "node",
    94  			},
    95  			want: clientMetadata(bson.D{
    96  				{Key: "name", Value: "azure.func"},
    97  			}),
    98  		},
    99  		{
   100  			name: "3. valid GCP",
   101  			env: map[string]string{
   102  				"K_SERVICE":            "servicename",
   103  				"FUNCTION_MEMORY_MB":   "1024",
   104  				"FUNCTION_TIMEOUT_SEC": "60",
   105  				"FUNCTION_REGION":      "us-central1",
   106  			},
   107  			want: clientMetadata(bson.D{
   108  				{Key: "name", Value: "gcp.func"},
   109  				{Key: "memory_mb", Value: 1024},
   110  				{Key: "region", Value: "us-central1"},
   111  				{Key: "timeout_sec", Value: 60},
   112  			}),
   113  		},
   114  		{
   115  			name: "4. valid Vercel",
   116  			env: map[string]string{
   117  				"VERCEL":        "1",
   118  				"VERCEL_REGION": "cdg1",
   119  			},
   120  			want: clientMetadata(bson.D{
   121  				{Key: "name", Value: "vercel"},
   122  				{Key: "region", Value: "cdg1"},
   123  			}),
   124  		},
   125  		{
   126  			name: "5. invalid multiple providers",
   127  			env: map[string]string{
   128  				"AWS_EXECUTION_ENV":        "AWS_Lambda_java8",
   129  				"FUNCTIONS_WORKER_RUNTIME": "node",
   130  			},
   131  			want: clientMetadata(nil),
   132  		},
   133  		{
   134  			name: "6. invalid long string",
   135  			env: map[string]string{
   136  				"AWS_EXECUTION_ENV": "AWS_Lambda_java8",
   137  				"AWS_REGION": func() string {
   138  					var s string
   139  					for i := 0; i < 512; i++ {
   140  						s += "a"
   141  					}
   142  					return s
   143  				}(),
   144  			},
   145  			want: clientMetadata(bson.D{
   146  				{Key: "name", Value: "aws.lambda"},
   147  			}),
   148  		},
   149  		{
   150  			name: "7. invalid wrong types",
   151  			env: map[string]string{
   152  				"AWS_EXECUTION_ENV":               "AWS_Lambda_java8",
   153  				"AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big",
   154  			},
   155  			want: clientMetadata(bson.D{
   156  				{Key: "name", Value: "aws.lambda"},
   157  			}),
   158  		},
   159  		{
   160  			name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"",
   161  			env: map[string]string{
   162  				"AWS_EXECUTION_ENV": "EC2",
   163  			},
   164  			want: clientMetadata(nil),
   165  		},
   166  	} {
   167  		test := test
   168  
   169  		mt.RunOpts(test.name, opts, func(mt *mtest.T) {
   170  			for k, v := range test.env {
   171  				mt.Setenv(k, v)
   172  			}
   173  
   174  			// Ping the server to ensure the handshake has completed.
   175  			err := mt.Client.Ping(context.Background(), nil)
   176  			require.NoError(mt, err, "Ping error: %v", err)
   177  
   178  			messages := mt.GetProxiedMessages()
   179  			handshakeMessage := messages[:1][0]
   180  
   181  			hello := handshake.LegacyHello
   182  			if os.Getenv("REQUIRE_API_VERSION") == "true" {
   183  				hello = "hello"
   184  			}
   185  
   186  			assert.Equal(mt, hello, handshakeMessage.CommandName)
   187  
   188  			// Lookup the "client" field in the command document.
   189  			clientVal, err := handshakeMessage.Sent.Command.LookupErr("client")
   190  			require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command)
   191  
   192  			got, ok := clientVal.DocumentOK()
   193  			require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type)
   194  
   195  			wantBytes, err := bson.Marshal(test.want)
   196  			require.NoError(mt, err, "error marshaling want document: %v", err)
   197  
   198  			want := bsoncore.Document(wantBytes)
   199  			assert.Equal(mt, want, got, "want: %v, got: %v", want, got)
   200  		})
   201  	}
   202  }
   203  
   204  func TestLoadBalancedConnectionHandshake(t *testing.T) {
   205  	mt := mtest.New(t)
   206  
   207  	lbopts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
   208  		mtest.LoadBalanced)
   209  
   210  	mt.RunOpts("LB connection handshake uses OP_MSG", lbopts, func(mt *mtest.T) {
   211  		// Ping the server to ensure the handshake has completed.
   212  		err := mt.Client.Ping(context.Background(), nil)
   213  		require.NoError(mt, err, "Ping error: %v", err)
   214  
   215  		messages := mt.GetProxiedMessages()
   216  		handshakeMessage := messages[:1][0]
   217  
   218  		// Per the specifications, if loadBalanced=true, drivers MUST use the hello
   219  		// command for the initial handshake and use the OP_MSG protocol.
   220  		assert.Equal(mt, "hello", handshakeMessage.CommandName)
   221  		assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode)
   222  	})
   223  
   224  	opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies(
   225  		mtest.ReplicaSet,
   226  		mtest.Sharded,
   227  		mtest.Single,
   228  		mtest.ShardedReplicaSet)
   229  
   230  	mt.RunOpts("non-LB connection handshake uses OP_QUERY", opts, func(mt *mtest.T) {
   231  		// Ping the server to ensure the handshake has completed.
   232  		err := mt.Client.Ping(context.Background(), nil)
   233  		require.NoError(mt, err, "Ping error: %v", err)
   234  
   235  		messages := mt.GetProxiedMessages()
   236  		handshakeMessage := messages[:1][0]
   237  
   238  		want := wiremessage.OpQuery
   239  
   240  		hello := handshake.LegacyHello
   241  		if os.Getenv("REQUIRE_API_VERSION") == "true" {
   242  			hello = "hello"
   243  
   244  			// If the server API version is requested, then we should use OP_MSG
   245  			// regardless of the topology
   246  			want = wiremessage.OpMsg
   247  		}
   248  
   249  		assert.Equal(mt, hello, handshakeMessage.CommandName)
   250  		assert.Equal(mt, want, handshakeMessage.Sent.OpCode)
   251  	})
   252  }
   253  

View as plain text