1 package hpke
2
3 import (
4 "crypto"
5 "crypto/rand"
6 "encoding/binary"
7 "io"
8
9 "github.com/cloudflare/circl/kem"
10 "golang.org/x/crypto/hkdf"
11 )
12
13 type dhKEM interface {
14 sizeDH() int
15 calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error
16 SeedSize() int
17 DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey)
18 UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
19 UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error)
20 }
21
22 type kemBase struct {
23 id KEM
24 name string
25 crypto.Hash
26 }
27
28 type dhKemBase struct {
29 kemBase
30 dhKEM
31 }
32
33 func (k kemBase) Name() string { return k.name }
34 func (k kemBase) SharedKeySize() int { return k.Hash.Size() }
35
36 func (k kemBase) getSuiteID() (sid [5]byte) {
37 sid[0], sid[1], sid[2] = 'K', 'E', 'M'
38 binary.BigEndian.PutUint16(sid[3:5], uint16(k.id))
39 return
40 }
41
42 func (k kemBase) extractExpand(dh, kemCtx []byte) []byte {
43 eaePkr := k.labeledExtract([]byte(""), []byte("eae_prk"), dh)
44 return k.labeledExpand(
45 eaePkr,
46 []byte("shared_secret"),
47 kemCtx,
48 uint16(k.Size()),
49 )
50 }
51
52 func (k kemBase) labeledExtract(salt, label, info []byte) []byte {
53 suiteID := k.getSuiteID()
54 labeledIKM := append(append(append(append(
55 make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(info)),
56 versionLabel...),
57 suiteID[:]...),
58 label...),
59 info...)
60 return hkdf.Extract(k.New, labeledIKM, salt)
61 }
62
63 func (k kemBase) labeledExpand(prk, label, info []byte, l uint16) []byte {
64 suiteID := k.getSuiteID()
65 labeledInfo := make(
66 []byte,
67 2,
68 2+len(versionLabel)+len(suiteID)+len(label)+len(info),
69 )
70 binary.BigEndian.PutUint16(labeledInfo[0:2], l)
71 labeledInfo = append(append(append(append(labeledInfo,
72 versionLabel...),
73 suiteID[:]...),
74 label...),
75 info...)
76 b := make([]byte, l)
77 rd := hkdf.Expand(k.New, prk, labeledInfo)
78 if _, err := io.ReadFull(rd, b); err != nil {
79 panic(err)
80 }
81 return b
82 }
83
84 func (k dhKemBase) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) (
85 ct []byte, ss []byte, err error,
86 ) {
87 seed := make([]byte, k.SeedSize())
88 _, err = io.ReadFull(rand.Reader, seed)
89 if err != nil {
90 return nil, nil, err
91 }
92
93 return k.authEncap(pkr, sks, seed)
94 }
95
96 func (k dhKemBase) Encapsulate(pkr kem.PublicKey) (
97 ct []byte, ss []byte, err error,
98 ) {
99 seed := make([]byte, k.SeedSize())
100 _, err = io.ReadFull(rand.Reader, seed)
101 if err != nil {
102 return nil, nil, err
103 }
104
105 return k.encap(pkr, seed)
106 }
107
108 func (k dhKemBase) AuthEncapsulateDeterministically(
109 pkr kem.PublicKey, sks kem.PrivateKey, seed []byte,
110 ) (ct, ss []byte, err error) {
111 return k.authEncap(pkr, sks, seed)
112 }
113
114 func (k dhKemBase) EncapsulateDeterministically(
115 pkr kem.PublicKey, seed []byte,
116 ) (ct, ss []byte, err error) {
117 return k.encap(pkr, seed)
118 }
119
120 func (k dhKemBase) encap(
121 pkR kem.PublicKey,
122 seed []byte,
123 ) (ct []byte, ss []byte, err error) {
124 dh := make([]byte, k.sizeDH())
125 enc, kemCtx, err := k.coreEncap(dh, pkR, seed)
126 if err != nil {
127 return nil, nil, err
128 }
129 ss = k.extractExpand(dh, kemCtx)
130 return enc, ss, nil
131 }
132
133 func (k dhKemBase) authEncap(
134 pkR kem.PublicKey,
135 skS kem.PrivateKey,
136 seed []byte,
137 ) (ct []byte, ss []byte, err error) {
138 dhLen := k.sizeDH()
139 dh := make([]byte, 2*dhLen)
140 enc, kemCtx, err := k.coreEncap(dh[:dhLen], pkR, seed)
141 if err != nil {
142 return nil, nil, err
143 }
144
145 err = k.calcDH(dh[dhLen:], skS, pkR)
146 if err != nil {
147 return nil, nil, err
148 }
149
150 pkS := skS.Public()
151 pkSm, err := pkS.MarshalBinary()
152 if err != nil {
153 return nil, nil, err
154 }
155 kemCtx = append(kemCtx, pkSm...)
156
157 ss = k.extractExpand(dh, kemCtx)
158 return enc, ss, nil
159 }
160
161 func (k dhKemBase) coreEncap(
162 dh []byte,
163 pkR kem.PublicKey,
164 seed []byte,
165 ) (enc []byte, kemCtx []byte, err error) {
166 pkE, skE := k.DeriveKeyPair(seed)
167 err = k.calcDH(dh, skE, pkR)
168 if err != nil {
169 return nil, nil, err
170 }
171
172 enc, err = pkE.MarshalBinary()
173 if err != nil {
174 return nil, nil, err
175 }
176 pkRm, err := pkR.MarshalBinary()
177 if err != nil {
178 return nil, nil, err
179 }
180 kemCtx = append(append([]byte{}, enc...), pkRm...)
181
182 return enc, kemCtx, nil
183 }
184
185 func (k dhKemBase) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) {
186 dh := make([]byte, k.sizeDH())
187 kemCtx, err := k.coreDecap(dh, skr, ct)
188 if err != nil {
189 return nil, err
190 }
191 return k.extractExpand(dh, kemCtx), nil
192 }
193
194 func (k dhKemBase) AuthDecapsulate(
195 skR kem.PrivateKey,
196 ct []byte,
197 pkS kem.PublicKey,
198 ) ([]byte, error) {
199 dhLen := k.sizeDH()
200 dh := make([]byte, 2*dhLen)
201 kemCtx, err := k.coreDecap(dh[:dhLen], skR, ct)
202 if err != nil {
203 return nil, err
204 }
205
206 err = k.calcDH(dh[dhLen:], skR, pkS)
207 if err != nil {
208 return nil, err
209 }
210
211 pkSm, err := pkS.MarshalBinary()
212 if err != nil {
213 return nil, err
214 }
215 kemCtx = append(kemCtx, pkSm...)
216 return k.extractExpand(dh, kemCtx), nil
217 }
218
219 func (k dhKemBase) coreDecap(
220 dh []byte,
221 skR kem.PrivateKey,
222 ct []byte,
223 ) ([]byte, error) {
224 pkE, err := k.UnmarshalBinaryPublicKey(ct)
225 if err != nil {
226 return nil, err
227 }
228
229 err = k.calcDH(dh, skR, pkE)
230 if err != nil {
231 return nil, err
232 }
233
234 pkR := skR.Public()
235 pkRm, err := pkR.MarshalBinary()
236 if err != nil {
237 return nil, err
238 }
239
240 return append(append([]byte{}, ct...), pkRm...), nil
241 }
242
View as plain text