...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/auth/aws_conv.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/auth

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package auth
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/rand"
    13  	"encoding/base64"
    14  	"errors"
    15  	"fmt"
    16  	"net/http"
    17  	"strings"
    18  	"time"
    19  
    20  	"go.mongodb.org/mongo-driver/bson"
    21  	"go.mongodb.org/mongo-driver/bson/primitive"
    22  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    23  	v4signer "go.mongodb.org/mongo-driver/internal/aws/signer/v4"
    24  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    25  )
    26  
    27  type clientState int
    28  
    29  const (
    30  	clientStarting clientState = iota
    31  	clientFirst
    32  	clientFinal
    33  	clientDone
    34  )
    35  
    36  type awsConversation struct {
    37  	state       clientState
    38  	valid       bool
    39  	nonce       []byte
    40  	credentials *credentials.Credentials
    41  }
    42  
    43  type serverMessage struct {
    44  	Nonce primitive.Binary `bson:"s"`
    45  	Host  string           `bson:"h"`
    46  }
    47  
    48  const (
    49  	amzDateFormat       = "20060102T150405Z"
    50  	defaultRegion       = "us-east-1"
    51  	maxHostLength       = 255
    52  	responceNonceLength = 64
    53  )
    54  
    55  // Step takes a string provided from a server (or just an empty string for the
    56  // very first conversation step) and attempts to move the authentication
    57  // conversation forward.  It returns a string to be sent to the server or an
    58  // error if the server message is invalid.  Calling Step after a conversation
    59  // completes is also an error.
    60  func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
    61  	switch ac.state {
    62  	case clientStarting:
    63  		ac.state = clientFirst
    64  		response = ac.firstMsg()
    65  	case clientFirst:
    66  		ac.state = clientFinal
    67  		response, err = ac.finalMsg(challenge)
    68  	case clientFinal:
    69  		ac.state = clientDone
    70  		ac.valid = true
    71  	default:
    72  		response, err = nil, errors.New("Conversation already completed")
    73  	}
    74  	return
    75  }
    76  
    77  // Done returns true if the conversation is completed or has errored.
    78  func (ac *awsConversation) Done() bool {
    79  	return ac.state == clientDone
    80  }
    81  
    82  // Valid returns true if the conversation successfully authenticated with the
    83  // server, including counter-validation that the server actually has the
    84  // user's stored credentials.
    85  func (ac *awsConversation) Valid() bool {
    86  	return ac.valid
    87  }
    88  
    89  func getRegion(host string) (string, error) {
    90  	region := defaultRegion
    91  
    92  	if len(host) == 0 {
    93  		return "", errors.New("invalid STS host: empty")
    94  	}
    95  	if len(host) > maxHostLength {
    96  		return "", errors.New("invalid STS host: too large")
    97  	}
    98  	// The implicit region for sts.amazonaws.com is us-east-1
    99  	if host == "sts.amazonaws.com" {
   100  		return region, nil
   101  	}
   102  	if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
   103  		return "", errors.New("invalid STS host: empty part")
   104  	}
   105  
   106  	// If the host has multiple parts, the second part is the region
   107  	parts := strings.Split(host, ".")
   108  	if len(parts) >= 2 {
   109  		region = parts[1]
   110  	}
   111  
   112  	return region, nil
   113  }
   114  
   115  func (ac *awsConversation) firstMsg() []byte {
   116  	// Values are cached for use in final message parameters
   117  	ac.nonce = make([]byte, 32)
   118  	_, _ = rand.Read(ac.nonce)
   119  
   120  	idx, msg := bsoncore.AppendDocumentStart(nil)
   121  	msg = bsoncore.AppendInt32Element(msg, "p", 110)
   122  	msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
   123  	msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
   124  	return msg
   125  }
   126  
   127  func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
   128  	var sm serverMessage
   129  	err := bson.Unmarshal(s1, &sm)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	// Check nonce prefix
   135  	if sm.Nonce.Subtype != 0x00 {
   136  		return nil, errors.New("server reply contained unexpected binary subtype")
   137  	}
   138  	if len(sm.Nonce.Data) != responceNonceLength {
   139  		return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
   140  	}
   141  	if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
   142  		return nil, errors.New("server nonce did not extend client nonce")
   143  	}
   144  
   145  	region, err := getRegion(sm.Host)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	creds, err := ac.credentials.GetWithContext(context.Background())
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	currentTime := time.Now().UTC()
   156  	body := "Action=GetCallerIdentity&Version=2011-06-15"
   157  
   158  	// Create http.Request
   159  	req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
   160  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   161  	req.Header.Set("Content-Length", "43")
   162  	req.Host = sm.Host
   163  	req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
   164  	if len(creds.SessionToken) > 0 {
   165  		req.Header.Set("X-Amz-Security-Token", creds.SessionToken)
   166  	}
   167  	req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
   168  	req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
   169  
   170  	// Create signer with credentials
   171  	signer := v4signer.NewSigner(ac.credentials)
   172  
   173  	// Get signed header
   174  	_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	// create message
   180  	idx, msg := bsoncore.AppendDocumentStart(nil)
   181  	msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
   182  	msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
   183  	if len(creds.SessionToken) > 0 {
   184  		msg = bsoncore.AppendStringElement(msg, "t", creds.SessionToken)
   185  	}
   186  	msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
   187  
   188  	return msg, nil
   189  }
   190  

View as plain text