...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/auth/auth.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/auth

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package auth
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net/http"
    14  
    15  	"go.mongodb.org/mongo-driver/mongo/address"
    16  	"go.mongodb.org/mongo-driver/mongo/description"
    17  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    20  )
    21  
    22  // AuthenticatorFactory constructs an authenticator.
    23  type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
    24  
    25  var authFactories = make(map[string]AuthenticatorFactory)
    26  
    27  func init() {
    28  	RegisterAuthenticatorFactory("", newDefaultAuthenticator)
    29  	RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
    30  	RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
    31  	RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
    32  	RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
    33  	RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
    34  	RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
    35  	RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
    36  }
    37  
    38  // CreateAuthenticator creates an authenticator.
    39  func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
    40  	if f, ok := authFactories[name]; ok {
    41  		return f(cred)
    42  	}
    43  
    44  	return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
    45  }
    46  
    47  // RegisterAuthenticatorFactory registers the authenticator factory.
    48  func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
    49  	authFactories[name] = factory
    50  }
    51  
    52  // HandshakeOptions packages options that can be passed to the Handshaker()
    53  // function.  DBUser is optional but must be of the form <dbname.username>;
    54  // if non-empty, then the connection will do SASL mechanism negotiation.
    55  type HandshakeOptions struct {
    56  	AppName               string
    57  	Authenticator         Authenticator
    58  	Compressors           []string
    59  	DBUser                string
    60  	PerformAuthentication func(description.Server) bool
    61  	ClusterClock          *session.ClusterClock
    62  	ServerAPI             *driver.ServerAPIOptions
    63  	LoadBalanced          bool
    64  	HTTPClient            *http.Client
    65  }
    66  
    67  type authHandshaker struct {
    68  	wrapped driver.Handshaker
    69  	options *HandshakeOptions
    70  
    71  	handshakeInfo driver.HandshakeInformation
    72  	conversation  SpeculativeConversation
    73  }
    74  
    75  var _ driver.Handshaker = (*authHandshaker)(nil)
    76  
    77  // GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
    78  // connection.
    79  func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
    80  	if ah.wrapped != nil {
    81  		return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
    82  	}
    83  
    84  	op := operation.NewHello().
    85  		AppName(ah.options.AppName).
    86  		Compressors(ah.options.Compressors).
    87  		SASLSupportedMechs(ah.options.DBUser).
    88  		ClusterClock(ah.options.ClusterClock).
    89  		ServerAPI(ah.options.ServerAPI).
    90  		LoadBalanced(ah.options.LoadBalanced)
    91  
    92  	if ah.options.Authenticator != nil {
    93  		if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
    94  			var err error
    95  			ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
    96  			if err != nil {
    97  				return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
    98  			}
    99  
   100  			firstMsg, err := ah.conversation.FirstMessage()
   101  			if err != nil {
   102  				return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
   103  			}
   104  
   105  			op = op.SpeculativeAuthenticate(firstMsg)
   106  		}
   107  	}
   108  
   109  	var err error
   110  	ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
   111  	if err != nil {
   112  		return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
   113  	}
   114  	return ah.handshakeInfo, nil
   115  }
   116  
   117  // FinishHandshake performs authentication for conn if necessary.
   118  func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
   119  	performAuth := ah.options.PerformAuthentication
   120  	if performAuth == nil {
   121  		performAuth = func(serv description.Server) bool {
   122  			// Authentication is possible against all server types except arbiters
   123  			return serv.Kind != description.RSArbiter
   124  		}
   125  	}
   126  
   127  	desc := conn.Description()
   128  	if performAuth(desc) && ah.options.Authenticator != nil {
   129  		cfg := &Config{
   130  			Description:   desc,
   131  			Connection:    conn,
   132  			ClusterClock:  ah.options.ClusterClock,
   133  			HandshakeInfo: ah.handshakeInfo,
   134  			ServerAPI:     ah.options.ServerAPI,
   135  			HTTPClient:    ah.options.HTTPClient,
   136  		}
   137  
   138  		if err := ah.authenticate(ctx, cfg); err != nil {
   139  			return newAuthError("auth error", err)
   140  		}
   141  	}
   142  
   143  	if ah.wrapped == nil {
   144  		return nil
   145  	}
   146  	return ah.wrapped.FinishHandshake(ctx, conn)
   147  }
   148  
   149  func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error {
   150  	// If the initial hello reply included a response to the speculative authentication attempt, we only need to
   151  	// conduct the remainder of the conversation.
   152  	if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
   153  		// Defensively ensure that the server did not include a response if speculative auth was not attempted.
   154  		if ah.conversation == nil {
   155  			return errors.New("speculative auth was not attempted but the server included a response")
   156  		}
   157  		return ah.conversation.Finish(ctx, cfg, speculativeResponse)
   158  	}
   159  
   160  	// If the server does not support speculative authentication or the first attempt was not successful, we need to
   161  	// perform authentication from scratch.
   162  	return ah.options.Authenticator.Auth(ctx, cfg)
   163  }
   164  
   165  // Handshaker creates a connection handshaker for the given authenticator.
   166  func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
   167  	return &authHandshaker{
   168  		wrapped: h,
   169  		options: options,
   170  	}
   171  }
   172  
   173  // Config holds the information necessary to perform an authentication attempt.
   174  type Config struct {
   175  	Description   description.Server
   176  	Connection    driver.Connection
   177  	ClusterClock  *session.ClusterClock
   178  	HandshakeInfo driver.HandshakeInformation
   179  	ServerAPI     *driver.ServerAPIOptions
   180  	HTTPClient    *http.Client
   181  }
   182  
   183  // Authenticator handles authenticating a connection.
   184  type Authenticator interface {
   185  	// Auth authenticates the connection.
   186  	Auth(context.Context, *Config) error
   187  }
   188  
   189  func newAuthError(msg string, inner error) error {
   190  	return &Error{
   191  		message: msg,
   192  		inner:   inner,
   193  	}
   194  }
   195  
   196  func newError(err error, mech string) error {
   197  	return &Error{
   198  		message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
   199  		inner:   err,
   200  	}
   201  }
   202  
   203  // Error is an error that occurred during authentication.
   204  type Error struct {
   205  	message string
   206  	inner   error
   207  }
   208  
   209  func (e *Error) Error() string {
   210  	if e.inner == nil {
   211  		return e.message
   212  	}
   213  	return fmt.Sprintf("%s: %s", e.message, e.inner)
   214  }
   215  
   216  // Inner returns the wrapped error.
   217  func (e *Error) Inner() error {
   218  	return e.inner
   219  }
   220  
   221  // Unwrap returns the underlying error.
   222  func (e *Error) Unwrap() error {
   223  	return e.inner
   224  }
   225  
   226  // Message returns the message.
   227  func (e *Error) Message() string {
   228  	return e.message
   229  }
   230  

View as plain text