1
15
16 package gmtls
17
18 import "bytes"
19
20 type clientHelloMsg struct {
21 raw []byte
22 vers uint16
23 random []byte
24 sessionId []byte
25 cipherSuites []uint16
26 compressionMethods []uint8
27 nextProtoNeg bool
28 serverName string
29 ocspStapling bool
30 scts bool
31 supportedCurves []CurveID
32 supportedPoints []uint8
33 ticketSupported bool
34 sessionTicket []uint8
35 supportedSignatureAlgorithms []SignatureScheme
36 secureRenegotiation []byte
37 secureRenegotiationSupported bool
38 alpnProtocols []string
39 }
40
41 func (m *clientHelloMsg) equal(i interface{}) bool {
42 m1, ok := i.(*clientHelloMsg)
43 if !ok {
44 return false
45 }
46
47 return bytes.Equal(m.raw, m1.raw) &&
48 m.vers == m1.vers &&
49 bytes.Equal(m.random, m1.random) &&
50 bytes.Equal(m.sessionId, m1.sessionId) &&
51 eqUint16s(m.cipherSuites, m1.cipherSuites) &&
52 bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
53 m.nextProtoNeg == m1.nextProtoNeg &&
54 m.serverName == m1.serverName &&
55 m.ocspStapling == m1.ocspStapling &&
56 m.scts == m1.scts &&
57 eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
58 bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
59 m.ticketSupported == m1.ticketSupported &&
60 bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
61 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
62 m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
63 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
64 eqStrings(m.alpnProtocols, m1.alpnProtocols)
65 }
66
67 func (m *clientHelloMsg) marshal() []byte {
68 if m.raw != nil {
69 return m.raw
70 }
71
72 length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
73 numExtensions := 0
74 extensionsLength := 0
75 if m.nextProtoNeg {
76 numExtensions++
77 }
78 if m.ocspStapling {
79 extensionsLength += 1 + 2 + 2
80 numExtensions++
81 }
82 if len(m.serverName) > 0 {
83 extensionsLength += 5 + len(m.serverName)
84 numExtensions++
85 }
86 if len(m.supportedCurves) > 0 {
87 extensionsLength += 2 + 2*len(m.supportedCurves)
88 numExtensions++
89 }
90 if len(m.supportedPoints) > 0 {
91 extensionsLength += 1 + len(m.supportedPoints)
92 numExtensions++
93 }
94 if m.ticketSupported {
95 extensionsLength += len(m.sessionTicket)
96 numExtensions++
97 }
98 if len(m.supportedSignatureAlgorithms) > 0 {
99 extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
100 numExtensions++
101 }
102 if m.secureRenegotiationSupported {
103 extensionsLength += 1 + len(m.secureRenegotiation)
104 numExtensions++
105 }
106 if len(m.alpnProtocols) > 0 {
107 extensionsLength += 2
108 for _, s := range m.alpnProtocols {
109 if l := len(s); l == 0 || l > 255 {
110 panic("invalid ALPN protocol")
111 }
112 extensionsLength++
113 extensionsLength += len(s)
114 }
115 numExtensions++
116 }
117 if m.scts {
118 numExtensions++
119 }
120 if numExtensions > 0 {
121 extensionsLength += 4 * numExtensions
122 length += 2 + extensionsLength
123 }
124
125 x := make([]byte, 4+length)
126 x[0] = typeClientHello
127 x[1] = uint8(length >> 16)
128 x[2] = uint8(length >> 8)
129 x[3] = uint8(length)
130 x[4] = uint8(m.vers >> 8)
131 x[5] = uint8(m.vers)
132 copy(x[6:38], m.random)
133 x[38] = uint8(len(m.sessionId))
134 copy(x[39:39+len(m.sessionId)], m.sessionId)
135 y := x[39+len(m.sessionId):]
136 y[0] = uint8(len(m.cipherSuites) >> 7)
137 y[1] = uint8(len(m.cipherSuites) << 1)
138 for i, suite := range m.cipherSuites {
139 y[2+i*2] = uint8(suite >> 8)
140 y[3+i*2] = uint8(suite)
141 }
142 z := y[2+len(m.cipherSuites)*2:]
143 z[0] = uint8(len(m.compressionMethods))
144 copy(z[1:], m.compressionMethods)
145
146 z = z[1+len(m.compressionMethods):]
147 if numExtensions > 0 {
148 z[0] = byte(extensionsLength >> 8)
149 z[1] = byte(extensionsLength)
150 z = z[2:]
151 }
152 if m.nextProtoNeg {
153 z[0] = byte(extensionNextProtoNeg >> 8)
154 z[1] = byte(extensionNextProtoNeg & 0xff)
155
156 z = z[4:]
157 }
158 if len(m.serverName) > 0 {
159 z[0] = byte(extensionServerName >> 8)
160 z[1] = byte(extensionServerName & 0xff)
161 l := len(m.serverName) + 5
162 z[2] = byte(l >> 8)
163 z[3] = byte(l)
164 z = z[4:]
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185 z[0] = byte((len(m.serverName) + 3) >> 8)
186 z[1] = byte(len(m.serverName) + 3)
187 z[3] = byte(len(m.serverName) >> 8)
188 z[4] = byte(len(m.serverName))
189 copy(z[5:], []byte(m.serverName))
190 z = z[l:]
191 }
192 if m.ocspStapling {
193
194 z[0] = byte(extensionStatusRequest >> 8)
195 z[1] = byte(extensionStatusRequest)
196 z[2] = 0
197 z[3] = 5
198 z[4] = 1
199
200 z = z[9:]
201 }
202 if len(m.supportedCurves) > 0 {
203
204 z[0] = byte(extensionSupportedCurves >> 8)
205 z[1] = byte(extensionSupportedCurves)
206 l := 2 + 2*len(m.supportedCurves)
207 z[2] = byte(l >> 8)
208 z[3] = byte(l)
209 l -= 2
210 z[4] = byte(l >> 8)
211 z[5] = byte(l)
212 z = z[6:]
213 for _, curve := range m.supportedCurves {
214 z[0] = byte(curve >> 8)
215 z[1] = byte(curve)
216 z = z[2:]
217 }
218 }
219 if len(m.supportedPoints) > 0 {
220
221 z[0] = byte(extensionSupportedPoints >> 8)
222 z[1] = byte(extensionSupportedPoints)
223 l := 1 + len(m.supportedPoints)
224 z[2] = byte(l >> 8)
225 z[3] = byte(l)
226 l--
227 z[4] = byte(l)
228 z = z[5:]
229 for _, pointFormat := range m.supportedPoints {
230 z[0] = pointFormat
231 z = z[1:]
232 }
233 }
234 if m.ticketSupported {
235
236 z[0] = byte(extensionSessionTicket >> 8)
237 z[1] = byte(extensionSessionTicket)
238 l := len(m.sessionTicket)
239 z[2] = byte(l >> 8)
240 z[3] = byte(l)
241 z = z[4:]
242 copy(z, m.sessionTicket)
243 z = z[len(m.sessionTicket):]
244 }
245 if len(m.supportedSignatureAlgorithms) > 0 {
246
247 z[0] = byte(extensionSignatureAlgorithms >> 8)
248 z[1] = byte(extensionSignatureAlgorithms)
249 l := 2 + 2*len(m.supportedSignatureAlgorithms)
250 z[2] = byte(l >> 8)
251 z[3] = byte(l)
252 z = z[4:]
253
254 l -= 2
255 z[0] = byte(l >> 8)
256 z[1] = byte(l)
257 z = z[2:]
258 for _, sigAlgo := range m.supportedSignatureAlgorithms {
259 z[0] = byte(sigAlgo >> 8)
260 z[1] = byte(sigAlgo)
261 z = z[2:]
262 }
263 }
264 if m.secureRenegotiationSupported {
265 z[0] = byte(extensionRenegotiationInfo >> 8)
266 z[1] = byte(extensionRenegotiationInfo & 0xff)
267 z[2] = 0
268 z[3] = byte(len(m.secureRenegotiation) + 1)
269 z[4] = byte(len(m.secureRenegotiation))
270 z = z[5:]
271 copy(z, m.secureRenegotiation)
272 z = z[len(m.secureRenegotiation):]
273 }
274 if len(m.alpnProtocols) > 0 {
275 z[0] = byte(extensionALPN >> 8)
276 z[1] = byte(extensionALPN & 0xff)
277 lengths := z[2:]
278 z = z[6:]
279
280 stringsLength := 0
281 for _, s := range m.alpnProtocols {
282 l := len(s)
283 z[0] = byte(l)
284 copy(z[1:], s)
285 z = z[1+l:]
286 stringsLength += 1 + l
287 }
288
289 lengths[2] = byte(stringsLength >> 8)
290 lengths[3] = byte(stringsLength)
291 stringsLength += 2
292 lengths[0] = byte(stringsLength >> 8)
293 lengths[1] = byte(stringsLength)
294 }
295 if m.scts {
296
297 z[0] = byte(extensionSCT >> 8)
298 z[1] = byte(extensionSCT)
299
300 z = z[4:]
301 }
302
303 m.raw = x
304
305 return x
306 }
307
308 func (m *clientHelloMsg) unmarshal(data []byte) bool {
309 if len(data) < 42 {
310 return false
311 }
312 m.raw = data
313 m.vers = uint16(data[4])<<8 | uint16(data[5])
314 m.random = data[6:38]
315 sessionIdLen := int(data[38])
316 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
317 return false
318 }
319 m.sessionId = data[39 : 39+sessionIdLen]
320 data = data[39+sessionIdLen:]
321 if len(data) < 2 {
322 return false
323 }
324
325
326 cipherSuiteLen := int(data[0])<<8 | int(data[1])
327 if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
328 return false
329 }
330 numCipherSuites := cipherSuiteLen / 2
331 m.cipherSuites = make([]uint16, numCipherSuites)
332 for i := 0; i < numCipherSuites; i++ {
333 m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
334 if m.cipherSuites[i] == scsvRenegotiation {
335 m.secureRenegotiationSupported = true
336 }
337 }
338 data = data[2+cipherSuiteLen:]
339 if len(data) < 1 {
340 return false
341 }
342 compressionMethodsLen := int(data[0])
343 if len(data) < 1+compressionMethodsLen {
344 return false
345 }
346 m.compressionMethods = data[1 : 1+compressionMethodsLen]
347
348 data = data[1+compressionMethodsLen:]
349
350 m.nextProtoNeg = false
351 m.serverName = ""
352 m.ocspStapling = false
353 m.ticketSupported = false
354 m.sessionTicket = nil
355 m.supportedSignatureAlgorithms = nil
356 m.alpnProtocols = nil
357 m.scts = false
358
359 if len(data) == 0 {
360
361 return true
362 }
363 if len(data) < 2 {
364 return false
365 }
366
367 extensionsLength := int(data[0])<<8 | int(data[1])
368 data = data[2:]
369 if extensionsLength != len(data) {
370 return false
371 }
372
373 for len(data) != 0 {
374 if len(data) < 4 {
375 return false
376 }
377 extension := uint16(data[0])<<8 | uint16(data[1])
378 length := int(data[2])<<8 | int(data[3])
379 data = data[4:]
380 if len(data) < length {
381 return false
382 }
383
384 switch extension {
385 case extensionServerName:
386 d := data[:length]
387 if len(d) < 2 {
388 return false
389 }
390 namesLen := int(d[0])<<8 | int(d[1])
391 d = d[2:]
392 if len(d) != namesLen {
393 return false
394 }
395 for len(d) > 0 {
396 if len(d) < 3 {
397 return false
398 }
399 nameType := d[0]
400 nameLen := int(d[1])<<8 | int(d[2])
401 d = d[3:]
402 if len(d) < nameLen {
403 return false
404 }
405 if nameType == 0 {
406 m.serverName = string(d[:nameLen])
407 break
408 }
409 d = d[nameLen:]
410 }
411 case extensionNextProtoNeg:
412 if length > 0 {
413 return false
414 }
415 m.nextProtoNeg = true
416 case extensionStatusRequest:
417 m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
418 case extensionSupportedCurves:
419
420 if length < 2 {
421 return false
422 }
423 l := int(data[0])<<8 | int(data[1])
424 if l%2 == 1 || length != l+2 {
425 return false
426 }
427 numCurves := l / 2
428 m.supportedCurves = make([]CurveID, numCurves)
429 d := data[2:]
430 for i := 0; i < numCurves; i++ {
431 m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
432 d = d[2:]
433 }
434 case extensionSupportedPoints:
435
436 if length < 1 {
437 return false
438 }
439 l := int(data[0])
440 if length != l+1 {
441 return false
442 }
443 m.supportedPoints = make([]uint8, l)
444 copy(m.supportedPoints, data[1:])
445 case extensionSessionTicket:
446
447 m.ticketSupported = true
448 m.sessionTicket = data[:length]
449 case extensionSignatureAlgorithms:
450
451 if length < 2 || length&1 != 0 {
452 return false
453 }
454 l := int(data[0])<<8 | int(data[1])
455 if l != length-2 {
456 return false
457 }
458 n := l / 2
459 d := data[2:]
460 m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
461 for i := range m.supportedSignatureAlgorithms {
462 m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
463 d = d[2:]
464
465 }
466 case extensionRenegotiationInfo:
467 if length == 0 {
468 return false
469 }
470 d := data[:length]
471 l := int(d[0])
472 d = d[1:]
473 if l != len(d) {
474 return false
475 }
476
477 m.secureRenegotiation = d
478 m.secureRenegotiationSupported = true
479 case extensionALPN:
480 if length < 2 {
481 return false
482 }
483 l := int(data[0])<<8 | int(data[1])
484 if l != length-2 {
485 return false
486 }
487 d := data[2:length]
488 for len(d) != 0 {
489 stringLen := int(d[0])
490 d = d[1:]
491 if stringLen == 0 || stringLen > len(d) {
492 return false
493 }
494 m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
495 d = d[stringLen:]
496 }
497 case extensionSCT:
498 m.scts = true
499 if length != 0 {
500 return false
501 }
502 }
503 data = data[length:]
504 }
505
506 return true
507 }
508
509 type serverHelloMsg struct {
510 raw []byte
511 vers uint16
512 random []byte
513 sessionId []byte
514 cipherSuite uint16
515 compressionMethod uint8
516 nextProtoNeg bool
517 nextProtos []string
518 ocspStapling bool
519 scts [][]byte
520 ticketSupported bool
521 secureRenegotiation []byte
522 secureRenegotiationSupported bool
523 alpnProtocol string
524 }
525
526 func (m *serverHelloMsg) equal(i interface{}) bool {
527 m1, ok := i.(*serverHelloMsg)
528 if !ok {
529 return false
530 }
531
532 if len(m.scts) != len(m1.scts) {
533 return false
534 }
535 for i, sct := range m.scts {
536 if !bytes.Equal(sct, m1.scts[i]) {
537 return false
538 }
539 }
540
541 return bytes.Equal(m.raw, m1.raw) &&
542 m.vers == m1.vers &&
543 bytes.Equal(m.random, m1.random) &&
544 bytes.Equal(m.sessionId, m1.sessionId) &&
545 m.cipherSuite == m1.cipherSuite &&
546 m.compressionMethod == m1.compressionMethod &&
547 m.nextProtoNeg == m1.nextProtoNeg &&
548 eqStrings(m.nextProtos, m1.nextProtos) &&
549 m.ocspStapling == m1.ocspStapling &&
550 m.ticketSupported == m1.ticketSupported &&
551 m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
552 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
553 m.alpnProtocol == m1.alpnProtocol
554 }
555
556 func (m *serverHelloMsg) marshal() []byte {
557 if m.raw != nil {
558 return m.raw
559 }
560
561 length := 38 + len(m.sessionId)
562 numExtensions := 0
563 extensionsLength := 0
564
565 nextProtoLen := 0
566 if m.nextProtoNeg {
567 numExtensions++
568 for _, v := range m.nextProtos {
569 nextProtoLen += len(v)
570 }
571 nextProtoLen += len(m.nextProtos)
572 extensionsLength += nextProtoLen
573 }
574 if m.ocspStapling {
575 numExtensions++
576 }
577 if m.ticketSupported {
578 numExtensions++
579 }
580 if m.secureRenegotiationSupported {
581 extensionsLength += 1 + len(m.secureRenegotiation)
582 numExtensions++
583 }
584 if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
585 if alpnLen >= 256 {
586 panic("invalid ALPN protocol")
587 }
588 extensionsLength += 2 + 1 + alpnLen
589 numExtensions++
590 }
591 sctLen := 0
592 if len(m.scts) > 0 {
593 for _, sct := range m.scts {
594 sctLen += len(sct) + 2
595 }
596 extensionsLength += 2 + sctLen
597 numExtensions++
598 }
599
600 if numExtensions > 0 {
601 extensionsLength += 4 * numExtensions
602 length += 2 + extensionsLength
603 }
604
605 x := make([]byte, 4+length)
606 x[0] = typeServerHello
607 x[1] = uint8(length >> 16)
608 x[2] = uint8(length >> 8)
609 x[3] = uint8(length)
610 x[4] = uint8(m.vers >> 8)
611 x[5] = uint8(m.vers)
612 copy(x[6:38], m.random)
613 x[38] = uint8(len(m.sessionId))
614 copy(x[39:39+len(m.sessionId)], m.sessionId)
615 z := x[39+len(m.sessionId):]
616 z[0] = uint8(m.cipherSuite >> 8)
617 z[1] = uint8(m.cipherSuite)
618 z[2] = m.compressionMethod
619
620 z = z[3:]
621 if numExtensions > 0 {
622 z[0] = byte(extensionsLength >> 8)
623 z[1] = byte(extensionsLength)
624 z = z[2:]
625 }
626 if m.nextProtoNeg {
627 z[0] = byte(extensionNextProtoNeg >> 8)
628 z[1] = byte(extensionNextProtoNeg & 0xff)
629 z[2] = byte(nextProtoLen >> 8)
630 z[3] = byte(nextProtoLen)
631 z = z[4:]
632
633 for _, v := range m.nextProtos {
634 l := len(v)
635 if l > 255 {
636 l = 255
637 }
638 z[0] = byte(l)
639 copy(z[1:], []byte(v[0:l]))
640 z = z[1+l:]
641 }
642 }
643 if m.ocspStapling {
644 z[0] = byte(extensionStatusRequest >> 8)
645 z[1] = byte(extensionStatusRequest)
646 z = z[4:]
647 }
648 if m.ticketSupported {
649 z[0] = byte(extensionSessionTicket >> 8)
650 z[1] = byte(extensionSessionTicket)
651 z = z[4:]
652 }
653 if m.secureRenegotiationSupported {
654 z[0] = byte(extensionRenegotiationInfo >> 8)
655 z[1] = byte(extensionRenegotiationInfo & 0xff)
656 z[2] = 0
657 z[3] = byte(len(m.secureRenegotiation) + 1)
658 z[4] = byte(len(m.secureRenegotiation))
659 z = z[5:]
660 copy(z, m.secureRenegotiation)
661 z = z[len(m.secureRenegotiation):]
662 }
663 if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
664 z[0] = byte(extensionALPN >> 8)
665 z[1] = byte(extensionALPN & 0xff)
666 l := 2 + 1 + alpnLen
667 z[2] = byte(l >> 8)
668 z[3] = byte(l)
669 l -= 2
670 z[4] = byte(l >> 8)
671 z[5] = byte(l)
672 l -= 1
673 z[6] = byte(l)
674 copy(z[7:], []byte(m.alpnProtocol))
675 z = z[7+alpnLen:]
676 }
677 if sctLen > 0 {
678 z[0] = byte(extensionSCT >> 8)
679 z[1] = byte(extensionSCT)
680 l := sctLen + 2
681 z[2] = byte(l >> 8)
682 z[3] = byte(l)
683 z[4] = byte(sctLen >> 8)
684 z[5] = byte(sctLen)
685
686 z = z[6:]
687 for _, sct := range m.scts {
688 z[0] = byte(len(sct) >> 8)
689 z[1] = byte(len(sct))
690 copy(z[2:], sct)
691 z = z[len(sct)+2:]
692 }
693 }
694
695 m.raw = x
696
697 return x
698 }
699
700 func (m *serverHelloMsg) unmarshal(data []byte) bool {
701 if len(data) < 42 {
702 return false
703 }
704 m.raw = data
705 m.vers = uint16(data[4])<<8 | uint16(data[5])
706 m.random = data[6:38]
707 sessionIdLen := int(data[38])
708 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
709 return false
710 }
711 m.sessionId = data[39 : 39+sessionIdLen]
712 data = data[39+sessionIdLen:]
713 if len(data) < 3 {
714 return false
715 }
716 m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
717 m.compressionMethod = data[2]
718 data = data[3:]
719
720 m.nextProtoNeg = false
721 m.nextProtos = nil
722 m.ocspStapling = false
723 m.scts = nil
724 m.ticketSupported = false
725 m.alpnProtocol = ""
726
727 if len(data) == 0 {
728
729 return true
730 }
731 if len(data) < 2 {
732 return false
733 }
734
735 extensionsLength := int(data[0])<<8 | int(data[1])
736 data = data[2:]
737 if len(data) != extensionsLength {
738 return false
739 }
740
741 for len(data) != 0 {
742 if len(data) < 4 {
743 return false
744 }
745 extension := uint16(data[0])<<8 | uint16(data[1])
746 length := int(data[2])<<8 | int(data[3])
747 data = data[4:]
748 if len(data) < length {
749 return false
750 }
751
752 switch extension {
753 case extensionNextProtoNeg:
754 m.nextProtoNeg = true
755 d := data[:length]
756 for len(d) > 0 {
757 l := int(d[0])
758 d = d[1:]
759 if l == 0 || l > len(d) {
760 return false
761 }
762 m.nextProtos = append(m.nextProtos, string(d[:l]))
763 d = d[l:]
764 }
765 case extensionStatusRequest:
766 if length > 0 {
767 return false
768 }
769 m.ocspStapling = true
770 case extensionSessionTicket:
771 if length > 0 {
772 return false
773 }
774 m.ticketSupported = true
775 case extensionRenegotiationInfo:
776 if length == 0 {
777 return false
778 }
779 d := data[:length]
780 l := int(d[0])
781 d = d[1:]
782 if l != len(d) {
783 return false
784 }
785
786 m.secureRenegotiation = d
787 m.secureRenegotiationSupported = true
788 case extensionALPN:
789 d := data[:length]
790 if len(d) < 3 {
791 return false
792 }
793 l := int(d[0])<<8 | int(d[1])
794 if l != len(d)-2 {
795 return false
796 }
797 d = d[2:]
798 l = int(d[0])
799 if l != len(d)-1 {
800 return false
801 }
802 d = d[1:]
803 if len(d) == 0 {
804
805 return false
806 }
807 m.alpnProtocol = string(d)
808 case extensionSCT:
809 d := data[:length]
810
811 if len(d) < 2 {
812 return false
813 }
814 l := int(d[0])<<8 | int(d[1])
815 d = d[2:]
816 if len(d) != l || l == 0 {
817 return false
818 }
819
820 m.scts = make([][]byte, 0, 3)
821 for len(d) != 0 {
822 if len(d) < 2 {
823 return false
824 }
825 sctLen := int(d[0])<<8 | int(d[1])
826 d = d[2:]
827 if sctLen == 0 || len(d) < sctLen {
828 return false
829 }
830 m.scts = append(m.scts, d[:sctLen])
831 d = d[sctLen:]
832 }
833 }
834 data = data[length:]
835 }
836
837 return true
838 }
839
840 type certificateMsg struct {
841 raw []byte
842 certificates [][]byte
843 }
844
845 func (m *certificateMsg) equal(i interface{}) bool {
846 m1, ok := i.(*certificateMsg)
847 if !ok {
848 return false
849 }
850
851 return bytes.Equal(m.raw, m1.raw) &&
852 eqByteSlices(m.certificates, m1.certificates)
853 }
854
855 func (m *certificateMsg) marshal() (x []byte) {
856 if m.raw != nil {
857 return m.raw
858 }
859
860 var i int
861 for _, slice := range m.certificates {
862 i += len(slice)
863 }
864
865 length := 3 + 3*len(m.certificates) + i
866 x = make([]byte, 4+length)
867 x[0] = typeCertificate
868 x[1] = uint8(length >> 16)
869 x[2] = uint8(length >> 8)
870 x[3] = uint8(length)
871
872 certificateOctets := length - 3
873 x[4] = uint8(certificateOctets >> 16)
874 x[5] = uint8(certificateOctets >> 8)
875 x[6] = uint8(certificateOctets)
876
877 y := x[7:]
878 for _, slice := range m.certificates {
879 y[0] = uint8(len(slice) >> 16)
880 y[1] = uint8(len(slice) >> 8)
881 y[2] = uint8(len(slice))
882 copy(y[3:], slice)
883 y = y[3+len(slice):]
884 }
885
886 m.raw = x
887 return
888 }
889
890 func (m *certificateMsg) unmarshal(data []byte) bool {
891 if len(data) < 7 {
892 return false
893 }
894
895 m.raw = data
896 certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
897 if uint32(len(data)) != certsLen+7 {
898 return false
899 }
900
901 numCerts := 0
902 d := data[7:]
903 for certsLen > 0 {
904 if len(d) < 4 {
905 return false
906 }
907 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
908 if uint32(len(d)) < 3+certLen {
909 return false
910 }
911 d = d[3+certLen:]
912 certsLen -= 3 + certLen
913 numCerts++
914 }
915
916 m.certificates = make([][]byte, numCerts)
917 d = data[7:]
918 for i := 0; i < numCerts; i++ {
919 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
920 m.certificates[i] = d[3 : 3+certLen]
921 d = d[3+certLen:]
922 }
923
924 return true
925 }
926
927 type serverKeyExchangeMsg struct {
928 raw []byte
929 key []byte
930 }
931
932 func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
933 m1, ok := i.(*serverKeyExchangeMsg)
934 if !ok {
935 return false
936 }
937
938 return bytes.Equal(m.raw, m1.raw) &&
939 bytes.Equal(m.key, m1.key)
940 }
941
942 func (m *serverKeyExchangeMsg) marshal() []byte {
943 if m.raw != nil {
944 return m.raw
945 }
946 length := len(m.key)
947 x := make([]byte, length+4)
948 x[0] = typeServerKeyExchange
949 x[1] = uint8(length >> 16)
950 x[2] = uint8(length >> 8)
951 x[3] = uint8(length)
952 copy(x[4:], m.key)
953
954 m.raw = x
955 return x
956 }
957
958 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
959 m.raw = data
960 if len(data) < 4 {
961 return false
962 }
963 m.key = data[4:]
964 return true
965 }
966
967 type certificateStatusMsg struct {
968 raw []byte
969 statusType uint8
970 response []byte
971 }
972
973 func (m *certificateStatusMsg) equal(i interface{}) bool {
974 m1, ok := i.(*certificateStatusMsg)
975 if !ok {
976 return false
977 }
978
979 return bytes.Equal(m.raw, m1.raw) &&
980 m.statusType == m1.statusType &&
981 bytes.Equal(m.response, m1.response)
982 }
983
984 func (m *certificateStatusMsg) marshal() []byte {
985 if m.raw != nil {
986 return m.raw
987 }
988
989 var x []byte
990 if m.statusType == statusTypeOCSP {
991 x = make([]byte, 4+4+len(m.response))
992 x[0] = typeCertificateStatus
993 l := len(m.response) + 4
994 x[1] = byte(l >> 16)
995 x[2] = byte(l >> 8)
996 x[3] = byte(l)
997 x[4] = statusTypeOCSP
998
999 l -= 4
1000 x[5] = byte(l >> 16)
1001 x[6] = byte(l >> 8)
1002 x[7] = byte(l)
1003 copy(x[8:], m.response)
1004 } else {
1005 x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
1006 }
1007
1008 m.raw = x
1009 return x
1010 }
1011
1012 func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1013 m.raw = data
1014 if len(data) < 5 {
1015 return false
1016 }
1017 m.statusType = data[4]
1018
1019 m.response = nil
1020 if m.statusType == statusTypeOCSP {
1021 if len(data) < 8 {
1022 return false
1023 }
1024 respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
1025 if uint32(len(data)) != 4+4+respLen {
1026 return false
1027 }
1028 m.response = data[8:]
1029 }
1030 return true
1031 }
1032
1033 type serverHelloDoneMsg struct{}
1034
1035 func (m *serverHelloDoneMsg) equal(i interface{}) bool {
1036 _, ok := i.(*serverHelloDoneMsg)
1037 return ok
1038 }
1039
1040 func (m *serverHelloDoneMsg) marshal() []byte {
1041 x := make([]byte, 4)
1042 x[0] = typeServerHelloDone
1043 return x
1044 }
1045
1046 func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1047 return len(data) == 4
1048 }
1049
1050 type clientKeyExchangeMsg struct {
1051 raw []byte
1052 ciphertext []byte
1053 }
1054
1055 func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1056 m1, ok := i.(*clientKeyExchangeMsg)
1057 if !ok {
1058 return false
1059 }
1060
1061 return bytes.Equal(m.raw, m1.raw) &&
1062 bytes.Equal(m.ciphertext, m1.ciphertext)
1063 }
1064
1065 func (m *clientKeyExchangeMsg) marshal() []byte {
1066 if m.raw != nil {
1067 return m.raw
1068 }
1069 length := len(m.ciphertext)
1070 x := make([]byte, length+4)
1071 x[0] = typeClientKeyExchange
1072 x[1] = uint8(length >> 16)
1073 x[2] = uint8(length >> 8)
1074 x[3] = uint8(length)
1075 copy(x[4:], m.ciphertext)
1076
1077 m.raw = x
1078 return x
1079 }
1080
1081 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1082 m.raw = data
1083 if len(data) < 4 {
1084 return false
1085 }
1086 l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1087 if l != len(data)-4 {
1088 return false
1089 }
1090 m.ciphertext = data[4:]
1091 return true
1092 }
1093
1094 type finishedMsg struct {
1095 raw []byte
1096 verifyData []byte
1097 }
1098
1099 func (m *finishedMsg) equal(i interface{}) bool {
1100 m1, ok := i.(*finishedMsg)
1101 if !ok {
1102 return false
1103 }
1104
1105 return bytes.Equal(m.raw, m1.raw) &&
1106 bytes.Equal(m.verifyData, m1.verifyData)
1107 }
1108
1109 func (m *finishedMsg) marshal() (x []byte) {
1110 if m.raw != nil {
1111 return m.raw
1112 }
1113
1114 x = make([]byte, 4+len(m.verifyData))
1115 x[0] = typeFinished
1116 x[3] = byte(len(m.verifyData))
1117 copy(x[4:], m.verifyData)
1118 m.raw = x
1119 return
1120 }
1121
1122 func (m *finishedMsg) unmarshal(data []byte) bool {
1123 m.raw = data
1124 if len(data) < 4 {
1125 return false
1126 }
1127 m.verifyData = data[4:]
1128 return true
1129 }
1130
1131 type nextProtoMsg struct {
1132 raw []byte
1133 proto string
1134 }
1135
1136 func (m *nextProtoMsg) equal(i interface{}) bool {
1137 m1, ok := i.(*nextProtoMsg)
1138 if !ok {
1139 return false
1140 }
1141
1142 return bytes.Equal(m.raw, m1.raw) &&
1143 m.proto == m1.proto
1144 }
1145
1146 func (m *nextProtoMsg) marshal() []byte {
1147 if m.raw != nil {
1148 return m.raw
1149 }
1150 l := len(m.proto)
1151 if l > 255 {
1152 l = 255
1153 }
1154
1155 padding := 32 - (l+2)%32
1156 length := l + padding + 2
1157 x := make([]byte, length+4)
1158 x[0] = typeNextProtocol
1159 x[1] = uint8(length >> 16)
1160 x[2] = uint8(length >> 8)
1161 x[3] = uint8(length)
1162
1163 y := x[4:]
1164 y[0] = byte(l)
1165 copy(y[1:], []byte(m.proto[0:l]))
1166 y = y[1+l:]
1167 y[0] = byte(padding)
1168
1169 m.raw = x
1170
1171 return x
1172 }
1173
1174 func (m *nextProtoMsg) unmarshal(data []byte) bool {
1175 m.raw = data
1176
1177 if len(data) < 5 {
1178 return false
1179 }
1180 data = data[4:]
1181 protoLen := int(data[0])
1182 data = data[1:]
1183 if len(data) < protoLen {
1184 return false
1185 }
1186 m.proto = string(data[0:protoLen])
1187 data = data[protoLen:]
1188
1189 if len(data) < 1 {
1190 return false
1191 }
1192 paddingLen := int(data[0])
1193 data = data[1:]
1194 if len(data) != paddingLen {
1195 return false
1196 }
1197
1198 return true
1199 }
1200
1201 type certificateRequestMsg struct {
1202 raw []byte
1203
1204
1205
1206 hasSignatureAndHash bool
1207
1208 certificateTypes []byte
1209 supportedSignatureAlgorithms []SignatureScheme
1210 certificateAuthorities [][]byte
1211 }
1212
1213 func (m *certificateRequestMsg) equal(i interface{}) bool {
1214 m1, ok := i.(*certificateRequestMsg)
1215 if !ok {
1216 return false
1217 }
1218
1219 return bytes.Equal(m.raw, m1.raw) &&
1220 bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1221 eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1222 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
1223 }
1224
1225 func (m *certificateRequestMsg) marshal() (x []byte) {
1226 if m.raw != nil {
1227 return m.raw
1228 }
1229
1230
1231 length := 1 + len(m.certificateTypes) + 2
1232 casLength := 0
1233 for _, ca := range m.certificateAuthorities {
1234 casLength += 2 + len(ca)
1235 }
1236 length += casLength
1237
1238 if m.hasSignatureAndHash {
1239 length += 2 +2*len(m.supportedSignatureAlgorithms)
1240 }
1241
1242 x = make([]byte, 4+length)
1243 x[0] = typeCertificateRequest
1244 x[1] = uint8(length >> 16)
1245 x[2] = uint8(length >> 8)
1246 x[3] = uint8(length)
1247
1248 x[4] = uint8(len(m.certificateTypes))
1249
1250 copy(x[5:], m.certificateTypes)
1251 y := x[5+len(m.certificateTypes):]
1252
1253 if m.hasSignatureAndHash {
1254 n := len(m.supportedSignatureAlgorithms) * 2
1255 y[0] = uint8(n >> 8)
1256 y[1] = uint8(n)
1257 y = y[2:]
1258 for _, sigAlgo := range m.supportedSignatureAlgorithms {
1259 y[0] = uint8(sigAlgo >> 8)
1260 y[1] = uint8(sigAlgo)
1261 y = y[2:]
1262 }
1263 }
1264
1265 y[0] = uint8(casLength >> 8)
1266 y[1] = uint8(casLength)
1267 y = y[2:]
1268 for _, ca := range m.certificateAuthorities {
1269 y[0] = uint8(len(ca) >> 8)
1270 y[1] = uint8(len(ca))
1271 y = y[2:]
1272 copy(y, ca)
1273 y = y[len(ca):]
1274 }
1275
1276 m.raw = x
1277 return
1278 }
1279
1280 func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1281 m.raw = data
1282
1283 if len(data) < 5 {
1284 return false
1285 }
1286
1287 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1288 if uint32(len(data))-4 != length {
1289 return false
1290 }
1291
1292 numCertTypes := int(data[4])
1293 data = data[5:]
1294 if numCertTypes == 0 || len(data) <= numCertTypes {
1295 return false
1296 }
1297
1298 m.certificateTypes = make([]byte, numCertTypes)
1299 if copy(m.certificateTypes, data) != numCertTypes {
1300 return false
1301 }
1302
1303 data = data[numCertTypes:]
1304
1305 if m.hasSignatureAndHash {
1306 if len(data) < 2 {
1307 return false
1308 }
1309 sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1310 data = data[2:]
1311 if sigAndHashLen&1 != 0 {
1312 return false
1313 }
1314 if len(data) < int(sigAndHashLen) {
1315 return false
1316 }
1317 numSigAlgos := sigAndHashLen / 2
1318 m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
1319 for i := range m.supportedSignatureAlgorithms {
1320 m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1321 data = data[2:]
1322 }
1323 }
1324
1325 if len(data) < 2 {
1326 return false
1327 }
1328 casLength := uint16(data[0])<<8 | uint16(data[1])
1329 data = data[2:]
1330 if len(data) < int(casLength) {
1331 return false
1332 }
1333 cas := make([]byte, casLength)
1334 copy(cas, data)
1335 data = data[casLength:]
1336
1337 m.certificateAuthorities = nil
1338 for len(cas) > 0 {
1339 if len(cas) < 2 {
1340 return false
1341 }
1342 caLen := uint16(cas[0])<<8 | uint16(cas[1])
1343 cas = cas[2:]
1344
1345 if len(cas) < int(caLen) {
1346 return false
1347 }
1348
1349 m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1350 cas = cas[caLen:]
1351 }
1352
1353 return len(data) == 0
1354 }
1355
1356 type certificateVerifyMsg struct {
1357 raw []byte
1358 hasSignatureAndHash bool
1359 signatureAlgorithm SignatureScheme
1360 signature []byte
1361 }
1362
1363 func (m *certificateVerifyMsg) equal(i interface{}) bool {
1364 m1, ok := i.(*certificateVerifyMsg)
1365 if !ok {
1366 return false
1367 }
1368
1369 return bytes.Equal(m.raw, m1.raw) &&
1370 m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1371 m.signatureAlgorithm == m1.signatureAlgorithm &&
1372 bytes.Equal(m.signature, m1.signature)
1373 }
1374
1375 func (m *certificateVerifyMsg) marshal() (x []byte) {
1376 if m.raw != nil {
1377 return m.raw
1378 }
1379
1380
1381 siglength := len(m.signature)
1382 length := 2 + siglength
1383 if m.hasSignatureAndHash {
1384 length += 2
1385 }
1386 x = make([]byte, 4+length)
1387 x[0] = typeCertificateVerify
1388 x[1] = uint8(length >> 16)
1389 x[2] = uint8(length >> 8)
1390 x[3] = uint8(length)
1391 y := x[4:]
1392 if m.hasSignatureAndHash {
1393 y[0] = uint8(m.signatureAlgorithm >> 8)
1394 y[1] = uint8(m.signatureAlgorithm)
1395 y = y[2:]
1396 }
1397 y[0] = uint8(siglength >> 8)
1398 y[1] = uint8(siglength)
1399 copy(y[2:], m.signature)
1400
1401 m.raw = x
1402
1403 return
1404 }
1405
1406 func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1407 m.raw = data
1408
1409 if len(data) < 6 {
1410 return false
1411 }
1412
1413 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1414 if uint32(len(data))-4 != length {
1415 return false
1416 }
1417
1418 data = data[4:]
1419 if m.hasSignatureAndHash {
1420 m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1421 data = data[2:]
1422 }
1423
1424 if len(data) < 2 {
1425 return false
1426 }
1427 siglength := int(data[0])<<8 + int(data[1])
1428 data = data[2:]
1429 if len(data) != siglength {
1430 return false
1431 }
1432
1433 m.signature = data
1434
1435 return true
1436 }
1437
1438 type newSessionTicketMsg struct {
1439 raw []byte
1440 ticket []byte
1441 }
1442
1443 func (m *newSessionTicketMsg) equal(i interface{}) bool {
1444 m1, ok := i.(*newSessionTicketMsg)
1445 if !ok {
1446 return false
1447 }
1448
1449 return bytes.Equal(m.raw, m1.raw) &&
1450 bytes.Equal(m.ticket, m1.ticket)
1451 }
1452
1453 func (m *newSessionTicketMsg) marshal() (x []byte) {
1454 if m.raw != nil {
1455 return m.raw
1456 }
1457
1458
1459 ticketLen := len(m.ticket)
1460 length := 2 + 4 + ticketLen
1461 x = make([]byte, 4+length)
1462 x[0] = typeNewSessionTicket
1463 x[1] = uint8(length >> 16)
1464 x[2] = uint8(length >> 8)
1465 x[3] = uint8(length)
1466 x[8] = uint8(ticketLen >> 8)
1467 x[9] = uint8(ticketLen)
1468 copy(x[10:], m.ticket)
1469
1470 m.raw = x
1471
1472 return
1473 }
1474
1475 func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1476 m.raw = data
1477
1478 if len(data) < 10 {
1479 return false
1480 }
1481
1482 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1483 if uint32(len(data))-4 != length {
1484 return false
1485 }
1486
1487 ticketLen := int(data[8])<<8 + int(data[9])
1488 if len(data)-10 != ticketLen {
1489 return false
1490 }
1491
1492 m.ticket = data[10:]
1493
1494 return true
1495 }
1496
1497 type helloRequestMsg struct {
1498 }
1499
1500 func (*helloRequestMsg) marshal() []byte {
1501 return []byte{typeHelloRequest, 0, 0, 0}
1502 }
1503
1504 func (*helloRequestMsg) unmarshal(data []byte) bool {
1505 return len(data) == 4
1506 }
1507
1508 func eqUint16s(x, y []uint16) bool {
1509 if len(x) != len(y) {
1510 return false
1511 }
1512 for i, v := range x {
1513 if y[i] != v {
1514 return false
1515 }
1516 }
1517 return true
1518 }
1519
1520 func eqCurveIDs(x, y []CurveID) bool {
1521 if len(x) != len(y) {
1522 return false
1523 }
1524 for i, v := range x {
1525 if y[i] != v {
1526 return false
1527 }
1528 }
1529 return true
1530 }
1531
1532 func eqStrings(x, y []string) bool {
1533 if len(x) != len(y) {
1534 return false
1535 }
1536 for i, v := range x {
1537 if y[i] != v {
1538 return false
1539 }
1540 }
1541 return true
1542 }
1543
1544 func eqByteSlices(x, y [][]byte) bool {
1545 if len(x) != len(y) {
1546 return false
1547 }
1548 for i, v := range x {
1549 if !bytes.Equal(v, y[i]) {
1550 return false
1551 }
1552 }
1553 return true
1554 }
1555
1556 func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
1557 if len(x) != len(y) {
1558 return false
1559 }
1560 for i, v := range x {
1561 if v != y[i] {
1562 return false
1563 }
1564 }
1565 return true
1566 }
1567
View as plain text