1 package modrequirements
2
3 import (
4 "context"
5 "fmt"
6 "runtime"
7 "slices"
8 "sync"
9 "sync/atomic"
10
11 "cuelang.org/go/internal/mod/mvs"
12 "cuelang.org/go/internal/mod/semver"
13 "cuelang.org/go/internal/par"
14 "cuelang.org/go/mod/module"
15 )
16
17 type majorVersionDefault struct {
18 version string
19 explicitDefault bool
20 ambiguousDefault bool
21 }
22
23
24
25
26
27 type Requirements struct {
28 registry Registry
29 mainModuleVersion module.Version
30
31
32
33
34 rootModules []module.Version
35 maxRootVersion map[string]string
36
37
38 origDefaultMajorVersions map[string]string
39
40
41
42
43 defaultMajorVersions map[string]majorVersionDefault
44
45 graphOnce sync.Once
46 graph atomic.Pointer[cachedGraph]
47 }
48
49
50
51 type Registry interface {
52 Requirements(ctx context.Context, m module.Version) ([]module.Version, error)
53 }
54
55
56
57 type cachedGraph struct {
58 mg *ModuleGraph
59 err error
60 }
61
62
63
64
65
66
67
68
69
70
71
72
73
74 func NewRequirements(mainModulePath string, reg Registry, rootModules []module.Version, defaultMajorVersions map[string]string) *Requirements {
75 mainModuleVersion := module.MustNewVersion(mainModulePath, "")
76
77
78 for i, v := range rootModules {
79 if v.Path() == mainModulePath {
80 panic(fmt.Sprintf("NewRequirements called with untrimmed build list: rootModules[%v] is a main module", i))
81 }
82 if !v.IsValid() {
83 panic("NewRequirements with invalid zero version")
84 }
85 }
86 rs := &Requirements{
87 registry: reg,
88 mainModuleVersion: mainModuleVersion,
89 rootModules: rootModules,
90 maxRootVersion: make(map[string]string, len(rootModules)),
91 }
92 for i, m := range rootModules {
93 if i > 0 {
94 prev := rootModules[i-1]
95 if prev.Path() > m.Path() || (prev.Path() == m.Path() && semver.Compare(prev.Version(), m.Version()) > 0) {
96 panic(fmt.Sprintf("NewRequirements called with unsorted roots: %v", rootModules))
97 }
98 }
99 if v, ok := rs.maxRootVersion[m.Path()]; !ok || semver.Compare(v, m.Version()) < 0 {
100 rs.maxRootVersion[m.Path()] = m.Version()
101 }
102 }
103 rs.initDefaultMajorVersions(defaultMajorVersions)
104 return rs
105 }
106
107
108
109 func (rs *Requirements) WithDefaultMajorVersions(defaults map[string]string) *Requirements {
110 rs1 := &Requirements{
111 registry: rs.registry,
112 mainModuleVersion: rs.mainModuleVersion,
113 rootModules: rs.rootModules,
114 maxRootVersion: rs.maxRootVersion,
115 }
116
117
118
119 rs1.graph.Store(rs.graph.Load())
120 if rs1.GraphIsLoaded() {
121 rs1.graphOnce.Do(func() {})
122 }
123 rs1.initDefaultMajorVersions(defaults)
124 return rs1
125 }
126
127 func (rs *Requirements) initDefaultMajorVersions(defaultMajorVersions map[string]string) {
128 rs.origDefaultMajorVersions = defaultMajorVersions
129 rs.defaultMajorVersions = make(map[string]majorVersionDefault)
130 for mpath, v := range defaultMajorVersions {
131 if _, _, ok := module.SplitPathVersion(mpath); ok {
132 panic(fmt.Sprintf("NewRequirements called with major version in defaultMajorVersions %q", mpath))
133 }
134 if semver.Major(v) != v {
135 panic(fmt.Sprintf("NewRequirements called with invalid major version %q for module %q", v, mpath))
136 }
137 rs.defaultMajorVersions[mpath] = majorVersionDefault{
138 version: v,
139 explicitDefault: true,
140 }
141 }
142
143
144 for _, m := range rs.rootModules {
145 if m.IsLocal() {
146 continue
147 }
148 mpath := m.BasePath()
149 d, ok := rs.defaultMajorVersions[mpath]
150 if !ok {
151 rs.defaultMajorVersions[mpath] = majorVersionDefault{
152 version: semver.Major(m.Version()),
153 }
154 continue
155 }
156 if d.explicitDefault {
157 continue
158 }
159 d.ambiguousDefault = true
160 rs.defaultMajorVersions[mpath] = d
161 }
162 }
163
164
165
166
167 func (rs *Requirements) RootSelected(mpath string) (version string, ok bool) {
168 if mpath == rs.mainModuleVersion.Path() {
169 return "", true
170 }
171 if v, ok := rs.maxRootVersion[mpath]; ok {
172 return v, true
173 }
174 return "", false
175 }
176
177
178
179 func (rs *Requirements) DefaultMajorVersions() map[string]string {
180 return rs.origDefaultMajorVersions
181 }
182
183 type MajorVersionDefaultStatus byte
184
185 const (
186 ExplicitDefault MajorVersionDefaultStatus = iota
187 NonExplicitDefault
188 NoDefault
189 AmbiguousDefault
190 )
191
192
193
194
195
196 func (rs *Requirements) DefaultMajorVersion(mpath string) (string, MajorVersionDefaultStatus) {
197 d, ok := rs.defaultMajorVersions[mpath]
198 switch {
199 case !ok:
200 return "", NoDefault
201 case d.ambiguousDefault:
202 return "", AmbiguousDefault
203 case d.explicitDefault:
204 return d.version, ExplicitDefault
205 default:
206 return d.version, NonExplicitDefault
207 }
208 }
209
210
211
212
213 func (rs *Requirements) RootModules() []module.Version {
214 return slices.Clip(rs.rootModules)
215 }
216
217
218
219
220
221
222
223
224
225 func (rs *Requirements) Graph(ctx context.Context) (*ModuleGraph, error) {
226 rs.graphOnce.Do(func() {
227 mg, mgErr := rs.readModGraph(ctx)
228 rs.graph.Store(&cachedGraph{mg, mgErr})
229 })
230 cached := rs.graph.Load()
231 return cached.mg, cached.err
232 }
233
234
235 func (rs *Requirements) GraphIsLoaded() bool {
236 return rs.graph.Load() != nil
237 }
238
239
240
241
242
243
244 type ModuleGraph struct {
245 g *mvs.Graph[module.Version]
246
247 buildListOnce sync.Once
248 buildList []module.Version
249 }
250
251
252
253
254
255
256
257
258
259
260 func (rs *Requirements) cueModSummary(ctx context.Context, m module.Version) (*modFileSummary, error) {
261 require, err := rs.registry.Requirements(ctx, m)
262 if err != nil {
263 return nil, err
264 }
265
266 return &modFileSummary{
267 module: m,
268 require: require,
269 }, nil
270 }
271
272 type modFileSummary struct {
273 module module.Version
274 require []module.Version
275 }
276
277
278
279
280
281 func (rs *Requirements) readModGraph(ctx context.Context) (*ModuleGraph, error) {
282 var (
283 mu sync.Mutex
284 hasError bool
285 mg = &ModuleGraph{
286 g: mvs.NewGraph[module.Version](module.Versions{}, cmpVersion, []module.Version{rs.mainModuleVersion}),
287 }
288 )
289
290 mg.g.Require(rs.mainModuleVersion, rs.rootModules)
291
292 var (
293 loadQueue = par.NewQueue(runtime.GOMAXPROCS(0))
294 loading sync.Map
295 loadCache par.ErrCache[module.Version, *modFileSummary]
296 )
297
298
299
300 loadOne := func(m module.Version) (*modFileSummary, error) {
301 return loadCache.Do(m, func() (*modFileSummary, error) {
302 summary, err := rs.cueModSummary(ctx, m)
303
304 mu.Lock()
305 if err == nil {
306 mg.g.Require(m, summary.require)
307 } else {
308 hasError = true
309 }
310 mu.Unlock()
311
312 return summary, err
313 })
314 }
315
316 for _, m := range rs.rootModules {
317 m := m
318 if !m.IsValid() {
319 panic("root module version is invalid")
320 }
321 if m.IsLocal() || m.Version() == "none" {
322 continue
323 }
324
325 if _, dup := loading.LoadOrStore(m, nil); dup {
326
327
328
329 continue
330 }
331
332 loadQueue.Add(func() {
333 loadOne(m)
334
335 })
336 }
337 <-loadQueue.Idle()
338
339 if hasError {
340 return mg, mg.findError(&loadCache)
341 }
342 return mg, nil
343 }
344
345
346
347
348
349
350 func (mg *ModuleGraph) RequiredBy(m module.Version) (reqs []module.Version, ok bool) {
351 return mg.g.RequiredBy(m)
352 }
353
354
355
356
357 func (mg *ModuleGraph) Selected(path string) (version string) {
358 return mg.g.Selected(path)
359 }
360
361
362
363
364 func (mg *ModuleGraph) WalkBreadthFirst(f func(m module.Version)) {
365 mg.g.WalkBreadthFirst(f)
366 }
367
368
369
370
371
372
373
374
375
376 func (mg *ModuleGraph) BuildList() []module.Version {
377 mg.buildListOnce.Do(func() {
378 mg.buildList = slices.Clip(mg.g.BuildList())
379 })
380 return mg.buildList
381 }
382
383 func (mg *ModuleGraph) findError(loadCache *par.ErrCache[module.Version, *modFileSummary]) error {
384 errStack := mg.g.FindPath(func(m module.Version) bool {
385 _, err := loadCache.Get(m)
386 return err != nil && err != par.ErrCacheEntryNotFound
387 })
388 if len(errStack) > 0 {
389
390
391
392
393
394
395 _, err := loadCache.Get(errStack[len(errStack)-1])
396 var noUpgrade func(from, to module.Version) bool
397 return mvs.NewBuildListError[module.Version](err, errStack, module.Versions{}, noUpgrade)
398 }
399
400 return nil
401 }
402
403
404
405
406
407
408
409 func cmpVersion(v1, v2 string) int {
410 if v2 == "" {
411 if v1 == "" {
412 return 0
413 }
414 return -1
415 }
416 if v1 == "" {
417 return 1
418 }
419 return semver.Compare(v1, v2)
420 }
421
View as plain text