1 package hpke
2
3 import (
4 "crypto/elliptic"
5 "crypto/rand"
6 "crypto/subtle"
7 "fmt"
8 "math/big"
9
10 "github.com/cloudflare/circl/kem"
11 )
12
13 type shortKEM struct {
14 dhKemBase
15 elliptic.Curve
16 }
17
18 func (s shortKEM) PrivateKeySize() int { return s.byteSize() }
19 func (s shortKEM) SeedSize() int { return s.byteSize() }
20 func (s shortKEM) CiphertextSize() int { return 1 + 2*s.byteSize() }
21 func (s shortKEM) PublicKeySize() int { return 1 + 2*s.byteSize() }
22 func (s shortKEM) EncapsulationSeedSize() int { return s.byteSize() }
23
24 func (s shortKEM) byteSize() int { return (s.Params().BitSize + 7) / 8 }
25
26 func (s shortKEM) sizeDH() int { return s.byteSize() }
27 func (s shortKEM) calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error {
28 PK := pk.(*shortKEMPubKey)
29 SK := sk.(*shortKEMPrivKey)
30 l := len(dh)
31 x, _ := s.ScalarMult(PK.x, PK.y, SK.priv)
32 if x.Sign() == 0 {
33 return ErrInvalidKEMSharedSecret
34 }
35 b := x.Bytes()
36 copy(dh[l-len(b):l], b)
37 return nil
38 }
39
40
41
42
43
44 func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
45
46
47 if len(seed) != s.SeedSize() {
48 panic(kem.ErrSeedSize)
49 }
50
51 bitmask := byte(0xFF)
52 if s.Params().BitSize == 521 {
53 bitmask = 0x01
54 }
55
56 dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
57 var bytes []byte
58 ctr := 0
59 for skBig := new(big.Int); skBig.Sign() == 0 || skBig.Cmp(s.Params().N) >= 0; ctr++ {
60 if ctr > 255 {
61 panic("derive key error")
62 }
63 bytes = s.labeledExpand(
64 dkpPrk,
65 []byte("candidate"),
66 []byte{byte(ctr)},
67 uint16(s.byteSize()),
68 )
69 bytes[0] &= bitmask
70 skBig.SetBytes(bytes)
71 }
72 l := s.PrivateKeySize()
73 sk := &shortKEMPrivKey{s, make([]byte, l), nil}
74 copy(sk.priv[l-len(bytes):], bytes)
75 return sk.Public(), sk
76 }
77
78 func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
79 sk, x, y, err := elliptic.GenerateKey(s, rand.Reader)
80 pub := &shortKEMPubKey{s, x, y}
81 return pub, &shortKEMPrivKey{s, sk, pub}, err
82 }
83
84 func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
85 l := s.PrivateKeySize()
86 if len(data) < l {
87 return nil, ErrInvalidKEMPrivateKey
88 }
89 sk := &shortKEMPrivKey{s, make([]byte, l), nil}
90 copy(sk.priv[l-len(data):l], data[:l])
91 if !sk.validate() {
92 return nil, ErrInvalidKEMPrivateKey
93 }
94
95 return sk, nil
96 }
97
98 func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
99 x, y := elliptic.Unmarshal(s, data)
100 if x == nil {
101 return nil, ErrInvalidKEMPublicKey
102 }
103 key := &shortKEMPubKey{s, x, y}
104 if !key.validate() {
105 return nil, ErrInvalidKEMPublicKey
106 }
107 return key, nil
108 }
109
110 type shortKEMPubKey struct {
111 scheme shortKEM
112 x, y *big.Int
113 }
114
115 func (k *shortKEMPubKey) String() string {
116 return fmt.Sprintf("x: %v\ny: %v", k.x.Text(16), k.y.Text(16))
117 }
118 func (k *shortKEMPubKey) Scheme() kem.Scheme { return k.scheme }
119 func (k *shortKEMPubKey) MarshalBinary() ([]byte, error) {
120 return elliptic.Marshal(k.scheme, k.x, k.y), nil
121 }
122
123 func (k *shortKEMPubKey) Equal(pk kem.PublicKey) bool {
124 k1, ok := pk.(*shortKEMPubKey)
125 return ok &&
126 k.scheme.Params().Name == k1.scheme.Params().Name &&
127 k.x.Cmp(k1.x) == 0 &&
128 k.y.Cmp(k1.y) == 0
129 }
130
131 func (k *shortKEMPubKey) validate() bool {
132 p := k.scheme.Params().P
133 notAtInfinity := k.x.Sign() > 0 && k.y.Sign() > 0
134 lessThanP := k.x.Cmp(p) < 0 && k.y.Cmp(p) < 0
135 onCurve := k.scheme.IsOnCurve(k.x, k.y)
136 return notAtInfinity && lessThanP && onCurve
137 }
138
139 type shortKEMPrivKey struct {
140 scheme shortKEM
141 priv []byte
142 pub *shortKEMPubKey
143 }
144
145 func (k *shortKEMPrivKey) String() string { return fmt.Sprintf("%x", k.priv) }
146 func (k *shortKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
147 func (k *shortKEMPrivKey) MarshalBinary() ([]byte, error) {
148 return append(make([]byte, 0, k.scheme.PrivateKeySize()), k.priv...), nil
149 }
150
151 func (k *shortKEMPrivKey) Equal(pk kem.PrivateKey) bool {
152 k1, ok := pk.(*shortKEMPrivKey)
153 return ok &&
154 k.scheme.Params().Name == k1.scheme.Params().Name &&
155 subtle.ConstantTimeCompare(k.priv, k1.priv) == 1
156 }
157
158 func (k *shortKEMPrivKey) Public() kem.PublicKey {
159 if k.pub == nil {
160 x, y := k.scheme.ScalarBaseMult(k.priv)
161 k.pub = &shortKEMPubKey{k.scheme, x, y}
162 }
163 return k.pub
164 }
165
166 func (k *shortKEMPrivKey) validate() bool {
167 n := new(big.Int).SetBytes(k.priv)
168 order := k.scheme.Curve.Params().N
169 return len(k.priv) == k.scheme.PrivateKeySize() && n.Cmp(order) < 0
170 }
171
View as plain text