...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/auth/scram_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/auth

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package auth
     8  
     9  import (
    10  	"context"
    11  	"testing"
    12  
    13  	"go.mongodb.org/mongo-driver/internal/assert"
    14  	"go.mongodb.org/mongo-driver/mongo/description"
    15  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    16  	"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
    17  )
    18  
    19  const (
    20  	scramSha1Nonce   = "fyko+d2lbbFgONRv9qkxdawL"
    21  	scramSha256Nonce = "rOprNGfwEbeRWgbNEkqO"
    22  )
    23  
    24  var (
    25  	scramSha1ShortPayloads = [][]byte{
    26  		[]byte("r=fyko+d2lbbFgONRv9qkxdawLHo+Vgk7qvUOKUwuWLIWg4l/9SraGMHEE,s=rQ9ZY3MntBeuP3E1TDVC4w==,i=10000"),
    27  		[]byte("v=UMWeI25JD1yNYZRMpZ4VHvhZ9e0="),
    28  	}
    29  	scramSha256ShortPayloads = [][]byte{
    30  		[]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"),
    31  		[]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="),
    32  	}
    33  	scramSha1LongPayloads   = append(scramSha1ShortPayloads, []byte{})
    34  	scramSha256LongPayloads = append(scramSha256ShortPayloads, []byte{})
    35  )
    36  
    37  func TestSCRAM(t *testing.T) {
    38  	t.Run("conversation", func(t *testing.T) {
    39  		testCases := []struct {
    40  			name                  string
    41  			createAuthenticatorFn func(*Cred) (Authenticator, error)
    42  			payloads              [][]byte
    43  			nonce                 string
    44  		}{
    45  			{"scram-sha-1 short conversation", newScramSHA1Authenticator, scramSha1ShortPayloads, scramSha1Nonce},
    46  			{"scram-sha-256 short conversation", newScramSHA256Authenticator, scramSha256ShortPayloads, scramSha256Nonce},
    47  			{"scram-sha-1 long conversation", newScramSHA1Authenticator, scramSha1LongPayloads, scramSha1Nonce},
    48  			{"scram-sha-256 long conversation", newScramSHA256Authenticator, scramSha256LongPayloads, scramSha256Nonce},
    49  		}
    50  		for _, tc := range testCases {
    51  			t.Run(tc.name, func(t *testing.T) {
    52  				authenticator, err := tc.createAuthenticatorFn(&Cred{
    53  					Username: "user",
    54  					Password: "pencil",
    55  					Source:   "admin",
    56  				})
    57  				assert.Nil(t, err, "error creating authenticator: %v", err)
    58  				sa, _ := authenticator.(*ScramAuthenticator)
    59  				sa.client = sa.client.WithNonceGenerator(func() string {
    60  					return tc.nonce
    61  				})
    62  
    63  				responses := make(chan []byte, len(tc.payloads))
    64  				writeReplies(responses, createSCRAMConversation(tc.payloads)...)
    65  
    66  				desc := description.Server{
    67  					WireVersion: &description.VersionRange{
    68  						Max: 21,
    69  					},
    70  				}
    71  				conn := &drivertest.ChannelConn{
    72  					Written:  make(chan []byte, len(tc.payloads)),
    73  					ReadResp: responses,
    74  					Desc:     desc,
    75  				}
    76  
    77  				err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn})
    78  				assert.Nil(t, err, "Auth error: %v\n", err)
    79  
    80  				// Verify that the first command sent is saslStart.
    81  				assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection")
    82  				startCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
    83  				assert.Nil(t, err, "error parsing wire message: %v", err)
    84  				cmdName := startCmd.Index(0).Key()
    85  				assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName)
    86  
    87  				// Verify that the saslStart command always has {options: {skipEmptyExchange: true}}
    88  				optionsVal, err := startCmd.LookupErr("options")
    89  				assert.Nil(t, err, "no options found in saslStart command")
    90  				optionsDoc := optionsVal.Document()
    91  				assert.Equal(t, optionsDoc, scramStartOptions, "expected options %v, got %v", scramStartOptions, optionsDoc)
    92  			})
    93  		}
    94  	})
    95  }
    96  
    97  func createSCRAMConversation(payloads [][]byte) []bsoncore.Document {
    98  	responses := make([]bsoncore.Document, len(payloads))
    99  	for idx, payload := range payloads {
   100  		res := createSCRAMServerResponse(payload, idx == len(payloads)-1)
   101  		responses[idx] = res
   102  	}
   103  	return responses
   104  }
   105  
   106  func createSCRAMServerResponse(payload []byte, done bool) bsoncore.Document {
   107  	return bsoncore.BuildDocumentFromElements(nil,
   108  		bsoncore.AppendInt32Element(nil, "conversationId", 1),
   109  		bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
   110  		bsoncore.AppendBooleanElement(nil, "done", done),
   111  		bsoncore.AppendInt32Element(nil, "ok", 1),
   112  	)
   113  }
   114  
   115  func writeReplies(c chan []byte, docs ...bsoncore.Document) {
   116  	for _, doc := range docs {
   117  		reply := drivertest.MakeReply(doc)
   118  		c <- reply
   119  	}
   120  }
   121  

View as plain text