...

Source file src/github.com/jackc/pgconn/krb5.go

Documentation: github.com/jackc/pgconn

     1  package pgconn
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"github.com/jackc/pgproto3/v2"
     8  )
     9  
    10  // NewGSSFunc creates a GSS authentication provider, for use with
    11  // RegisterGSSProvider.
    12  type NewGSSFunc func() (GSS, error)
    13  
    14  var newGSS NewGSSFunc
    15  
    16  // RegisterGSSProvider registers a GSS authentication provider. For example, if
    17  // you need to use Kerberos to authenticate with your server, add this to your
    18  // main package:
    19  //
    20  //	import "github.com/otan/gopgkrb5"
    21  //
    22  //	func init() {
    23  //		pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
    24  //	}
    25  func RegisterGSSProvider(newGSSArg NewGSSFunc) {
    26  	newGSS = newGSSArg
    27  }
    28  
    29  // GSS provides GSSAPI authentication (e.g., Kerberos).
    30  type GSS interface {
    31  	GetInitToken(host string, service string) ([]byte, error)
    32  	GetInitTokenFromSPN(spn string) ([]byte, error)
    33  	Continue(inToken []byte) (done bool, outToken []byte, err error)
    34  }
    35  
    36  func (c *PgConn) gssAuth() error {
    37  	if newGSS == nil {
    38  		return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
    39  	}
    40  	cli, err := newGSS()
    41  	if err != nil {
    42  		return err
    43  	}
    44  
    45  	var nextData []byte
    46  	if c.config.KerberosSpn != "" {
    47  		// Use the supplied SPN if provided.
    48  		nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
    49  	} else {
    50  		// Allow the kerberos service name to be overridden
    51  		service := "postgres"
    52  		if c.config.KerberosSrvName != "" {
    53  			service = c.config.KerberosSrvName
    54  		}
    55  		nextData, err = cli.GetInitToken(c.config.Host, service)
    56  	}
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	for {
    62  		gssResponse := &pgproto3.GSSResponse{
    63  			Data: nextData,
    64  		}
    65  		buf, err := gssResponse.Encode(nil)
    66  		if err != nil {
    67  			return err
    68  		}
    69  		_, err = c.conn.Write(buf)
    70  		if err != nil {
    71  			return err
    72  		}
    73  		resp, err := c.rxGSSContinue()
    74  		if err != nil {
    75  			return err
    76  		}
    77  		var done bool
    78  		done, nextData, err = cli.Continue(resp.Data)
    79  		if err != nil {
    80  			return err
    81  		}
    82  		if done {
    83  			break
    84  		}
    85  	}
    86  	return nil
    87  }
    88  
    89  func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
    90  	msg, err := c.receiveMessage()
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	switch m := msg.(type) {
    96  	case *pgproto3.AuthenticationGSSContinue:
    97  		return m, nil
    98  	case *pgproto3.ErrorResponse:
    99  		return nil, ErrorResponseToPgError(m)
   100  	}
   101  
   102  	return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
   103  }
   104  

View as plain text