1
2
3
4
5
6
7 package ecdh
8
9 import (
10 "bytes"
11 "crypto/rand"
12 "github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
13 "io"
14 "testing"
15
16 "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
17 )
18
19 func TestCurves(t *testing.T) {
20 for _, curve := range ecc.Curves {
21 ECDHCurve, ok := curve.Curve.(ecc.ECDHCurve)
22 if !ok {
23 continue
24 }
25
26 t.Run(ECDHCurve.GetCurveName(), func(t *testing.T) {
27 testFingerprint := make([]byte, 20)
28 _, err := io.ReadFull(rand.Reader, testFingerprint[:])
29 if err != nil {
30 t.Fatal(err)
31 }
32
33 priv := testGenerate(t, ECDHCurve)
34 testEncryptDecrypt(t, priv, curve.Oid.Bytes(), testFingerprint)
35 testValidation(t, priv)
36
37
38 priv = testGenerate(t, ECDHCurve)
39 testMarshalUnmarshal(t, priv)
40 })
41 }
42 }
43
44 func testGenerate(t *testing.T, curve ecc.ECDHCurve) *PrivateKey {
45 kdf := KDF{
46 Hash: algorithm.SHA512,
47 Cipher: algorithm.AES256,
48 }
49
50 priv, err := GenerateKey(rand.Reader, curve, kdf)
51 if err != nil {
52 t.Fatal(err)
53 }
54
55 return priv
56 }
57
58 func testEncryptDecrypt(t *testing.T, priv *PrivateKey, oid, fingerprint []byte) {
59 message := []byte("hello world")
60
61 vsG, m, err := Encrypt(rand.Reader, &priv.PublicKey, message, oid, fingerprint)
62 if err != nil {
63 t.Errorf("error encrypting: %s", err)
64 }
65
66 message2, err := Decrypt(priv, vsG, m, oid, fingerprint)
67 if err != nil {
68 t.Errorf("error decrypting: %s", err)
69 }
70
71 if !bytes.Equal(message2, message) {
72 t.Errorf("decryption failed, got: %x, want: %x", message2, message)
73 }
74 }
75
76 func testValidation(t *testing.T, priv *PrivateKey) {
77 if err := Validate(priv); err != nil {
78 t.Fatalf("valid key marked as invalid: %s", err)
79 }
80
81 priv.D[5] ^= 1
82 if err := Validate(priv); err == nil {
83 t.Fatalf("failed to detect invalid key")
84 }
85 }
86
87 func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) {
88 p := priv.MarshalPoint()
89 d := priv.MarshalByteSecret()
90
91 parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF.Hash, priv.KDF.Cipher))
92
93 if err := parsed.UnmarshalPoint(p); err != nil {
94 t.Fatalf("unable to unmarshal point: %s", err)
95 }
96
97 if err := parsed.UnmarshalByteSecret(d); err != nil {
98 t.Fatalf("unable to unmarshal integer: %s", err)
99 }
100
101 expectedD := make([]byte, len(priv.D))
102 copy(expectedD, priv.D)
103
104
105 if priv.curve.GetCurveName() == "curve25519" {
106 expectedD[0] &= 248
107 expectedD[31] &= 127
108 expectedD[31] |= 64
109 }
110
111 if !bytes.Equal(priv.Point, parsed.Point) || !bytes.Equal(expectedD, parsed.D) {
112 t.Fatal("failed to marshal/unmarshal correctly")
113 }
114 }
115
View as plain text