...

Source file src/github.com/jackc/pgx/v5/pgconn/krb5.go

Documentation: github.com/jackc/pgx/v5/pgconn

     1  package pgconn
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"github.com/jackc/pgx/v5/pgproto3"
     8  )
     9  
    10  // NewGSSFunc creates a GSS authentication provider, for use with
    11  // RegisterGSSProvider.
    12  type NewGSSFunc func() (GSS, error)
    13  
    14  var newGSS NewGSSFunc
    15  
    16  // RegisterGSSProvider registers a GSS authentication provider. For example, if
    17  // you need to use Kerberos to authenticate with your server, add this to your
    18  // main package:
    19  //
    20  //	import "github.com/otan/gopgkrb5"
    21  //
    22  //	func init() {
    23  //		pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
    24  //	}
    25  func RegisterGSSProvider(newGSSArg NewGSSFunc) {
    26  	newGSS = newGSSArg
    27  }
    28  
    29  // GSS provides GSSAPI authentication (e.g., Kerberos).
    30  type GSS interface {
    31  	GetInitToken(host string, service string) ([]byte, error)
    32  	GetInitTokenFromSPN(spn string) ([]byte, error)
    33  	Continue(inToken []byte) (done bool, outToken []byte, err error)
    34  }
    35  
    36  func (c *PgConn) gssAuth() error {
    37  	if newGSS == nil {
    38  		return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
    39  	}
    40  	cli, err := newGSS()
    41  	if err != nil {
    42  		return err
    43  	}
    44  
    45  	var nextData []byte
    46  	if c.config.KerberosSpn != "" {
    47  		// Use the supplied SPN if provided.
    48  		nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
    49  	} else {
    50  		// Allow the kerberos service name to be overridden
    51  		service := "postgres"
    52  		if c.config.KerberosSrvName != "" {
    53  			service = c.config.KerberosSrvName
    54  		}
    55  		nextData, err = cli.GetInitToken(c.config.Host, service)
    56  	}
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	for {
    62  		gssResponse := &pgproto3.GSSResponse{
    63  			Data: nextData,
    64  		}
    65  		c.frontend.Send(gssResponse)
    66  		err = c.flushWithPotentialWriteReadDeadlock()
    67  		if err != nil {
    68  			return err
    69  		}
    70  		resp, err := c.rxGSSContinue()
    71  		if err != nil {
    72  			return err
    73  		}
    74  		var done bool
    75  		done, nextData, err = cli.Continue(resp.Data)
    76  		if err != nil {
    77  			return err
    78  		}
    79  		if done {
    80  			break
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
    87  	msg, err := c.receiveMessage()
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	switch m := msg.(type) {
    93  	case *pgproto3.AuthenticationGSSContinue:
    94  		return m, nil
    95  	case *pgproto3.ErrorResponse:
    96  		return nil, ErrorResponseToPgError(m)
    97  	}
    98  
    99  	return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
   100  }
   101  

View as plain text