1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package main
19
20 import (
21 "bytes"
22 "flag"
23 "fmt"
24 "go/ast"
25 "go/constant"
26 "go/format"
27 "go/printer"
28 "go/token"
29 "go/types"
30 "io"
31 "log"
32 "os"
33 "path/filepath"
34 "regexp"
35 "strings"
36
37 "golang.org/x/tools/go/packages"
38 )
39
40 const help = `
41 Commands:
42 extract Extract one-line signature of exported types of
43 the given package.
44
45 Functions that have more than one return
46 argument or unknown types are skipped.
47 `
48
49
50
51 const copyright = `// Copyright 2020 The CUE Authors
52 //
53 // Licensed under the Apache License, Version 2.0 (the "License");
54 // you may not use this file except in compliance with the License.
55 // You may obtain a copy of the License at
56 //
57 // http://www.apache.org/licenses/LICENSE-2.0
58 //
59 // Unless required by applicable law or agreed to in writing, software
60 // distributed under the License is distributed on an "AS IS" BASIS,
61 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
62 // See the License for the specific language governing permissions and
63 // limitations under the License.
64
65 // Copyright 2018 The Go Authors. All rights reserved.
66 // Use of this source code is governed by a BSD-style
67 // license that can be found in the LICENSE file.
68 `
69
70 var genLine string
71
72 var (
73 exclude = flag.String("exclude", "", "comma-separated list of regexps of entries to exclude")
74 stripstr = flag.Bool("stripstr", false, "Remove String suffix from functions")
75 )
76
77 func init() {
78 log.SetFlags(log.Lshortfile)
79 }
80
81 func main() {
82 flag.Parse()
83
84 genLine = "// Originally generated with: go run qgo.go " + strings.Join(os.Args[1:], " ")
85
86 args := flag.Args()
87 if len(args) == 0 {
88 fmt.Println(strings.TrimSpace(help))
89 return
90 }
91
92 command := args[0]
93 args = args[1:]
94
95 switch command {
96 case "extract":
97 extract(args)
98 }
99 }
100
101 var exclusions []*regexp.Regexp
102
103 func initExclusions() {
104 for _, re := range strings.Split(*exclude, ",") {
105 if re != "" {
106 exclusions = append(exclusions, regexp.MustCompile(re))
107 }
108 }
109 }
110
111 func filter(name string) bool {
112 if !ast.IsExported(name) {
113 return true
114 }
115 for _, ex := range exclusions {
116 if ex.MatchString(name) {
117 return true
118 }
119 }
120 return false
121 }
122
123 func pkgName() string {
124 pkg, err := os.Getwd()
125 if err != nil {
126 log.Fatal(err)
127 }
128 return filepath.Base(pkg)
129 }
130
131 type extracter struct {
132 pkg *packages.Package
133 }
134
135 func extract(args []string) {
136 cfg := &packages.Config{
137 Mode: packages.LoadFiles |
138 packages.LoadAllSyntax |
139 packages.LoadTypes,
140 }
141 pkgs, err := packages.Load(cfg, args...)
142 if err != nil {
143 log.Fatal(err)
144 }
145
146 e := extracter{}
147
148 lastPkg := ""
149 var w *bytes.Buffer
150 initExclusions()
151
152 flushFile := func() {
153 if w != nil && w.Len() > 0 {
154 b, err := format.Source(w.Bytes())
155 if err != nil {
156 log.Fatal(err)
157 }
158 err = os.WriteFile(lastPkg+".go", b, 0644)
159 if err != nil {
160 log.Fatal(err)
161 }
162 }
163 w = &bytes.Buffer{}
164 }
165
166 for _, p := range pkgs {
167 e.pkg = p
168 for _, f := range p.Syntax {
169 if lastPkg != p.Name {
170 flushFile()
171 lastPkg = p.Name
172 fmt.Fprint(w, copyright)
173 fmt.Fprintln(w)
174 fmt.Fprintln(w, genLine)
175 fmt.Fprintln(w)
176 fmt.Fprintf(w, "package %s\n", pkgName())
177 fmt.Fprintln(w)
178 fmt.Fprintf(w, "import %q", p.PkgPath)
179 fmt.Fprintln(w)
180 }
181
182 for _, d := range f.Decls {
183 switch x := d.(type) {
184 case *ast.FuncDecl:
185 e.reportFun(w, x)
186 case *ast.GenDecl:
187 e.reportDecl(w, x)
188 }
189 }
190 }
191 }
192 flushFile()
193 }
194
195 func (e *extracter) reportFun(w io.Writer, x *ast.FuncDecl) {
196 if filter(x.Name.Name) {
197 return
198 }
199 pkgName := e.pkg.Name
200 override := ""
201 params := []ast.Expr{}
202 if x.Type.Params != nil {
203 for _, f := range x.Type.Params.List {
204 tx := f.Type
205 if star, isStar := tx.(*ast.StarExpr); isStar {
206 if i, ok := star.X.(*ast.Ident); ok && ast.IsExported(i.Name) {
207 f.Type = &ast.SelectorExpr{X: ast.NewIdent(pkgName), Sel: i}
208 if isStar {
209 f.Type = &ast.StarExpr{X: f.Type}
210 }
211 }
212 }
213 for _, n := range f.Names {
214 params = append(params, n)
215 if n.Name == pkgName {
216 override = pkgName + x.Name.Name
217 }
218 }
219 }
220 }
221 var fn ast.Expr = &ast.SelectorExpr{
222 X: ast.NewIdent(pkgName),
223 Sel: x.Name,
224 }
225 if override != "" {
226 fn = ast.NewIdent(override)
227 }
228 x.Body = &ast.BlockStmt{List: []ast.Stmt{
229 &ast.ReturnStmt{Results: []ast.Expr{&ast.CallExpr{
230 Fun: fn,
231 Args: params,
232 }}},
233 }}
234 if name := x.Name.Name; *stripstr && strings.HasSuffix(name, "String") {
235 newName := name[:len(name)-len("String")]
236 x.Name = ast.NewIdent(newName)
237 if x.Doc != nil {
238 for _, c := range x.Doc.List {
239 c.Text = strings.Replace(c.Text, name, newName, -1)
240 }
241 }
242 }
243 types := []ast.Expr{}
244 if x.Recv == nil && x.Type != nil && x.Type.Results != nil && !strings.HasPrefix(x.Name.Name, "New") {
245 for _, f := range x.Type.Results.List {
246 if len(f.Names) == 0 {
247 types = append(types, f.Type)
248 } else {
249 for range f.Names {
250 types = append(types, f.Type)
251 }
252 }
253 }
254 }
255 if len(types) != 1 {
256 switch len(types) {
257 case 2:
258 if i, ok := types[1].(*ast.Ident); ok && i.Name == "error" {
259 break
260 }
261 fallthrough
262 default:
263 fmt.Printf("Skipping ")
264 x.Doc = nil
265 printer.Fprint(os.Stdout, e.pkg.Fset, x)
266 fmt.Println()
267 return
268 }
269 }
270 fmt.Fprintln(w)
271 printer.Fprint(w, e.pkg.Fset, x.Doc)
272 printer.Fprint(w, e.pkg.Fset, x)
273 fmt.Fprint(w, "\n")
274 if override != "" {
275 fmt.Fprintf(w, "var %s = %s.%s\n\n", override, pkgName, x.Name.Name)
276 }
277 }
278
279 func (e *extracter) reportDecl(w io.Writer, x *ast.GenDecl) {
280 if x.Tok != token.CONST {
281 return
282 }
283 k := 0
284 for _, s := range x.Specs {
285 if v, ok := s.(*ast.ValueSpec); ok && !filter(v.Names[0].Name) {
286 if v.Values == nil {
287 v.Values = make([]ast.Expr, len(v.Names))
288 }
289 for i, expr := range v.Names {
290
291 if _, ok := v.Values[i].(*ast.BasicLit); ok {
292 continue
293 }
294 tv, _ := types.Eval(e.pkg.Fset, e.pkg.Types, v.Pos(), v.Names[0].Name)
295 tok := token.ILLEGAL
296 switch tv.Value.Kind() {
297 case constant.Bool:
298 v.Values[i] = ast.NewIdent(tv.Value.ExactString())
299 continue
300 case constant.String:
301 tok = token.STRING
302 case constant.Int:
303 tok = token.INT
304 case constant.Float:
305 tok = token.FLOAT
306 default:
307 fmt.Printf("Skipping %s\n", v.Names)
308 continue
309 }
310 v.Values[i] = &ast.BasicLit{
311 ValuePos: expr.Pos(),
312 Kind: tok,
313 Value: tv.Value.ExactString(),
314 }
315 }
316 v.Type = nil
317 x.Specs[k] = v
318 k++
319 }
320 }
321 x.Specs = x.Specs[:k]
322 if len(x.Specs) == 0 {
323 return
324 }
325 fmt.Fprintln(w)
326 printer.Fprint(w, e.pkg.Fset, x)
327 fmt.Fprintln(w)
328 }
329
View as plain text