1
2 package expander
3
4 import (
5 "crypto"
6 "encoding/binary"
7 "errors"
8 "io"
9
10 "github.com/cloudflare/circl/xof"
11 )
12
13 type Expander interface {
14
15
16 Expand(in []byte, length uint) (pseudo []byte)
17 }
18
19 type expanderMD struct {
20 h crypto.Hash
21 dst []byte
22 }
23
24
25 func NewExpanderMD(h crypto.Hash, dst []byte) *expanderMD {
26 return &expanderMD{h, dst}
27 }
28
29 func (e *expanderMD) calcDSTPrime() []byte {
30 var dstPrime []byte
31 if l := len(e.dst); l > maxDSTLength {
32 H := e.h.New()
33 mustWrite(H, longDSTPrefix[:])
34 mustWrite(H, e.dst)
35 dstPrime = H.Sum(nil)
36 } else {
37 dstPrime = make([]byte, l, l+1)
38 copy(dstPrime, e.dst)
39 }
40 return append(dstPrime, byte(len(dstPrime)))
41 }
42
43 func (e *expanderMD) Expand(in []byte, n uint) []byte {
44 H := e.h.New()
45 bLen := uint(H.Size())
46 ell := (n + (bLen - 1)) / bLen
47 if ell > 255 {
48 panic(errorLongOutput)
49 }
50
51 zPad := make([]byte, H.BlockSize())
52 libStr := []byte{0, 0}
53 libStr[0] = byte((n >> 8) & 0xFF)
54 libStr[1] = byte(n & 0xFF)
55 dstPrime := e.calcDSTPrime()
56
57 mustWrite(H, zPad)
58 mustWrite(H, in)
59 mustWrite(H, libStr)
60 mustWrite(H, []byte{0})
61 mustWrite(H, dstPrime)
62 b0 := H.Sum(nil)
63
64 H.Reset()
65 mustWrite(H, b0)
66 mustWrite(H, []byte{1})
67 mustWrite(H, dstPrime)
68 bi := H.Sum(nil)
69 pseudo := append([]byte{}, bi...)
70 for i := uint(2); i <= ell; i++ {
71 H.Reset()
72 for i := range b0 {
73 bi[i] ^= b0[i]
74 }
75 mustWrite(H, bi)
76 mustWrite(H, []byte{byte(i)})
77 mustWrite(H, dstPrime)
78 bi = H.Sum(nil)
79 pseudo = append(pseudo, bi...)
80 }
81 return pseudo[0:n]
82 }
83
84
85 type expanderXOF struct {
86 id xof.ID
87 kSecLevel uint
88 dst []byte
89 }
90
91
92
93
94 func NewExpanderXOF(id xof.ID, kSecLevel uint, dst []byte) *expanderXOF {
95 return &expanderXOF{id, kSecLevel, dst}
96 }
97
98
99 func (e *expanderXOF) Expand(in []byte, n uint) []byte {
100 bLen := []byte{0, 0}
101 binary.BigEndian.PutUint16(bLen, uint16(n))
102 pseudo := make([]byte, n)
103 dstPrime := e.calcDSTPrime()
104
105 H := e.id.New()
106 mustWrite(H, in)
107 mustWrite(H, bLen)
108 mustWrite(H, dstPrime)
109 mustReadFull(H, pseudo)
110 return pseudo
111 }
112
113 func (e *expanderXOF) calcDSTPrime() []byte {
114 var dstPrime []byte
115 if l := len(e.dst); l > maxDSTLength {
116 H := e.id.New()
117 mustWrite(H, longDSTPrefix[:])
118 mustWrite(H, e.dst)
119 max := ((2 * e.kSecLevel) + 7) / 8
120 dstPrime = make([]byte, max, max+1)
121 mustReadFull(H, dstPrime)
122 } else {
123 dstPrime = make([]byte, l, l+1)
124 copy(dstPrime, e.dst)
125 }
126 return append(dstPrime, byte(len(dstPrime)))
127 }
128
129 func mustWrite(w io.Writer, b []byte) {
130 if n, err := w.Write(b); err != nil || n != len(b) {
131 panic(err)
132 }
133 }
134
135 func mustReadFull(r io.Reader, b []byte) {
136 if n, err := io.ReadFull(r, b); err != nil || n != len(b) {
137 panic(err)
138 }
139 }
140
141 const maxDSTLength = 255
142
143 var (
144 longDSTPrefix = [17]byte{'H', '2', 'C', '-', 'O', 'V', 'E', 'R', 'S', 'I', 'Z', 'E', '-', 'D', 'S', 'T', '-'}
145
146 errorLongOutput = errors.New("requested too many bytes")
147 )
148
View as plain text