...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/auth/speculative_scram_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/auth

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package auth
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"testing"
    13  
    14  	"go.mongodb.org/mongo-driver/bson"
    15  	"go.mongodb.org/mongo-driver/internal/assert"
    16  	"go.mongodb.org/mongo-driver/internal/handshake"
    17  	"go.mongodb.org/mongo-driver/mongo/address"
    18  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
    20  )
    21  
    22  var (
    23  	// The base elements for a hello response.
    24  	handshakeHelloElements = [][]byte{
    25  		bsoncore.AppendInt32Element(nil, "ok", 1),
    26  		bsoncore.AppendBooleanElement(nil, handshake.LegacyHelloLowercase, true),
    27  		bsoncore.AppendInt32Element(nil, "maxBsonObjectSize", 16777216),
    28  		bsoncore.AppendInt32Element(nil, "maxMessageSizeBytes", 48000000),
    29  		bsoncore.AppendInt32Element(nil, "minWireVersion", 0),
    30  		bsoncore.AppendInt32Element(nil, "maxWireVersion", 6),
    31  	}
    32  	// The first payload sent by the driver for SCRAM-SHA-1/256 authentication.
    33  	firstScramSha1ClientPayload   = []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL")
    34  	firstScramSha256ClientPayload = []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO")
    35  )
    36  
    37  func TestSpeculativeSCRAM(t *testing.T) {
    38  	cred := &Cred{
    39  		Username:    "user",
    40  		Password:    "pencil",
    41  		PasswordSet: true,
    42  		Source:      "admin",
    43  	}
    44  
    45  	t.Run("speculative response included", func(t *testing.T) {
    46  		// Tests for SCRAM-SHA1 and SCRAM-SHA-256 when the hello response contains a reply to the speculative
    47  		// authentication attempt. The driver should only send a saslContinue after the hello to complete
    48  		// authentication.
    49  
    50  		testCases := []struct {
    51  			name               string
    52  			mechanism          string
    53  			firstClientPayload []byte
    54  			payloads           [][]byte
    55  			nonce              string
    56  		}{
    57  			{"SCRAM-SHA-1", "SCRAM-SHA-1", firstScramSha1ClientPayload, scramSha1ShortPayloads, scramSha1Nonce},
    58  			{"SCRAM-SHA-256", "SCRAM-SHA-256", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
    59  			{"Default", "", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
    60  		}
    61  
    62  		for _, tc := range testCases {
    63  			t.Run(tc.name, func(t *testing.T) {
    64  				// Create a SCRAM authenticator and overwrite the nonce generator to make the conversation
    65  				// deterministic.
    66  				authenticator, err := CreateAuthenticator(tc.mechanism, cred)
    67  				assert.Nil(t, err, "CreateAuthenticator error: %v", err)
    68  				setNonce(t, authenticator, tc.nonce)
    69  
    70  				// Create a Handshaker and fake connection to authenticate.
    71  				handshaker := Handshaker(nil, &HandshakeOptions{
    72  					Authenticator: authenticator,
    73  					DBUser:        "admin.user",
    74  				})
    75  				responses := make(chan []byte, len(tc.payloads))
    76  				writeReplies(responses, createSpeculativeSCRAMHandshake(tc.payloads)...)
    77  
    78  				conn := &drivertest.ChannelConn{
    79  					Written:  make(chan []byte, len(tc.payloads)),
    80  					ReadResp: responses,
    81  				}
    82  
    83  				// Do both parts of the handshake.
    84  				info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
    85  				assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
    86  				assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
    87  				conn.Desc = info.Description // Set conn.Desc so the new description will be used for the authentication.
    88  
    89  				err = handshaker.FinishHandshake(context.Background(), conn)
    90  				assert.Nil(t, err, "FinishHandshake error: %v", err)
    91  				assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
    92  
    93  				// Assert that the driver sent hello with the speculative authentication message.
    94  				assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
    95  					len(tc.payloads), (conn.Written))
    96  				helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
    97  				assert.Nil(t, err, "error parsing hello command: %v", err)
    98  				assertCommandName(t, helloCmd, handshake.LegacyHello)
    99  
   100  				// Assert that the correct document was sent for speculative authentication.
   101  				authDocVal, err := helloCmd.LookupErr("speculativeAuthenticate")
   102  				assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(helloCmd))
   103  				authDoc := authDocVal.Document()
   104  				sentMechanism := tc.mechanism
   105  				if sentMechanism == "" {
   106  					sentMechanism = "SCRAM-SHA-256"
   107  				}
   108  
   109  				expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
   110  					bsoncore.AppendInt32Element(nil, "saslStart", 1),
   111  					bsoncore.AppendStringElement(nil, "mechanism", sentMechanism),
   112  					bsoncore.AppendBinaryElement(nil, "payload", 0x00, tc.firstClientPayload),
   113  					bsoncore.AppendStringElement(nil, "db", "admin"),
   114  					bsoncore.AppendDocumentElement(nil, "options", bsoncore.BuildDocumentFromElements(nil,
   115  						bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
   116  					)),
   117  				)
   118  				assert.True(t, bytes.Equal(expectedAuthDoc, authDoc),
   119  					"expected speculative auth document %s, got %s",
   120  					bson.Raw(expectedAuthDoc),
   121  					authDoc,
   122  				)
   123  
   124  				// Assert that the last command sent in the handshake is saslContinue.
   125  
   126  				saslContinueCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
   127  				assert.Nil(t, err, "error parsing saslContinue command: %v", err)
   128  				assertCommandName(t, saslContinueCmd, "saslContinue")
   129  			})
   130  		}
   131  	})
   132  	t.Run("speculative response not included", func(t *testing.T) {
   133  		// Tests for SCRAM-SHA-1 and SCRAM-SHA-256 when the hello response does not contain a reply to the
   134  		// speculative authentication attempt. The driver should send both saslStart and saslContinue after the initial
   135  		// hello.
   136  
   137  		// There is no test for the default mechanism because we can't control the nonce used for the actual
   138  		// authentication attempt after the speculative attempt fails.
   139  
   140  		testCases := []struct {
   141  			mechanism string
   142  			payloads  [][]byte
   143  			nonce     string
   144  		}{
   145  			{"SCRAM-SHA-1", scramSha1ShortPayloads, scramSha1Nonce},
   146  			{"SCRAM-SHA-256", scramSha256ShortPayloads, scramSha256Nonce},
   147  		}
   148  
   149  		for _, tc := range testCases {
   150  			t.Run(tc.mechanism, func(t *testing.T) {
   151  				authenticator, err := CreateAuthenticator(tc.mechanism, cred)
   152  				assert.Nil(t, err, "CreateAuthenticator error: %v", err)
   153  				setNonce(t, authenticator, tc.nonce)
   154  
   155  				handshaker := Handshaker(nil, &HandshakeOptions{
   156  					Authenticator: authenticator,
   157  					DBUser:        "admin.user",
   158  				})
   159  				numResponses := len(tc.payloads) + 1 // +1 for hello response
   160  				responses := make(chan []byte, numResponses)
   161  				writeReplies(responses, createRegularSCRAMHandshake(tc.payloads)...)
   162  
   163  				conn := &drivertest.ChannelConn{
   164  					Written:  make(chan []byte, numResponses),
   165  					ReadResp: responses,
   166  				}
   167  
   168  				info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
   169  				assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
   170  				assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
   171  					bson.Raw(info.SpeculativeAuthenticate))
   172  				conn.Desc = info.Description
   173  
   174  				err = handshaker.FinishHandshake(context.Background(), conn)
   175  				assert.Nil(t, err, "FinishHandshake error: %v", err)
   176  				assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
   177  
   178  				assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
   179  					numResponses, len(conn.Written))
   180  				hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
   181  				assert.Nil(t, err, "error parsing hello command: %v", err)
   182  				assertCommandName(t, hello, handshake.LegacyHello)
   183  				_, err = hello.LookupErr("speculativeAuthenticate")
   184  				assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
   185  
   186  				saslStart, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
   187  				assert.Nil(t, err, "error parsing saslStart command: %v", err)
   188  				assertCommandName(t, saslStart, "saslStart")
   189  
   190  				saslContinue, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
   191  				assert.Nil(t, err, "error parsing saslContinue command: %v", err)
   192  				assertCommandName(t, saslContinue, "saslContinue")
   193  			})
   194  		}
   195  	})
   196  }
   197  
   198  func setNonce(t *testing.T, authenticator Authenticator, nonce string) {
   199  	t.Helper()
   200  	nonceGenerator := func() string {
   201  		return nonce
   202  	}
   203  
   204  	switch converted := authenticator.(type) {
   205  	case *ScramAuthenticator:
   206  		converted.client = converted.client.WithNonceGenerator(nonceGenerator)
   207  	case *DefaultAuthenticator:
   208  		sa := converted.speculativeAuthenticator.(*ScramAuthenticator)
   209  		sa.client = sa.client.WithNonceGenerator(nonceGenerator)
   210  	default:
   211  		t.Fatalf("invalid authenticator type %T", authenticator)
   212  	}
   213  }
   214  
   215  // createSpeculativeSCRAMHandshake creates the server replies for a successful speculative SCRAM authentication attempt.
   216  // There are two replies:
   217  //
   218  // 1. hello reply containing a "speculativeAuthenticate" document.
   219  // 2. saslContinue reply with done:true
   220  func createSpeculativeSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
   221  	firstAuthResponse := createSCRAMServerResponse(payloads[0], false)
   222  	firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", firstAuthResponse)
   223  	hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
   224  
   225  	responses := []bsoncore.Document{hello}
   226  	for idx := 1; idx < len(payloads); idx++ {
   227  		responses = append(responses, createSCRAMServerResponse(payloads[idx], idx == len(payloads)-1))
   228  	}
   229  	return responses
   230  }
   231  
   232  // createRegularSCRAMHandshake creates the server replies for a handshake + SCRAM authentication attempt. There are
   233  // three replies:
   234  //
   235  // 1. hello reply
   236  // 2. saslStart reply with done:false
   237  // 3. saslContinue reply with done:true
   238  func createRegularSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
   239  	hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
   240  	responses := []bsoncore.Document{hello}
   241  
   242  	for idx, payload := range payloads {
   243  		responses = append(responses, createSCRAMServerResponse(payload, idx == len(payloads)-1))
   244  	}
   245  	return responses
   246  }
   247  
   248  func assertCommandName(t *testing.T, cmd bsoncore.Document, expectedName string) {
   249  	t.Helper()
   250  
   251  	actualName := cmd.Index(0).Key()
   252  	assert.Equal(t, expectedName, actualName, "expected command name '%s', got '%s'", expectedName, actualName)
   253  }
   254  

View as plain text