1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 package main
72
73 import (
74 "bytes"
75 "flag"
76 "fmt"
77 "go/ast"
78 "go/format"
79 "go/printer"
80 "go/token"
81 "go/types"
82 "log"
83 "os"
84 "strconv"
85 "strings"
86 "unicode"
87
88 "golang.org/x/tools/go/packages"
89 )
90
91 var (
92 outputFile = flag.String("o", "", "write output to `file` (default standard output)")
93 dstPath = flag.String("dst", ".", "set destination import `path`")
94 pkgName = flag.String("pkg", "", "set destination package `name`")
95 prefix = flag.String("prefix", "&_", "set bundled identifier prefix to `p` (default is \"&_\", where & stands for the original name)")
96 buildTags = flag.String("tags", "", "the build constraints to be inserted into the generated file")
97
98 importMap = map[string]string{}
99 )
100
101 func init() {
102 flag.Var(flagFunc(addImportMap), "import", "rewrite import using `map`, of form old=new (can be repeated)")
103 }
104
105 func addImportMap(s string) {
106 if strings.Count(s, "=") != 1 {
107 log.Fatal("-import argument must be of the form old=new")
108 }
109 i := strings.Index(s, "=")
110 old, new := s[:i], s[i+1:]
111 if old == "" || new == "" {
112 log.Fatal("-import argument must be of the form old=new; old and new must be non-empty")
113 }
114 importMap[old] = new
115 }
116
117 func usage() {
118 fmt.Fprintf(os.Stderr, "Usage: bundle [options] <src>\n")
119 flag.PrintDefaults()
120 }
121
122 func main() {
123 log.SetPrefix("bundle: ")
124 log.SetFlags(0)
125
126 flag.Usage = usage
127 flag.Parse()
128 args := flag.Args()
129 if len(args) != 1 {
130 usage()
131 os.Exit(2)
132 }
133
134 cfg := &packages.Config{Mode: packages.NeedName}
135 pkgs, err := packages.Load(cfg, *dstPath)
136 if err != nil {
137 log.Fatalf("cannot load destination package: %v", err)
138 }
139 if packages.PrintErrors(pkgs) > 0 || len(pkgs) != 1 {
140 log.Fatalf("failed to load destination package")
141 }
142 if *pkgName == "" {
143 *pkgName = pkgs[0].Name
144 }
145
146 code, err := bundle(args[0], pkgs[0].PkgPath, *pkgName, *prefix, *buildTags)
147 if err != nil {
148 log.Fatal(err)
149 }
150 if *outputFile != "" {
151 err := os.WriteFile(*outputFile, code, 0666)
152 if err != nil {
153 log.Fatal(err)
154 }
155 } else {
156 _, err := os.Stdout.Write(code)
157 if err != nil {
158 log.Fatal(err)
159 }
160 }
161 }
162
163
164 func isStandardImportPath(path string) bool {
165 i := strings.Index(path, "/")
166 if i < 0 {
167 i = len(path)
168 }
169 elem := path[:i]
170 return !strings.Contains(elem, ".")
171 }
172
173 var testingOnlyPackagesConfig *packages.Config
174
175 func bundle(src, dst, dstpkg, prefix, buildTags string) ([]byte, error) {
176
177 cfg := &packages.Config{}
178 if testingOnlyPackagesConfig != nil {
179 *cfg = *testingOnlyPackagesConfig
180 } else {
181
182
183 cfg.Env = append(os.Environ(), "GOFLAGS=-mod=mod")
184 }
185 cfg.Mode = packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo
186 pkgs, err := packages.Load(cfg, src)
187 if err != nil {
188 return nil, err
189 }
190 if packages.PrintErrors(pkgs) > 0 || len(pkgs) != 1 {
191 return nil, fmt.Errorf("failed to load source package")
192 }
193 pkg := pkgs[0]
194
195 if strings.Contains(prefix, "&") {
196 prefix = strings.Replace(prefix, "&", pkg.Syntax[0].Name.Name, -1)
197 }
198
199 objsToUpdate := make(map[types.Object]bool)
200 var rename func(from types.Object)
201 rename = func(from types.Object) {
202 if !objsToUpdate[from] {
203 objsToUpdate[from] = true
204
205
206
207
208
209
210 if _, ok := from.(*types.TypeName); ok {
211 for id, obj := range pkg.TypesInfo.Uses {
212 if obj == from {
213 if field := pkg.TypesInfo.Defs[id]; field != nil {
214 rename(field)
215 }
216 }
217 }
218 }
219 }
220 }
221
222
223 scope := pkg.Types.Scope()
224 for _, name := range scope.Names() {
225 rename(scope.Lookup(name))
226 }
227
228 var out bytes.Buffer
229 if buildTags != "" {
230 fmt.Fprintf(&out, "//go:build %s\n", buildTags)
231 }
232
233 fmt.Fprintf(&out, "// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.\n")
234 if *outputFile != "" && buildTags == "" {
235 fmt.Fprintf(&out, "//go:generate bundle %s\n", strings.Join(quoteArgs(os.Args[1:]), " "))
236 } else {
237 fmt.Fprintf(&out, "// $ bundle %s\n", strings.Join(os.Args[1:], " "))
238 }
239 fmt.Fprintf(&out, "\n")
240
241
242 for _, f := range pkg.Syntax {
243 if doc := f.Doc.Text(); strings.TrimSpace(doc) != "" {
244 for _, line := range strings.Split(doc, "\n") {
245 fmt.Fprintf(&out, "// %s\n", line)
246 }
247 }
248 }
249
250 fmt.Fprintln(&out)
251
252 fmt.Fprintf(&out, "package %s\n\n", dstpkg)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269 var pkgStd = make(map[string]bool)
270 var pkgExt = make(map[string]bool)
271 for _, f := range pkg.Syntax {
272 for _, imp := range f.Imports {
273 path, err := strconv.Unquote(imp.Path.Value)
274 if err != nil {
275 log.Fatalf("invalid import path string: %v", err)
276 }
277 if path == dst {
278 continue
279 }
280 if newPath, ok := importMap[path]; ok {
281 path = newPath
282 }
283
284 var name string
285 if imp.Name != nil {
286 name = imp.Name.Name
287 }
288 spec := fmt.Sprintf("%s %q", name, path)
289 if isStandardImportPath(path) {
290 pkgStd[spec] = true
291 } else {
292 pkgExt[spec] = true
293 }
294 }
295 }
296
297
298 fmt.Fprintln(&out, "import (")
299 for p := range pkgStd {
300 fmt.Fprintf(&out, "\t%s\n", p)
301 }
302 if len(pkgExt) > 0 {
303 fmt.Fprintln(&out)
304 }
305 for p := range pkgExt {
306 fmt.Fprintf(&out, "\t%s\n", p)
307 }
308 fmt.Fprint(&out, ")\n\n")
309
310
311 for _, f := range pkg.Syntax {
312
313 for id, obj := range pkg.TypesInfo.Defs {
314 if objsToUpdate[obj] {
315 id.Name = prefix + obj.Name()
316 }
317 }
318 for id, obj := range pkg.TypesInfo.Uses {
319 if objsToUpdate[obj] {
320 id.Name = prefix + obj.Name()
321 }
322 }
323
324
325
326
327 ast.Inspect(f, func(n ast.Node) bool {
328 if sel, ok := n.(*ast.SelectorExpr); ok {
329 if id, ok := sel.X.(*ast.Ident); ok {
330 if obj, ok := pkg.TypesInfo.Uses[id].(*types.PkgName); ok {
331 if obj.Imported().Path() == dst {
332 id.Name = "@@@"
333 }
334 }
335 }
336 }
337 return true
338 })
339
340 last := f.Package
341 if len(f.Imports) > 0 {
342 imp := f.Imports[len(f.Imports)-1]
343 last = imp.End()
344 if imp.Comment != nil {
345 if e := imp.Comment.End(); e > last {
346 last = e
347 }
348 }
349 }
350
351
352
353 var buf bytes.Buffer
354 for _, decl := range f.Decls {
355 if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT {
356 continue
357 }
358
359 beg, end := sourceRange(decl)
360
361 printComments(&out, f.Comments, last, beg)
362
363 buf.Reset()
364 format.Node(&buf, pkg.Fset, &printer.CommentedNode{Node: decl, Comments: f.Comments})
365
366
367 out.Write(bytes.Replace(buf.Bytes(), []byte("@@@."), nil, -1))
368
369 last = printSameLineComment(&out, f.Comments, pkg.Fset, end)
370
371 out.WriteString("\n\n")
372 }
373
374 printLastComments(&out, f.Comments, last)
375 }
376
377
378 result, err := format.Source(out.Bytes())
379 if err != nil {
380 log.Fatalf("formatting failed: %v", err)
381 }
382
383 return result, nil
384 }
385
386
387
388 func sourceRange(decl ast.Decl) (beg, end token.Pos) {
389 beg = decl.Pos()
390 end = decl.End()
391
392 var doc, com *ast.CommentGroup
393
394 switch d := decl.(type) {
395 case *ast.GenDecl:
396 doc = d.Doc
397 if len(d.Specs) > 0 {
398 switch spec := d.Specs[len(d.Specs)-1].(type) {
399 case *ast.ValueSpec:
400 com = spec.Comment
401 case *ast.TypeSpec:
402 com = spec.Comment
403 }
404 }
405 case *ast.FuncDecl:
406 doc = d.Doc
407 }
408
409 if doc != nil {
410 beg = doc.Pos()
411 }
412 if com != nil && com.End() > end {
413 end = com.End()
414 }
415
416 return beg, end
417 }
418
419 func printComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos, end token.Pos) {
420 for _, cg := range comments {
421 if pos <= cg.Pos() && cg.Pos() < end {
422 for _, c := range cg.List {
423 fmt.Fprintln(out, c.Text)
424 }
425 fmt.Fprintln(out)
426 }
427 }
428 }
429
430 const infinity = 1 << 30
431
432 func printLastComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos token.Pos) {
433 printComments(out, comments, pos, infinity)
434 }
435
436 func printSameLineComment(out *bytes.Buffer, comments []*ast.CommentGroup, fset *token.FileSet, pos token.Pos) token.Pos {
437 tf := fset.File(pos)
438 for _, cg := range comments {
439 if pos <= cg.Pos() && tf.Line(cg.Pos()) == tf.Line(pos) {
440 for _, c := range cg.List {
441 fmt.Fprintln(out, c.Text)
442 }
443 return cg.End()
444 }
445 }
446 return pos
447 }
448
449 func quoteArgs(ss []string) []string {
450
451
452
453
454
455
456
457
458
459 var qs []string
460 for _, s := range ss {
461 if s == "" || containsSpace(s) {
462 s = strconv.Quote(s)
463 }
464 qs = append(qs, s)
465 }
466 return qs
467 }
468
469 func containsSpace(s string) bool {
470 for _, r := range s {
471 if unicode.IsSpace(r) {
472 return true
473 }
474 }
475 return false
476 }
477
478 type flagFunc func(string)
479
480 func (f flagFunc) Set(s string) error {
481 f(s)
482 return nil
483 }
484
485 func (f flagFunc) String() string { return "" }
486
View as plain text