1
2
3
4
5
6
7 package auth
8
9 import (
10 "context"
11 "fmt"
12
13 "go.mongodb.org/mongo-driver/bson"
14 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
15 "go.mongodb.org/mongo-driver/x/mongo/driver"
16 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
17 )
18
19
20 type SaslClient interface {
21 Start() (string, []byte, error)
22 Next(challenge []byte) ([]byte, error)
23 Completed() bool
24 }
25
26
27 type SaslClientCloser interface {
28 SaslClient
29 Close()
30 }
31
32
33 type ExtraOptionsSaslClient interface {
34 StartCommandOptions() bsoncore.Document
35 }
36
37
38
39 type saslConversation struct {
40 client SaslClient
41 source string
42 mechanism string
43 speculative bool
44 }
45
46 var _ SpeculativeConversation = (*saslConversation)(nil)
47
48 func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation {
49 authSource := source
50 if authSource == "" {
51 authSource = defaultAuthDB
52 }
53 return &saslConversation{
54 client: client,
55 source: authSource,
56 speculative: speculative,
57 }
58 }
59
60
61
62 func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) {
63 var payload []byte
64 var err error
65 sc.mechanism, payload, err = sc.client.Start()
66 if err != nil {
67 return nil, err
68 }
69
70 saslCmdElements := [][]byte{
71 bsoncore.AppendInt32Element(nil, "saslStart", 1),
72 bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism),
73 bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
74 }
75 if sc.speculative {
76
77
78
79 saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source))
80 }
81 if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok {
82 optionsDoc := extraOptionsClient.StartCommandOptions()
83 saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc))
84 }
85
86 return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil
87 }
88
89 type saslResponse struct {
90 ConversationID int `bson:"conversationId"`
91 Code int `bson:"code"`
92 Done bool `bson:"done"`
93 Payload []byte `bson:"payload"`
94 }
95
96
97 func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error {
98 if closer, ok := sc.client.(SaslClientCloser); ok {
99 defer closer.Close()
100 }
101
102 var saslResp saslResponse
103 err := bson.Unmarshal(firstResponse, &saslResp)
104 if err != nil {
105 fullErr := fmt.Errorf("unmarshal error: %w", err)
106 return newError(fullErr, sc.mechanism)
107 }
108
109 cid := saslResp.ConversationID
110 var payload []byte
111 var rdr bsoncore.Document
112 for {
113 if saslResp.Code != 0 {
114 return newError(err, sc.mechanism)
115 }
116
117 if saslResp.Done && sc.client.Completed() {
118 return nil
119 }
120
121 payload, err = sc.client.Next(saslResp.Payload)
122 if err != nil {
123 return newError(err, sc.mechanism)
124 }
125
126 if saslResp.Done && sc.client.Completed() {
127 return nil
128 }
129
130 doc := bsoncore.BuildDocumentFromElements(nil,
131 bsoncore.AppendInt32Element(nil, "saslContinue", 1),
132 bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
133 bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
134 )
135 saslContinueCmd := operation.NewCommand(doc).
136 Database(sc.source).
137 Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
138 ClusterClock(cfg.ClusterClock).
139 ServerAPI(cfg.ServerAPI)
140
141 err = saslContinueCmd.Execute(ctx)
142 if err != nil {
143 return newError(err, sc.mechanism)
144 }
145 rdr = saslContinueCmd.Result()
146
147 err = bson.Unmarshal(rdr, &saslResp)
148 if err != nil {
149 fullErr := fmt.Errorf("unmarshal error: %w", err)
150 return newError(fullErr, sc.mechanism)
151 }
152 }
153 }
154
155
156 func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error {
157
158 conversation := newSaslConversation(client, authSource, false)
159
160 saslStartDoc, err := conversation.FirstMessage()
161 if err != nil {
162 return newError(err, conversation.mechanism)
163 }
164 saslStartCmd := operation.NewCommand(saslStartDoc).
165 Database(authSource).
166 Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
167 ClusterClock(cfg.ClusterClock).
168 ServerAPI(cfg.ServerAPI)
169 if err := saslStartCmd.Execute(ctx); err != nil {
170 return newError(err, conversation.mechanism)
171 }
172
173 return conversation.Finish(ctx, cfg, saslStartCmd.Result())
174 }
175
View as plain text