1 package oprf
2
3 import (
4 "bytes"
5 "encoding"
6 "encoding/binary"
7 "encoding/hex"
8 "encoding/json"
9 "fmt"
10 "io"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/cloudflare/circl/group"
16 "github.com/cloudflare/circl/internal/test"
17 "github.com/cloudflare/circl/zk/dleq"
18 )
19
20 type vector struct {
21 Identifier string `json:"identifier"`
22 Mode Mode `json:"mode"`
23 Hash string `json:"hash"`
24 PkSm string `json:"pkSm"`
25 SkSm string `json:"skSm"`
26 Seed string `json:"seed"`
27 KeyInfo string `json:"keyInfo"`
28 GroupDST string `json:"groupDST"`
29 Vectors []struct {
30 Batch int `json:"Batch"`
31 Blind string `json:"Blind"`
32 Info string `json:"Info"`
33 BlindedElement string `json:"BlindedElement"`
34 EvaluationElement string `json:"EvaluationElement"`
35 Proof struct {
36 Proof string `json:"proof"`
37 R string `json:"r"`
38 } `json:"Proof"`
39 Input string `json:"Input"`
40 Output string `json:"Output"`
41 } `json:"vectors"`
42 }
43
44 func toBytes(t *testing.T, s, errMsg string) []byte {
45 t.Helper()
46 bytes, err := hex.DecodeString(s)
47 test.CheckNoErr(t, err, "decoding "+errMsg)
48
49 return bytes
50 }
51
52 func toListBytes(t *testing.T, s, errMsg string) [][]byte {
53 t.Helper()
54 strs := strings.Split(s, ",")
55 out := make([][]byte, len(strs))
56 for i := range strs {
57 out[i] = toBytes(t, strs[i], errMsg)
58 }
59
60 return out
61 }
62
63 func flattenList(t *testing.T, s, errMsg string) []byte {
64 t.Helper()
65 strs := strings.Split(s, ",")
66 out := []byte{0, 0}
67 binary.BigEndian.PutUint16(out, uint16(len(strs)))
68 for i := range strs {
69 out = append(out, toBytes(t, strs[i], errMsg)...)
70 }
71
72 return out
73 }
74
75 func toScalar(t *testing.T, g group.Group, s, errMsg string) group.Scalar {
76 t.Helper()
77 r := g.NewScalar()
78 rBytes := toBytes(t, s, errMsg)
79 err := r.UnmarshalBinary(rBytes)
80 test.CheckNoErr(t, err, errMsg)
81
82 return r
83 }
84
85 func readFile(t *testing.T, fileName string) []vector {
86 t.Helper()
87 jsonFile, err := os.Open(fileName)
88 if err != nil {
89 t.Fatalf("File %v can not be opened. Error: %v", fileName, err)
90 }
91 defer jsonFile.Close()
92 input, err := io.ReadAll(jsonFile)
93 if err != nil {
94 t.Fatalf("File %v can not be read. Error: %v", fileName, err)
95 }
96
97 var v []vector
98 err = json.Unmarshal(input, &v)
99 if err != nil {
100 t.Fatalf("File %v can not be loaded. Error: %v", fileName, err)
101 }
102
103 return v
104 }
105
106 func (v *vector) SetUpParties(t *testing.T) (id params, s commonServer, c commonClient) {
107 suite, err := GetSuite(v.Identifier)
108 test.CheckNoErr(t, err, "suite id")
109 seed := toBytes(t, v.Seed, "seed for key derivation")
110 keyInfo := toBytes(t, v.KeyInfo, "info for key derivation")
111 privateKey, err := DeriveKey(suite, v.Mode, seed, keyInfo)
112 test.CheckNoErr(t, err, "deriving key")
113
114 got, err := privateKey.MarshalBinary()
115 test.CheckNoErr(t, err, "serializing private key")
116 want := toBytes(t, v.SkSm, "private key")
117 v.compareBytes(t, got, want)
118
119 switch v.Mode {
120 case BaseMode:
121 s = NewServer(suite, privateKey)
122 c = NewClient(suite)
123 case VerifiableMode:
124 s = NewVerifiableServer(suite, privateKey)
125 c = NewVerifiableClient(suite, s.PublicKey())
126 case PartialObliviousMode:
127 var info []byte
128 s = &s1{NewPartialObliviousServer(suite, privateKey), info}
129 c = &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
130 }
131
132 return suite.(params), s, c
133 }
134
135 func (v *vector) compareLists(t *testing.T, got, want [][]byte) {
136 t.Helper()
137 for i := range got {
138 if !bytes.Equal(got[i], want[i]) {
139 test.ReportError(t, got[i], want[i], v.Identifier, v.Mode, i)
140 }
141 }
142 }
143
144 func (v *vector) compareBytes(t *testing.T, got, want []byte) {
145 t.Helper()
146 if !bytes.Equal(got, want) {
147 test.ReportError(t, got, want, v.Identifier, v.Mode)
148 }
149 }
150
151 func (v *vector) test(t *testing.T) {
152 params, server, client := v.SetUpParties(t)
153
154 for i, vi := range v.Vectors {
155 if v.Mode == PartialObliviousMode {
156 info := toBytes(t, vi.Info, "info")
157 ss := server.(*s1)
158 cc := client.(*c1)
159 ss.info = info
160 cc.info = info
161 }
162
163 inputs := toListBytes(t, vi.Input, "input")
164 blindsBytes := toListBytes(t, vi.Blind, "blind")
165
166 blinds := make([]Blind, len(blindsBytes))
167 for j := range blindsBytes {
168 blinds[j] = params.group.NewScalar()
169 err := blinds[j].UnmarshalBinary(blindsBytes[j])
170 test.CheckNoErr(t, err, "invalid blind")
171 }
172
173 finData, evalReq, err := client.blind(inputs, blinds)
174 test.CheckNoErr(t, err, "invalid client request")
175 evalReqBytes, err := elementsMarshalBinary(params.group, evalReq.Elements)
176 test.CheckNoErr(t, err, "bad serialization")
177 v.compareBytes(t, evalReqBytes, flattenList(t, vi.BlindedElement, "blindedElement"))
178
179 eval, err := server.Evaluate(evalReq)
180 test.CheckNoErr(t, err, "invalid evaluation")
181 elemBytes, err := elementsMarshalBinary(params.group, eval.Elements)
182 test.CheckNoErr(t, err, "invalid evaluations marshaling")
183 v.compareBytes(t, elemBytes, flattenList(t, vi.EvaluationElement, "evaluation"))
184
185 if v.Mode == VerifiableMode || v.Mode == PartialObliviousMode {
186 randomness := toScalar(t, params.group, vi.Proof.R, "invalid proof random scalar")
187 var proof encoding.BinaryMarshaler
188 switch v.Mode {
189 case VerifiableMode:
190 ss := server.(VerifiableServer)
191 prover := dleq.Prover{Params: ss.getDLEQParams()}
192 proof, err = prover.ProveBatchWithRandomness(
193 ss.privateKey.k,
194 ss.params.group.Generator(),
195 server.PublicKey().e,
196 evalReq.Elements,
197 eval.Elements,
198 randomness)
199 case PartialObliviousMode:
200 ss := server.(*s1)
201 keyProof, _, _ := ss.secretFromInfo(ss.info)
202 prover := dleq.Prover{Params: ss.getDLEQParams()}
203 proof, err = prover.ProveBatchWithRandomness(
204 keyProof,
205 ss.params.group.Generator(),
206 ss.params.group.NewElement().MulGen(keyProof),
207 eval.Elements,
208 evalReq.Elements,
209 randomness)
210 }
211 test.CheckNoErr(t, err, "failed proof generation")
212 proofBytes, errr := proof.MarshalBinary()
213 test.CheckNoErr(t, errr, "failed proof marshaling")
214 v.compareBytes(t, proofBytes, toBytes(t, vi.Proof.Proof, "proof"))
215 }
216
217 outputs, err := client.Finalize(finData, eval)
218 test.CheckNoErr(t, err, "invalid finalize")
219 expectedOutputs := toListBytes(t, vi.Output, "output")
220 v.compareLists(t,
221 outputs,
222 expectedOutputs,
223 )
224
225 for j := range inputs {
226 output, err := server.FullEvaluate(inputs[j])
227 test.CheckNoErr(t, err, "invalid full evaluate")
228 got := output
229 want := expectedOutputs[j]
230 if !bytes.Equal(got, want) {
231 test.ReportError(t, got, want, v.Identifier, v.Mode, i, j)
232 }
233
234 test.CheckOk(server.VerifyFinalize(inputs[j], output), "verify finalize", t)
235 }
236 }
237 }
238
239 func TestVectors(t *testing.T) {
240
241
242
243 v := readFile(t, "testdata/allVectors.json")
244
245 for i := range v {
246 suite, err := GetSuite(v[i].Identifier)
247 if err != nil {
248 t.Logf(v[i].Identifier + " not supported yet")
249 continue
250 }
251 t.Run(fmt.Sprintf("%v/Mode%v", suite, v[i].Mode), v[i].test)
252 }
253 }
254
View as plain text