...
1 package pgconn
2
3 import (
4 "errors"
5 "fmt"
6
7 "github.com/jackc/pgx/v5/pgproto3"
8 )
9
10
11
12 type NewGSSFunc func() (GSS, error)
13
14 var newGSS NewGSSFunc
15
16
17
18
19
20
21
22
23
24
25 func RegisterGSSProvider(newGSSArg NewGSSFunc) {
26 newGSS = newGSSArg
27 }
28
29
30 type GSS interface {
31 GetInitToken(host string, service string) ([]byte, error)
32 GetInitTokenFromSPN(spn string) ([]byte, error)
33 Continue(inToken []byte) (done bool, outToken []byte, err error)
34 }
35
36 func (c *PgConn) gssAuth() error {
37 if newGSS == nil {
38 return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
39 }
40 cli, err := newGSS()
41 if err != nil {
42 return err
43 }
44
45 var nextData []byte
46 if c.config.KerberosSpn != "" {
47
48 nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
49 } else {
50
51 service := "postgres"
52 if c.config.KerberosSrvName != "" {
53 service = c.config.KerberosSrvName
54 }
55 nextData, err = cli.GetInitToken(c.config.Host, service)
56 }
57 if err != nil {
58 return err
59 }
60
61 for {
62 gssResponse := &pgproto3.GSSResponse{
63 Data: nextData,
64 }
65 c.frontend.Send(gssResponse)
66 err = c.flushWithPotentialWriteReadDeadlock()
67 if err != nil {
68 return err
69 }
70 resp, err := c.rxGSSContinue()
71 if err != nil {
72 return err
73 }
74 var done bool
75 done, nextData, err = cli.Continue(resp.Data)
76 if err != nil {
77 return err
78 }
79 if done {
80 break
81 }
82 }
83 return nil
84 }
85
86 func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
87 msg, err := c.receiveMessage()
88 if err != nil {
89 return nil, err
90 }
91
92 switch m := msg.(type) {
93 case *pgproto3.AuthenticationGSSContinue:
94 return m, nil
95 case *pgproto3.ErrorResponse:
96 return nil, ErrorResponseToPgError(m)
97 }
98
99 return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
100 }
101
View as plain text