1
2
3 package gmtls
4
5 import (
6 "crypto"
7 "crypto/ecdsa"
8 "crypto/rsa"
9 "errors"
10 "fmt"
11 "io"
12 "sync/atomic"
13 "time"
14 )
15
16
17
18 func (c *Conn) serverHandshakeAutoSwitch() error {
19
20
21 c.config.serverInitOnce.Do(func() { c.config.serverInit(nil) })
22
23 msg, err := c.readHandshake()
24 if err != nil {
25 return err
26 }
27
28 clientHello, ok := msg.(*clientHelloMsg)
29 if !ok {
30 _ = c.sendAlert(alertUnexpectedMessage)
31 return unexpectedMessageError("Client Hello Msg", msg)
32 }
33
34
35
36
37
38 switch clientHello.vers {
39 case VersionGMSSL:
40
41 hs := &serverHandshakeStateGM{
42 c: c,
43 clientHello: clientHello,
44 }
45
46 isResume, err := processClientHelloGM(c, hs)
47 if err != nil {
48 return err
49 }
50
51 return runServerHandshakeGM(c, hs, isResume)
52 case VersionSSL30, VersionTLS10, VersionTLS11, VersionTLS12:
53
54
55 hs := &serverHandshakeState{
56 c: c,
57 clientHello: clientHello,
58 }
59
60 isResume, err := processClientHello(c, hs)
61 if err != nil {
62 return err
63 }
64
65 return runServerHandshake(c, hs, isResume)
66 default:
67 _ = c.sendAlert(alertProtocolVersion)
68 return fmt.Errorf("tls: mix server handshake unsupport client protocol version: %X", clientHello.vers)
69 }
70 }
71
72
73
74
75
76
77
78
79
80
81 func processClientHelloGM(c *Conn, hs *serverHandshakeStateGM) (isResume bool, err error) {
82 if c.config.GetConfigForClient != nil {
83 if newConfig, err := c.config.GetConfigForClient(hs.clientHelloInfo()); err != nil {
84 _ = c.sendAlert(alertInternalError)
85 return false, err
86 } else if newConfig != nil {
87 newConfig.serverInitOnce.Do(func() { newConfig.serverInit(c.config) })
88 c.config = newConfig
89 }
90 }
91 var ok bool
92 c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
93 if !ok {
94 _ = c.sendAlert(alertProtocolVersion)
95 return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
96 }
97 c.haveVers = true
98
99 hs.hello = new(serverHelloMsg)
100
101 foundCompression := false
102
103 for _, compression := range hs.clientHello.compressionMethods {
104 if compression == compressionNone {
105 foundCompression = true
106 break
107 }
108 }
109
110 if !foundCompression {
111 _ = c.sendAlert(alertHandshakeFailure)
112 return false, errors.New("tls: client does not support uncompressed connections")
113 }
114
115 hs.hello.vers = c.vers
116 hs.hello.random = make([]byte, 32)
117 _, err = io.ReadFull(c.config.rand(), hs.hello.random)
118 if err != nil {
119 _ = c.sendAlert(alertInternalError)
120 return false, err
121 }
122
123
124 gmtRandom(&(hs.hello.random))
125
126 if len(hs.clientHello.secureRenegotiation) != 0 {
127 _ = c.sendAlert(alertHandshakeFailure)
128 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
129 }
130
131 hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
132 hs.hello.compressionMethod = compressionNone
133 if len(hs.clientHello.serverName) > 0 {
134 c.serverName = hs.clientHello.serverName
135 }
136
137 if len(hs.clientHello.alpnProtocols) > 0 {
138 if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback {
139 hs.hello.alpnProtocol = selectedProto
140 c.clientProtocol = selectedProto
141 }
142 } else {
143
144
145
146
147 if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
148 hs.hello.nextProtoNeg = true
149 hs.hello.nextProtos = c.config.NextProtos
150 }
151 }
152
153
154
155
156 sigCert, err := c.config.getCertificate(hs.clientHelloInfo())
157 if err != nil {
158 _ = c.sendAlert(alertInternalError)
159 return false, err
160 }
161 encCert, err := c.config.GetKECertificate(hs.clientHelloInfo())
162 if err != nil {
163 _ = c.sendAlert(alertInternalError)
164 return false, err
165 }
166
167 if encCert == nil || sigCert == nil {
168 _ = c.sendAlert(alertInternalError)
169 return false, fmt.Errorf("tls: amount of server certificates must be greater than 2, which will sign and encipher respectively")
170 }
171
172 hs.cert = []Certificate{*sigCert, *encCert}
173
174 if hs.clientHello.scts {
175 hs.hello.scts = hs.cert[0].SignedCertificateTimestamps
176 }
177
178 if hs.checkForResumption() {
179 return true, nil
180 }
181
182 var preferenceList, supportedList []uint16
183 if c.config.PreferServerCipherSuites {
184 preferenceList = getCipherSuites(c.config)
185 supportedList = hs.clientHello.cipherSuites
186 } else {
187 preferenceList = hs.clientHello.cipherSuites
188 supportedList = getCipherSuites(c.config)
189 }
190
191 for _, id := range preferenceList {
192 if hs.setCipherSuite(id, supportedList, c.vers) {
193 break
194 }
195 }
196
197 if hs.suite == nil {
198 _ = c.sendAlert(alertHandshakeFailure)
199 return false, errors.New("tls: no cipher suite supported by both client and server")
200 }
201
202
203 for _, id := range hs.clientHello.cipherSuites {
204 if id == TLS_FALLBACK_SCSV {
205
206 if hs.clientHello.vers < c.config.maxVersion() {
207 _ = c.sendAlert(alertInappropriateFallback)
208 return false, errors.New("tls: client using inappropriate protocol fallback")
209 }
210 break
211 }
212 }
213 return false, nil
214 }
215
216
217
218
219
220
221
222
223
224
225 func runServerHandshakeGM(c *Conn, hs *serverHandshakeStateGM, isResume bool) error {
226
227 c.buffering = true
228 if isResume {
229
230 if err := hs.doResumeHandshake(); err != nil {
231 return err
232 }
233 if err := hs.establishKeys(); err != nil {
234 return err
235 }
236
237
238
239 if hs.hello.ticketSupported {
240 if err := hs.sendSessionTicket(); err != nil {
241 return err
242 }
243 }
244 if err := hs.sendFinished(c.serverFinished[:]); err != nil {
245 return err
246 }
247 if _, err := c.flush(); err != nil {
248 return err
249 }
250 c.clientFinishedIsFirst = false
251 if err := hs.readFinished(nil); err != nil {
252 return err
253 }
254 c.didResume = true
255 } else {
256
257
258 if err := hs.doFullHandshake(); err != nil {
259 return err
260 }
261 if err := hs.establishKeys(); err != nil {
262 return err
263 }
264 if err := hs.readFinished(c.clientFinished[:]); err != nil {
265 return err
266 }
267 c.clientFinishedIsFirst = true
268 c.buffering = true
269 if err := hs.sendSessionTicket(); err != nil {
270 return err
271 }
272 if err := hs.sendFinished(nil); err != nil {
273 return err
274 }
275 if _, err := c.flush(); err != nil {
276 return err
277 }
278 }
279
280 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
281 atomic.StoreUint32(&c.handshakeStatus, 1)
282
283 return nil
284 }
285
286
287
288
289
290
291
292
293
294
295 func processClientHello(c *Conn, hs *serverHandshakeState) (bool, error) {
296 if c.config.GetConfigForClient != nil {
297 if newConfig, err := c.config.GetConfigForClient(hs.clientHelloInfo()); err != nil {
298 _ = c.sendAlert(alertInternalError)
299 return false, err
300 } else if newConfig != nil {
301 newConfig.serverInitOnce.Do(func() { newConfig.serverInit(c.config) })
302 c.config = newConfig
303 }
304 }
305 var ok bool
306 var err error
307 c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
308 if !ok {
309 _ = c.sendAlert(alertProtocolVersion)
310 return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
311 }
312 c.haveVers = true
313
314 hs.hello = new(serverHelloMsg)
315
316 supportedCurve := false
317 preferredCurves := c.config.curvePreferences()
318 Curves:
319 for _, curve := range hs.clientHello.supportedCurves {
320 for _, supported := range preferredCurves {
321 if supported == curve {
322 supportedCurve = true
323 break Curves
324 }
325 }
326 }
327
328 supportedPointFormat := false
329 for _, pointFormat := range hs.clientHello.supportedPoints {
330 if pointFormat == pointFormatUncompressed {
331 supportedPointFormat = true
332 break
333 }
334 }
335 hs.ellipticOk = supportedCurve && supportedPointFormat
336
337 foundCompression := false
338
339 for _, compression := range hs.clientHello.compressionMethods {
340 if compression == compressionNone {
341 foundCompression = true
342 break
343 }
344 }
345
346 if !foundCompression {
347 _ = c.sendAlert(alertHandshakeFailure)
348 return false, errors.New("tls: client does not support uncompressed connections")
349 }
350
351 hs.hello.vers = c.vers
352 hs.hello.random = make([]byte, 32)
353 _, err = io.ReadFull(c.config.rand(), hs.hello.random)
354 if err != nil {
355 _ = c.sendAlert(alertInternalError)
356 return false, err
357 }
358
359 if len(hs.clientHello.secureRenegotiation) != 0 {
360 _ = c.sendAlert(alertHandshakeFailure)
361 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
362 }
363
364 hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
365 hs.hello.compressionMethod = compressionNone
366 if len(hs.clientHello.serverName) > 0 {
367 c.serverName = hs.clientHello.serverName
368 }
369
370 if len(hs.clientHello.alpnProtocols) > 0 {
371 if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback {
372 hs.hello.alpnProtocol = selectedProto
373 c.clientProtocol = selectedProto
374 }
375 } else {
376
377
378
379
380 if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
381 hs.hello.nextProtoNeg = true
382 hs.hello.nextProtos = c.config.NextProtos
383 }
384 }
385
386 hs.cert, err = c.config.getCertificate(hs.clientHelloInfo())
387 if err != nil {
388 _ = c.sendAlert(alertInternalError)
389 return false, err
390 }
391 if hs.clientHello.scts {
392 hs.hello.scts = hs.cert.SignedCertificateTimestamps
393 }
394
395 if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
396 switch priv.Public().(type) {
397 case *ecdsa.PublicKey:
398 hs.ecdsaOk = true
399 case *rsa.PublicKey:
400 hs.rsaSignOk = true
401 default:
402 _ = c.sendAlert(alertInternalError)
403 return false, fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
404 }
405 }
406 if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
407 switch priv.Public().(type) {
408 case *rsa.PublicKey:
409 hs.rsaDecryptOk = true
410 default:
411 _ = c.sendAlert(alertInternalError)
412 return false, fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
413 }
414 }
415
416 if hs.checkForResumption() {
417 return true, nil
418 }
419
420 var preferenceList, supportedList []uint16
421 if c.config.PreferServerCipherSuites {
422 preferenceList = c.config.cipherSuites()
423 supportedList = hs.clientHello.cipherSuites
424 } else {
425 preferenceList = hs.clientHello.cipherSuites
426 supportedList = c.config.cipherSuites()
427 }
428
429 for _, id := range preferenceList {
430 if hs.setCipherSuite(id, supportedList, c.vers) {
431 break
432 }
433 }
434
435 if hs.suite == nil {
436 _ = c.sendAlert(alertHandshakeFailure)
437 return false, errors.New("tls: no cipher suite supported by both client and server")
438 }
439
440
441 for _, id := range hs.clientHello.cipherSuites {
442 if id == TLS_FALLBACK_SCSV {
443
444 if hs.clientHello.vers < c.config.maxVersion() {
445 _ = c.sendAlert(alertInappropriateFallback)
446 return false, errors.New("tls: client using inappropriate protocol fallback")
447 }
448 break
449 }
450 }
451
452 return false, nil
453 }
454
455
456
457
458
459
460
461
462
463
464 func runServerHandshake(c *Conn, hs *serverHandshakeState, isResume bool) error {
465
466 c.buffering = true
467 if isResume {
468
469 if err := hs.doResumeHandshake(); err != nil {
470 return err
471 }
472 if err := hs.establishKeys(); err != nil {
473 return err
474 }
475
476
477
478 if hs.hello.ticketSupported {
479 if err := hs.sendSessionTicket(); err != nil {
480 return err
481 }
482 }
483 if err := hs.sendFinished(c.serverFinished[:]); err != nil {
484 return err
485 }
486 if _, err := c.flush(); err != nil {
487 return err
488 }
489 c.clientFinishedIsFirst = false
490 if err := hs.readFinished(nil); err != nil {
491 return err
492 }
493 c.didResume = true
494 } else {
495
496
497 if err := hs.doFullHandshake(); err != nil {
498 return err
499 }
500 if err := hs.establishKeys(); err != nil {
501 return err
502 }
503 if err := hs.readFinished(c.clientFinished[:]); err != nil {
504 return err
505 }
506 c.clientFinishedIsFirst = true
507 c.buffering = true
508 if err := hs.sendSessionTicket(); err != nil {
509 return err
510 }
511 if err := hs.sendFinished(nil); err != nil {
512 return err
513 }
514 if _, err := c.flush(); err != nil {
515 return err
516 }
517 }
518
519 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
520 atomic.StoreUint32(&c.handshakeStatus, 1)
521 return nil
522 }
523
524
525 func gmtRandom(raw *[]byte) uint32 {
526 rd := *raw
527 unixTime := time.Now().Unix()
528 rd[0] = uint8(unixTime >> 24)
529 rd[1] = uint8(unixTime >> 16)
530 rd[2] = uint8(unixTime >> 8)
531 rd[3] = uint8(unixTime)
532 return uint32(unixTime)
533 }
534
View as plain text