1 package validator_test
2
3 import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "regexp"
8 "sort"
9 "strconv"
10 "strings"
11 "testing"
12
13 "github.com/stretchr/testify/require"
14 "github.com/vektah/gqlparser/v2"
15 "github.com/vektah/gqlparser/v2/ast"
16 "github.com/vektah/gqlparser/v2/gqlerror"
17 "gopkg.in/yaml.v2"
18 )
19
20 type Spec struct {
21 Name string
22 Rule string
23 Schema string
24 Query string
25 Errors gqlerror.List
26 }
27
28 type Deviation struct {
29 Rule string
30 Errors []*gqlerror.Error
31 Skip string
32
33 pattern *regexp.Regexp
34 }
35
36 func TestValidation(t *testing.T) {
37 var rawSchemas []string
38 readYaml("./imported/spec/schemas.yml", &rawSchemas)
39
40 var deviations []*Deviation
41 readYaml("./imported/deviations.yml", &deviations)
42 for _, d := range deviations {
43 d.pattern = regexp.MustCompile("^" + d.Rule + "$")
44 }
45
46 var schemas = make([]*ast.Schema, 0, len(rawSchemas))
47 for i, schema := range rawSchemas {
48 schema, err := gqlparser.LoadSchema(&ast.Source{Input: schema, Name: fmt.Sprintf("schemas.yml[%d]", i)})
49 if err != nil {
50 panic(err)
51 }
52 schemas = append(schemas, schema)
53 }
54
55 err := filepath.Walk("./", func(path string, info os.FileInfo, err error) error {
56 if info.IsDir() || !strings.HasSuffix(path, ".spec.yml") {
57 return nil
58 }
59
60 runSpec(t, schemas, deviations, path)
61 return nil
62 })
63 require.NoError(t, err)
64 }
65
66 func runSpec(t *testing.T, schemas []*ast.Schema, deviations []*Deviation, filename string) {
67 ruleName := strings.TrimSuffix(filepath.Base(filename), ".spec.yml")
68
69 var specs []Spec
70 readYaml(filename, &specs)
71 t.Run(ruleName, func(t *testing.T) {
72 for _, spec := range specs {
73 if len(spec.Errors) == 0 {
74 spec.Errors = nil
75 }
76 t.Run(spec.Name, func(t *testing.T) {
77 for _, deviation := range deviations {
78 if deviation.pattern.MatchString(ruleName + "/" + spec.Name) {
79 if deviation.Skip != "" {
80 t.Skip(deviation.Skip)
81 }
82 if deviation.Errors != nil {
83 spec.Errors = deviation.Errors
84 }
85 }
86 }
87
88
89 var schema *ast.Schema
90 if idx, err := strconv.Atoi(spec.Schema); err != nil {
91 var gqlErr error
92 schema, gqlErr = gqlparser.LoadSchema(&ast.Source{Input: spec.Schema, Name: spec.Name})
93 if gqlErr != nil {
94 t.Fatal(err)
95 }
96 } else {
97 schema = schemas[idx]
98 }
99 _, errList := gqlparser.LoadQuery(schema, spec.Query)
100 var finalErrors gqlerror.List
101 for _, err := range errList {
102
103 if spec.Rule != "" && err.Rule != spec.Rule {
104 continue
105 }
106 finalErrors = append(finalErrors, err)
107 }
108
109 for i := range spec.Errors {
110 spec.Errors[i].Rule = spec.Rule
111
112
113 spec.Errors[i].Message = strings.Replace(spec.Errors[i].Message, "; Did you mean", ". Did you mean", -1)
114 }
115 sort.Slice(spec.Errors, compareErrors(spec.Errors))
116 sort.Slice(finalErrors, compareErrors(finalErrors))
117
118 if len(finalErrors) != len(spec.Errors) {
119 t.Errorf("wrong number of errors returned\ngot:\n%s\nwant:\n%s", finalErrors.Error(), spec.Errors)
120 } else {
121 for i := range spec.Errors {
122 expected := spec.Errors[i]
123 actual := finalErrors[i]
124 if actual.Rule != spec.Rule {
125 continue
126 }
127 var errLocs []string
128 if expected.Message != actual.Message {
129 errLocs = append(errLocs, "message mismatch")
130 }
131 if len(expected.Locations) > 0 && len(actual.Locations) == 0 {
132 errLocs = append(errLocs, "missing location")
133 }
134 if len(expected.Locations) > 0 && len(actual.Locations) > 0 {
135 found := false
136 for _, loc := range expected.Locations {
137 if actual.Locations[0].Line == loc.Line {
138 found = true
139 break
140 }
141 }
142
143 if !found {
144 errLocs = append(errLocs, "line")
145 }
146 }
147
148 if len(errLocs) > 0 {
149 t.Errorf("%s\ngot: %s\nwant: %s", strings.Join(errLocs, ", "), finalErrors[i].Error(), spec.Errors[i].Error())
150 }
151 }
152 }
153
154 if t.Failed() {
155 t.Logf("name: '%s'", spec.Name)
156 t.Log("\nquery:", spec.Query)
157 }
158 })
159 }
160 })
161 }
162
163 func compareErrors(errors gqlerror.List) func(i, j int) bool {
164 return func(i, j int) bool {
165 cmp := strings.Compare(errors[i].Message, errors[j].Message)
166 if cmp == 0 && len(errors[i].Locations) > 0 && len(errors[j].Locations) > 0 {
167 return errors[i].Locations[0].Line > errors[j].Locations[0].Line
168 }
169 return cmp < 0
170 }
171 }
172
173 func readYaml(filename string, result interface{}) {
174 b, err := os.ReadFile(filename)
175 if err != nil {
176 panic(err)
177 }
178 err = yaml.Unmarshal(b, result)
179 if err != nil {
180 panic(fmt.Errorf("unable to load %s: %s", filename, err.Error()))
181 }
182 }
183
View as plain text