...

Source file src/github.com/cloudflare/circl/sign/dilithium/gen.go

Documentation: github.com/cloudflare/circl/sign/dilithium

     1  //go:build ignore
     2  // +build ignore
     3  
     4  // Autogenerates wrappers from templates to prevent too much duplicated code
     5  // between the code for different modes.
     6  package main
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"go/format"
    12  	"os"
    13  	"path"
    14  	"strings"
    15  	"text/template"
    16  
    17  	"github.com/cloudflare/circl/sign/dilithium/internal/common/params"
    18  )
    19  
    20  type Mode struct {
    21  	Name          string
    22  	UseAES        bool
    23  	K             int
    24  	L             int
    25  	Eta           int
    26  	DoubleEtaBits int
    27  	Omega         int
    28  	Tau           int
    29  	Gamma1Bits    int
    30  	Gamma2        int
    31  }
    32  
    33  func (m Mode) Pkg() string {
    34  	return strings.ToLower(m.Mode())
    35  }
    36  
    37  func (m Mode) Impl() string {
    38  	return "impl" + m.Mode()
    39  }
    40  
    41  func (m Mode) Mode() string {
    42  	return strings.ReplaceAll(strings.ReplaceAll(m.Name,
    43  		"Dilithium", "Mode"), "-AES", "AES")
    44  }
    45  
    46  var (
    47  	Modes = []Mode{
    48  		{
    49  			Name:          "Dilithium2",
    50  			UseAES:        false,
    51  			K:             4,
    52  			L:             4,
    53  			Eta:           2,
    54  			DoubleEtaBits: 3,
    55  			Omega:         80,
    56  			Tau:           39,
    57  			Gamma1Bits:    17,
    58  			Gamma2:        (params.Q - 1) / 88,
    59  		},
    60  		{
    61  			Name:          "Dilithium2-AES",
    62  			UseAES:        true,
    63  			K:             4,
    64  			L:             4,
    65  			Eta:           2,
    66  			DoubleEtaBits: 3,
    67  			Omega:         80,
    68  			Tau:           39,
    69  			Gamma1Bits:    17,
    70  			Gamma2:        (params.Q - 1) / 88,
    71  		},
    72  		{
    73  			Name:          "Dilithium3",
    74  			UseAES:        false,
    75  			K:             6,
    76  			L:             5,
    77  			Eta:           4,
    78  			DoubleEtaBits: 4,
    79  			Omega:         55,
    80  			Tau:           49,
    81  			Gamma1Bits:    19,
    82  			Gamma2:        (params.Q - 1) / 32,
    83  		},
    84  		{
    85  			Name:          "Dilithium3-AES",
    86  			UseAES:        true,
    87  			K:             6,
    88  			L:             5,
    89  			Eta:           4,
    90  			DoubleEtaBits: 4,
    91  			Omega:         55,
    92  			Tau:           49,
    93  			Gamma1Bits:    19,
    94  			Gamma2:        (params.Q - 1) / 32,
    95  		},
    96  		{
    97  			Name:          "Dilithium5",
    98  			UseAES:        false,
    99  			K:             8,
   100  			L:             7,
   101  			Eta:           2,
   102  			DoubleEtaBits: 3,
   103  			Omega:         75,
   104  			Tau:           60,
   105  			Gamma1Bits:    19,
   106  			Gamma2:        (params.Q - 1) / 32,
   107  		},
   108  		{
   109  			Name:          "Dilithium5-AES",
   110  			UseAES:        true,
   111  			K:             8,
   112  			L:             7,
   113  			Eta:           2,
   114  			DoubleEtaBits: 3,
   115  			Omega:         75,
   116  			Tau:           60,
   117  			Gamma1Bits:    19,
   118  			Gamma2:        (params.Q - 1) / 32,
   119  		},
   120  	}
   121  	TemplateWarning = "// Code generated from"
   122  )
   123  
   124  func main() {
   125  	generateModePackageFiles()
   126  	generateModeToplevelFiles()
   127  	generateParamsFiles()
   128  	generateSourceFiles()
   129  }
   130  
   131  // Generates modeX/internal/params.go from templates/params.templ.go
   132  func generateParamsFiles() {
   133  	tl, err := template.ParseFiles("templates/params.templ.go")
   134  	if err != nil {
   135  		panic(err)
   136  	}
   137  
   138  	for _, mode := range Modes {
   139  		buf := new(bytes.Buffer)
   140  		err := tl.Execute(buf, mode)
   141  		if err != nil {
   142  			panic(err)
   143  		}
   144  
   145  		// Formating output code
   146  		code, err := format.Source(buf.Bytes())
   147  		if err != nil {
   148  			panic("error formating code")
   149  		}
   150  
   151  		res := string(code)
   152  		offset := strings.Index(res, TemplateWarning)
   153  		if offset == -1 {
   154  			panic("Missing template warning in params.templ.go")
   155  		}
   156  		err = os.WriteFile(mode.Pkg()+"/internal/params.go",
   157  			[]byte(res[offset:]), 0o644)
   158  		if err != nil {
   159  			panic(err)
   160  		}
   161  	}
   162  }
   163  
   164  // Generates modeX.go from templates/mode.templ.go
   165  func generateModeToplevelFiles() {
   166  	tl, err := template.ParseFiles("templates/mode.templ.go")
   167  	if err != nil {
   168  		panic(err)
   169  	}
   170  
   171  	for _, mode := range Modes {
   172  		buf := new(bytes.Buffer)
   173  		err := tl.Execute(buf, mode)
   174  		if err != nil {
   175  			panic(err)
   176  		}
   177  
   178  		res := string(buf.Bytes())
   179  		offset := strings.Index(res, TemplateWarning)
   180  		if offset == -1 {
   181  			panic("Missing template warning in mode.templ.go")
   182  		}
   183  		err = os.WriteFile(mode.Pkg()+".go", []byte(res[offset:]), 0o644)
   184  		if err != nil {
   185  			panic(err)
   186  		}
   187  	}
   188  }
   189  
   190  // Generates modeX/dilithium.go from templates/modePkg.templ.go
   191  func generateModePackageFiles() {
   192  	tl, err := template.ParseFiles("templates/modePkg.templ.go")
   193  	if err != nil {
   194  		panic(err)
   195  	}
   196  
   197  	for _, mode := range Modes {
   198  		buf := new(bytes.Buffer)
   199  		err := tl.Execute(buf, mode)
   200  		if err != nil {
   201  			panic(err)
   202  		}
   203  
   204  		res := string(buf.Bytes())
   205  		offset := strings.Index(res, TemplateWarning)
   206  		if offset == -1 {
   207  			panic("Missing template warning in modePkg.templ.go")
   208  		}
   209  		err = os.WriteFile(mode.Pkg()+"/dilithium.go", []byte(res[offset:]), 0o644)
   210  		if err != nil {
   211  			panic(err)
   212  		}
   213  	}
   214  }
   215  
   216  // Copies mode3 source files to other modes
   217  func generateSourceFiles() {
   218  	files := make(map[string][]byte)
   219  
   220  	// Ignore mode specific files.
   221  	ignored := func(x string) bool {
   222  		return x == "params.go" || x == "params_test.go" ||
   223  			strings.HasSuffix(x, ".swp")
   224  	}
   225  
   226  	fs, err := os.ReadDir("mode3/internal")
   227  	if err != nil {
   228  		panic(err)
   229  	}
   230  
   231  	// Read files
   232  	for _, f := range fs {
   233  		name := f.Name()
   234  		if ignored(name) {
   235  			continue
   236  		}
   237  		files[name], err = os.ReadFile(path.Join("mode3/internal", name))
   238  		if err != nil {
   239  			panic(err)
   240  		}
   241  	}
   242  
   243  	// Go over modes
   244  	for _, mode := range Modes {
   245  		if mode.Name == "Dilithium3" {
   246  			continue
   247  		}
   248  
   249  		fs, err = os.ReadDir(path.Join(mode.Pkg(), "internal"))
   250  		for _, f := range fs {
   251  			name := f.Name()
   252  			fn := path.Join(mode.Pkg(), "internal", name)
   253  			if ignored(name) {
   254  				continue
   255  			}
   256  			_, ok := files[name]
   257  			if !ok {
   258  				fmt.Printf("Removing superfluous file: %s\n", fn)
   259  				err = os.Remove(fn)
   260  				if err != nil {
   261  					panic(err)
   262  				}
   263  			}
   264  			if f.IsDir() {
   265  				panic(fmt.Sprintf("%s: is a directory", fn))
   266  			}
   267  			if f.Type()&os.ModeSymlink != 0 {
   268  				fmt.Printf("Removing symlink: %s\n", fn)
   269  				err = os.Remove(fn)
   270  				if err != nil {
   271  					panic(err)
   272  				}
   273  			}
   274  		}
   275  		for name, expected := range files {
   276  			fn := path.Join(mode.Pkg(), "internal", name)
   277  			expected = []byte(fmt.Sprintf(
   278  				"%s mode3/internal/%s by gen.go\n\n%s",
   279  				TemplateWarning,
   280  				name,
   281  				string(expected),
   282  			))
   283  			got, err := os.ReadFile(fn)
   284  			if err == nil {
   285  				if bytes.Equal(got, expected) {
   286  					continue
   287  				}
   288  			}
   289  			fmt.Printf("Updating %s\n", fn)
   290  			err = os.WriteFile(fn, expected, 0o644)
   291  			if err != nil {
   292  				panic(err)
   293  			}
   294  		}
   295  	}
   296  }
   297  

View as plain text