1
2
3
4
5
6
7 package auth
8
9 import (
10 "bytes"
11 "context"
12 "testing"
13
14 "go.mongodb.org/mongo-driver/bson"
15 "go.mongodb.org/mongo-driver/internal/assert"
16 "go.mongodb.org/mongo-driver/internal/handshake"
17 "go.mongodb.org/mongo-driver/mongo/address"
18 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19 "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
20 )
21
22 var (
23
24 handshakeHelloElements = [][]byte{
25 bsoncore.AppendInt32Element(nil, "ok", 1),
26 bsoncore.AppendBooleanElement(nil, handshake.LegacyHelloLowercase, true),
27 bsoncore.AppendInt32Element(nil, "maxBsonObjectSize", 16777216),
28 bsoncore.AppendInt32Element(nil, "maxMessageSizeBytes", 48000000),
29 bsoncore.AppendInt32Element(nil, "minWireVersion", 0),
30 bsoncore.AppendInt32Element(nil, "maxWireVersion", 6),
31 }
32
33 firstScramSha1ClientPayload = []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL")
34 firstScramSha256ClientPayload = []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO")
35 )
36
37 func TestSpeculativeSCRAM(t *testing.T) {
38 cred := &Cred{
39 Username: "user",
40 Password: "pencil",
41 PasswordSet: true,
42 Source: "admin",
43 }
44
45 t.Run("speculative response included", func(t *testing.T) {
46
47
48
49
50 testCases := []struct {
51 name string
52 mechanism string
53 firstClientPayload []byte
54 payloads [][]byte
55 nonce string
56 }{
57 {"SCRAM-SHA-1", "SCRAM-SHA-1", firstScramSha1ClientPayload, scramSha1ShortPayloads, scramSha1Nonce},
58 {"SCRAM-SHA-256", "SCRAM-SHA-256", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
59 {"Default", "", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
60 }
61
62 for _, tc := range testCases {
63 t.Run(tc.name, func(t *testing.T) {
64
65
66 authenticator, err := CreateAuthenticator(tc.mechanism, cred)
67 assert.Nil(t, err, "CreateAuthenticator error: %v", err)
68 setNonce(t, authenticator, tc.nonce)
69
70
71 handshaker := Handshaker(nil, &HandshakeOptions{
72 Authenticator: authenticator,
73 DBUser: "admin.user",
74 })
75 responses := make(chan []byte, len(tc.payloads))
76 writeReplies(responses, createSpeculativeSCRAMHandshake(tc.payloads)...)
77
78 conn := &drivertest.ChannelConn{
79 Written: make(chan []byte, len(tc.payloads)),
80 ReadResp: responses,
81 }
82
83
84 info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
85 assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
86 assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
87 conn.Desc = info.Description
88
89 err = handshaker.FinishHandshake(context.Background(), conn)
90 assert.Nil(t, err, "FinishHandshake error: %v", err)
91 assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
92
93
94 assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
95 len(tc.payloads), (conn.Written))
96 helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
97 assert.Nil(t, err, "error parsing hello command: %v", err)
98 assertCommandName(t, helloCmd, handshake.LegacyHello)
99
100
101 authDocVal, err := helloCmd.LookupErr("speculativeAuthenticate")
102 assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(helloCmd))
103 authDoc := authDocVal.Document()
104 sentMechanism := tc.mechanism
105 if sentMechanism == "" {
106 sentMechanism = "SCRAM-SHA-256"
107 }
108
109 expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
110 bsoncore.AppendInt32Element(nil, "saslStart", 1),
111 bsoncore.AppendStringElement(nil, "mechanism", sentMechanism),
112 bsoncore.AppendBinaryElement(nil, "payload", 0x00, tc.firstClientPayload),
113 bsoncore.AppendStringElement(nil, "db", "admin"),
114 bsoncore.AppendDocumentElement(nil, "options", bsoncore.BuildDocumentFromElements(nil,
115 bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
116 )),
117 )
118 assert.True(t, bytes.Equal(expectedAuthDoc, authDoc),
119 "expected speculative auth document %s, got %s",
120 bson.Raw(expectedAuthDoc),
121 authDoc,
122 )
123
124
125
126 saslContinueCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
127 assert.Nil(t, err, "error parsing saslContinue command: %v", err)
128 assertCommandName(t, saslContinueCmd, "saslContinue")
129 })
130 }
131 })
132 t.Run("speculative response not included", func(t *testing.T) {
133
134
135
136
137
138
139
140 testCases := []struct {
141 mechanism string
142 payloads [][]byte
143 nonce string
144 }{
145 {"SCRAM-SHA-1", scramSha1ShortPayloads, scramSha1Nonce},
146 {"SCRAM-SHA-256", scramSha256ShortPayloads, scramSha256Nonce},
147 }
148
149 for _, tc := range testCases {
150 t.Run(tc.mechanism, func(t *testing.T) {
151 authenticator, err := CreateAuthenticator(tc.mechanism, cred)
152 assert.Nil(t, err, "CreateAuthenticator error: %v", err)
153 setNonce(t, authenticator, tc.nonce)
154
155 handshaker := Handshaker(nil, &HandshakeOptions{
156 Authenticator: authenticator,
157 DBUser: "admin.user",
158 })
159 numResponses := len(tc.payloads) + 1
160 responses := make(chan []byte, numResponses)
161 writeReplies(responses, createRegularSCRAMHandshake(tc.payloads)...)
162
163 conn := &drivertest.ChannelConn{
164 Written: make(chan []byte, numResponses),
165 ReadResp: responses,
166 }
167
168 info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
169 assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
170 assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
171 bson.Raw(info.SpeculativeAuthenticate))
172 conn.Desc = info.Description
173
174 err = handshaker.FinishHandshake(context.Background(), conn)
175 assert.Nil(t, err, "FinishHandshake error: %v", err)
176 assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
177
178 assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
179 numResponses, len(conn.Written))
180 hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
181 assert.Nil(t, err, "error parsing hello command: %v", err)
182 assertCommandName(t, hello, handshake.LegacyHello)
183 _, err = hello.LookupErr("speculativeAuthenticate")
184 assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
185
186 saslStart, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
187 assert.Nil(t, err, "error parsing saslStart command: %v", err)
188 assertCommandName(t, saslStart, "saslStart")
189
190 saslContinue, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
191 assert.Nil(t, err, "error parsing saslContinue command: %v", err)
192 assertCommandName(t, saslContinue, "saslContinue")
193 })
194 }
195 })
196 }
197
198 func setNonce(t *testing.T, authenticator Authenticator, nonce string) {
199 t.Helper()
200 nonceGenerator := func() string {
201 return nonce
202 }
203
204 switch converted := authenticator.(type) {
205 case *ScramAuthenticator:
206 converted.client = converted.client.WithNonceGenerator(nonceGenerator)
207 case *DefaultAuthenticator:
208 sa := converted.speculativeAuthenticator.(*ScramAuthenticator)
209 sa.client = sa.client.WithNonceGenerator(nonceGenerator)
210 default:
211 t.Fatalf("invalid authenticator type %T", authenticator)
212 }
213 }
214
215
216
217
218
219
220 func createSpeculativeSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
221 firstAuthResponse := createSCRAMServerResponse(payloads[0], false)
222 firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", firstAuthResponse)
223 hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
224
225 responses := []bsoncore.Document{hello}
226 for idx := 1; idx < len(payloads); idx++ {
227 responses = append(responses, createSCRAMServerResponse(payloads[idx], idx == len(payloads)-1))
228 }
229 return responses
230 }
231
232
233
234
235
236
237
238 func createRegularSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
239 hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
240 responses := []bsoncore.Document{hello}
241
242 for idx, payload := range payloads {
243 responses = append(responses, createSCRAMServerResponse(payload, idx == len(payloads)-1))
244 }
245 return responses
246 }
247
248 func assertCommandName(t *testing.T, cmd bsoncore.Document, expectedName string) {
249 t.Helper()
250
251 actualName := cmd.Index(0).Key()
252 assert.Equal(t, expectedName, actualName, "expected command name '%s', got '%s'", expectedName, actualName)
253 }
254
View as plain text