1
2
3
4
5
6
7 package ecdh
8
9 import (
10 "bytes"
11 "errors"
12 "io"
13
14 "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap"
15 "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
16 "github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
17 )
18
19 type KDF struct {
20 Hash algorithm.Hash
21 Cipher algorithm.Cipher
22 }
23
24 type PublicKey struct {
25 curve ecc.ECDHCurve
26 Point []byte
27 KDF
28 }
29
30 type PrivateKey struct {
31 PublicKey
32 D []byte
33 }
34
35 func NewPublicKey(curve ecc.ECDHCurve, kdfHash algorithm.Hash, kdfCipher algorithm.Cipher) *PublicKey {
36 return &PublicKey{
37 curve: curve,
38 KDF: KDF{
39 Hash: kdfHash,
40 Cipher: kdfCipher,
41 },
42 }
43 }
44
45 func NewPrivateKey(key PublicKey) *PrivateKey {
46 return &PrivateKey{
47 PublicKey: key,
48 }
49 }
50
51 func (pk *PublicKey) GetCurve() ecc.ECDHCurve {
52 return pk.curve
53 }
54
55 func (pk *PublicKey) MarshalPoint() []byte {
56 return pk.curve.MarshalBytePoint(pk.Point)
57 }
58
59 func (pk *PublicKey) UnmarshalPoint(p []byte) error {
60 pk.Point = pk.curve.UnmarshalBytePoint(p)
61 if pk.Point == nil {
62 return errors.New("ecdh: failed to parse EC point")
63 }
64 return nil
65 }
66
67 func (sk *PrivateKey) MarshalByteSecret() []byte {
68 return sk.curve.MarshalByteSecret(sk.D)
69 }
70
71 func (sk *PrivateKey) UnmarshalByteSecret(d []byte) error {
72 sk.D = sk.curve.UnmarshalByteSecret(d)
73
74 if sk.D == nil {
75 return errors.New("ecdh: failed to parse scalar")
76 }
77 return nil
78 }
79
80 func GenerateKey(rand io.Reader, c ecc.ECDHCurve, kdf KDF) (priv *PrivateKey, err error) {
81 priv = new(PrivateKey)
82 priv.PublicKey.curve = c
83 priv.PublicKey.KDF = kdf
84 priv.PublicKey.Point, priv.D, err = c.GenerateECDH(rand)
85 return
86 }
87
88 func Encrypt(random io.Reader, pub *PublicKey, msg, curveOID, fingerprint []byte) (vsG, c []byte, err error) {
89 if len(msg) > 40 {
90 return nil, nil, errors.New("ecdh: message too long")
91 }
92
93
94
95 padding := make([]byte, 40-len(msg))
96 for i := range padding {
97 padding[i] = byte(40 - len(msg))
98 }
99 m := append(msg, padding...)
100
101 ephemeral, zb, err := pub.curve.Encaps(random, pub.Point)
102 if err != nil {
103 return nil, nil, err
104 }
105
106 vsG = pub.curve.MarshalBytePoint(ephemeral)
107
108 z, err := buildKey(pub, zb, curveOID, fingerprint, false, false)
109 if err != nil {
110 return nil, nil, err
111 }
112
113 if c, err = keywrap.Wrap(z, m); err != nil {
114 return nil, nil, err
115 }
116
117 return vsG, c, nil
118
119 }
120
121 func Decrypt(priv *PrivateKey, vsG, c, curveOID, fingerprint []byte) (msg []byte, err error) {
122 var m []byte
123 zb, err := priv.PublicKey.curve.Decaps(priv.curve.UnmarshalBytePoint(vsG), priv.D)
124
125
126 for i := 0; i < 3; i++ {
127 var z []byte
128
129 z, err = buildKey(&priv.PublicKey, zb, curveOID, fingerprint, i == 1, i == 2)
130 if err != nil {
131 return nil, err
132 }
133
134
135 m, err = keywrap.Unwrap(z, c)
136 if err == nil {
137 break
138 }
139 }
140
141
142 if err != nil {
143 return nil, err
144 }
145
146
147
148 return m[:len(m)-int(m[len(m)-1])], nil
149 }
150
151 func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLeading, stripTrailing bool) ([]byte, error) {
152
153
154
155 param := new(bytes.Buffer)
156 if _, err := param.Write(curveOID); err != nil {
157 return nil, err
158 }
159 algKDF := []byte{18, 3, 1, pub.KDF.Hash.Id(), pub.KDF.Cipher.Id()}
160 if _, err := param.Write(algKDF); err != nil {
161 return nil, err
162 }
163 if _, err := param.Write([]byte("Anonymous Sender ")); err != nil {
164 return nil, err
165 }
166
167 if _, err := param.Write(fingerprint[:20]); err != nil {
168 return nil, err
169 }
170 if param.Len()-len(curveOID) != 45 {
171 return nil, errors.New("ecdh: malformed KDF Param")
172 }
173
174
175 h := pub.KDF.Hash.New()
176 if _, err := h.Write([]byte{0x0, 0x0, 0x0, 0x1}); err != nil {
177 return nil, err
178 }
179 zbLen := len(zb)
180 i := 0
181 j := zbLen - 1
182 if stripLeading {
183
184 for i < zbLen && zb[i] == 0 {
185 i++
186 }
187 }
188 if stripTrailing {
189
190
191
192 for j >= 0 && zb[j] == 0 {
193 j--
194 }
195 }
196 if _, err := h.Write(zb[i : j+1]); err != nil {
197 return nil, err
198 }
199 if _, err := h.Write(param.Bytes()); err != nil {
200 return nil, err
201 }
202 mb := h.Sum(nil)
203
204 return mb[:pub.KDF.Cipher.KeySize()], nil
205
206 }
207
208 func Validate(priv *PrivateKey) error {
209 return priv.curve.ValidateECDH(priv.Point, priv.D)
210 }
211
View as plain text