1 package oprf
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "encoding"
7 "encoding/binary"
8 "fmt"
9 "testing"
10
11 "github.com/cloudflare/circl/group"
12 "github.com/cloudflare/circl/internal/test"
13 )
14
15 type commonClient interface {
16 blind(inputs [][]byte, blinds []Blind) (*FinalizeData, *EvaluationRequest, error)
17 DeterministicBlind(inputs [][]byte, blinds []Blind) (*FinalizeData, *EvaluationRequest, error)
18 Blind(inputs [][]byte) (*FinalizeData, *EvaluationRequest, error)
19 Finalize(d *FinalizeData, e *Evaluation) ([][]byte, error)
20 }
21
22 type c1 struct {
23 PartialObliviousClient
24 info []byte
25 }
26
27 func (c *c1) Finalize(f *FinalizeData, e *Evaluation) ([][]byte, error) {
28 return c.PartialObliviousClient.Finalize(f, e, c.info)
29 }
30
31 type commonServer interface {
32 Evaluate(req *EvaluationRequest) (*Evaluation, error)
33 FullEvaluate(input []byte) ([]byte, error)
34 VerifyFinalize(input, expectedOutput []byte) bool
35 PublicKey() *PublicKey
36 }
37
38 type s1 struct {
39 PartialObliviousServer
40 info []byte
41 }
42
43 func (s *s1) Evaluate(req *EvaluationRequest) (*Evaluation, error) {
44 return s.PartialObliviousServer.Evaluate(req, s.info)
45 }
46
47 func (s *s1) FullEvaluate(input []byte) ([]byte, error) {
48 return s.PartialObliviousServer.FullEvaluate(input, s.info)
49 }
50
51 func (s *s1) VerifyFinalize(input, expectedOutput []byte) bool {
52 return s.PartialObliviousServer.VerifyFinalize(input, s.info, expectedOutput)
53 }
54
55 type canMarshal interface {
56 encoding.BinaryMarshaler
57 UnmarshalBinary(id Suite, data []byte) (err error)
58 }
59
60 func testMarshal(t *testing.T, suite Suite, x, y canMarshal, name string) {
61 t.Helper()
62
63 wantBytes, err := x.MarshalBinary()
64 test.CheckNoErr(t, err, "error on marshaling "+name)
65
66 err = y.UnmarshalBinary(suite, wantBytes)
67 test.CheckNoErr(t, err, "error on unmarshaling "+name)
68
69 gotBytes, err := x.MarshalBinary()
70 test.CheckNoErr(t, err, "error on marshaling "+name)
71
72 if !bytes.Equal(gotBytes, wantBytes) {
73 test.ReportError(t, gotBytes, wantBytes)
74 }
75 }
76
77 func elementsMarshalBinary(g group.Group, e []group.Element) ([]byte, error) {
78 output := make([]byte, 2, 2+len(e)*int(g.Params().CompressedElementLength))
79 binary.BigEndian.PutUint16(output[0:2], uint16(len(e)))
80
81 for i := range e {
82 ei, err := e[i].MarshalBinaryCompress()
83 if err != nil {
84 return nil, err
85 }
86 output = append(output, ei...)
87 }
88
89 return output, nil
90 }
91
92 func testAPI(t *testing.T, server commonServer, client commonClient) {
93 t.Helper()
94
95 inputs := [][]byte{{0x00}, {0xFF}}
96 finData, evalReq, err := client.Blind(inputs)
97 test.CheckNoErr(t, err, "invalid blinding of client")
98
99 blinds := finData.CopyBlinds()
100 _, detEvalReq, err := client.DeterministicBlind(inputs, blinds)
101 test.CheckNoErr(t, err, "invalid deterministic blinding of client")
102 test.CheckOk(len(detEvalReq.Elements) == len(evalReq.Elements), "invalid number of evaluations", t)
103 for i := range evalReq.Elements {
104 test.CheckOk(evalReq.Elements[i].IsEqual(detEvalReq.Elements[i]), "invalid blinded element mismatch", t)
105 }
106
107 eval, err := server.Evaluate(evalReq)
108 test.CheckNoErr(t, err, "invalid evaluation of server")
109 test.CheckOk(eval != nil, "invalid evaluation of server: no evaluation", t)
110
111 clientOutputs, err := client.Finalize(finData, eval)
112 test.CheckNoErr(t, err, "invalid finalize of client")
113 test.CheckOk(clientOutputs != nil, "invalid finalize of client: no outputs", t)
114
115 for i := range inputs {
116 valid := server.VerifyFinalize(inputs[i], clientOutputs[i])
117 test.CheckOk(valid, "invalid verification from the server", t)
118
119 serverOutput, err := server.FullEvaluate(inputs[i])
120 test.CheckNoErr(t, err, "FullEvaluate failed")
121
122 if !bytes.Equal(serverOutput, clientOutputs[i]) {
123 test.ReportError(t, serverOutput, clientOutputs[i])
124 }
125 }
126 }
127
128 func TestAPI(t *testing.T) {
129 info := []byte("shared info")
130
131 for _, suite := range []Suite{
132 SuiteRistretto255,
133 SuiteP256,
134 SuiteP384,
135 SuiteP521,
136 } {
137 t.Run(suite.(fmt.Stringer).String(), func(t *testing.T) {
138 private, err := GenerateKey(suite, rand.Reader)
139 test.CheckNoErr(t, err, "failed private key generation")
140 testMarshal(t, suite, private, new(PrivateKey), "PrivateKey")
141 public := private.Public()
142 testMarshal(t, suite, public, new(PublicKey), "PublicKey")
143
144 t.Run("OPRF", func(t *testing.T) {
145 s := NewServer(suite, private)
146 c := NewClient(suite)
147 testAPI(t, s, c)
148 })
149
150 t.Run("VOPRF", func(t *testing.T) {
151 s := NewVerifiableServer(suite, private)
152 c := NewVerifiableClient(suite, s.PublicKey())
153 testAPI(t, s, c)
154 })
155
156 t.Run("POPRF", func(t *testing.T) {
157 s := &s1{NewPartialObliviousServer(suite, private), info}
158 c := &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
159 testAPI(t, s, c)
160 })
161 })
162 }
163 }
164
165 func TestErrors(t *testing.T) {
166 goodID := SuiteP256
167 strErrNil := "must be nil"
168 strErrK := "must fail key"
169 strErrC := "must fail client"
170 strErrS := "must fail server"
171
172 t.Run("badID", func(t *testing.T) {
173 var badID Suite
174
175 k, err := GenerateKey(badID, rand.Reader)
176 test.CheckIsErr(t, err, strErrK)
177 test.CheckOk(k == nil, strErrNil, t)
178
179 k, err = DeriveKey(badID, BaseMode, nil, nil)
180 test.CheckIsErr(t, err, strErrK)
181 test.CheckOk(k == nil, strErrNil, t)
182
183 err = new(PrivateKey).UnmarshalBinary(badID, nil)
184 test.CheckIsErr(t, err, strErrK)
185
186 err = new(PublicKey).UnmarshalBinary(badID, nil)
187 test.CheckIsErr(t, err, strErrK)
188
189 err = test.CheckPanic(func() { NewClient(badID) })
190 test.CheckNoErr(t, err, strErrC)
191
192 err = test.CheckPanic(func() { NewServer(badID, nil) })
193 test.CheckNoErr(t, err, strErrS)
194
195 err = test.CheckPanic(func() { NewVerifiableClient(badID, nil) })
196 test.CheckNoErr(t, err, strErrC)
197
198 err = test.CheckPanic(func() { NewVerifiableServer(badID, nil) })
199 test.CheckNoErr(t, err, strErrS)
200
201 err = test.CheckPanic(func() { NewPartialObliviousClient(badID, nil) })
202 test.CheckNoErr(t, err, strErrC)
203
204 err = test.CheckPanic(func() { NewPartialObliviousServer(badID, nil) })
205 test.CheckNoErr(t, err, strErrS)
206 })
207
208 t.Run("nilPubKey", func(t *testing.T) {
209 err := test.CheckPanic(func() { NewVerifiableClient(goodID, nil) })
210 test.CheckNoErr(t, err, strErrC)
211 })
212
213 t.Run("nilCalls", func(t *testing.T) {
214 c := NewClient(goodID)
215 finData, evalReq, err := c.Blind(nil)
216 test.CheckIsErr(t, err, strErrC)
217 test.CheckOk(finData == nil, strErrNil, t)
218 test.CheckOk(evalReq == nil, strErrNil, t)
219
220 var emptyEval Evaluation
221 finData, _, _ = c.Blind([][]byte{[]byte("in0"), []byte("in1")})
222 out, err := c.Finalize(finData, &emptyEval)
223 test.CheckIsErr(t, err, strErrC)
224 test.CheckOk(out == nil, strErrNil, t)
225 })
226
227 t.Run("invalidProof", func(t *testing.T) {
228 key, _ := GenerateKey(goodID, rand.Reader)
229 s := NewVerifiableServer(goodID, key)
230 c := NewVerifiableClient(goodID, key.Public())
231
232 finData, evalReq, _ := c.Blind([][]byte{[]byte("in0"), []byte("in1")})
233 _, _ = s.Evaluate(evalReq)
234 _, evalReq, _ = c.Blind([][]byte{[]byte("in0"), []byte("in2")})
235 badEV, _ := s.Evaluate(evalReq)
236 _, err := c.Finalize(finData, badEV)
237 test.CheckIsErr(t, err, strErrC)
238 })
239
240 t.Run("badKeyGen", func(t *testing.T) {
241 key, err := GenerateKey(goodID, nil)
242 test.CheckIsErr(t, err, strErrNil)
243 test.CheckOk(key == nil, strErrNil, t)
244
245 key, err = DeriveKey(goodID, Mode(8), nil, nil)
246 test.CheckIsErr(t, err, strErrK)
247 test.CheckOk(key == nil, strErrNil, t)
248 })
249 }
250
251 func Example_oprf() {
252 suite := SuiteP256
253
254 private, _ := GenerateKey(suite, rand.Reader)
255 server := NewServer(suite, private)
256
257 client := NewClient(suite)
258
259
260 inputs := [][]byte{[]byte("first input"), []byte("second input")}
261 finData, evalReq, _ := client.Blind(inputs)
262
263
264
265
266
267 evaluation, _ := server.Evaluate(evalReq)
268
269
270
271
272
273 outputs, err := client.Finalize(finData, evaluation)
274 fmt.Print(err == nil && len(inputs) == len(outputs))
275
276 }
277
278 func BenchmarkAPI(b *testing.B) {
279 for _, suite := range []Suite{
280 SuiteRistretto255,
281 SuiteP256,
282 SuiteP384,
283 SuiteP521,
284 } {
285 key, err := GenerateKey(suite, rand.Reader)
286 test.CheckNoErr(b, err, "failed key generation")
287
288 b.Run("OPRF/"+suite.Identifier(), func(b *testing.B) {
289 s := NewServer(suite, key)
290 c := NewClient(suite)
291 benchAPI(b, s, c)
292 })
293
294 b.Run("VOPRF/"+suite.Identifier(), func(b *testing.B) {
295 s := NewVerifiableServer(suite, key)
296 c := NewVerifiableClient(suite, s.PublicKey())
297 benchAPI(b, s, c)
298 })
299
300 b.Run("POPRF/"+suite.Identifier(), func(b *testing.B) {
301 info := []byte("shared info")
302 s := &s1{NewPartialObliviousServer(suite, key), info}
303 c := &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
304 benchAPI(b, s, c)
305 })
306 }
307 }
308
309 func benchAPI(b *testing.B, server commonServer, client commonClient) {
310 b.Helper()
311 inputs := [][]byte{[]byte("first input"), []byte("second input")}
312 finData, evalReq, err := client.Blind(inputs)
313 test.CheckNoErr(b, err, "failed client request")
314
315 eval, err := server.Evaluate(evalReq)
316 test.CheckNoErr(b, err, "failed server evaluate")
317
318 clientOutputs, err := client.Finalize(finData, eval)
319 test.CheckNoErr(b, err, "failed client finalize")
320
321 b.Run("Client/Request", func(b *testing.B) {
322 for i := 0; i < b.N; i++ {
323 _, _, _ = client.Blind(inputs)
324 }
325 })
326
327 b.Run("Server/Evaluate", func(b *testing.B) {
328 for i := 0; i < b.N; i++ {
329 _, _ = server.Evaluate(evalReq)
330 }
331 })
332
333 b.Run("Client/Finalize", func(b *testing.B) {
334 for i := 0; i < b.N; i++ {
335 _, _ = client.Finalize(finData, eval)
336 }
337 })
338
339 b.Run("Server/VerifyFinalize", func(b *testing.B) {
340 for i := 0; i < b.N; i++ {
341 for j := range inputs {
342 server.VerifyFinalize(inputs[j], clientOutputs[j])
343 }
344 }
345 })
346
347 b.Run("Server/FullEvaluate", func(b *testing.B) {
348 for i := 0; i < b.N; i++ {
349 for j := range inputs {
350 _, _ = server.FullEvaluate(inputs[j])
351 }
352 }
353 })
354 }
355
View as plain text