1
2
3
4
5
6
7 package auth
8
9 import (
10 "bytes"
11 "context"
12 "crypto/rand"
13 "encoding/base64"
14 "errors"
15 "fmt"
16 "net/http"
17 "strings"
18 "time"
19
20 "go.mongodb.org/mongo-driver/bson"
21 "go.mongodb.org/mongo-driver/bson/primitive"
22 "go.mongodb.org/mongo-driver/internal/aws/credentials"
23 v4signer "go.mongodb.org/mongo-driver/internal/aws/signer/v4"
24 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25 )
26
27 type clientState int
28
29 const (
30 clientStarting clientState = iota
31 clientFirst
32 clientFinal
33 clientDone
34 )
35
36 type awsConversation struct {
37 state clientState
38 valid bool
39 nonce []byte
40 credentials *credentials.Credentials
41 }
42
43 type serverMessage struct {
44 Nonce primitive.Binary `bson:"s"`
45 Host string `bson:"h"`
46 }
47
48 const (
49 amzDateFormat = "20060102T150405Z"
50 defaultRegion = "us-east-1"
51 maxHostLength = 255
52 responceNonceLength = 64
53 )
54
55
56
57
58
59
60 func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
61 switch ac.state {
62 case clientStarting:
63 ac.state = clientFirst
64 response = ac.firstMsg()
65 case clientFirst:
66 ac.state = clientFinal
67 response, err = ac.finalMsg(challenge)
68 case clientFinal:
69 ac.state = clientDone
70 ac.valid = true
71 default:
72 response, err = nil, errors.New("Conversation already completed")
73 }
74 return
75 }
76
77
78 func (ac *awsConversation) Done() bool {
79 return ac.state == clientDone
80 }
81
82
83
84
85 func (ac *awsConversation) Valid() bool {
86 return ac.valid
87 }
88
89 func getRegion(host string) (string, error) {
90 region := defaultRegion
91
92 if len(host) == 0 {
93 return "", errors.New("invalid STS host: empty")
94 }
95 if len(host) > maxHostLength {
96 return "", errors.New("invalid STS host: too large")
97 }
98
99 if host == "sts.amazonaws.com" {
100 return region, nil
101 }
102 if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
103 return "", errors.New("invalid STS host: empty part")
104 }
105
106
107 parts := strings.Split(host, ".")
108 if len(parts) >= 2 {
109 region = parts[1]
110 }
111
112 return region, nil
113 }
114
115 func (ac *awsConversation) firstMsg() []byte {
116
117 ac.nonce = make([]byte, 32)
118 _, _ = rand.Read(ac.nonce)
119
120 idx, msg := bsoncore.AppendDocumentStart(nil)
121 msg = bsoncore.AppendInt32Element(msg, "p", 110)
122 msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
123 msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
124 return msg
125 }
126
127 func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
128 var sm serverMessage
129 err := bson.Unmarshal(s1, &sm)
130 if err != nil {
131 return nil, err
132 }
133
134
135 if sm.Nonce.Subtype != 0x00 {
136 return nil, errors.New("server reply contained unexpected binary subtype")
137 }
138 if len(sm.Nonce.Data) != responceNonceLength {
139 return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
140 }
141 if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
142 return nil, errors.New("server nonce did not extend client nonce")
143 }
144
145 region, err := getRegion(sm.Host)
146 if err != nil {
147 return nil, err
148 }
149
150 creds, err := ac.credentials.GetWithContext(context.Background())
151 if err != nil {
152 return nil, err
153 }
154
155 currentTime := time.Now().UTC()
156 body := "Action=GetCallerIdentity&Version=2011-06-15"
157
158
159 req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
160 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
161 req.Header.Set("Content-Length", "43")
162 req.Host = sm.Host
163 req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
164 if len(creds.SessionToken) > 0 {
165 req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
166 }
167 req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
168 req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
169
170
171 signer := v4signer.NewSigner(ac.credentials)
172
173
174 _, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
175 if err != nil {
176 return nil, err
177 }
178
179
180 idx, msg := bsoncore.AppendDocumentStart(nil)
181 msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
182 msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
183 if len(creds.SessionToken) > 0 {
184 msg = bsoncore.AppendStringElement(msg, "t", creds.SessionToken)
185 }
186 msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
187
188 return msg, nil
189 }
190
View as plain text