1
2
3
4
5 package wycheproof
6
7 import (
8 "bytes"
9 "crypto/ecdh"
10 "fmt"
11 "testing"
12 )
13
14 func TestECDHStdLib(t *testing.T) {
15 type ECDHTestVector struct {
16
17 Comment string `json:"comment,omitempty"`
18
19 Flags []string `json:"flags,omitempty"`
20
21 Private string `json:"private,omitempty"`
22
23 Public string `json:"public,omitempty"`
24
25 Result string `json:"result,omitempty"`
26
27 Shared string `json:"shared,omitempty"`
28
29 TcID int `json:"tcId,omitempty"`
30 }
31
32 type ECDHTestGroup struct {
33 Curve string `json:"curve,omitempty"`
34 Tests []*ECDHTestVector `json:"tests,omitempty"`
35 }
36
37 type Root struct {
38 TestGroups []*ECDHTestGroup `json:"testGroups,omitempty"`
39 }
40
41 flagsShouldPass := map[string]bool{
42
43 "CompressedPoint": false,
44
45 "UnnamedCurve": false,
46
47 "WrongOrder": false,
48 "UnusedParam": false,
49
50
51 "Twist": true,
52 "SmallPublicKey": false,
53 "LowOrderPublic": false,
54 "ZeroSharedSecret": false,
55 "NonCanonicalPublic": true,
56 }
57
58
59
60 curveToCurve := map[string]ecdh.Curve{
61 "secp256r1": ecdh.P256(),
62 "secp384r1": ecdh.P384(),
63 "secp521r1": ecdh.P521(),
64 "curve25519": ecdh.X25519(),
65 }
66
67 curveToKeySize := map[string]int{
68 "secp256r1": 32,
69 "secp384r1": 48,
70 "secp521r1": 66,
71 "curve25519": 32,
72 }
73
74 for _, f := range []string{
75 "ecdh_secp256r1_ecpoint_test.json",
76 "ecdh_secp384r1_ecpoint_test.json",
77 "ecdh_secp521r1_ecpoint_test.json",
78 "x25519_test.json",
79 } {
80 var root Root
81 readTestVector(t, f, &root)
82 for _, tg := range root.TestGroups {
83 if _, ok := curveToCurve[tg.Curve]; !ok {
84 continue
85 }
86 for _, tt := range tg.Tests {
87 tg, tt := tg, tt
88 t.Run(fmt.Sprintf("%s/%d", tg.Curve, tt.TcID), func(t *testing.T) {
89 t.Logf("Type: %v", tt.Result)
90 t.Logf("Flags: %q", tt.Flags)
91 t.Log(tt.Comment)
92
93 shouldPass := shouldPass(tt.Result, tt.Flags, flagsShouldPass)
94
95 curve := curveToCurve[tg.Curve]
96 p := decodeHex(tt.Public)
97 pub, err := curve.NewPublicKey(p)
98 if err != nil {
99 if shouldPass {
100 t.Errorf("NewPublicKey: %v", err)
101 }
102 return
103 }
104
105 privBytes := decodeHex(tt.Private)
106 if len(privBytes) != curveToKeySize[tg.Curve] {
107 t.Skipf("non-standard key size %d", len(privBytes))
108 }
109
110 priv, err := curve.NewPrivateKey(privBytes)
111 if err != nil {
112 if shouldPass {
113 t.Errorf("NewPrivateKey: %v", err)
114 }
115 return
116 }
117
118 shared := decodeHex(tt.Shared)
119 x, err := priv.ECDH(pub)
120 if err != nil {
121 if tg.Curve == "curve25519" && !shouldPass {
122
123
124 return
125 }
126 t.Fatalf("ECDH: %v", err)
127 }
128
129 if bytes.Equal(shared, x) != shouldPass {
130 if shouldPass {
131 t.Errorf("ECDH = %x, want %x", shared, x)
132 } else {
133 t.Errorf("ECDH = %x, want anything else", shared)
134 }
135 }
136 })
137 }
138 }
139 }
140 }
141
View as plain text