1
2
3
4
5
6 package main
7
8 import (
9 "bytes"
10 "fmt"
11 "go/format"
12 "os"
13 "path"
14 "strings"
15 "text/template"
16 )
17
18 type Instance struct {
19 Name string
20 K int
21 Eta1 int
22 CiphertextSize int
23 DU int
24 DV int
25 }
26
27 func (m Instance) Pkg() string {
28 return strings.ToLower(m.Name)
29 }
30
31 func (m Instance) Impl() string {
32 return "impl" + m.Name
33 }
34
35 var (
36 Instances = []Instance{
37 {
38 Name: "Kyber512",
39 Eta1: 3,
40 K: 2,
41 CiphertextSize: 768,
42 DU: 10,
43 DV: 4,
44 },
45 {
46 Name: "Kyber768",
47 Eta1: 2,
48 K: 3,
49 CiphertextSize: 1088,
50 DU: 10,
51 DV: 4,
52 },
53 {
54 Name: "Kyber1024",
55 Eta1: 2,
56 K: 4,
57 CiphertextSize: 1568,
58 DU: 11,
59 DV: 5,
60 },
61 }
62 TemplateWarning = "// Code generated from"
63 )
64
65 func main() {
66 generatePackageFiles()
67 generateParamsFiles()
68 generateSourceFiles()
69 }
70
71
72 func generateParamsFiles() {
73 tl, err := template.ParseFiles("templates/params.templ.go")
74 if err != nil {
75 panic(err)
76 }
77
78 for _, mode := range Instances {
79 buf := new(bytes.Buffer)
80 err := tl.Execute(buf, mode)
81 if err != nil {
82 panic(err)
83 }
84
85
86 code, err := format.Source(buf.Bytes())
87 if err != nil {
88 panic("error formating code")
89 }
90
91 res := string(code)
92 offset := strings.Index(res, TemplateWarning)
93 if offset == -1 {
94 panic("Missing template warning in params.templ.go")
95 }
96 err = os.WriteFile(mode.Pkg()+"/internal/params.go",
97 []byte(res[offset:]), 0o644)
98 if err != nil {
99 panic(err)
100 }
101 }
102 }
103
104
105 func generatePackageFiles() {
106 tl, err := template.ParseFiles("templates/pkg.templ.go")
107 if err != nil {
108 panic(err)
109 }
110
111 for _, mode := range Instances {
112 buf := new(bytes.Buffer)
113 err := tl.Execute(buf, mode)
114 if err != nil {
115 panic(err)
116 }
117
118 res := string(buf.Bytes())
119 offset := strings.Index(res, TemplateWarning)
120 if offset == -1 {
121 panic("Missing template warning in pkg.templ.go")
122 }
123 err = os.WriteFile(mode.Pkg()+"/kyber.go", []byte(res[offset:]), 0o644)
124 if err != nil {
125 panic(err)
126 }
127 }
128 }
129
130
131 func generateSourceFiles() {
132 files := make(map[string][]byte)
133
134
135 ignored := func(x string) bool {
136 return x == "params.go" || x == "params_test.go"
137 }
138
139 fs, err := os.ReadDir("kyber512/internal")
140 if err != nil {
141 panic(err)
142 }
143
144
145 for _, f := range fs {
146 name := f.Name()
147 if ignored(name) {
148 continue
149 }
150 files[name], err = os.ReadFile(path.Join("kyber512/internal", name))
151 if err != nil {
152 panic(err)
153 }
154 }
155
156
157 for _, mode := range Instances {
158 if mode.Name == "Kyber512" {
159 continue
160 }
161
162 fs, err = os.ReadDir(path.Join(mode.Pkg(), "internal"))
163 for _, f := range fs {
164 name := f.Name()
165 fn := path.Join(mode.Pkg(), "internal", name)
166 if ignored(name) {
167 continue
168 }
169 _, ok := files[name]
170 if !ok {
171 fmt.Printf("Removing superfluous file: %s\n", fn)
172 err = os.Remove(fn)
173 if err != nil {
174 panic(err)
175 }
176 }
177 if f.IsDir() {
178 panic(fmt.Sprintf("%s: is a directory", fn))
179 }
180 if f.Type()&os.ModeSymlink != 0 {
181 fmt.Printf("Removing symlink: %s\n", fn)
182 err = os.Remove(fn)
183 if err != nil {
184 panic(err)
185 }
186 }
187 }
188 for name, expected := range files {
189 fn := path.Join(mode.Pkg(), "internal", name)
190 expected = []byte(fmt.Sprintf(
191 "%s kyber512/internal/%s by gen.go\n\n%s",
192 TemplateWarning,
193 name,
194 string(expected),
195 ))
196 got, err := os.ReadFile(fn)
197 if err == nil {
198 if bytes.Equal(got, expected) {
199 continue
200 }
201 }
202 fmt.Printf("Updating %s\n", fn)
203 err = os.WriteFile(fn, expected, 0o644)
204 if err != nil {
205 panic(err)
206 }
207 }
208 }
209 }
210
View as plain text