1 package code
2
3 import (
4 "bytes"
5 "errors"
6 "fmt"
7 "os"
8 "os/exec"
9 "path/filepath"
10 "runtime/debug"
11 "strings"
12 "sync"
13
14 "golang.org/x/tools/go/packages"
15 )
16
17 var (
18 once = sync.Once{}
19 modInfo *debug.BuildInfo
20 )
21
22 var mode = packages.NeedName |
23 packages.NeedFiles |
24 packages.NeedImports |
25 packages.NeedTypes |
26 packages.NeedSyntax |
27 packages.NeedTypesInfo |
28 packages.NeedModule |
29 packages.NeedDeps
30
31 type (
32
33
34 Packages struct {
35 packages map[string]*packages.Package
36 importToName map[string]string
37 loadErrors []error
38 buildFlags []string
39
40 numLoadCalls int
41 numNameCalls int
42 }
43
44 Option func(p *Packages)
45 )
46
47
48 func WithBuildTags(tags ...string) func(p *Packages) {
49 return func(p *Packages) {
50 p.buildFlags = append(p.buildFlags, "-tags", strings.Join(tags, ","))
51 }
52 }
53
54
55
56 func NewPackages(opts ...Option) *Packages {
57 p := &Packages{}
58 for _, opt := range opts {
59 opt(p)
60 }
61 return p
62 }
63
64 func (p *Packages) CleanupUserPackages() {
65 once.Do(func() {
66 var ok bool
67 modInfo, ok = debug.ReadBuildInfo()
68 if !ok {
69 modInfo = nil
70 }
71 })
72
73
74 if modInfo != nil {
75 var toRemove []string
76 for k := range p.packages {
77 if !strings.HasPrefix(k, modInfo.Main.Path) {
78 toRemove = append(toRemove, k)
79 }
80 }
81 for _, k := range toRemove {
82 delete(p.packages, k)
83 }
84 } else {
85 p.packages = nil
86 }
87 }
88
89
90
91 func (p *Packages) ReloadAll(importPaths ...string) []*packages.Package {
92 if p.packages != nil {
93 p.CleanupUserPackages()
94 }
95 return p.LoadAll(importPaths...)
96 }
97
98
99
100 func (p *Packages) LoadAll(importPaths ...string) []*packages.Package {
101 if p.packages == nil {
102 p.packages = map[string]*packages.Package{}
103 }
104
105 missing := make([]string, 0, len(importPaths))
106 for _, path := range importPaths {
107 if _, ok := p.packages[path]; ok {
108 continue
109 }
110 missing = append(missing, path)
111 }
112
113 if len(missing) > 0 {
114 p.numLoadCalls++
115 pkgs, err := packages.Load(&packages.Config{
116 Mode: mode,
117 BuildFlags: p.buildFlags,
118 }, missing...)
119 if err != nil {
120 p.loadErrors = append(p.loadErrors, err)
121 }
122
123 for _, pkg := range pkgs {
124 p.addToCache(pkg)
125 }
126 }
127
128 res := make([]*packages.Package, 0, len(importPaths))
129 for _, path := range importPaths {
130 res = append(res, p.packages[NormalizeVendor(path)])
131 }
132 return res
133 }
134
135 func (p *Packages) addToCache(pkg *packages.Package) {
136 imp := NormalizeVendor(pkg.PkgPath)
137 p.packages[imp] = pkg
138 for _, imp := range pkg.Imports {
139 if _, found := p.packages[NormalizeVendor(imp.PkgPath)]; !found {
140 p.addToCache(imp)
141 }
142 }
143 }
144
145
146 func (p *Packages) Load(importPath string) *packages.Package {
147
148 if p.packages != nil {
149 if pkg, ok := p.packages[importPath]; ok {
150 return pkg
151 }
152 }
153
154 pkgs := p.LoadAll(importPath)
155 if len(pkgs) == 0 {
156 return nil
157 }
158 return pkgs[0]
159 }
160
161
162
163 func (p *Packages) LoadWithTypes(importPath string) *packages.Package {
164 pkg := p.Load(importPath)
165 if pkg == nil || pkg.TypesInfo == nil {
166 p.numLoadCalls++
167 pkgs, err := packages.Load(&packages.Config{
168 Mode: mode,
169 BuildFlags: p.buildFlags,
170 }, importPath)
171 if err != nil {
172 p.loadErrors = append(p.loadErrors, err)
173 return nil
174 }
175 p.addToCache(pkgs[0])
176 pkg = pkgs[0]
177 }
178 return pkg
179 }
180
181
182 func (p *Packages) NameForPackage(importPath string) string {
183 if importPath == "" {
184 panic(errors.New("import path can not be empty"))
185 }
186 if p.importToName == nil {
187 p.importToName = map[string]string{}
188 }
189
190 importPath = NormalizeVendor(importPath)
191
192
193 if name := p.importToName[importPath]; name != "" {
194 return name
195 }
196
197
198 pkg := p.packages[importPath]
199
200 if pkg == nil {
201
202 p.numNameCalls++
203 pkgs, err := packages.Load(&packages.Config{
204 Mode: packages.NeedName,
205 BuildFlags: p.buildFlags,
206 }, importPath)
207 if err != nil {
208 p.loadErrors = append(p.loadErrors, err)
209 } else {
210 pkg = pkgs[0]
211 }
212 }
213
214 if pkg == nil || pkg.Name == "" {
215 return SanitizePackageName(filepath.Base(importPath))
216 }
217
218 p.importToName[importPath] = pkg.Name
219
220 return pkg.Name
221 }
222
223
224
225 func (p *Packages) Evict(importPath string) {
226 delete(p.packages, importPath)
227
228 for _, pkg := range p.packages {
229 for _, imported := range pkg.Imports {
230 if imported.PkgPath == importPath {
231 p.Evict(pkg.PkgPath)
232 }
233 }
234 }
235 }
236
237 func (p *Packages) ModTidy() error {
238 p.packages = nil
239 tidyCmd := exec.Command("go", "mod", "tidy")
240 tidyCmd.Stdout = os.Stdout
241 tidyCmd.Stderr = os.Stdout
242 if err := tidyCmd.Run(); err != nil {
243 return fmt.Errorf("go mod tidy failed: %w", err)
244 }
245 return nil
246 }
247
248
249 func (p *Packages) Errors() PkgErrors {
250 var res []error
251 res = append(res, p.loadErrors...)
252 for _, pkg := range p.packages {
253 for _, err := range pkg.Errors {
254 res = append(res, err)
255 }
256 }
257 return res
258 }
259
260 func (p *Packages) Count() int {
261 return len(p.packages)
262 }
263
264 type PkgErrors []error
265
266 func (p PkgErrors) Error() string {
267 var b bytes.Buffer
268 b.WriteString("packages.Load: ")
269 for _, e := range p {
270 b.WriteString(e.Error() + "\n")
271 }
272 return b.String()
273 }
274
View as plain text