...

Source file src/github.com/cloudflare/circl/pke/kyber/internal/common/poly_test.go

Documentation: github.com/cloudflare/circl/pke/kyber/internal/common

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"testing"
     8  )
     9  
    10  func (p *Poly) RandAbsLe9Q() {
    11  	max := 9 * uint32(Q)
    12  	r := randSliceUint32WithMax(uint(N), max)
    13  	for i := 0; i < N; i++ {
    14  		p[i] = int16(int32(r[i]))
    15  	}
    16  }
    17  
    18  // Returns x mod^± q
    19  func sModQ(x int16) int16 {
    20  	x = x % Q
    21  	if x >= (Q-1)/2 {
    22  		x = x - Q
    23  	}
    24  	return x
    25  }
    26  
    27  func TestDecompressMessage(t *testing.T) {
    28  	var m, m2 [PlaintextSize]byte
    29  	var p Poly
    30  	for i := 0; i < 1000; i++ {
    31  		if n, err := rand.Read(m[:]); err != nil {
    32  			t.Error(err)
    33  		} else if n != len(m) {
    34  			t.Fatal("short read from RNG")
    35  		}
    36  
    37  		p.DecompressMessage(m[:])
    38  		p.CompressMessageTo(m2[:])
    39  		if m != m2 {
    40  			t.Fatal()
    41  		}
    42  	}
    43  }
    44  
    45  func TestCompress(t *testing.T) {
    46  	for _, d := range []int{4, 5, 10, 11} {
    47  		d := d
    48  		t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
    49  			var p, q Poly
    50  			bound := (Q + (1 << uint(d))) >> uint(d+1)
    51  			buf := make([]byte, (N*d-1)/8+1)
    52  			for i := 0; i < 1000; i++ {
    53  				p.Rand()
    54  				p.CompressTo(buf, d)
    55  				q.Decompress(buf, d)
    56  				for j := 0; j < N; j++ {
    57  					diff := sModQ(p[j] - q[j])
    58  					if diff < 0 {
    59  						diff = -diff
    60  					}
    61  					if diff > bound {
    62  						t.Logf("%v\n", buf)
    63  						t.Fatalf("|%d - %d mod^± q| = %d > %d, j=%d",
    64  							p[i], q[j], diff, bound, j)
    65  					}
    66  				}
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  func TestCompressMessage(t *testing.T) {
    73  	var p Poly
    74  	var m [32]byte
    75  	ok := true
    76  	for i := 0; i < int(Q); i++ {
    77  		p[0] = int16(i)
    78  		p.CompressMessageTo(m[:])
    79  		want := byte(0)
    80  		if i >= 833 && i < 2497 {
    81  			want = 1
    82  		}
    83  		if m[0] != want {
    84  			ok = false
    85  			t.Logf("%d %d %d", i, want, m[0])
    86  		}
    87  	}
    88  	if !ok {
    89  		t.Fatal()
    90  	}
    91  }
    92  
    93  func TestMulHat(t *testing.T) {
    94  	for k := 0; k < 1000; k++ {
    95  		var a, b, p, ah, bh, ph Poly
    96  		a.RandAbsLeQ()
    97  		b.RandAbsLeQ()
    98  		b[0] = 1
    99  
   100  		ah = a
   101  		bh = b
   102  		ah.NTT()
   103  		bh.NTT()
   104  		ph.MulHat(&ah, &bh)
   105  		ph.BarrettReduce()
   106  		ph.InvNTT()
   107  
   108  		for i := 0; i < N; i++ {
   109  			for j := 0; j < N; j++ {
   110  				v := montReduce(int32(a[i]) * int32(b[j]))
   111  				k := i + j
   112  				if k >= N {
   113  					// Recall xᴺ = -1.
   114  					k -= N
   115  					v = -v
   116  				}
   117  				p[k] = barrettReduce(v + p[k])
   118  			}
   119  		}
   120  
   121  		for i := 0; i < N; i++ {
   122  			p[i] = int16((int32(p[i]) * ((1 << 16) % int32(Q))) % int32(Q))
   123  		}
   124  
   125  		p.Normalize()
   126  		ph.Normalize()
   127  		a.Normalize()
   128  		b.Normalize()
   129  
   130  		if p != ph {
   131  			t.Fatalf("%v\n%v\n%v\n%v", a, b, p, ph)
   132  		}
   133  	}
   134  }
   135  
   136  func TestAddAgainstGeneric(t *testing.T) {
   137  	for k := 0; k < 1000; k++ {
   138  		var p1, p2, a, b Poly
   139  		a.RandAbsLeQ()
   140  		b.RandAbsLeQ()
   141  		p1.Add(&a, &b)
   142  		p2.addGeneric(&a, &b)
   143  		if p1 != p2 {
   144  			t.Fatalf("Add(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   145  		}
   146  	}
   147  }
   148  
   149  func BenchmarkAdd(b *testing.B) {
   150  	var p Poly
   151  	for i := 0; i < b.N; i++ {
   152  		p.Add(&p, &p)
   153  	}
   154  }
   155  
   156  func BenchmarkAddGeneric(b *testing.B) {
   157  	var p Poly
   158  	for i := 0; i < b.N; i++ {
   159  		p.addGeneric(&p, &p)
   160  	}
   161  }
   162  
   163  func TestSubAgainstGeneric(t *testing.T) {
   164  	for k := 0; k < 1000; k++ {
   165  		var p1, p2, a, b Poly
   166  		a.RandAbsLeQ()
   167  		b.RandAbsLeQ()
   168  		p1.Sub(&a, &b)
   169  		p2.subGeneric(&a, &b)
   170  		if p1 != p2 {
   171  			t.Fatalf("Sub(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   172  		}
   173  	}
   174  }
   175  
   176  func BenchmarkSub(b *testing.B) {
   177  	var p Poly
   178  	for i := 0; i < b.N; i++ {
   179  		p.Sub(&p, &p)
   180  	}
   181  }
   182  
   183  func BenchmarkSubGeneric(b *testing.B) {
   184  	var p Poly
   185  	for i := 0; i < b.N; i++ {
   186  		p.subGeneric(&p, &p)
   187  	}
   188  }
   189  
   190  func TestMulHatAgainstGeneric(t *testing.T) {
   191  	for k := 0; k < 1000; k++ {
   192  		var p1, p2, a, b Poly
   193  		a.RandAbsLeQ()
   194  		b.RandAbsLeQ()
   195  		a2 := a
   196  		b2 := b
   197  		a2.Tangle()
   198  		b2.Tangle()
   199  		p1.MulHat(&a2, &b2)
   200  		p1.Detangle()
   201  		p2.mulHatGeneric(&a, &b)
   202  		if p1 != p2 {
   203  			t.Fatalf("MulHat(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   204  		}
   205  	}
   206  }
   207  
   208  func BenchmarkMulHat(b *testing.B) {
   209  	var p Poly
   210  	for i := 0; i < b.N; i++ {
   211  		p.MulHat(&p, &p)
   212  	}
   213  }
   214  
   215  func BenchmarkMulHatGeneric(b *testing.B) {
   216  	var p Poly
   217  	for i := 0; i < b.N; i++ {
   218  		p.mulHatGeneric(&p, &p)
   219  	}
   220  }
   221  
   222  func BenchmarkBarrettReduce(b *testing.B) {
   223  	var p Poly
   224  	for i := 0; i < b.N; i++ {
   225  		p.BarrettReduce()
   226  	}
   227  }
   228  
   229  func BenchmarkBarrettReduceGeneric(b *testing.B) {
   230  	var p Poly
   231  	for i := 0; i < b.N; i++ {
   232  		p.barrettReduceGeneric()
   233  	}
   234  }
   235  
   236  func TestBarrettReduceAgainstGeneric(t *testing.T) {
   237  	for k := 0; k < 1000; k++ {
   238  		var p1, p2, a Poly
   239  		a.RandAbsLe9Q()
   240  		p1 = a
   241  		p2 = a
   242  		p1.BarrettReduce()
   243  		p2.barrettReduceGeneric()
   244  		if p1 != p2 {
   245  			t.Fatalf("BarrettReduce(%v) = \n%v \n!= %v", a, p1, p2)
   246  		}
   247  	}
   248  }
   249  
   250  func BenchmarkNormalize(b *testing.B) {
   251  	var p Poly
   252  	for i := 0; i < b.N; i++ {
   253  		p.Normalize()
   254  	}
   255  }
   256  
   257  func BenchmarkNormalizeGeneric(b *testing.B) {
   258  	var p Poly
   259  	for i := 0; i < b.N; i++ {
   260  		p.barrettReduceGeneric()
   261  	}
   262  }
   263  
   264  func TestNormalizeAgainstGeneric(t *testing.T) {
   265  	for k := 0; k < 1000; k++ {
   266  		var p1, p2, a Poly
   267  		a.RandAbsLe9Q()
   268  		p1 = a
   269  		p2 = a
   270  		p1.Normalize()
   271  		p2.normalizeGeneric()
   272  		if p1 != p2 {
   273  			t.Fatalf("Normalize(%v) = \n%v \n!= %v", a, p1, p2)
   274  		}
   275  	}
   276  }
   277  
   278  func (p *Poly) OldCompressTo(m []byte, d int) {
   279  	switch d {
   280  	case 4:
   281  		var t [8]uint16
   282  		idx := 0
   283  		for i := 0; i < N/8; i++ {
   284  			for j := 0; j < 8; j++ {
   285  				t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
   286  					uint32(Q)) & ((1 << 4) - 1)
   287  			}
   288  			m[idx] = byte(t[0]) | byte(t[1]<<4)
   289  			m[idx+1] = byte(t[2]) | byte(t[3]<<4)
   290  			m[idx+2] = byte(t[4]) | byte(t[5]<<4)
   291  			m[idx+3] = byte(t[6]) | byte(t[7]<<4)
   292  			idx += 4
   293  		}
   294  
   295  	case 5:
   296  		var t [8]uint16
   297  		idx := 0
   298  		for i := 0; i < N/8; i++ {
   299  			for j := 0; j < 8; j++ {
   300  				t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
   301  					uint32(Q)) & ((1 << 5) - 1)
   302  			}
   303  			m[idx] = byte(t[0]) | byte(t[1]<<5)
   304  			m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
   305  			m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
   306  			m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
   307  			m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
   308  			idx += 5
   309  		}
   310  
   311  	case 10:
   312  		var t [4]uint16
   313  		idx := 0
   314  		for i := 0; i < N/4; i++ {
   315  			for j := 0; j < 4; j++ {
   316  				t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
   317  					uint32(Q)) & ((1 << 10) - 1)
   318  			}
   319  			m[idx] = byte(t[0])
   320  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
   321  			m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
   322  			m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
   323  			m[idx+4] = byte(t[3] >> 2)
   324  			idx += 5
   325  		}
   326  	case 11:
   327  		var t [8]uint16
   328  		idx := 0
   329  		for i := 0; i < N/8; i++ {
   330  			for j := 0; j < 8; j++ {
   331  				t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
   332  					uint32(Q)) & ((1 << 11) - 1)
   333  			}
   334  			m[idx] = byte(t[0])
   335  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
   336  			m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
   337  			m[idx+3] = byte(t[2] >> 2)
   338  			m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
   339  			m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
   340  			m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
   341  			m[idx+7] = byte(t[5] >> 1)
   342  			m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
   343  			m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
   344  			m[idx+10] = byte(t[7] >> 3)
   345  			idx += 11
   346  		}
   347  	default:
   348  		panic("unsupported d")
   349  	}
   350  }
   351  
   352  func TestCompressFullInputFirstCoeff(t *testing.T) {
   353  	for _, d := range []int{4, 5, 10, 11} {
   354  		d := d
   355  		t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
   356  			var p, q Poly
   357  			bound := (Q + (1 << uint(d))) >> uint(d+1)
   358  			buf := make([]byte, (N*d-1)/8+1)
   359  			buf2 := make([]byte, len(buf))
   360  			for i := int16(0); i < Q; i++ {
   361  				p[0] = i
   362  				p.CompressTo(buf, d)
   363  				p.OldCompressTo(buf2, d)
   364  				if !bytes.Equal(buf, buf2) {
   365  					t.Fatalf("%d", i)
   366  				}
   367  				q.Decompress(buf, d)
   368  				diff := sModQ(p[0] - q[0])
   369  				if diff < 0 {
   370  					diff = -diff
   371  				}
   372  				if diff > bound {
   373  					t.Logf("%v\n", buf)
   374  					t.Fatalf("|%d - %d mod^± q| = %d > %d",
   375  						p[0], q[0], diff, bound)
   376  				}
   377  			}
   378  		})
   379  	}
   380  }
   381  

View as plain text