...

Source file src/github.com/Microsoft/go-winio/tools/mkwinsyscall/mkwinsyscall.go

Documentation: github.com/Microsoft/go-winio/tools/mkwinsyscall

     1  //go:build windows
     2  
     3  // Copyright 2013 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package main
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"errors"
    13  	"flag"
    14  	"fmt"
    15  	"go/format"
    16  	"go/parser"
    17  	"go/token"
    18  	"io"
    19  	"log"
    20  	"os"
    21  	"path/filepath"
    22  	"runtime"
    23  	"sort"
    24  	"strconv"
    25  	"strings"
    26  	"text/template"
    27  
    28  	"golang.org/x/sys/windows"
    29  )
    30  
    31  const (
    32  	pkgSyscall = "syscall"
    33  	pkgWindows = "windows"
    34  
    35  	// common types.
    36  
    37  	tBool    = "bool"
    38  	tBoolPtr = "*bool"
    39  	tError   = "error"
    40  	tString  = "string"
    41  
    42  	// error variable names.
    43  
    44  	varErr         = "err"
    45  	varErrNTStatus = "ntStatus"
    46  	varErrHR       = "hr"
    47  )
    48  
    49  var (
    50  	filename       = flag.String("output", "", "output file name (standard output if omitted)")
    51  	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    52  	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
    53  	winio          = flag.Bool("winio", false, `import this package ("github.com/Microsoft/go-winio")`)
    54  	utf16          = flag.Bool("utf16", true, "encode string arguments as UTF-16 for syscalls not ending in 'A' or 'W'")
    55  	sortdecls      = flag.Bool("sort", true, "sort DLL and function declarations")
    56  )
    57  
    58  func trim(s string) string {
    59  	return strings.Trim(s, " \t")
    60  }
    61  
    62  func endsIn(s string, c byte) bool {
    63  	return len(s) >= 1 && s[len(s)-1] == c
    64  }
    65  
    66  var packageName string
    67  
    68  func packagename() string {
    69  	return packageName
    70  }
    71  
    72  func windowsdot() string {
    73  	if packageName == pkgWindows {
    74  		return ""
    75  	}
    76  	return pkgWindows + "."
    77  }
    78  
    79  func syscalldot() string {
    80  	if packageName == pkgSyscall {
    81  		return ""
    82  	}
    83  	return pkgSyscall + "."
    84  }
    85  
    86  // Param is function parameter.
    87  type Param struct {
    88  	Name      string
    89  	Type      string
    90  	fn        *Fn
    91  	tmpVarIdx int
    92  }
    93  
    94  // tmpVar returns temp variable name that will be used to represent p during syscall.
    95  func (p *Param) tmpVar() string {
    96  	if p.tmpVarIdx < 0 {
    97  		p.tmpVarIdx = p.fn.curTmpVarIdx
    98  		p.fn.curTmpVarIdx++
    99  	}
   100  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
   101  }
   102  
   103  // BoolTmpVarCode returns source code for bool temp variable.
   104  func (p *Param) BoolTmpVarCode() string {
   105  	const code = `var %[1]s uint32
   106  	if %[2]s {
   107  		%[1]s = 1
   108  	}`
   109  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   110  }
   111  
   112  // BoolPointerTmpVarCode returns source code for bool temp variable.
   113  func (p *Param) BoolPointerTmpVarCode() string {
   114  	const code = `var %[1]s uint32
   115  	if *%[2]s {
   116  		%[1]s = 1
   117  	}`
   118  	return fmt.Sprintf(code, p.tmpVar(), p.Name)
   119  }
   120  
   121  // SliceTmpVarCode returns source code for slice temp variable.
   122  func (p *Param) SliceTmpVarCode() string {
   123  	const code = `var %s *%s
   124  	if len(%s) > 0 {
   125  		%s = &%s[0]
   126  	}`
   127  	tmp := p.tmpVar()
   128  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   129  }
   130  
   131  // StringTmpVarCode returns source code for string temp variable.
   132  func (p *Param) StringTmpVarCode() string {
   133  	errvar := p.fn.Rets.ErrorVarName()
   134  	if errvar == "" {
   135  		errvar = "_"
   136  	}
   137  	tmp := p.tmpVar()
   138  	const code = `var %s %s
   139  	%s, %s = %s(%s)`
   140  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   141  	if errvar == "-" {
   142  		return s
   143  	}
   144  	const morecode = `
   145  	if %s != nil {
   146  		return
   147  	}`
   148  	return s + fmt.Sprintf(morecode, errvar)
   149  }
   150  
   151  // TmpVarCode returns source code for temp variable.
   152  func (p *Param) TmpVarCode() string {
   153  	switch {
   154  	case p.Type == tBool:
   155  		return p.BoolTmpVarCode()
   156  	case p.Type == tBoolPtr:
   157  		return p.BoolPointerTmpVarCode()
   158  	case strings.HasPrefix(p.Type, "[]"):
   159  		return p.SliceTmpVarCode()
   160  	default:
   161  		return ""
   162  	}
   163  }
   164  
   165  // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable.
   166  func (p *Param) TmpVarReadbackCode() string {
   167  	switch {
   168  	case p.Type == tBoolPtr:
   169  		return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar())
   170  	default:
   171  		return ""
   172  	}
   173  }
   174  
   175  // TmpVarHelperCode returns source code for helper's temp variable.
   176  func (p *Param) TmpVarHelperCode() string {
   177  	if p.Type != "string" {
   178  		return ""
   179  	}
   180  	return p.StringTmpVarCode()
   181  }
   182  
   183  // SyscallArgList returns source code fragments representing p parameter
   184  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   185  // the first element and length.
   186  func (p *Param) SyscallArgList() []string {
   187  	t := p.HelperType()
   188  	var s string
   189  	switch {
   190  	case t == tBoolPtr:
   191  		s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar())
   192  	case t[0] == '*':
   193  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   194  	case t == tBool:
   195  		s = p.tmpVar()
   196  	case strings.HasPrefix(t, "[]"):
   197  		return []string{
   198  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   199  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   200  		}
   201  	default:
   202  		s = p.Name
   203  	}
   204  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   205  }
   206  
   207  // IsError determines if p parameter is used to return error.
   208  func (p *Param) IsError() bool {
   209  	return p.Name == varErr && p.Type == tError
   210  }
   211  
   212  // HelperType returns type of parameter p used in helper function.
   213  func (p *Param) HelperType() string {
   214  	if p.Type == tString {
   215  		return p.fn.StrconvType()
   216  	}
   217  	return p.Type
   218  }
   219  
   220  // join concatenates parameters ps into a string with sep separator.
   221  // Each parameter is converted into string by applying fn to it
   222  // before conversion.
   223  func join(ps []*Param, fn func(*Param) string, sep string) string {
   224  	if len(ps) == 0 {
   225  		return ""
   226  	}
   227  	a := make([]string, 0)
   228  	for _, p := range ps {
   229  		a = append(a, fn(p))
   230  	}
   231  	return strings.Join(a, sep)
   232  }
   233  
   234  // Rets describes function return parameters.
   235  type Rets struct {
   236  	Name          string
   237  	Type          string
   238  	ReturnsError  bool
   239  	FailCond      string
   240  	fnMaybeAbsent bool
   241  }
   242  
   243  // ErrorVarName returns error variable name for r.
   244  func (r *Rets) ErrorVarName() string {
   245  	if r.ReturnsError {
   246  		return varErr
   247  	}
   248  	if r.Type == tError {
   249  		return r.Name
   250  	}
   251  	return ""
   252  }
   253  
   254  // ToParams converts r into slice of *Param.
   255  func (r *Rets) ToParams() []*Param {
   256  	ps := make([]*Param, 0)
   257  	if len(r.Name) > 0 {
   258  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   259  	}
   260  	if r.ReturnsError {
   261  		ps = append(ps, &Param{Name: varErr, Type: tError})
   262  	}
   263  	return ps
   264  }
   265  
   266  // List returns source code of syscall return parameters.
   267  func (r *Rets) List() string {
   268  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   269  	if len(s) > 0 {
   270  		s = "(" + s + ")"
   271  	} else if r.fnMaybeAbsent {
   272  		s = "(err error)"
   273  	}
   274  	return s
   275  }
   276  
   277  // PrintList returns source code of trace printing part correspondent
   278  // to syscall return values.
   279  func (r *Rets) PrintList() string {
   280  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   281  }
   282  
   283  // SetReturnValuesCode returns source code that accepts syscall return values.
   284  func (r *Rets) SetReturnValuesCode() string {
   285  	if r.Name == "" && !r.ReturnsError {
   286  		return ""
   287  	}
   288  	retvar := "r0"
   289  	if r.Name == "" {
   290  		retvar = "r1"
   291  	}
   292  	errvar := "_"
   293  	if r.ReturnsError {
   294  		errvar = "e1"
   295  	}
   296  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   297  }
   298  
   299  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   300  	const code = `if %s {
   301  		err = errnoErr(e1)
   302  	}`
   303  	cond := retvar + " == 0"
   304  	if r.FailCond != "" {
   305  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   306  	}
   307  	return fmt.Sprintf(code, cond)
   308  }
   309  
   310  // SetErrorCode returns source code that sets return parameters.
   311  func (r *Rets) SetErrorCode() string {
   312  	const code = `if r0 != 0 {
   313  		%s = %sErrno(r0)
   314  	}`
   315  	const ntStatus = `if r0 != 0 {
   316  		%s = %sNTStatus(r0)
   317  	}`
   318  	const hrCode = `if int32(r0) < 0 {
   319  		if r0&0x1fff0000 == 0x00070000 {
   320  			r0 &= 0xffff
   321  		}
   322  		%s = %sErrno(r0)
   323  	}`
   324  
   325  	if r.Name == "" && !r.ReturnsError {
   326  		return ""
   327  	}
   328  	if r.Name == "" {
   329  		return r.useLongHandleErrorCode("r1")
   330  	}
   331  	if r.Type == tError {
   332  		switch r.Name {
   333  		case varErrNTStatus, strings.ToLower(varErrNTStatus): // allow ntstatus to work
   334  			return fmt.Sprintf(ntStatus, r.Name, windowsdot())
   335  		case varErrHR:
   336  			return fmt.Sprintf(hrCode, r.Name, syscalldot())
   337  		default:
   338  			return fmt.Sprintf(code, r.Name, syscalldot())
   339  		}
   340  	}
   341  
   342  	var s string
   343  	switch {
   344  	case r.Type[0] == '*':
   345  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   346  	case r.Type == tBool:
   347  		s = fmt.Sprintf("%s = r0 != 0", r.Name)
   348  	default:
   349  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   350  	}
   351  	if !r.ReturnsError {
   352  		return s
   353  	}
   354  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   355  }
   356  
   357  // Fn describes syscall function.
   358  type Fn struct {
   359  	Name        string
   360  	Params      []*Param
   361  	Rets        *Rets
   362  	PrintTrace  bool
   363  	dllname     string
   364  	dllfuncname string
   365  	src         string
   366  	// TODO: get rid of this field and just use parameter index instead
   367  	curTmpVarIdx int // insure tmp variables have uniq names
   368  }
   369  
   370  // extractParams parses s to extract function parameters.
   371  func extractParams(s string, f *Fn) ([]*Param, error) {
   372  	s = trim(s)
   373  	if s == "" {
   374  		return nil, nil
   375  	}
   376  	a := strings.Split(s, ",")
   377  	ps := make([]*Param, len(a))
   378  	for i := range ps {
   379  		s2 := trim(a[i])
   380  		b := strings.Split(s2, " ")
   381  		if len(b) != 2 {
   382  			b = strings.Split(s2, "\t")
   383  			if len(b) != 2 {
   384  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   385  			}
   386  		}
   387  		ps[i] = &Param{
   388  			Name:      trim(b[0]),
   389  			Type:      trim(b[1]),
   390  			fn:        f,
   391  			tmpVarIdx: -1,
   392  		}
   393  	}
   394  	return ps, nil
   395  }
   396  
   397  // extractSection extracts text out of string s starting after start
   398  // and ending just before end. found return value will indicate success,
   399  // and prefix, body and suffix will contain correspondent parts of string s.
   400  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   401  	s = trim(s)
   402  	if strings.HasPrefix(s, string(start)) {
   403  		// no prefix
   404  		body = s[1:]
   405  	} else {
   406  		a := strings.SplitN(s, string(start), 2)
   407  		if len(a) != 2 {
   408  			return "", "", s, false
   409  		}
   410  		prefix = a[0]
   411  		body = a[1]
   412  	}
   413  	a := strings.SplitN(body, string(end), 2)
   414  	if len(a) != 2 {
   415  		return "", "", "", false
   416  	}
   417  	return prefix, a[0], a[1], true
   418  }
   419  
   420  // newFn parses string s and return created function Fn.
   421  func newFn(s string) (*Fn, error) {
   422  	s = trim(s)
   423  	f := &Fn{
   424  		Rets:       &Rets{},
   425  		src:        s,
   426  		PrintTrace: *printTraceFlag,
   427  	}
   428  	// function name and args
   429  	prefix, body, s, found := extractSection(s, '(', ')')
   430  	if !found || prefix == "" {
   431  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   432  	}
   433  	f.Name = prefix
   434  	var err error
   435  	f.Params, err = extractParams(body, f)
   436  	if err != nil {
   437  		return nil, err
   438  	}
   439  	// return values
   440  	_, body, s, found = extractSection(s, '(', ')')
   441  	if found {
   442  		r, err := extractParams(body, f)
   443  		if err != nil {
   444  			return nil, err
   445  		}
   446  		switch len(r) {
   447  		case 0:
   448  		case 1:
   449  			if r[0].IsError() {
   450  				f.Rets.ReturnsError = true
   451  			} else {
   452  				f.Rets.Name = r[0].Name
   453  				f.Rets.Type = r[0].Type
   454  			}
   455  		case 2:
   456  			if !r[1].IsError() {
   457  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   458  			}
   459  			f.Rets.ReturnsError = true
   460  			f.Rets.Name = r[0].Name
   461  			f.Rets.Type = r[0].Type
   462  		default:
   463  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   464  		}
   465  	}
   466  	// fail condition
   467  	_, body, s, found = extractSection(s, '[', ']')
   468  	if found {
   469  		f.Rets.FailCond = body
   470  	}
   471  	// dll and dll function names
   472  	s = trim(s)
   473  	if s == "" {
   474  		return f, nil
   475  	}
   476  	if !strings.HasPrefix(s, "=") {
   477  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   478  	}
   479  	s = trim(s[1:])
   480  	if i := strings.LastIndex(s, "."); i >= 0 {
   481  		f.dllname = s[:i]
   482  		f.dllfuncname = s[i+1:]
   483  	} else {
   484  		f.dllfuncname = s
   485  	}
   486  	if f.dllfuncname == "" {
   487  		return nil, fmt.Errorf("function name is not specified in %q", s)
   488  	}
   489  	if n := f.dllfuncname; endsIn(n, '?') {
   490  		f.dllfuncname = n[:len(n)-1]
   491  		f.Rets.fnMaybeAbsent = true
   492  	}
   493  	return f, nil
   494  }
   495  
   496  // DLLName returns DLL name for function f.
   497  func (f *Fn) DLLName() string {
   498  	if f.dllname == "" {
   499  		return "kernel32"
   500  	}
   501  	return f.dllname
   502  }
   503  
   504  // DLLVar returns a valid Go identifier that represents DLLName.
   505  func (f *Fn) DLLVar() string {
   506  	id := strings.Map(func(r rune) rune {
   507  		switch r {
   508  		case '.', '-':
   509  			return '_'
   510  		default:
   511  			return r
   512  		}
   513  	}, f.DLLName())
   514  	if !token.IsIdentifier(id) {
   515  		panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName()))
   516  	}
   517  	return id
   518  }
   519  
   520  // DLLFuncName returns DLL function name for function f.
   521  func (f *Fn) DLLFuncName() string {
   522  	if f.dllfuncname == "" {
   523  		return f.Name
   524  	}
   525  	return f.dllfuncname
   526  }
   527  
   528  // ParamList returns source code for function f parameters.
   529  func (f *Fn) ParamList() string {
   530  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   531  }
   532  
   533  // HelperParamList returns source code for helper function f parameters.
   534  func (f *Fn) HelperParamList() string {
   535  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
   536  }
   537  
   538  // ParamPrintList returns source code of trace printing part correspondent
   539  // to syscall input parameters.
   540  func (f *Fn) ParamPrintList() string {
   541  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   542  }
   543  
   544  // ParamCount return number of syscall parameters for function f.
   545  func (f *Fn) ParamCount() int {
   546  	n := 0
   547  	for _, p := range f.Params {
   548  		n += len(p.SyscallArgList())
   549  	}
   550  	return n
   551  }
   552  
   553  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   554  // to use. It returns parameter count for correspondent SyscallX function.
   555  func (f *Fn) SyscallParamCount() int {
   556  	n := f.ParamCount()
   557  	switch {
   558  	case n <= 3:
   559  		return 3
   560  	case n <= 6:
   561  		return 6
   562  	case n <= 9:
   563  		return 9
   564  	case n <= 12:
   565  		return 12
   566  	case n <= 15:
   567  		return 15
   568  	default:
   569  		panic("too many arguments to system call")
   570  	}
   571  }
   572  
   573  // Syscall determines which SyscallX function to use for function f.
   574  func (f *Fn) Syscall() string {
   575  	c := f.SyscallParamCount()
   576  	if c == 3 {
   577  		return syscalldot() + "Syscall"
   578  	}
   579  	return syscalldot() + "Syscall" + strconv.Itoa(c)
   580  }
   581  
   582  // SyscallParamList returns source code for SyscallX parameters for function f.
   583  func (f *Fn) SyscallParamList() string {
   584  	a := make([]string, 0)
   585  	for _, p := range f.Params {
   586  		a = append(a, p.SyscallArgList()...)
   587  	}
   588  	for len(a) < f.SyscallParamCount() {
   589  		a = append(a, "0")
   590  	}
   591  	return strings.Join(a, ", ")
   592  }
   593  
   594  // HelperCallParamList returns source code of call into function f helper.
   595  func (f *Fn) HelperCallParamList() string {
   596  	a := make([]string, 0, len(f.Params))
   597  	for _, p := range f.Params {
   598  		s := p.Name
   599  		if p.Type == tString {
   600  			s = p.tmpVar()
   601  		}
   602  		a = append(a, s)
   603  	}
   604  	return strings.Join(a, ", ")
   605  }
   606  
   607  // MaybeAbsent returns source code for handling functions that are possibly unavailable.
   608  func (f *Fn) MaybeAbsent() string {
   609  	if !f.Rets.fnMaybeAbsent {
   610  		return ""
   611  	}
   612  	const code = `%[1]s = proc%[2]s.Find()
   613  	if %[1]s != nil {
   614  		return
   615  	}`
   616  	errorVar := f.Rets.ErrorVarName()
   617  	if errorVar == "" {
   618  		errorVar = varErr
   619  	}
   620  	return fmt.Sprintf(code, errorVar, f.DLLFuncName())
   621  }
   622  
   623  // IsUTF16 is true, if f is W (UTF-16) function and false for all A (ASCII) functions.
   624  // Functions ending in neither will default to UTF-16, unless the `-utf16` flag is set
   625  // to `false`.
   626  func (f *Fn) IsUTF16() bool {
   627  	s := f.DLLFuncName()
   628  	return endsIn(s, 'W') || (*utf16 && !endsIn(s, 'A'))
   629  }
   630  
   631  // StrconvFunc returns name of Go string to OS string function for f.
   632  func (f *Fn) StrconvFunc() string {
   633  	if f.IsUTF16() {
   634  		return syscalldot() + "UTF16PtrFromString"
   635  	}
   636  	return syscalldot() + "BytePtrFromString"
   637  }
   638  
   639  // StrconvType returns Go type name used for OS string for f.
   640  func (f *Fn) StrconvType() string {
   641  	if f.IsUTF16() {
   642  		return "*uint16"
   643  	}
   644  	return "*byte"
   645  }
   646  
   647  // HasStringParam is true, if f has at least one string parameter.
   648  // Otherwise it is false.
   649  func (f *Fn) HasStringParam() bool {
   650  	for _, p := range f.Params {
   651  		if p.Type == tString {
   652  			return true
   653  		}
   654  	}
   655  	return false
   656  }
   657  
   658  // HelperName returns name of function f helper.
   659  func (f *Fn) HelperName() string {
   660  	if !f.HasStringParam() {
   661  		return f.Name
   662  	}
   663  	return "_" + f.Name
   664  }
   665  
   666  // DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when
   667  // naming a variable that refers to the DLL.
   668  type DLL struct {
   669  	Name string
   670  	Var  string
   671  }
   672  
   673  // Source files and functions.
   674  type Source struct {
   675  	Funcs           []*Fn
   676  	DLLFuncNames    []*Fn
   677  	Files           []string
   678  	StdLibImports   []string
   679  	ExternalImports []string
   680  }
   681  
   682  func (src *Source) Import(pkg string) {
   683  	src.StdLibImports = append(src.StdLibImports, pkg)
   684  	sort.Strings(src.StdLibImports)
   685  }
   686  
   687  func (src *Source) ExternalImport(pkg string) {
   688  	src.ExternalImports = append(src.ExternalImports, pkg)
   689  	sort.Strings(src.ExternalImports)
   690  }
   691  
   692  // ParseFiles parses files listed in fs and extracts all syscall
   693  // functions listed in sys comments. It returns source files
   694  // and functions collection *Source if successful.
   695  func ParseFiles(fs []string) (*Source, error) {
   696  	src := &Source{
   697  		Funcs: make([]*Fn, 0),
   698  		Files: make([]string, 0),
   699  		StdLibImports: []string{
   700  			"unsafe",
   701  		},
   702  		ExternalImports: make([]string, 0),
   703  	}
   704  	for _, file := range fs {
   705  		if err := src.ParseFile(file); err != nil {
   706  			return nil, err
   707  		}
   708  	}
   709  	src.DLLFuncNames = make([]*Fn, 0, len(src.Funcs))
   710  	uniq := make(map[string]bool, len(src.Funcs))
   711  	for _, fn := range src.Funcs {
   712  		name := fn.DLLFuncName()
   713  		if !uniq[name] {
   714  			src.DLLFuncNames = append(src.DLLFuncNames, fn)
   715  			uniq[name] = true
   716  		}
   717  	}
   718  	return src, nil
   719  }
   720  
   721  // DLLs return dll names for a source set src.
   722  func (src *Source) DLLs() []DLL {
   723  	uniq := make(map[string]bool)
   724  	r := make([]DLL, 0)
   725  	for _, f := range src.Funcs {
   726  		id := f.DLLVar()
   727  		if _, found := uniq[id]; !found {
   728  			uniq[id] = true
   729  			r = append(r, DLL{f.DLLName(), id})
   730  		}
   731  	}
   732  	if *sortdecls {
   733  		sort.Slice(r, func(i, j int) bool {
   734  			return r[i].Var < r[j].Var
   735  		})
   736  	}
   737  	return r
   738  }
   739  
   740  // ParseFile adds additional file (or files, if path is a glob pattern) path to a source set src.
   741  func (src *Source) ParseFile(path string) error {
   742  	file, err := os.Open(path)
   743  	if err == nil {
   744  		defer file.Close()
   745  		return src.parseFile(file)
   746  	} else if !(errors.Is(err, os.ErrNotExist) || errors.Is(err, windows.ERROR_INVALID_NAME)) {
   747  		return err
   748  	}
   749  
   750  	paths, err := filepath.Glob(path)
   751  	if err != nil {
   752  		return err
   753  	}
   754  
   755  	for _, path := range paths {
   756  		file, err := os.Open(path)
   757  		if err != nil {
   758  			return err
   759  		}
   760  		err = src.parseFile(file)
   761  		file.Close()
   762  		if err != nil {
   763  			return err
   764  		}
   765  	}
   766  
   767  	return nil
   768  }
   769  
   770  func (src *Source) parseFile(file *os.File) error {
   771  	s := bufio.NewScanner(file)
   772  	for s.Scan() {
   773  		t := trim(s.Text())
   774  		if len(t) < 7 {
   775  			continue
   776  		}
   777  		if !strings.HasPrefix(t, "//sys") {
   778  			continue
   779  		}
   780  		t = t[5:]
   781  		if !(t[0] == ' ' || t[0] == '\t') {
   782  			continue
   783  		}
   784  		f, err := newFn(t[1:])
   785  		if err != nil {
   786  			return err
   787  		}
   788  		src.Funcs = append(src.Funcs, f)
   789  	}
   790  	if err := s.Err(); err != nil {
   791  		return err
   792  	}
   793  	src.Files = append(src.Files, file.Name())
   794  	if *sortdecls {
   795  		sort.Slice(src.Funcs, func(i, j int) bool {
   796  			fi, fj := src.Funcs[i], src.Funcs[j]
   797  			if fi.DLLName() == fj.DLLName() {
   798  				return fi.DLLFuncName() < fj.DLLFuncName()
   799  			}
   800  			return fi.DLLName() < fj.DLLName()
   801  		})
   802  	}
   803  
   804  	// get package name
   805  	fset := token.NewFileSet()
   806  	_, err := file.Seek(0, 0)
   807  	if err != nil {
   808  		return err
   809  	}
   810  	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
   811  	if err != nil {
   812  		return err
   813  	}
   814  	packageName = pkg.Name.Name
   815  
   816  	return nil
   817  }
   818  
   819  // IsStdRepo reports whether src is part of standard library.
   820  func (src *Source) IsStdRepo() (bool, error) {
   821  	if len(src.Files) == 0 {
   822  		return false, errors.New("no input files provided")
   823  	}
   824  	abspath, err := filepath.Abs(src.Files[0])
   825  	if err != nil {
   826  		return false, err
   827  	}
   828  	goroot := runtime.GOROOT()
   829  	if runtime.GOOS == "windows" {
   830  		abspath = strings.ToLower(abspath)
   831  		goroot = strings.ToLower(goroot)
   832  	}
   833  	sep := string(os.PathSeparator)
   834  	if !strings.HasSuffix(goroot, sep) {
   835  		goroot += sep
   836  	}
   837  	return strings.HasPrefix(abspath, goroot), nil
   838  }
   839  
   840  // Generate output source file from a source set src.
   841  func (src *Source) Generate(w io.Writer) error {
   842  	const (
   843  		pkgStd         = iota // any package in std library
   844  		pkgXSysWindows        // x/sys/windows package
   845  		pkgOther
   846  	)
   847  	isStdRepo, err := src.IsStdRepo()
   848  	if err != nil {
   849  		return err
   850  	}
   851  	var pkgtype int
   852  	switch {
   853  	case isStdRepo:
   854  		pkgtype = pkgStd
   855  	case packageName == "windows":
   856  		// TODO: this needs better logic than just using package name
   857  		pkgtype = pkgXSysWindows
   858  	default:
   859  		pkgtype = pkgOther
   860  	}
   861  	if *systemDLL {
   862  		switch pkgtype {
   863  		case pkgStd:
   864  			src.Import("internal/syscall/windows/sysdll")
   865  		case pkgXSysWindows:
   866  		default:
   867  			src.ExternalImport("golang.org/x/sys/windows")
   868  		}
   869  	}
   870  	if *winio {
   871  		src.ExternalImport("github.com/Microsoft/go-winio")
   872  	}
   873  	if packageName != "syscall" {
   874  		src.Import("syscall")
   875  	}
   876  	funcMap := template.FuncMap{
   877  		"packagename": packagename,
   878  		"syscalldot":  syscalldot,
   879  		"newlazydll": func(dll string) string {
   880  			arg := "\"" + dll + ".dll\""
   881  			if !*systemDLL {
   882  				return syscalldot() + "NewLazyDLL(" + arg + ")"
   883  			}
   884  			if strings.HasPrefix(dll, "api_") || strings.HasPrefix(dll, "ext_") {
   885  				arg = strings.Replace(arg, "_", "-", -1)
   886  			}
   887  			switch pkgtype {
   888  			case pkgStd:
   889  				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
   890  			case pkgXSysWindows:
   891  				return "NewLazySystemDLL(" + arg + ")"
   892  			default:
   893  				return "windows.NewLazySystemDLL(" + arg + ")"
   894  			}
   895  		},
   896  	}
   897  	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
   898  	err = t.Execute(w, src)
   899  	if err != nil {
   900  		return errors.New("Failed to execute template: " + err.Error())
   901  	}
   902  	return nil
   903  }
   904  
   905  func writeTempSourceFile(data []byte) (string, error) {
   906  	f, err := os.CreateTemp("", "mkwinsyscall-generated-*.go")
   907  	if err != nil {
   908  		return "", err
   909  	}
   910  	_, err = f.Write(data)
   911  	if closeErr := f.Close(); err == nil {
   912  		err = closeErr
   913  	}
   914  	if err != nil {
   915  		os.Remove(f.Name()) // best effort
   916  		return "", err
   917  	}
   918  	return f.Name(), nil
   919  }
   920  
   921  func usage() {
   922  	fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n")
   923  	flag.PrintDefaults()
   924  	os.Exit(1)
   925  }
   926  
   927  func main() {
   928  	flag.Usage = usage
   929  	flag.Parse()
   930  	if len(flag.Args()) <= 0 {
   931  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   932  		usage()
   933  	}
   934  
   935  	src, err := ParseFiles(flag.Args())
   936  	if err != nil {
   937  		log.Fatal(err)
   938  	}
   939  
   940  	var buf bytes.Buffer
   941  	if err := src.Generate(&buf); err != nil {
   942  		log.Fatal(err)
   943  	}
   944  
   945  	data, err := format.Source(buf.Bytes())
   946  	if err != nil {
   947  		log.Printf("failed to format source: %v", err)
   948  		f, err := writeTempSourceFile(buf.Bytes())
   949  		if err != nil {
   950  			log.Fatalf("failed to write unformatted source to file: %v", err)
   951  		}
   952  		log.Fatalf("for diagnosis, wrote unformatted source to %v", f)
   953  	}
   954  	if *filename == "" {
   955  		_, err = os.Stdout.Write(data)
   956  	} else {
   957  		//nolint:gosec // G306: code file, no need for wants 0600
   958  		err = os.WriteFile(*filename, data, 0644)
   959  	}
   960  	if err != nil {
   961  		log.Fatal(err)
   962  	}
   963  }
   964  
   965  // TODO: use println instead to print in the following template
   966  
   967  const srcTemplate = `
   968  {{define "main"}} //go:build windows
   969  
   970  // Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
   971  
   972  package {{packagename}}
   973  
   974  import (
   975  {{range .StdLibImports}}"{{.}}"
   976  {{end}}
   977  
   978  {{range .ExternalImports}}"{{.}}"
   979  {{end}}
   980  )
   981  
   982  var _ unsafe.Pointer
   983  
   984  // Do the interface allocations only once for common
   985  // Errno values.
   986  const (
   987  	errnoERROR_IO_PENDING = 997
   988  )
   989  
   990  var (
   991  	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
   992  	errERROR_EINVAL error     = {{syscalldot}}EINVAL
   993  )
   994  
   995  // errnoErr returns common boxed Errno values, to prevent
   996  // allocations at runtime.
   997  func errnoErr(e {{syscalldot}}Errno) error {
   998  	switch e {
   999  	case 0:
  1000  		return errERROR_EINVAL
  1001  	case errnoERROR_IO_PENDING:
  1002  		return errERROR_IO_PENDING
  1003  	}
  1004  	// TODO: add more here, after collecting data on the common
  1005  	// error values see on Windows. (perhaps when running
  1006  	// all.bat?)
  1007  	return e
  1008  }
  1009  
  1010  var (
  1011  {{template "dlls" .}}
  1012  {{template "funcnames" .}})
  1013  {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
  1014  {{end}}
  1015  
  1016  {{/* help functions */}}
  1017  
  1018  {{define "dlls"}}{{range .DLLs}}	mod{{.Var}} = {{newlazydll .Name}}
  1019  {{end}}{{end}}
  1020  
  1021  {{define "funcnames"}}{{range .DLLFuncNames}}	proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}")
  1022  {{end}}{{end}}
  1023  
  1024  {{define "helperbody"}}
  1025  func {{.Name}}({{.ParamList}}) {{template "results" .}}{
  1026  {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
  1027  }
  1028  {{end}}
  1029  
  1030  {{define "funcbody"}}
  1031  func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
  1032  {{template "maybeabsent" .}}	{{template "tmpvars" .}}	{{template "syscall" .}}	{{template "tmpvarsreadback" .}}
  1033  {{template "seterror" .}}{{template "printtrace" .}}	return
  1034  }
  1035  {{end}}
  1036  
  1037  {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
  1038  {{end}}{{end}}{{end}}
  1039  
  1040  {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}}
  1041  {{end}}{{end}}
  1042  
  1043  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
  1044  {{end}}{{end}}{{end}}
  1045  
  1046  {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
  1047  
  1048  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
  1049  
  1050  {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}}
  1051  {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}}
  1052  
  1053  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
  1054  {{end}}{{end}}
  1055  
  1056  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
  1057  {{end}}{{end}}
  1058  
  1059  `
  1060  

View as plain text