1
2
3
4
5
6 package main
7
8 import (
9 "bytes"
10 "fmt"
11 "go/format"
12 "os"
13 "path"
14 "strings"
15 "text/template"
16
17 "github.com/cloudflare/circl/sign/dilithium/internal/common/params"
18 )
19
20 type Mode struct {
21 Name string
22 UseAES bool
23 K int
24 L int
25 Eta int
26 DoubleEtaBits int
27 Omega int
28 Tau int
29 Gamma1Bits int
30 Gamma2 int
31 }
32
33 func (m Mode) Pkg() string {
34 return strings.ToLower(m.Mode())
35 }
36
37 func (m Mode) Impl() string {
38 return "impl" + m.Mode()
39 }
40
41 func (m Mode) Mode() string {
42 return strings.ReplaceAll(strings.ReplaceAll(m.Name,
43 "Dilithium", "Mode"), "-AES", "AES")
44 }
45
46 var (
47 Modes = []Mode{
48 {
49 Name: "Dilithium2",
50 UseAES: false,
51 K: 4,
52 L: 4,
53 Eta: 2,
54 DoubleEtaBits: 3,
55 Omega: 80,
56 Tau: 39,
57 Gamma1Bits: 17,
58 Gamma2: (params.Q - 1) / 88,
59 },
60 {
61 Name: "Dilithium2-AES",
62 UseAES: true,
63 K: 4,
64 L: 4,
65 Eta: 2,
66 DoubleEtaBits: 3,
67 Omega: 80,
68 Tau: 39,
69 Gamma1Bits: 17,
70 Gamma2: (params.Q - 1) / 88,
71 },
72 {
73 Name: "Dilithium3",
74 UseAES: false,
75 K: 6,
76 L: 5,
77 Eta: 4,
78 DoubleEtaBits: 4,
79 Omega: 55,
80 Tau: 49,
81 Gamma1Bits: 19,
82 Gamma2: (params.Q - 1) / 32,
83 },
84 {
85 Name: "Dilithium3-AES",
86 UseAES: true,
87 K: 6,
88 L: 5,
89 Eta: 4,
90 DoubleEtaBits: 4,
91 Omega: 55,
92 Tau: 49,
93 Gamma1Bits: 19,
94 Gamma2: (params.Q - 1) / 32,
95 },
96 {
97 Name: "Dilithium5",
98 UseAES: false,
99 K: 8,
100 L: 7,
101 Eta: 2,
102 DoubleEtaBits: 3,
103 Omega: 75,
104 Tau: 60,
105 Gamma1Bits: 19,
106 Gamma2: (params.Q - 1) / 32,
107 },
108 {
109 Name: "Dilithium5-AES",
110 UseAES: true,
111 K: 8,
112 L: 7,
113 Eta: 2,
114 DoubleEtaBits: 3,
115 Omega: 75,
116 Tau: 60,
117 Gamma1Bits: 19,
118 Gamma2: (params.Q - 1) / 32,
119 },
120 }
121 TemplateWarning = "// Code generated from"
122 )
123
124 func main() {
125 generateModePackageFiles()
126 generateModeToplevelFiles()
127 generateParamsFiles()
128 generateSourceFiles()
129 }
130
131
132 func generateParamsFiles() {
133 tl, err := template.ParseFiles("templates/params.templ.go")
134 if err != nil {
135 panic(err)
136 }
137
138 for _, mode := range Modes {
139 buf := new(bytes.Buffer)
140 err := tl.Execute(buf, mode)
141 if err != nil {
142 panic(err)
143 }
144
145
146 code, err := format.Source(buf.Bytes())
147 if err != nil {
148 panic("error formating code")
149 }
150
151 res := string(code)
152 offset := strings.Index(res, TemplateWarning)
153 if offset == -1 {
154 panic("Missing template warning in params.templ.go")
155 }
156 err = os.WriteFile(mode.Pkg()+"/internal/params.go",
157 []byte(res[offset:]), 0o644)
158 if err != nil {
159 panic(err)
160 }
161 }
162 }
163
164
165 func generateModeToplevelFiles() {
166 tl, err := template.ParseFiles("templates/mode.templ.go")
167 if err != nil {
168 panic(err)
169 }
170
171 for _, mode := range Modes {
172 buf := new(bytes.Buffer)
173 err := tl.Execute(buf, mode)
174 if err != nil {
175 panic(err)
176 }
177
178 res := string(buf.Bytes())
179 offset := strings.Index(res, TemplateWarning)
180 if offset == -1 {
181 panic("Missing template warning in mode.templ.go")
182 }
183 err = os.WriteFile(mode.Pkg()+".go", []byte(res[offset:]), 0o644)
184 if err != nil {
185 panic(err)
186 }
187 }
188 }
189
190
191 func generateModePackageFiles() {
192 tl, err := template.ParseFiles("templates/modePkg.templ.go")
193 if err != nil {
194 panic(err)
195 }
196
197 for _, mode := range Modes {
198 buf := new(bytes.Buffer)
199 err := tl.Execute(buf, mode)
200 if err != nil {
201 panic(err)
202 }
203
204 res := string(buf.Bytes())
205 offset := strings.Index(res, TemplateWarning)
206 if offset == -1 {
207 panic("Missing template warning in modePkg.templ.go")
208 }
209 err = os.WriteFile(mode.Pkg()+"/dilithium.go", []byte(res[offset:]), 0o644)
210 if err != nil {
211 panic(err)
212 }
213 }
214 }
215
216
217 func generateSourceFiles() {
218 files := make(map[string][]byte)
219
220
221 ignored := func(x string) bool {
222 return x == "params.go" || x == "params_test.go" ||
223 strings.HasSuffix(x, ".swp")
224 }
225
226 fs, err := os.ReadDir("mode3/internal")
227 if err != nil {
228 panic(err)
229 }
230
231
232 for _, f := range fs {
233 name := f.Name()
234 if ignored(name) {
235 continue
236 }
237 files[name], err = os.ReadFile(path.Join("mode3/internal", name))
238 if err != nil {
239 panic(err)
240 }
241 }
242
243
244 for _, mode := range Modes {
245 if mode.Name == "Dilithium3" {
246 continue
247 }
248
249 fs, err = os.ReadDir(path.Join(mode.Pkg(), "internal"))
250 for _, f := range fs {
251 name := f.Name()
252 fn := path.Join(mode.Pkg(), "internal", name)
253 if ignored(name) {
254 continue
255 }
256 _, ok := files[name]
257 if !ok {
258 fmt.Printf("Removing superfluous file: %s\n", fn)
259 err = os.Remove(fn)
260 if err != nil {
261 panic(err)
262 }
263 }
264 if f.IsDir() {
265 panic(fmt.Sprintf("%s: is a directory", fn))
266 }
267 if f.Type()&os.ModeSymlink != 0 {
268 fmt.Printf("Removing symlink: %s\n", fn)
269 err = os.Remove(fn)
270 if err != nil {
271 panic(err)
272 }
273 }
274 }
275 for name, expected := range files {
276 fn := path.Join(mode.Pkg(), "internal", name)
277 expected = []byte(fmt.Sprintf(
278 "%s mode3/internal/%s by gen.go\n\n%s",
279 TemplateWarning,
280 name,
281 string(expected),
282 ))
283 got, err := os.ReadFile(fn)
284 if err == nil {
285 if bytes.Equal(got, expected) {
286 continue
287 }
288 }
289 fmt.Printf("Updating %s\n", fn)
290 err = os.WriteFile(fn, expected, 0o644)
291 if err != nil {
292 panic(err)
293 }
294 }
295 }
296 }
297
View as plain text