1
2
3
4 package gssapi
5
6 import (
7 "bytes"
8 "encoding/binary"
9 "fmt"
10
11 "github.com/alexbrainman/sspi"
12 "github.com/alexbrainman/sspi/kerberos"
13 )
14
15
16
17 type SSPIClient struct {
18 creds *sspi.Credentials
19 ctx *kerberos.ClientContext
20 }
21
22
23 func NewSSPIClient() (*SSPIClient, error) {
24 creds, err := kerberos.AcquireCurrentUserCredentials()
25 if err != nil {
26 return nil, err
27 }
28
29 return NewSSPIClientWithCredentials(creds), nil
30 }
31
32
33 func NewSSPIClientWithCredentials(creds *sspi.Credentials) *SSPIClient {
34 return &SSPIClient{
35 creds: creds,
36 }
37 }
38
39
40
41 func NewSSPIClientWithUserCredentials(domain, username, password string) (*SSPIClient, error) {
42 creds, err := kerberos.AcquireUserCredentials(domain, username, password)
43 if err != nil {
44 return nil, err
45 }
46
47 return &SSPIClient{
48 creds: creds,
49 }, nil
50 }
51
52
53 func (c *SSPIClient) Close() error {
54 err1 := c.DeleteSecContext()
55 err2 := c.creds.Release()
56 if err1 != nil {
57 return err1
58 }
59 if err2 != nil {
60 return err2
61 }
62 return nil
63 }
64
65
66 func (c *SSPIClient) DeleteSecContext() error {
67 return c.ctx.Release()
68 }
69
70
71
72
73 func (c *SSPIClient) InitSecContext(target string, token []byte) ([]byte, bool, error) {
74 sspiFlags := uint32(sspi.ISC_REQ_INTEGRITY | sspi.ISC_REQ_CONFIDENTIALITY | sspi.ISC_REQ_MUTUAL_AUTH)
75
76 switch token {
77 case nil:
78 ctx, completed, output, err := kerberos.NewClientContextWithFlags(c.creds, target, sspiFlags)
79 if err != nil {
80 return nil, false, err
81 }
82 c.ctx = ctx
83
84 return output, !completed, nil
85 default:
86
87 completed, output, err := c.ctx.Update(token)
88 if err != nil {
89 return nil, false, err
90 }
91 if err := c.ctx.VerifyFlags(); err != nil {
92 return nil, false, fmt.Errorf("error verifying flags: %v", err)
93 }
94 return output, !completed, nil
95
96 }
97 }
98
99
100
101 func (c *SSPIClient) NegotiateSaslAuth(token []byte, authzid string) ([]byte, error) {
102
103
104
105
106
107 const KERB_WRAP_NO_ENCRYPT = 0x80000001
108
109
110 flags, inputPayload, err := c.ctx.DecryptMessage(token, 0)
111 if err != nil {
112 return nil, fmt.Errorf("error decrypting message: %w", err)
113 }
114 if flags&KERB_WRAP_NO_ENCRYPT == 0 {
115
116 return nil, fmt.Errorf("message encrypted")
117 }
118
119
120
121
122
123
124
125
126
127
128 if len(inputPayload) != 4 {
129 return nil, fmt.Errorf("bad server token")
130 }
131 if inputPayload[0] == 0x0 && !bytes.Equal(inputPayload, []byte{0x0, 0x0, 0x0, 0x0}) {
132 return nil, fmt.Errorf("bad server token")
133 }
134
135
136
137
138
139
140 selectedSec := 0
141 var maxSecMsgSize uint32
142 if selectedSec != 0 {
143 maxSecMsgSize, _, _, _, err = c.ctx.Sizes()
144 if err != nil {
145 return nil, fmt.Errorf("error getting security context max message size: %w", err)
146 }
147 }
148
149
150 inputPayload, err = c.ctx.EncryptMessage(handshakePayload(byte(selectedSec), maxSecMsgSize, []byte(authzid)), KERB_WRAP_NO_ENCRYPT, 0)
151 if err != nil {
152 return nil, fmt.Errorf("error encrypting message: %w", err)
153 }
154
155 return inputPayload, nil
156 }
157
158 func handshakePayload(secLayer byte, maxSize uint32, authzid []byte) []byte {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175 var selectedSecurity byte = secLayer
176 var truncatedSize uint32
177 if selectedSecurity != 0 {
178
179 truncatedSize = 0b00000000_11111111_11111111_11111111
180 if truncatedSize > maxSize {
181 truncatedSize = maxSize
182 }
183 }
184
185 payload := make([]byte, 4, 4+len(authzid))
186 binary.BigEndian.PutUint32(payload, truncatedSize)
187 payload[0] = selectedSecurity
188 payload = append(payload, []byte(authzid)...)
189
190 return payload
191 }
192
View as plain text