...
1
15
16 package gmtls
17
18 import (
19 "bytes"
20 "crypto/aes"
21 "crypto/cipher"
22 "crypto/hmac"
23 "crypto/sha256"
24 "crypto/subtle"
25 "errors"
26 "io"
27 )
28
29
30
31 type sessionState struct {
32 vers uint16
33 cipherSuite uint16
34 masterSecret []byte
35 certificates [][]byte
36
37
38 usedOldKey bool
39 }
40
41 func (s *sessionState) equal(i interface{}) bool {
42 s1, ok := i.(*sessionState)
43 if !ok {
44 return false
45 }
46
47 if s.vers != s1.vers ||
48 s.cipherSuite != s1.cipherSuite ||
49 !bytes.Equal(s.masterSecret, s1.masterSecret) {
50 return false
51 }
52
53 if len(s.certificates) != len(s1.certificates) {
54 return false
55 }
56
57 for i := range s.certificates {
58 if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
59 return false
60 }
61 }
62
63 return true
64 }
65
66 func (s *sessionState) marshal() []byte {
67 length := 2 + 2 + 2 + len(s.masterSecret) + 2
68 for _, cert := range s.certificates {
69 length += 4 + len(cert)
70 }
71
72 ret := make([]byte, length)
73 x := ret
74 x[0] = byte(s.vers >> 8)
75 x[1] = byte(s.vers)
76 x[2] = byte(s.cipherSuite >> 8)
77 x[3] = byte(s.cipherSuite)
78 x[4] = byte(len(s.masterSecret) >> 8)
79 x[5] = byte(len(s.masterSecret))
80 x = x[6:]
81 copy(x, s.masterSecret)
82 x = x[len(s.masterSecret):]
83
84 x[0] = byte(len(s.certificates) >> 8)
85 x[1] = byte(len(s.certificates))
86 x = x[2:]
87
88 for _, cert := range s.certificates {
89 x[0] = byte(len(cert) >> 24)
90 x[1] = byte(len(cert) >> 16)
91 x[2] = byte(len(cert) >> 8)
92 x[3] = byte(len(cert))
93 copy(x[4:], cert)
94 x = x[4+len(cert):]
95 }
96
97 return ret
98 }
99
100 func (s *sessionState) unmarshal(data []byte) bool {
101 if len(data) < 8 {
102 return false
103 }
104
105 s.vers = uint16(data[0])<<8 | uint16(data[1])
106 s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
107 masterSecretLen := int(data[4])<<8 | int(data[5])
108 data = data[6:]
109 if len(data) < masterSecretLen {
110 return false
111 }
112
113 s.masterSecret = data[:masterSecretLen]
114 data = data[masterSecretLen:]
115
116 if len(data) < 2 {
117 return false
118 }
119
120 numCerts := int(data[0])<<8 | int(data[1])
121 data = data[2:]
122
123 s.certificates = make([][]byte, numCerts)
124 for i := range s.certificates {
125 if len(data) < 4 {
126 return false
127 }
128 certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
129 data = data[4:]
130 if certLen < 0 {
131 return false
132 }
133 if len(data) < certLen {
134 return false
135 }
136 s.certificates[i] = data[:certLen]
137 data = data[certLen:]
138 }
139
140 return len(data) == 0
141 }
142
143 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
144 serialized := state.marshal()
145 encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
146 keyName := encrypted[:ticketKeyNameLen]
147 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
148 macBytes := encrypted[len(encrypted)-sha256.Size:]
149
150 if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
151 return nil, err
152 }
153 key := c.config.ticketKeys()[0]
154 copy(keyName, key.keyName[:])
155 block, err := aes.NewCipher(key.aesKey[:])
156 if err != nil {
157 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
158 }
159 cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
160
161 mac := hmac.New(sha256.New, key.hmacKey[:])
162 mac.Write(encrypted[:len(encrypted)-sha256.Size])
163 mac.Sum(macBytes[:0])
164
165 return encrypted, nil
166 }
167
168 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
169 if c.config.SessionTicketsDisabled ||
170 len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
171 return nil, false
172 }
173
174 keyName := encrypted[:ticketKeyNameLen]
175 iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
176 macBytes := encrypted[len(encrypted)-sha256.Size:]
177
178 keys := c.config.ticketKeys()
179 keyIndex := -1
180 for i, candidateKey := range keys {
181 if bytes.Equal(keyName, candidateKey.keyName[:]) {
182 keyIndex = i
183 break
184 }
185 }
186
187 if keyIndex == -1 {
188 return nil, false
189 }
190 key := &keys[keyIndex]
191
192 mac := hmac.New(sha256.New, key.hmacKey[:])
193 mac.Write(encrypted[:len(encrypted)-sha256.Size])
194 expected := mac.Sum(nil)
195
196 if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
197 return nil, false
198 }
199
200 block, err := aes.NewCipher(key.aesKey[:])
201 if err != nil {
202 return nil, false
203 }
204 ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
205 plaintext := ciphertext
206 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
207
208 state := &sessionState{usedOldKey: keyIndex > 0}
209 ok := state.unmarshal(plaintext)
210 return state, ok
211 }
212
View as plain text