...

Source file src/github.com/cloudflare/circl/oprf/vectors_test.go

Documentation: github.com/cloudflare/circl/oprf

     1  package oprf
     2  
     3  import (
     4  	"bytes"
     5  	"encoding"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/cloudflare/circl/group"
    16  	"github.com/cloudflare/circl/internal/test"
    17  	"github.com/cloudflare/circl/zk/dleq"
    18  )
    19  
    20  type vector struct {
    21  	Identifier string `json:"identifier"`
    22  	Mode       Mode   `json:"mode"`
    23  	Hash       string `json:"hash"`
    24  	PkSm       string `json:"pkSm"`
    25  	SkSm       string `json:"skSm"`
    26  	Seed       string `json:"seed"`
    27  	KeyInfo    string `json:"keyInfo"`
    28  	GroupDST   string `json:"groupDST"`
    29  	Vectors    []struct {
    30  		Batch             int    `json:"Batch"`
    31  		Blind             string `json:"Blind"`
    32  		Info              string `json:"Info"`
    33  		BlindedElement    string `json:"BlindedElement"`
    34  		EvaluationElement string `json:"EvaluationElement"`
    35  		Proof             struct {
    36  			Proof string `json:"proof"`
    37  			R     string `json:"r"`
    38  		} `json:"Proof"`
    39  		Input  string `json:"Input"`
    40  		Output string `json:"Output"`
    41  	} `json:"vectors"`
    42  }
    43  
    44  func toBytes(t *testing.T, s, errMsg string) []byte {
    45  	t.Helper()
    46  	bytes, err := hex.DecodeString(s)
    47  	test.CheckNoErr(t, err, "decoding "+errMsg)
    48  
    49  	return bytes
    50  }
    51  
    52  func toListBytes(t *testing.T, s, errMsg string) [][]byte {
    53  	t.Helper()
    54  	strs := strings.Split(s, ",")
    55  	out := make([][]byte, len(strs))
    56  	for i := range strs {
    57  		out[i] = toBytes(t, strs[i], errMsg)
    58  	}
    59  
    60  	return out
    61  }
    62  
    63  func flattenList(t *testing.T, s, errMsg string) []byte {
    64  	t.Helper()
    65  	strs := strings.Split(s, ",")
    66  	out := []byte{0, 0}
    67  	binary.BigEndian.PutUint16(out, uint16(len(strs)))
    68  	for i := range strs {
    69  		out = append(out, toBytes(t, strs[i], errMsg)...)
    70  	}
    71  
    72  	return out
    73  }
    74  
    75  func toScalar(t *testing.T, g group.Group, s, errMsg string) group.Scalar {
    76  	t.Helper()
    77  	r := g.NewScalar()
    78  	rBytes := toBytes(t, s, errMsg)
    79  	err := r.UnmarshalBinary(rBytes)
    80  	test.CheckNoErr(t, err, errMsg)
    81  
    82  	return r
    83  }
    84  
    85  func readFile(t *testing.T, fileName string) []vector {
    86  	t.Helper()
    87  	jsonFile, err := os.Open(fileName)
    88  	if err != nil {
    89  		t.Fatalf("File %v can not be opened. Error: %v", fileName, err)
    90  	}
    91  	defer jsonFile.Close()
    92  	input, err := io.ReadAll(jsonFile)
    93  	if err != nil {
    94  		t.Fatalf("File %v can not be read. Error: %v", fileName, err)
    95  	}
    96  
    97  	var v []vector
    98  	err = json.Unmarshal(input, &v)
    99  	if err != nil {
   100  		t.Fatalf("File %v can not be loaded. Error: %v", fileName, err)
   101  	}
   102  
   103  	return v
   104  }
   105  
   106  func (v *vector) SetUpParties(t *testing.T) (id params, s commonServer, c commonClient) {
   107  	suite, err := GetSuite(v.Identifier)
   108  	test.CheckNoErr(t, err, "suite id")
   109  	seed := toBytes(t, v.Seed, "seed for key derivation")
   110  	keyInfo := toBytes(t, v.KeyInfo, "info for key derivation")
   111  	privateKey, err := DeriveKey(suite, v.Mode, seed, keyInfo)
   112  	test.CheckNoErr(t, err, "deriving key")
   113  
   114  	got, err := privateKey.MarshalBinary()
   115  	test.CheckNoErr(t, err, "serializing private key")
   116  	want := toBytes(t, v.SkSm, "private key")
   117  	v.compareBytes(t, got, want)
   118  
   119  	switch v.Mode {
   120  	case BaseMode:
   121  		s = NewServer(suite, privateKey)
   122  		c = NewClient(suite)
   123  	case VerifiableMode:
   124  		s = NewVerifiableServer(suite, privateKey)
   125  		c = NewVerifiableClient(suite, s.PublicKey())
   126  	case PartialObliviousMode:
   127  		var info []byte
   128  		s = &s1{NewPartialObliviousServer(suite, privateKey), info}
   129  		c = &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
   130  	}
   131  
   132  	return suite.(params), s, c
   133  }
   134  
   135  func (v *vector) compareLists(t *testing.T, got, want [][]byte) {
   136  	t.Helper()
   137  	for i := range got {
   138  		if !bytes.Equal(got[i], want[i]) {
   139  			test.ReportError(t, got[i], want[i], v.Identifier, v.Mode, i)
   140  		}
   141  	}
   142  }
   143  
   144  func (v *vector) compareBytes(t *testing.T, got, want []byte) {
   145  	t.Helper()
   146  	if !bytes.Equal(got, want) {
   147  		test.ReportError(t, got, want, v.Identifier, v.Mode)
   148  	}
   149  }
   150  
   151  func (v *vector) test(t *testing.T) {
   152  	params, server, client := v.SetUpParties(t)
   153  
   154  	for i, vi := range v.Vectors {
   155  		if v.Mode == PartialObliviousMode {
   156  			info := toBytes(t, vi.Info, "info")
   157  			ss := server.(*s1)
   158  			cc := client.(*c1)
   159  			ss.info = info
   160  			cc.info = info
   161  		}
   162  
   163  		inputs := toListBytes(t, vi.Input, "input")
   164  		blindsBytes := toListBytes(t, vi.Blind, "blind")
   165  
   166  		blinds := make([]Blind, len(blindsBytes))
   167  		for j := range blindsBytes {
   168  			blinds[j] = params.group.NewScalar()
   169  			err := blinds[j].UnmarshalBinary(blindsBytes[j])
   170  			test.CheckNoErr(t, err, "invalid blind")
   171  		}
   172  
   173  		finData, evalReq, err := client.blind(inputs, blinds)
   174  		test.CheckNoErr(t, err, "invalid client request")
   175  		evalReqBytes, err := elementsMarshalBinary(params.group, evalReq.Elements)
   176  		test.CheckNoErr(t, err, "bad serialization")
   177  		v.compareBytes(t, evalReqBytes, flattenList(t, vi.BlindedElement, "blindedElement"))
   178  
   179  		eval, err := server.Evaluate(evalReq)
   180  		test.CheckNoErr(t, err, "invalid evaluation")
   181  		elemBytes, err := elementsMarshalBinary(params.group, eval.Elements)
   182  		test.CheckNoErr(t, err, "invalid evaluations marshaling")
   183  		v.compareBytes(t, elemBytes, flattenList(t, vi.EvaluationElement, "evaluation"))
   184  
   185  		if v.Mode == VerifiableMode || v.Mode == PartialObliviousMode {
   186  			randomness := toScalar(t, params.group, vi.Proof.R, "invalid proof random scalar")
   187  			var proof encoding.BinaryMarshaler
   188  			switch v.Mode {
   189  			case VerifiableMode:
   190  				ss := server.(VerifiableServer)
   191  				prover := dleq.Prover{Params: ss.getDLEQParams()}
   192  				proof, err = prover.ProveBatchWithRandomness(
   193  					ss.privateKey.k,
   194  					ss.params.group.Generator(),
   195  					server.PublicKey().e,
   196  					evalReq.Elements,
   197  					eval.Elements,
   198  					randomness)
   199  			case PartialObliviousMode:
   200  				ss := server.(*s1)
   201  				keyProof, _, _ := ss.secretFromInfo(ss.info)
   202  				prover := dleq.Prover{Params: ss.getDLEQParams()}
   203  				proof, err = prover.ProveBatchWithRandomness(
   204  					keyProof,
   205  					ss.params.group.Generator(),
   206  					ss.params.group.NewElement().MulGen(keyProof),
   207  					eval.Elements,
   208  					evalReq.Elements,
   209  					randomness)
   210  			}
   211  			test.CheckNoErr(t, err, "failed proof generation")
   212  			proofBytes, errr := proof.MarshalBinary()
   213  			test.CheckNoErr(t, errr, "failed proof marshaling")
   214  			v.compareBytes(t, proofBytes, toBytes(t, vi.Proof.Proof, "proof"))
   215  		}
   216  
   217  		outputs, err := client.Finalize(finData, eval)
   218  		test.CheckNoErr(t, err, "invalid finalize")
   219  		expectedOutputs := toListBytes(t, vi.Output, "output")
   220  		v.compareLists(t,
   221  			outputs,
   222  			expectedOutputs,
   223  		)
   224  
   225  		for j := range inputs {
   226  			output, err := server.FullEvaluate(inputs[j])
   227  			test.CheckNoErr(t, err, "invalid full evaluate")
   228  			got := output
   229  			want := expectedOutputs[j]
   230  			if !bytes.Equal(got, want) {
   231  				test.ReportError(t, got, want, v.Identifier, v.Mode, i, j)
   232  			}
   233  
   234  			test.CheckOk(server.VerifyFinalize(inputs[j], output), "verify finalize", t)
   235  		}
   236  	}
   237  }
   238  
   239  func TestVectors(t *testing.T) {
   240  	// Draft published at https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-voprf-10
   241  	// Test vectors at https://github.com/cfrg/draft-irtf-cfrg-voprf
   242  	// Version supported: v10
   243  	v := readFile(t, "testdata/allVectors.json")
   244  
   245  	for i := range v {
   246  		suite, err := GetSuite(v[i].Identifier)
   247  		if err != nil {
   248  			t.Logf(v[i].Identifier + " not supported yet")
   249  			continue
   250  		}
   251  		t.Run(fmt.Sprintf("%v/Mode%v", suite, v[i].Mode), v[i].test)
   252  	}
   253  }
   254  

View as plain text