1
2
3
4
5
6
7 package auth
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "net/http"
14
15 "go.mongodb.org/mongo-driver/mongo/address"
16 "go.mongodb.org/mongo-driver/mongo/description"
17 "go.mongodb.org/mongo-driver/x/mongo/driver"
18 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
19 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
20 )
21
22
23 type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
24
25 var authFactories = make(map[string]AuthenticatorFactory)
26
27 func init() {
28 RegisterAuthenticatorFactory("", newDefaultAuthenticator)
29 RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
30 RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
31 RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
32 RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
33 RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
34 RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
35 RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
36 }
37
38
39 func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
40 if f, ok := authFactories[name]; ok {
41 return f(cred)
42 }
43
44 return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
45 }
46
47
48 func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
49 authFactories[name] = factory
50 }
51
52
53
54
55 type HandshakeOptions struct {
56 AppName string
57 Authenticator Authenticator
58 Compressors []string
59 DBUser string
60 PerformAuthentication func(description.Server) bool
61 ClusterClock *session.ClusterClock
62 ServerAPI *driver.ServerAPIOptions
63 LoadBalanced bool
64 HTTPClient *http.Client
65 }
66
67 type authHandshaker struct {
68 wrapped driver.Handshaker
69 options *HandshakeOptions
70
71 handshakeInfo driver.HandshakeInformation
72 conversation SpeculativeConversation
73 }
74
75 var _ driver.Handshaker = (*authHandshaker)(nil)
76
77
78
79 func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
80 if ah.wrapped != nil {
81 return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
82 }
83
84 op := operation.NewHello().
85 AppName(ah.options.AppName).
86 Compressors(ah.options.Compressors).
87 SASLSupportedMechs(ah.options.DBUser).
88 ClusterClock(ah.options.ClusterClock).
89 ServerAPI(ah.options.ServerAPI).
90 LoadBalanced(ah.options.LoadBalanced)
91
92 if ah.options.Authenticator != nil {
93 if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
94 var err error
95 ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
96 if err != nil {
97 return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
98 }
99
100 firstMsg, err := ah.conversation.FirstMessage()
101 if err != nil {
102 return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
103 }
104
105 op = op.SpeculativeAuthenticate(firstMsg)
106 }
107 }
108
109 var err error
110 ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
111 if err != nil {
112 return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
113 }
114 return ah.handshakeInfo, nil
115 }
116
117
118 func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
119 performAuth := ah.options.PerformAuthentication
120 if performAuth == nil {
121 performAuth = func(serv description.Server) bool {
122
123 return serv.Kind != description.RSArbiter
124 }
125 }
126
127 desc := conn.Description()
128 if performAuth(desc) && ah.options.Authenticator != nil {
129 cfg := &Config{
130 Description: desc,
131 Connection: conn,
132 ClusterClock: ah.options.ClusterClock,
133 HandshakeInfo: ah.handshakeInfo,
134 ServerAPI: ah.options.ServerAPI,
135 HTTPClient: ah.options.HTTPClient,
136 }
137
138 if err := ah.authenticate(ctx, cfg); err != nil {
139 return newAuthError("auth error", err)
140 }
141 }
142
143 if ah.wrapped == nil {
144 return nil
145 }
146 return ah.wrapped.FinishHandshake(ctx, conn)
147 }
148
149 func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error {
150
151
152 if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
153
154 if ah.conversation == nil {
155 return errors.New("speculative auth was not attempted but the server included a response")
156 }
157 return ah.conversation.Finish(ctx, cfg, speculativeResponse)
158 }
159
160
161
162 return ah.options.Authenticator.Auth(ctx, cfg)
163 }
164
165
166 func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
167 return &authHandshaker{
168 wrapped: h,
169 options: options,
170 }
171 }
172
173
174 type Config struct {
175 Description description.Server
176 Connection driver.Connection
177 ClusterClock *session.ClusterClock
178 HandshakeInfo driver.HandshakeInformation
179 ServerAPI *driver.ServerAPIOptions
180 HTTPClient *http.Client
181 }
182
183
184 type Authenticator interface {
185
186 Auth(context.Context, *Config) error
187 }
188
189 func newAuthError(msg string, inner error) error {
190 return &Error{
191 message: msg,
192 inner: inner,
193 }
194 }
195
196 func newError(err error, mech string) error {
197 return &Error{
198 message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
199 inner: err,
200 }
201 }
202
203
204 type Error struct {
205 message string
206 inner error
207 }
208
209 func (e *Error) Error() string {
210 if e.inner == nil {
211 return e.message
212 }
213 return fmt.Sprintf("%s: %s", e.message, e.inner)
214 }
215
216
217 func (e *Error) Inner() error {
218 return e.inner
219 }
220
221
222 func (e *Error) Unwrap() error {
223 return e.inner
224 }
225
226
227 func (e *Error) Message() string {
228 return e.message
229 }
230
View as plain text