...

Source file src/github.com/twmb/franz-go/pkg/sasl/scram/scram.go

Documentation: github.com/twmb/franz-go/pkg/sasl/scram

     1  // Package scram provides SCRAM-SHA-256 and SCRAM-SHA-512 sasl authentication
     2  // as specified in RFC5802.
     3  package scram
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/hmac"
     9  	"crypto/rand"
    10  	"crypto/sha256"
    11  	"crypto/sha512"
    12  	"encoding/base64"
    13  	"errors"
    14  	"fmt"
    15  	"hash"
    16  	"strconv"
    17  	"strings"
    18  
    19  	"golang.org/x/crypto/pbkdf2"
    20  
    21  	"github.com/twmb/franz-go/pkg/sasl"
    22  )
    23  
    24  // Auth contains information for authentication.
    25  //
    26  // This client may add fields to this struct in the future if Kafka adds more
    27  // extensions to SCRAM.
    28  type Auth struct {
    29  	// Zid is an optional authorization ID to use in authenticating.
    30  	Zid string
    31  
    32  	// User is username to use for authentication.
    33  	//
    34  	// Note that this package does not attempt to "prepare" the username
    35  	// for authentication; this package assumes that the incoming username
    36  	// has already been prepared / does not need preparing.
    37  	//
    38  	// Preparing simply normalizes case / removes invalid characters; doing
    39  	// so is likely not necessary.
    40  	User string
    41  
    42  	// Pass is the password to use for authentication.
    43  	Pass string
    44  
    45  	// Nonce, if provided, is the nonce to use for authentication. If not
    46  	// provided, this package uses 20 bytes read with crypto/rand.
    47  	Nonce []byte
    48  
    49  	// IsToken, if true, suffixes the "tokenauth=true" extra attribute to
    50  	// the initial authentication message.
    51  	//
    52  	// Set this to true if the user and pass are from a delegation token.
    53  	IsToken bool
    54  
    55  	_ struct{} // require explicit field initialization
    56  }
    57  
    58  // AsSha256Mechanism returns a sasl mechanism that will use 'a' as credentials
    59  // for all sasl sessions.
    60  //
    61  // This is a shortcut for using the Sha256 function and is useful when you do
    62  // not need to live-rotate credentials.
    63  func (a Auth) AsSha256Mechanism() sasl.Mechanism {
    64  	return Sha256(func(context.Context) (Auth, error) {
    65  		return a, nil
    66  	})
    67  }
    68  
    69  // AsSha512Mechanism returns a sasl mechanism that will use 'a' as credentials
    70  // for all sasl sessions.
    71  //
    72  // This is a shortcut for using the Sha512 function and is useful when you do
    73  // not need to live-rotate credentials.
    74  func (a Auth) AsSha512Mechanism() sasl.Mechanism {
    75  	return Sha512(func(context.Context) (Auth, error) {
    76  		return a, nil
    77  	})
    78  }
    79  
    80  // Sha256 returns a SCRAM-SHA-256 sasl mechanism that will call authFn
    81  // whenever authentication is needed. The returned Auth is used for a single
    82  // session.
    83  func Sha256(authFn func(context.Context) (Auth, error)) sasl.Mechanism {
    84  	return scram{authFn, sha256.New, "SCRAM-SHA-256"}
    85  }
    86  
    87  // Sha512 returns a SCRAM-SHA-512 sasl mechanism that will call authFn
    88  // whenever authentication is needed. The returned Auth is used for a single
    89  // session.
    90  func Sha512(authFn func(context.Context) (Auth, error)) sasl.Mechanism {
    91  	return scram{authFn, sha512.New, "SCRAM-SHA-512"}
    92  }
    93  
    94  type scram struct {
    95  	authFn  func(context.Context) (Auth, error)
    96  	newhash func() hash.Hash
    97  	name    string
    98  }
    99  
   100  var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
   101  
   102  func (s scram) Name() string { return s.name }
   103  func (s scram) Authenticate(ctx context.Context, _ string) (sasl.Session, []byte, error) {
   104  	auth, err := s.authFn(ctx)
   105  	if err != nil {
   106  		return nil, nil, err
   107  	}
   108  	if auth.User == "" || auth.Pass == "" {
   109  		return nil, nil, errors.New(s.name + " user and pass must be non-empty")
   110  	}
   111  	if len(auth.Nonce) == 0 {
   112  		buf := make([]byte, 20)
   113  		if _, err = rand.Read(buf); err != nil {
   114  			return nil, nil, err
   115  		}
   116  		auth.Nonce = buf
   117  	}
   118  
   119  	auth.Nonce = []byte(base64.RawStdEncoding.EncodeToString(auth.Nonce))
   120  
   121  	clientFirstMsgBare := make([]byte, 0, 100)
   122  	clientFirstMsgBare = append(clientFirstMsgBare, "n="...)
   123  	clientFirstMsgBare = append(clientFirstMsgBare, escaper.Replace(auth.User)...)
   124  	clientFirstMsgBare = append(clientFirstMsgBare, ",r="...)
   125  	clientFirstMsgBare = append(clientFirstMsgBare, auth.Nonce...)
   126  	if auth.IsToken {
   127  		clientFirstMsgBare = append(clientFirstMsgBare, ",tokenauth=true"...) // KIP-48
   128  	}
   129  
   130  	gs2Header := "n," // no channel binding
   131  	if auth.Zid != "" {
   132  		gs2Header += "a=" + escaper.Replace(auth.Zid)
   133  	}
   134  	gs2Header += ","
   135  	clientFirstMsg := append([]byte(gs2Header), clientFirstMsgBare...)
   136  	return &session{
   137  		step:    0,
   138  		auth:    auth,
   139  		newhash: s.newhash,
   140  
   141  		clientFirstMsgBare: clientFirstMsgBare,
   142  	}, clientFirstMsg, nil
   143  }
   144  
   145  type session struct {
   146  	step    int
   147  	auth    Auth
   148  	newhash func() hash.Hash
   149  
   150  	clientFirstMsgBare []byte
   151  	expServerSignature []byte
   152  }
   153  
   154  func (s *session) Challenge(resp []byte) (bool, []byte, error) {
   155  	step := s.step
   156  	s.step++
   157  	switch step {
   158  	case 0:
   159  		response, err := s.authenticateClient(resp)
   160  		return false, response, err
   161  	case 1:
   162  		err := s.verifyServer(resp)
   163  		return err == nil, nil, err
   164  	default:
   165  		return false, nil, fmt.Errorf("challenge / response should be done, but still going at %d", step)
   166  	}
   167  }
   168  
   169  // server-first-message = [reserved-mext ","] nonce "," salt "," iteration-count ["," extensions]
   170  // we ignore extensions
   171  func (s *session) authenticateClient(serverFirstMsg []byte) ([]byte, error) {
   172  	kvs := bytes.Split(serverFirstMsg, []byte(","))
   173  	if len(kvs) < 3 {
   174  		return nil, fmt.Errorf("got %d kvs != exp min 3", len(kvs))
   175  	}
   176  
   177  	// NONCE
   178  	if !bytes.HasPrefix(kvs[0], []byte("r=")) {
   179  		return nil, fmt.Errorf("unexpected kv %q where nonce expected", kvs[0])
   180  	}
   181  	serverNonce := kvs[0][2:]
   182  	if !bytes.HasPrefix(serverNonce, s.auth.Nonce) {
   183  		return nil, errors.New("server did not reply with nonce beginning with client nonce")
   184  	}
   185  
   186  	// SALT
   187  	if !bytes.HasPrefix(kvs[1], []byte("s=")) {
   188  		return nil, fmt.Errorf("unexpected kv %q where salt expected", kvs[1])
   189  	}
   190  	salt, err := base64.StdEncoding.DecodeString(string(kvs[1][2:]))
   191  	if err != nil {
   192  		return nil, fmt.Errorf("server salt %q decode err: %v", kvs[1][2:], err)
   193  	}
   194  
   195  	// ITERATIONS
   196  	if !bytes.HasPrefix(kvs[2], []byte("i=")) {
   197  		return nil, fmt.Errorf("unexpected kv %q where iterations expected", kvs[2])
   198  	}
   199  	iters, err := strconv.Atoi(string(kvs[2][2:]))
   200  	if err != nil {
   201  		return nil, fmt.Errorf("server iterations %q parse err: %v", kvs[2][2:], err)
   202  	}
   203  	if iters < 4096 {
   204  		return nil, fmt.Errorf("server iterations %d less than minimum 4096", iters)
   205  	}
   206  
   207  	//////////////////
   208  	// CALCULATIONS //
   209  	//////////////////
   210  
   211  	h := s.newhash()
   212  	saltedPassword := pbkdf2.Key([]byte(s.auth.Pass), salt, iters, h.Size(), s.newhash) // SaltedPassword := Hi(Normalize(password), salt, i)
   213  
   214  	mac := hmac.New(s.newhash, saltedPassword)
   215  	if _, err = mac.Write([]byte("Client Key")); err != nil {
   216  		return nil, fmt.Errorf("hmac err: %v", err)
   217  	}
   218  	clientKey := mac.Sum(nil) // ClientKey := HMAC(SaltedPassword, "Client Key")
   219  	if _, err = h.Write(clientKey); err != nil {
   220  		return nil, fmt.Errorf("sha err: %v", err)
   221  	}
   222  	storedKey := h.Sum(nil) // StoredKey := H(ClientKey)
   223  
   224  	// biws is `n,,` base64 encoded; we do not use a channel
   225  	clientFinalMsgWithoutProof := append([]byte("c=biws,r="), serverNonce...)
   226  	authMsg := append(s.clientFirstMsgBare, ',')             // AuthMsg := client-first-message-bare + "," +
   227  	authMsg = append(authMsg, serverFirstMsg...)             //            server-first-message +
   228  	authMsg = append(authMsg, ',')                           //            "," +
   229  	authMsg = append(authMsg, clientFinalMsgWithoutProof...) //            client-final-message-without-proof
   230  
   231  	mac = hmac.New(s.newhash, storedKey)
   232  	if _, err = mac.Write(authMsg); err != nil {
   233  		return nil, fmt.Errorf("hmac err: %v", err)
   234  	}
   235  	clientSignature := mac.Sum(nil) // ClientSignature := HMAC(StoredKey, AuthMessage)
   236  
   237  	clientProof := clientSignature
   238  	for i, c := range clientKey {
   239  		clientProof[i] ^= c // ClientProof := ClientKey XOR ClientSignature
   240  	}
   241  
   242  	mac = hmac.New(s.newhash, saltedPassword)
   243  	if _, err = mac.Write([]byte("Server Key")); err != nil {
   244  		return nil, fmt.Errorf("hmac err: %v", err)
   245  	}
   246  	serverKey := mac.Sum(nil) // ServerKey := HMAC(SaltedPassword, "Server Key")
   247  	mac = hmac.New(s.newhash, serverKey)
   248  	if _, err = mac.Write(authMsg); err != nil {
   249  		return nil, fmt.Errorf("hmac err: %v", err)
   250  	}
   251  	s.expServerSignature = []byte(base64.StdEncoding.EncodeToString(mac.Sum(nil))) // ServerSignature := HMAC(ServerKey, AuthMessage)
   252  
   253  	clientFinalMsg := append(clientFinalMsgWithoutProof, ",p="...)
   254  	clientFinalMsg = append(clientFinalMsg, base64.StdEncoding.EncodeToString(clientProof)...)
   255  	return clientFinalMsg, nil
   256  }
   257  
   258  func (s *session) verifyServer(serverFinalMsg []byte) error {
   259  	kvs := bytes.Split(serverFinalMsg, []byte(","))
   260  	if len(kvs) < 1 {
   261  		return errors.New("received no kvs, even though this should be impossible")
   262  	}
   263  
   264  	kv := kvs[0]
   265  	if isErr := bytes.HasPrefix(kv, []byte("e=")); isErr {
   266  		return fmt.Errorf("server sent authentication error %q", kv[2:])
   267  	}
   268  	if !bytes.HasPrefix(kv, []byte("v=")) {
   269  		return fmt.Errorf("server sent unexpected first kv %q", kv)
   270  	}
   271  	if !bytes.Equal(s.expServerSignature, kv[2:]) {
   272  		return fmt.Errorf("server signature mismatch; got %q != exp %q", kv[2:], s.expServerSignature)
   273  	}
   274  	return nil
   275  }
   276  

View as plain text