1 package hpke
2
3 import (
4 "encoding/binary"
5 "errors"
6 "fmt"
7 )
8
9 func (st state) keySchedule(ss, info, psk, pskID []byte) (*encdecContext, error) {
10 if err := st.verifyPSKInputs(psk, pskID); err != nil {
11 return nil, err
12 }
13
14 pskIDHash := st.labeledExtract(nil, []byte("psk_id_hash"), pskID)
15 infoHash := st.labeledExtract(nil, []byte("info_hash"), info)
16 keySchCtx := append(append(
17 []byte{st.modeID},
18 pskIDHash...),
19 infoHash...)
20
21 secret := st.labeledExtract(ss, []byte("secret"), psk)
22
23 Nk := uint16(st.aeadID.KeySize())
24 key := st.labeledExpand(secret, []byte("key"), keySchCtx, Nk)
25
26 aead, err := st.aeadID.New(key)
27 if err != nil {
28 return nil, err
29 }
30
31 Nn := uint16(aead.NonceSize())
32 baseNonce := st.labeledExpand(secret, []byte("base_nonce"), keySchCtx, Nn)
33 exporterSecret := st.labeledExpand(
34 secret,
35 []byte("exp"),
36 keySchCtx,
37 uint16(st.kdfID.ExtractSize()),
38 )
39
40 return &encdecContext{
41 st.Suite,
42 ss,
43 secret,
44 keySchCtx,
45 exporterSecret,
46 key,
47 baseNonce,
48 make([]byte, Nn),
49 aead,
50 make([]byte, Nn),
51 }, nil
52 }
53
54 func (st state) verifyPSKInputs(psk, pskID []byte) error {
55 gotPSK := psk != nil
56 gotPSKID := pskID != nil
57 if gotPSK != gotPSKID {
58 return errors.New("inconsistent PSK inputs")
59 }
60 switch st.modeID {
61 case modeBase | modeAuth:
62 if gotPSK {
63 return errors.New("PSK input provided when not needed")
64 }
65 case modePSK | modeAuthPSK:
66 if !gotPSK {
67 return errors.New("missing required PSK input")
68 }
69 }
70 return nil
71 }
72
73
74 func (suite Suite) Params() (KEM, KDF, AEAD) {
75 return suite.kemID, suite.kdfID, suite.aeadID
76 }
77
78 func (suite Suite) String() string {
79 return fmt.Sprintf(
80 "kem_id: %v kdf_id: %v aead_id: %v",
81 suite.kemID, suite.kdfID, suite.aeadID,
82 )
83 }
84
85 func (suite Suite) getSuiteID() (id [10]byte) {
86 id[0], id[1], id[2], id[3] = 'H', 'P', 'K', 'E'
87 binary.BigEndian.PutUint16(id[4:6], uint16(suite.kemID))
88 binary.BigEndian.PutUint16(id[6:8], uint16(suite.kdfID))
89 binary.BigEndian.PutUint16(id[8:10], uint16(suite.aeadID))
90 return
91 }
92
93 func (suite Suite) isValid() bool {
94 return suite.kemID.IsValid() &&
95 suite.kdfID.IsValid() &&
96 suite.aeadID.IsValid()
97 }
98
99 func (suite Suite) labeledExtract(salt, label, ikm []byte) []byte {
100 suiteID := suite.getSuiteID()
101 labeledIKM := append(append(append(append(
102 make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(ikm)),
103 versionLabel...),
104 suiteID[:]...),
105 label...),
106 ikm...)
107 return suite.kdfID.Extract(labeledIKM, salt)
108 }
109
110 func (suite Suite) labeledExpand(prk, label, info []byte, l uint16) []byte {
111 suiteID := suite.getSuiteID()
112 labeledInfo := make([]byte,
113 2, 2+len(versionLabel)+len(suiteID)+len(label)+len(info))
114 binary.BigEndian.PutUint16(labeledInfo[0:2], l)
115 labeledInfo = append(append(append(append(labeledInfo,
116 versionLabel...),
117 suiteID[:]...),
118 label...),
119 info...)
120 return suite.kdfID.Expand(prk, labeledInfo, uint(l))
121 }
122
View as plain text