...

Source file src/github.com/jackc/pgx/v5/pgconn/auth_scram.go

Documentation: github.com/jackc/pgx/v5/pgconn

     1  // SCRAM-SHA-256 authentication
     2  //
     3  // Resources:
     4  //   https://tools.ietf.org/html/rfc5802
     5  //   https://tools.ietf.org/html/rfc8265
     6  //   https://www.postgresql.org/docs/current/sasl-authentication.html
     7  //
     8  // Inspiration drawn from other implementations:
     9  //   https://github.com/lib/pq/pull/608
    10  //   https://github.com/lib/pq/pull/788
    11  //   https://github.com/lib/pq/pull/833
    12  
    13  package pgconn
    14  
    15  import (
    16  	"bytes"
    17  	"crypto/hmac"
    18  	"crypto/rand"
    19  	"crypto/sha256"
    20  	"encoding/base64"
    21  	"errors"
    22  	"fmt"
    23  	"strconv"
    24  
    25  	"github.com/jackc/pgx/v5/pgproto3"
    26  	"golang.org/x/crypto/pbkdf2"
    27  	"golang.org/x/text/secure/precis"
    28  )
    29  
    30  const clientNonceLen = 18
    31  
    32  // Perform SCRAM authentication.
    33  func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
    34  	sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
    35  	if err != nil {
    36  		return err
    37  	}
    38  
    39  	// Send client-first-message in a SASLInitialResponse
    40  	saslInitialResponse := &pgproto3.SASLInitialResponse{
    41  		AuthMechanism: "SCRAM-SHA-256",
    42  		Data:          sc.clientFirstMessage(),
    43  	}
    44  	c.frontend.Send(saslInitialResponse)
    45  	err = c.flushWithPotentialWriteReadDeadlock()
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	// Receive server-first-message payload in an AuthenticationSASLContinue.
    51  	saslContinue, err := c.rxSASLContinue()
    52  	if err != nil {
    53  		return err
    54  	}
    55  	err = sc.recvServerFirstMessage(saslContinue.Data)
    56  	if err != nil {
    57  		return err
    58  	}
    59  
    60  	// Send client-final-message in a SASLResponse
    61  	saslResponse := &pgproto3.SASLResponse{
    62  		Data: []byte(sc.clientFinalMessage()),
    63  	}
    64  	c.frontend.Send(saslResponse)
    65  	err = c.flushWithPotentialWriteReadDeadlock()
    66  	if err != nil {
    67  		return err
    68  	}
    69  
    70  	// Receive server-final-message payload in an AuthenticationSASLFinal.
    71  	saslFinal, err := c.rxSASLFinal()
    72  	if err != nil {
    73  		return err
    74  	}
    75  	return sc.recvServerFinalMessage(saslFinal.Data)
    76  }
    77  
    78  func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
    79  	msg, err := c.receiveMessage()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	switch m := msg.(type) {
    84  	case *pgproto3.AuthenticationSASLContinue:
    85  		return m, nil
    86  	case *pgproto3.ErrorResponse:
    87  		return nil, ErrorResponseToPgError(m)
    88  	}
    89  
    90  	return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
    91  }
    92  
    93  func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
    94  	msg, err := c.receiveMessage()
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	switch m := msg.(type) {
    99  	case *pgproto3.AuthenticationSASLFinal:
   100  		return m, nil
   101  	case *pgproto3.ErrorResponse:
   102  		return nil, ErrorResponseToPgError(m)
   103  	}
   104  
   105  	return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
   106  }
   107  
   108  type scramClient struct {
   109  	serverAuthMechanisms []string
   110  	password             []byte
   111  	clientNonce          []byte
   112  
   113  	clientFirstMessageBare []byte
   114  
   115  	serverFirstMessage   []byte
   116  	clientAndServerNonce []byte
   117  	salt                 []byte
   118  	iterations           int
   119  
   120  	saltedPassword []byte
   121  	authMessage    []byte
   122  }
   123  
   124  func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
   125  	sc := &scramClient{
   126  		serverAuthMechanisms: serverAuthMechanisms,
   127  	}
   128  
   129  	// Ensure server supports SCRAM-SHA-256
   130  	hasScramSHA256 := false
   131  	for _, mech := range sc.serverAuthMechanisms {
   132  		if mech == "SCRAM-SHA-256" {
   133  			hasScramSHA256 = true
   134  			break
   135  		}
   136  	}
   137  	if !hasScramSHA256 {
   138  		return nil, errors.New("server does not support SCRAM-SHA-256")
   139  	}
   140  
   141  	// precis.OpaqueString is equivalent to SASLprep for password.
   142  	var err error
   143  	sc.password, err = precis.OpaqueString.Bytes([]byte(password))
   144  	if err != nil {
   145  		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
   146  		sc.password = []byte(password)
   147  	}
   148  
   149  	buf := make([]byte, clientNonceLen)
   150  	_, err = rand.Read(buf)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
   155  	base64.RawStdEncoding.Encode(sc.clientNonce, buf)
   156  
   157  	return sc, nil
   158  }
   159  
   160  func (sc *scramClient) clientFirstMessage() []byte {
   161  	sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
   162  	return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
   163  }
   164  
   165  func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
   166  	sc.serverFirstMessage = serverFirstMessage
   167  	buf := serverFirstMessage
   168  	if !bytes.HasPrefix(buf, []byte("r=")) {
   169  		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
   170  	}
   171  	buf = buf[2:]
   172  
   173  	idx := bytes.IndexByte(buf, ',')
   174  	if idx == -1 {
   175  		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
   176  	}
   177  	sc.clientAndServerNonce = buf[:idx]
   178  	buf = buf[idx+1:]
   179  
   180  	if !bytes.HasPrefix(buf, []byte("s=")) {
   181  		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
   182  	}
   183  	buf = buf[2:]
   184  
   185  	idx = bytes.IndexByte(buf, ',')
   186  	if idx == -1 {
   187  		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
   188  	}
   189  	saltStr := buf[:idx]
   190  	buf = buf[idx+1:]
   191  
   192  	if !bytes.HasPrefix(buf, []byte("i=")) {
   193  		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
   194  	}
   195  	buf = buf[2:]
   196  	iterationsStr := buf
   197  
   198  	var err error
   199  	sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
   200  	if err != nil {
   201  		return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
   202  	}
   203  
   204  	sc.iterations, err = strconv.Atoi(string(iterationsStr))
   205  	if err != nil || sc.iterations <= 0 {
   206  		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
   207  	}
   208  
   209  	if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
   210  		return errors.New("invalid SCRAM nonce: did not start with client nonce")
   211  	}
   212  
   213  	if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
   214  		return errors.New("invalid SCRAM nonce: did not include server nonce")
   215  	}
   216  
   217  	return nil
   218  }
   219  
   220  func (sc *scramClient) clientFinalMessage() string {
   221  	clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
   222  
   223  	sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
   224  	sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
   225  
   226  	clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
   227  
   228  	return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
   229  }
   230  
   231  func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
   232  	if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
   233  		return errors.New("invalid SCRAM server-final-message received from server")
   234  	}
   235  
   236  	serverSignature := serverFinalMessage[2:]
   237  
   238  	if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
   239  		return errors.New("invalid SCRAM ServerSignature received from server")
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func computeHMAC(key, msg []byte) []byte {
   246  	mac := hmac.New(sha256.New, key)
   247  	mac.Write(msg)
   248  	return mac.Sum(nil)
   249  }
   250  
   251  func computeClientProof(saltedPassword, authMessage []byte) []byte {
   252  	clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
   253  	storedKey := sha256.Sum256(clientKey)
   254  	clientSignature := computeHMAC(storedKey[:], authMessage)
   255  
   256  	clientProof := make([]byte, len(clientSignature))
   257  	for i := 0; i < len(clientSignature); i++ {
   258  		clientProof[i] = clientKey[i] ^ clientSignature[i]
   259  	}
   260  
   261  	buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
   262  	base64.StdEncoding.Encode(buf, clientProof)
   263  	return buf
   264  }
   265  
   266  func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
   267  	serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
   268  	serverSignature := computeHMAC(serverKey, authMessage)
   269  	buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
   270  	base64.StdEncoding.Encode(buf, serverSignature)
   271  	return buf
   272  }
   273  

View as plain text