...

Source file src/github.com/jackc/pgconn/auth_scram.go

Documentation: github.com/jackc/pgconn

     1  // SCRAM-SHA-256 authentication
     2  //
     3  // Resources:
     4  //   https://tools.ietf.org/html/rfc5802
     5  //   https://tools.ietf.org/html/rfc8265
     6  //   https://www.postgresql.org/docs/current/sasl-authentication.html
     7  //
     8  // Inspiration drawn from other implementations:
     9  //   https://github.com/lib/pq/pull/608
    10  //   https://github.com/lib/pq/pull/788
    11  //   https://github.com/lib/pq/pull/833
    12  
    13  package pgconn
    14  
    15  import (
    16  	"bytes"
    17  	"crypto/hmac"
    18  	"crypto/rand"
    19  	"crypto/sha256"
    20  	"encoding/base64"
    21  	"errors"
    22  	"fmt"
    23  	"strconv"
    24  
    25  	"github.com/jackc/pgproto3/v2"
    26  	"golang.org/x/crypto/pbkdf2"
    27  	"golang.org/x/text/secure/precis"
    28  )
    29  
    30  const clientNonceLen = 18
    31  
    32  // Perform SCRAM authentication.
    33  func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
    34  	sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
    35  	if err != nil {
    36  		return err
    37  	}
    38  
    39  	// Send client-first-message in a SASLInitialResponse
    40  	saslInitialResponse := &pgproto3.SASLInitialResponse{
    41  		AuthMechanism: "SCRAM-SHA-256",
    42  		Data:          sc.clientFirstMessage(),
    43  	}
    44  	buf, err := saslInitialResponse.Encode(nil)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	_, err = c.conn.Write(buf)
    49  	if err != nil {
    50  		return err
    51  	}
    52  
    53  	// Receive server-first-message payload in a AuthenticationSASLContinue.
    54  	saslContinue, err := c.rxSASLContinue()
    55  	if err != nil {
    56  		return err
    57  	}
    58  	err = sc.recvServerFirstMessage(saslContinue.Data)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	// Send client-final-message in a SASLResponse
    64  	saslResponse := &pgproto3.SASLResponse{
    65  		Data: []byte(sc.clientFinalMessage()),
    66  	}
    67  	buf, err = saslResponse.Encode(nil)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	_, err = c.conn.Write(buf)
    72  	if err != nil {
    73  		return err
    74  	}
    75  
    76  	// Receive server-final-message payload in a AuthenticationSASLFinal.
    77  	saslFinal, err := c.rxSASLFinal()
    78  	if err != nil {
    79  		return err
    80  	}
    81  	return sc.recvServerFinalMessage(saslFinal.Data)
    82  }
    83  
    84  func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
    85  	msg, err := c.receiveMessage()
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	switch m := msg.(type) {
    90  	case *pgproto3.AuthenticationSASLContinue:
    91  		return m, nil
    92  	case *pgproto3.ErrorResponse:
    93  		return nil, ErrorResponseToPgError(m)
    94  	}
    95  
    96  	return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
    97  }
    98  
    99  func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
   100  	msg, err := c.receiveMessage()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	switch m := msg.(type) {
   105  	case *pgproto3.AuthenticationSASLFinal:
   106  		return m, nil
   107  	case *pgproto3.ErrorResponse:
   108  		return nil, ErrorResponseToPgError(m)
   109  	}
   110  
   111  	return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
   112  }
   113  
   114  type scramClient struct {
   115  	serverAuthMechanisms []string
   116  	password             []byte
   117  	clientNonce          []byte
   118  
   119  	clientFirstMessageBare []byte
   120  
   121  	serverFirstMessage   []byte
   122  	clientAndServerNonce []byte
   123  	salt                 []byte
   124  	iterations           int
   125  
   126  	saltedPassword []byte
   127  	authMessage    []byte
   128  }
   129  
   130  func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
   131  	sc := &scramClient{
   132  		serverAuthMechanisms: serverAuthMechanisms,
   133  	}
   134  
   135  	// Ensure server supports SCRAM-SHA-256
   136  	hasScramSHA256 := false
   137  	for _, mech := range sc.serverAuthMechanisms {
   138  		if mech == "SCRAM-SHA-256" {
   139  			hasScramSHA256 = true
   140  			break
   141  		}
   142  	}
   143  	if !hasScramSHA256 {
   144  		return nil, errors.New("server does not support SCRAM-SHA-256")
   145  	}
   146  
   147  	// precis.OpaqueString is equivalent to SASLprep for password.
   148  	var err error
   149  	sc.password, err = precis.OpaqueString.Bytes([]byte(password))
   150  	if err != nil {
   151  		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
   152  		sc.password = []byte(password)
   153  	}
   154  
   155  	buf := make([]byte, clientNonceLen)
   156  	_, err = rand.Read(buf)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
   161  	base64.RawStdEncoding.Encode(sc.clientNonce, buf)
   162  
   163  	return sc, nil
   164  }
   165  
   166  func (sc *scramClient) clientFirstMessage() []byte {
   167  	sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
   168  	return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
   169  }
   170  
   171  func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
   172  	sc.serverFirstMessage = serverFirstMessage
   173  	buf := serverFirstMessage
   174  	if !bytes.HasPrefix(buf, []byte("r=")) {
   175  		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
   176  	}
   177  	buf = buf[2:]
   178  
   179  	idx := bytes.IndexByte(buf, ',')
   180  	if idx == -1 {
   181  		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
   182  	}
   183  	sc.clientAndServerNonce = buf[:idx]
   184  	buf = buf[idx+1:]
   185  
   186  	if !bytes.HasPrefix(buf, []byte("s=")) {
   187  		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
   188  	}
   189  	buf = buf[2:]
   190  
   191  	idx = bytes.IndexByte(buf, ',')
   192  	if idx == -1 {
   193  		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
   194  	}
   195  	saltStr := buf[:idx]
   196  	buf = buf[idx+1:]
   197  
   198  	if !bytes.HasPrefix(buf, []byte("i=")) {
   199  		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
   200  	}
   201  	buf = buf[2:]
   202  	iterationsStr := buf
   203  
   204  	var err error
   205  	sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
   206  	if err != nil {
   207  		return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
   208  	}
   209  
   210  	sc.iterations, err = strconv.Atoi(string(iterationsStr))
   211  	if err != nil || sc.iterations <= 0 {
   212  		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
   213  	}
   214  
   215  	if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
   216  		return errors.New("invalid SCRAM nonce: did not start with client nonce")
   217  	}
   218  
   219  	if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
   220  		return errors.New("invalid SCRAM nonce: did not include server nonce")
   221  	}
   222  
   223  	return nil
   224  }
   225  
   226  func (sc *scramClient) clientFinalMessage() string {
   227  	clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
   228  
   229  	sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
   230  	sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
   231  
   232  	clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
   233  
   234  	return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
   235  }
   236  
   237  func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
   238  	if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
   239  		return errors.New("invalid SCRAM server-final-message received from server")
   240  	}
   241  
   242  	serverSignature := serverFinalMessage[2:]
   243  
   244  	if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
   245  		return errors.New("invalid SCRAM ServerSignature received from server")
   246  	}
   247  
   248  	return nil
   249  }
   250  
   251  func computeHMAC(key, msg []byte) []byte {
   252  	mac := hmac.New(sha256.New, key)
   253  	mac.Write(msg)
   254  	return mac.Sum(nil)
   255  }
   256  
   257  func computeClientProof(saltedPassword, authMessage []byte) []byte {
   258  	clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
   259  	storedKey := sha256.Sum256(clientKey)
   260  	clientSignature := computeHMAC(storedKey[:], authMessage)
   261  
   262  	clientProof := make([]byte, len(clientSignature))
   263  	for i := 0; i < len(clientSignature); i++ {
   264  		clientProof[i] = clientKey[i] ^ clientSignature[i]
   265  	}
   266  
   267  	buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
   268  	base64.StdEncoding.Encode(buf, clientProof)
   269  	return buf
   270  }
   271  
   272  func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
   273  	serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
   274  	serverSignature := computeHMAC(serverKey, authMessage)
   275  	buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
   276  	base64.StdEncoding.Encode(buf, serverSignature)
   277  	return buf
   278  }
   279  

View as plain text