| // Copyright 2013 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // +build ignore |
| |
| /* |
| mksyscall_windows generates windows system call bodies |
| |
| It parses all files specified on command line containing function |
| prototypes (like syscall_windows.go) and prints system call bodies |
| to standard output. |
| |
| The prototypes are marked by lines beginning with "//sys" and read |
| like func declarations if //sys is replaced by func, but: |
| |
| * The parameter lists must give a name for each argument. This |
| includes return parameters. |
| |
| * The parameter lists must give a type for each argument: |
| the (x, y, z int) shorthand is not allowed. |
| |
| * If the return parameter is an error number, it must be named err. |
| |
| * If go func name needs to be different from it's winapi dll name, |
| the winapi name could be specified at the end, after "=" sign, like |
| //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA |
| |
| * Each function that returns err needs to supply a condition, that |
| return value of winapi will be tested against to detect failure. |
| This would set err to windows "last-error", otherwise it will be nil. |
| The value can be provided at end of //sys declaration, like |
| //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA |
| and is [failretval==0] by default. |
| |
| Usage: |
| mksyscall_windows [flags] [path ...] |
| |
| The flags are: |
| -trace |
| Generate print statement after every syscall. |
| */ |
| package main |
| |
| import ( |
| "bufio" |
| "errors" |
| "flag" |
| "fmt" |
| "go/parser" |
| "go/token" |
| "io" |
| "log" |
| "os" |
| "strconv" |
| "strings" |
| "text/template" |
| ) |
| |
| var PrintTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") |
| |
| func trim(s string) string { |
| return strings.Trim(s, " \t") |
| } |
| |
| var packageName string |
| |
| func packagename() string { |
| return packageName |
| } |
| |
| func windowsdot() string { |
| if packageName == "windows" { |
| return "" |
| } |
| return "windows." |
| } |
| |
| // Param is function parameter |
| type Param struct { |
| Name string |
| Type string |
| fn *Fn |
| tmpVarIdx int |
| } |
| |
| // tmpVar returns temp variable name that will be used to represent p during syscall. |
| func (p *Param) tmpVar() string { |
| if p.tmpVarIdx < 0 { |
| p.tmpVarIdx = p.fn.curTmpVarIdx |
| p.fn.curTmpVarIdx++ |
| } |
| return fmt.Sprintf("_p%d", p.tmpVarIdx) |
| } |
| |
| // BoolTmpVarCode returns source code for bool temp variable. |
| func (p *Param) BoolTmpVarCode() string { |
| const code = `var %s uint32 |
| if %s { |
| %s = 1 |
| } else { |
| %s = 0 |
| }` |
| tmp := p.tmpVar() |
| return fmt.Sprintf(code, tmp, p.Name, tmp, tmp) |
| } |
| |
| // SliceTmpVarCode returns source code for slice temp variable. |
| func (p *Param) SliceTmpVarCode() string { |
| const code = `var %s *%s |
| if len(%s) > 0 { |
| %s = &%s[0] |
| }` |
| tmp := p.tmpVar() |
| return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) |
| } |
| |
| // StringTmpVarCode returns source code for string temp variable. |
| func (p *Param) StringTmpVarCode() string { |
| errvar := p.fn.Rets.ErrorVarName() |
| if errvar == "" { |
| errvar = "_" |
| } |
| tmp := p.tmpVar() |
| const code = `var %s %s |
| %s, %s = %s(%s)` |
| s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) |
| if errvar == "-" { |
| return s |
| } |
| const morecode = ` |
| if %s != nil { |
| return |
| }` |
| return s + fmt.Sprintf(morecode, errvar) |
| } |
| |
| // TmpVarCode returns source code for temp variable. |
| func (p *Param) TmpVarCode() string { |
| switch { |
| case p.Type == "string": |
| return p.StringTmpVarCode() |
| case p.Type == "bool": |
| return p.BoolTmpVarCode() |
| case strings.HasPrefix(p.Type, "[]"): |
| return p.SliceTmpVarCode() |
| default: |
| return "" |
| } |
| } |
| |
| // SyscallArgList returns source code fragments representing p parameter |
| // in syscall. Slices are translated into 2 syscall parameters: pointer to |
| // the first element and length. |
| func (p *Param) SyscallArgList() []string { |
| var s string |
| switch { |
| case p.Type[0] == '*': |
| s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) |
| case p.Type == "string": |
| s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar()) |
| case p.Type == "bool": |
| s = p.tmpVar() |
| case strings.HasPrefix(p.Type, "[]"): |
| return []string{ |
| fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), |
| fmt.Sprintf("uintptr(len(%s))", p.Name), |
| } |
| default: |
| s = p.Name |
| } |
| return []string{fmt.Sprintf("uintptr(%s)", s)} |
| } |
| |
| // IsError determines if p parameter is used to return error. |
| func (p *Param) IsError() bool { |
| return p.Name == "err" && p.Type == "error" |
| } |
| |
| // join concatenates parameters ps into a string with sep separator. |
| // Each parameter is converted into string by applying fn to it |
| // before conversion. |
| func join(ps []*Param, fn func(*Param) string, sep string) string { |
| if len(ps) == 0 { |
| return "" |
| } |
| a := make([]string, 0) |
| for _, p := range ps { |
| a = append(a, fn(p)) |
| } |
| return strings.Join(a, sep) |
| } |
| |
| // Rets describes function return parameters. |
| type Rets struct { |
| Name string |
| Type string |
| ReturnsError bool |
| FailCond string |
| } |
| |
| // ErrorVarName returns error variable name for r. |
| func (r *Rets) ErrorVarName() string { |
| if r.ReturnsError { |
| return "err" |
| } |
| if r.Type == "error" { |
| return r.Name |
| } |
| return "" |
| } |
| |
| // ToParams converts r into slice of *Param. |
| func (r *Rets) ToParams() []*Param { |
| ps := make([]*Param, 0) |
| if len(r.Name) > 0 { |
| ps = append(ps, &Param{Name: r.Name, Type: r.Type}) |
| } |
| if r.ReturnsError { |
| ps = append(ps, &Param{Name: "err", Type: "error"}) |
| } |
| return ps |
| } |
| |
| // List returns source code of syscall return parameters. |
| func (r *Rets) List() string { |
| s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") |
| if len(s) > 0 { |
| s = "(" + s + ")" |
| } |
| return s |
| } |
| |
| // PrintList returns source code of trace printing part correspondent |
| // to syscall return values. |
| func (r *Rets) PrintList() string { |
| return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) |
| } |
| |
| // SetReturnValuesCode returns source code that accepts syscall return values. |
| func (r *Rets) SetReturnValuesCode() string { |
| if r.Name == "" && !r.ReturnsError { |
| return "" |
| } |
| retvar := "r0" |
| if r.Name == "" { |
| retvar = "r1" |
| } |
| errvar := "_" |
| if r.ReturnsError { |
| errvar = "e1" |
| } |
| return fmt.Sprintf("%s, _, %s := ", retvar, errvar) |
| } |
| |
| func (r *Rets) useLongHandleErrorCode(retvar string) string { |
| const code = `if %s { |
| if e1 != 0 { |
| err = error(e1) |
| } else { |
| err = %sEINVAL |
| } |
| }` |
| cond := retvar + " == 0" |
| if r.FailCond != "" { |
| cond = strings.Replace(r.FailCond, "failretval", retvar, 1) |
| } |
| return fmt.Sprintf(code, cond, windowsdot()) |
| } |
| |
| // SetErrorCode returns source code that sets return parameters. |
| func (r *Rets) SetErrorCode() string { |
| const code = `if r0 != 0 { |
| %s = %sErrno(r0) |
| }` |
| if r.Name == "" && !r.ReturnsError { |
| return "" |
| } |
| if r.Name == "" { |
| return r.useLongHandleErrorCode("r1") |
| } |
| if r.Type == "error" { |
| return fmt.Sprintf(code, r.Name, windowsdot()) |
| } |
| s := "" |
| switch { |
| case r.Type[0] == '*': |
| s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) |
| case r.Type == "bool": |
| s = fmt.Sprintf("%s = r0 != 0", r.Name) |
| default: |
| s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) |
| } |
| if !r.ReturnsError { |
| return s |
| } |
| return s + "\n\t" + r.useLongHandleErrorCode(r.Name) |
| } |
| |
| // Fn describes a syscall function. |
| type Fn struct { |
| Name string |
| Params []*Param |
| Rets *Rets |
| PrintTrace bool |
| dllname string |
| dllfuncname string |
| src string |
| // TODO: get rid of this field and just use parameter index instead |
| curTmpVarIdx int // insure tmp variables have uniq names |
| } |
| |
| // extractParams parses s to extract function parameters. |
| func extractParams(s string, f *Fn) ([]*Param, error) { |
| s = trim(s) |
| if s == "" { |
| return nil, nil |
| } |
| a := strings.Split(s, ",") |
| ps := make([]*Param, len(a)) |
| for i := range ps { |
| s2 := trim(a[i]) |
| b := strings.Split(s2, " ") |
| if len(b) != 2 { |
| b = strings.Split(s2, "\t") |
| if len(b) != 2 { |
| return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") |
| } |
| } |
| ps[i] = &Param{ |
| Name: trim(b[0]), |
| Type: trim(b[1]), |
| fn: f, |
| tmpVarIdx: -1, |
| } |
| } |
| return ps, nil |
| } |
| |
| // extractSection extracts text out of string s starting after start |
| // and ending just before end. found return value will indicate success, |
| // and prefix, body and suffix will contain correspondent parts of string s. |
| func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { |
| s = trim(s) |
| if strings.HasPrefix(s, string(start)) { |
| // no prefix |
| body = s[1:] |
| } else { |
| a := strings.SplitN(s, string(start), 2) |
| if len(a) != 2 { |
| return "", "", s, false |
| } |
| prefix = a[0] |
| body = a[1] |
| } |
| a := strings.SplitN(body, string(end), 2) |
| if len(a) != 2 { |
| return "", "", "", false |
| } |
| return prefix, a[0], a[1], true |
| } |
| |
| // newFn parses string s and return created function Fn. |
| func newFn(s string) (*Fn, error) { |
| s = trim(s) |
| f := &Fn{ |
| Rets: &Rets{}, |
| src: s, |
| PrintTrace: *PrintTraceFlag, |
| } |
| // function name and args |
| prefix, body, s, found := extractSection(s, '(', ')') |
| if !found || prefix == "" { |
| return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") |
| } |
| f.Name = prefix |
| var err error |
| f.Params, err = extractParams(body, f) |
| if err != nil { |
| return nil, err |
| } |
| // return values |
| _, body, s, found = extractSection(s, '(', ')') |
| if found { |
| r, err := extractParams(body, f) |
| if err != nil { |
| return nil, err |
| } |
| switch len(r) { |
| case 0: |
| case 1: |
| if r[0].IsError() { |
| f.Rets.ReturnsError = true |
| } else { |
| f.Rets.Name = r[0].Name |
| f.Rets.Type = r[0].Type |
| } |
| case 2: |
| if !r[1].IsError() { |
| return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") |
| } |
| f.Rets.ReturnsError = true |
| f.Rets.Name = r[0].Name |
| f.Rets.Type = r[0].Type |
| default: |
| return nil, errors.New("Too many return values in \"" + f.src + "\"") |
| } |
| } |
| // fail condition |
| _, body, s, found = extractSection(s, '[', ']') |
| if found { |
| f.Rets.FailCond = body |
| } |
| // dll and dll function names |
| s = trim(s) |
| if s == "" { |
| return f, nil |
| } |
| if !strings.HasPrefix(s, "=") { |
| return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") |
| } |
| s = trim(s[1:]) |
| a := strings.Split(s, ".") |
| switch len(a) { |
| case 1: |
| f.dllfuncname = a[0] |
| case 2: |
| f.dllname = a[0] |
| f.dllfuncname = a[1] |
| default: |
| return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") |
| } |
| return f, nil |
| } |
| |
| // DLLName returns DLL name for function f. |
| func (f *Fn) DLLName() string { |
| if f.dllname == "" { |
| return "kernel32" |
| } |
| return f.dllname |
| } |
| |
| // DLLName returns DLL function name for function f. |
| func (f *Fn) DLLFuncName() string { |
| if f.dllfuncname == "" { |
| return f.Name |
| } |
| return f.dllfuncname |
| } |
| |
| // ParamList returns source code for function f parameters. |
| func (f *Fn) ParamList() string { |
| return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") |
| } |
| |
| // ParamPrintList returns source code of trace printing part correspondent |
| // to syscall input parameters. |
| func (f *Fn) ParamPrintList() string { |
| return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) |
| } |
| |
| // ParamCount return number of syscall parameters for function f. |
| func (f *Fn) ParamCount() int { |
| n := 0 |
| for _, p := range f.Params { |
| n += len(p.SyscallArgList()) |
| } |
| return n |
| } |
| |
| // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... |
| // to use. It returns parameter count for correspondent SyscallX function. |
| func (f *Fn) SyscallParamCount() int { |
| n := f.ParamCount() |
| switch { |
| case n <= 3: |
| return 3 |
| case n <= 6: |
| return 6 |
| case n <= 9: |
| return 9 |
| case n <= 12: |
| return 12 |
| case n <= 15: |
| return 15 |
| default: |
| panic("too many arguments to system call") |
| } |
| } |
| |
| // Syscall determines which SyscallX function to use for function f. |
| func (f *Fn) Syscall() string { |
| c := f.SyscallParamCount() |
| if c == 3 { |
| return windowsdot() + "Syscall" |
| } |
| return windowsdot() + "Syscall" + strconv.Itoa(c) |
| } |
| |
| // SyscallParamList returns source code for SyscallX parameters for function f. |
| func (f *Fn) SyscallParamList() string { |
| a := make([]string, 0) |
| for _, p := range f.Params { |
| a = append(a, p.SyscallArgList()...) |
| } |
| for len(a) < f.SyscallParamCount() { |
| a = append(a, "0") |
| } |
| return strings.Join(a, ", ") |
| } |
| |
| // IsUTF16 is true, if f is W (utf16) function. It is false |
| // for all A (ascii) functions. |
| func (f *Fn) IsUTF16() bool { |
| s := f.DLLFuncName() |
| return s[len(s)-1] == 'W' |
| } |
| |
| // StrconvFunc returns name of Go string to OS string function for f. |
| func (f *Fn) StrconvFunc() string { |
| if f.IsUTF16() { |
| return windowsdot() + "UTF16PtrFromString" |
| } |
| return windowsdot() + "BytePtrFromString" |
| } |
| |
| // StrconvType returns Go type name used for OS string for f. |
| func (f *Fn) StrconvType() string { |
| if f.IsUTF16() { |
| return "*uint16" |
| } |
| return "*byte" |
| } |
| |
| // Source files and functions. |
| type Source struct { |
| Funcs []*Fn |
| Files []string |
| } |
| |
| // ParseFiles parses files listed in fs and extracts all syscall |
| // functions listed in sys comments. It returns source files |
| // and functions collection *Source if successful. |
| func ParseFiles(fs []string) (*Source, error) { |
| src := &Source{ |
| Funcs: make([]*Fn, 0), |
| Files: make([]string, 0), |
| } |
| for _, file := range fs { |
| if err := src.ParseFile(file); err != nil { |
| return nil, err |
| } |
| } |
| return src, nil |
| } |
| |
| // DLLs return dll names for a source set src. |
| func (src *Source) DLLs() []string { |
| uniq := make(map[string]bool) |
| r := make([]string, 0) |
| for _, f := range src.Funcs { |
| name := f.DLLName() |
| if _, found := uniq[name]; !found { |
| uniq[name] = true |
| r = append(r, name) |
| } |
| } |
| return r |
| } |
| |
| // ParseFile adds adition file path to a source set src. |
| func (src *Source) ParseFile(path string) error { |
| file, err := os.Open(path) |
| if err != nil { |
| return err |
| } |
| defer file.Close() |
| |
| s := bufio.NewScanner(file) |
| for s.Scan() { |
| t := trim(s.Text()) |
| if len(t) < 7 { |
| continue |
| } |
| if !strings.HasPrefix(t, "//sys") { |
| continue |
| } |
| t = t[5:] |
| if !(t[0] == ' ' || t[0] == '\t') { |
| continue |
| } |
| f, err := newFn(t[1:]) |
| if err != nil { |
| return err |
| } |
| src.Funcs = append(src.Funcs, f) |
| } |
| if err := s.Err(); err != nil { |
| return err |
| } |
| src.Files = append(src.Files, path) |
| |
| // get package name |
| fset := token.NewFileSet() |
| _, err = file.Seek(0, 0) |
| if err != nil { |
| return err |
| } |
| pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) |
| if err != nil { |
| return err |
| } |
| packageName = pkg.Name.Name |
| |
| return nil |
| } |
| |
| // Generate output source file from a source set src. |
| func (src *Source) Generate(w io.Writer) error { |
| funcMap := template.FuncMap{ |
| "windowsdot": windowsdot, |
| "packagename": packagename, |
| } |
| t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) |
| err := t.Execute(w, src) |
| if err != nil { |
| return errors.New("Failed to execute template: " + err.Error()) |
| } |
| return nil |
| } |
| |
| func usage() { |
| fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n") |
| flag.PrintDefaults() |
| os.Exit(1) |
| } |
| |
| func main() { |
| flag.Usage = usage |
| flag.Parse() |
| if len(os.Args) <= 1 { |
| fmt.Fprintf(os.Stderr, "no files to parse provided\n") |
| usage() |
| } |
| src, err := ParseFiles(os.Args[1:]) |
| if err != nil { |
| log.Fatal(err) |
| } |
| if err := src.Generate(os.Stdout); err != nil { |
| log.Fatal(err) |
| } |
| } |
| |
| // TODO: use println instead to print in the following template |
| const srcTemplate = ` |
| |
| {{define "main"}}// go build mksyscall_windows.go && ./mksyscall_windows{{range .Files}} {{.}}{{end}} |
| // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT |
| |
| package {{packagename}} |
| |
| import "unsafe"{{if windowsdot}} |
| import "code.google.com/p/go.sys/windows"{{end}} |
| |
| var ( |
| {{template "dlls" .}} |
| {{template "funcnames" .}}) |
| {{range .Funcs}}{{template "funcbody" .}}{{end}} |
| {{end}} |
| |
| {{/* help functions */}} |
| |
| {{define "dlls"}}{{range .DLLs}} mod{{.}} = {{windowsdot}}NewLazyDLL("{{.}}.dll") |
| {{end}}{{end}} |
| |
| {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") |
| {{end}}{{end}} |
| |
| {{define "funcbody"}} |
| func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{ |
| {{template "tmpvars" .}} {{template "syscall" .}} |
| {{template "seterror" .}}{{template "printtrace" .}} return |
| } |
| {{end}} |
| |
| {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} |
| {{end}}{{end}}{{end}} |
| |
| {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} |
| |
| {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} |
| {{end}}{{end}} |
| |
| {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") |
| {{end}}{{end}} |
| |
| ` |