1
16
17 package main
18
19 import (
20 "bytes"
21 "encoding/json"
22 "fmt"
23 "io"
24 "log"
25 "os"
26 "os/exec"
27 "path/filepath"
28 "strings"
29
30 "gopkg.in/yaml.v2"
31 )
32
33
34 type Package struct {
35 Dir string `yaml:",omitempty"`
36 ImportPath string `yaml:",omitempty"`
37 Imports []string `yaml:",omitempty"`
38 TestImports []string `yaml:",omitempty"`
39 XTestImports []string `yaml:",omitempty"`
40 }
41
42
43
44 type ImportRestriction struct {
45
46
47
48
49 BaseDir string `yaml:"baseImportPath"`
50
51
52
53 IgnoredSubTrees []string `yaml:"ignoredSubTrees,omitempty"`
54
55
56
57 AllowedImports []string `yaml:"allowedImports"`
58
59 ExcludeTests bool `yaml:"excludeTests"`
60 }
61
62
63
64 func (i *ImportRestriction) ForbiddenImportsFor(pkg Package) ([]string, error) {
65 if restricted, err := i.isRestrictedDir(pkg.Dir); err != nil {
66 return []string{}, err
67 } else if !restricted {
68 return []string{}, nil
69 }
70
71 return i.forbiddenImportsFor(pkg), nil
72 }
73
74
75
76
77
78
79 func (i *ImportRestriction) isRestrictedDir(dir string) (bool, error) {
80 if under, err := isPathUnder(i.BaseDir, dir); err != nil {
81 return false, err
82 } else if !under {
83 return false, nil
84 }
85
86 for _, ignored := range i.IgnoredSubTrees {
87 if under, err := isPathUnder(ignored, dir); err != nil {
88 return false, err
89 } else if under {
90 return false, nil
91 }
92 }
93
94 return true, nil
95 }
96
97
98 func isPathUnder(base, path string) (bool, error) {
99 absBase, err := filepath.Abs(base)
100 if err != nil {
101 return false, err
102 }
103 absPath, err := filepath.Abs(path)
104 if err != nil {
105 return false, err
106 }
107
108 relPath, err := filepath.Rel(absBase, absPath)
109 if err != nil {
110 return false, err
111 }
112
113
114
115 return !strings.HasPrefix(relPath, ".."), nil
116 }
117
118
119
120
121 func (i *ImportRestriction) forbiddenImportsFor(pkg Package) []string {
122 forbiddenImportSet := map[string]struct{}{}
123 imports := pkg.Imports
124 if !i.ExcludeTests {
125 imports = append(imports, append(pkg.TestImports, pkg.XTestImports...)...)
126 }
127 for _, imp := range imports {
128 if i.isForbidden(imp) {
129 forbiddenImportSet[imp] = struct{}{}
130 }
131 }
132
133 var forbiddenImports []string
134 for imp := range forbiddenImportSet {
135 forbiddenImports = append(forbiddenImports, imp)
136 }
137 return forbiddenImports
138 }
139
140
141
142
143
144
145 func (i *ImportRestriction) isForbidden(imp string) bool {
146 importsBelowRoot := strings.HasPrefix(imp, rootPackage)
147 importsBelowBase := strings.HasPrefix(imp, i.BaseDir)
148 importsAllowed := false
149 for _, allowed := range i.AllowedImports {
150 exactlyImportsAllowed := imp == allowed
151 importsBelowAllowed := strings.HasPrefix(imp, fmt.Sprintf("%s/", allowed))
152 importsAllowed = importsAllowed || (importsBelowAllowed || exactlyImportsAllowed)
153 }
154
155 return importsBelowRoot && !importsBelowBase && !importsAllowed
156 }
157
158 var rootPackage string
159
160 func main() {
161 if len(os.Args) != 3 {
162 log.Fatalf("Usage: %s <root> <restrictions.yaml>", os.Args[0])
163 }
164
165 rootPackage = os.Args[1]
166 configFile := os.Args[2]
167 importRestrictions, err := loadImportRestrictions(configFile)
168 if err != nil {
169 log.Fatalf("Failed to load import restrictions: %v", err)
170 }
171
172 foundForbiddenImports := false
173 for _, restriction := range importRestrictions {
174 baseDir := restriction.BaseDir
175 if filepath.IsAbs(baseDir) {
176 log.Fatalf("%q appears to be an absolute path", baseDir)
177 }
178 if !strings.HasPrefix(baseDir, "./") {
179 baseDir = "./" + baseDir
180 }
181 baseDir = strings.TrimRight(baseDir, "/")
182 log.Printf("Inspecting imports under %s/...\n", baseDir)
183
184 packages, err := resolvePackageTree(baseDir)
185 if err != nil {
186 log.Fatalf("Failed to resolve package tree: %v", err)
187 } else if len(packages) == 0 {
188 log.Fatalf("Found no packages under tree %s", baseDir)
189 }
190
191 log.Printf("- validating imports for %d packages", len(packages))
192 restrictionViolated := false
193 for _, pkg := range packages {
194 if forbidden, err := restriction.ForbiddenImportsFor(pkg); err != nil {
195 log.Fatalf("-- failed to validate imports: %v", err)
196 } else if len(forbidden) != 0 {
197 logForbiddenPackages(pkg.ImportPath, forbidden)
198 restrictionViolated = true
199 }
200 }
201 if restrictionViolated {
202 foundForbiddenImports = true
203 log.Println("- FAIL")
204 } else {
205 log.Println("- OK")
206 }
207 }
208
209 if foundForbiddenImports {
210 os.Exit(1)
211 }
212 }
213
214 func loadImportRestrictions(configFile string) ([]ImportRestriction, error) {
215 config, err := os.ReadFile(configFile)
216 if err != nil {
217 return nil, fmt.Errorf("failed to load configuration from %s: %v", configFile, err)
218 }
219
220 var importRestrictions []ImportRestriction
221 if err := yaml.Unmarshal(config, &importRestrictions); err != nil {
222 return nil, fmt.Errorf("failed to unmarshal from %s: %v", configFile, err)
223 }
224
225 return importRestrictions, nil
226 }
227
228 func resolvePackageTree(treeBase string) ([]Package, error) {
229 cmd := "go"
230 args := []string{"list", "-json", fmt.Sprintf("%s/...", treeBase)}
231 c := exec.Command(cmd, args...)
232 stdout, err := c.Output()
233 if err != nil {
234 var message string
235 if ee, ok := err.(*exec.ExitError); ok {
236 message = fmt.Sprintf("%v\n%v", ee, string(ee.Stderr))
237 } else {
238 message = fmt.Sprintf("%v", err)
239 }
240 return nil, fmt.Errorf("failed to run `%s %s`: %v", cmd, strings.Join(args, " "), message)
241 }
242
243 packages, err := decodePackages(bytes.NewReader(stdout))
244 if err != nil {
245 return nil, fmt.Errorf("failed to decode packages: %v", err)
246 }
247
248 return packages, nil
249 }
250
251 func decodePackages(r io.Reader) ([]Package, error) {
252
253
254
255
256
257 var packages []Package
258 decoder := json.NewDecoder(r)
259 for decoder.More() {
260 var pkg Package
261 if err := decoder.Decode(&pkg); err != nil {
262 return nil, fmt.Errorf("invalid package: %v", err)
263 }
264 packages = append(packages, pkg)
265 }
266
267 return packages, nil
268 }
269
270 func logForbiddenPackages(base string, forbidden []string) {
271 log.Printf("-- found forbidden imports for %s:\n", base)
272 for _, forbiddenPackage := range forbidden {
273 log.Printf("--- %s\n", forbiddenPackage)
274 }
275 }
276
View as plain text