...
1 package fieldset
2
3 import (
4 "fmt"
5 "strings"
6
7 "github.com/vektah/gqlparser/v2/ast"
8
9 "github.com/99designs/gqlgen/codegen"
10 "github.com/99designs/gqlgen/codegen/templates"
11 )
12
13
14
15 type Set []Field
16
17
18 type Field []string
19
20
21 func New(raw string, prefix []string) Set {
22 if !strings.Contains(raw, "{") {
23 return parseUnnestedKeyFieldSet(raw, prefix)
24 }
25
26 var (
27 ret = Set{}
28 subPrefix = prefix
29 )
30 before, during, after := extractSubs(raw)
31
32 if before != "" {
33 befores := New(before, prefix)
34 if len(befores) > 0 {
35 subPrefix = befores[len(befores)-1]
36 ret = append(ret, befores[:len(befores)-1]...)
37 }
38 }
39 if during != "" {
40 ret = append(ret, New(during, subPrefix)...)
41 }
42 if after != "" {
43 ret = append(ret, New(after, prefix)...)
44 }
45 return ret
46 }
47
48
49 func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition {
50 objType := schemaType
51 def := objType.Fields.ForName(f[0])
52
53 for _, part := range f[1:] {
54 if objType.Kind != ast.Object {
55 panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f))
56 }
57 x := def.Type.Name()
58 objType = schema.Types[x]
59 if objType == nil {
60 panic("invalid schema type: " + x)
61 }
62 def = objType.Fields.ForName(part)
63 }
64 if def == nil {
65 return nil
66 }
67 ret := *def
68 ret.Name = f.ToGoPrivate()
69
70 return &ret
71 }
72
73
74 func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field {
75 var def *codegen.Field
76
77 for _, part := range f {
78 def = fieldByName(obj, part)
79 if def == nil {
80 panic("unable to find field " + f[0])
81 }
82 obj = objects.ByName(def.TypeReference.Definition.Name)
83 }
84 return def
85 }
86
87
88 func (f Field) ToGo() string {
89 var ret string
90
91 for _, field := range f {
92 ret += templates.ToGo(field)
93 }
94 return ret
95 }
96
97
98 func (f Field) ToGoPrivate() string {
99 var ret string
100
101 for i, field := range f {
102 if i == 0 {
103 field = trimArgumentFromFieldName(field)
104 ret += templates.ToGoPrivate(field)
105 continue
106 }
107 ret += templates.ToGo(field)
108 }
109 return ret
110 }
111
112
113 func (f Field) Join(str string) string {
114 return strings.Join(f, str)
115 }
116
117
118 func (f Field) JoinGo(str string) string {
119 strs := []string{}
120
121 for _, s := range f {
122 strs = append(strs, templates.ToGo(s))
123 }
124 return strings.Join(strs, str)
125 }
126
127 func (f Field) LastIndex() int {
128 return len(f) - 1
129 }
130
131
132
133
134 func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
135 ret := Set{}
136
137 for _, s := range strings.Fields(raw) {
138 next := append(prefix[0:len(prefix):len(prefix)], s)
139 ret = append(ret, next)
140 }
141 return ret
142 }
143
144
145 func extractSubs(str string) (string, string, string) {
146 start := strings.Index(str, "{")
147 end := matchingBracketIndex(str, start)
148
149 if start < 0 || end < 0 {
150 panic("invalid key fieldSet: " + str)
151 }
152 return trimArgumentFromFieldName(strings.TrimSpace(str[:start])), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:])
153 }
154
155
156 func matchingBracketIndex(str string, start int) int {
157 if start < 0 || len(str) <= start+1 {
158 return -1
159 }
160 var depth int
161
162 for i, c := range str[start+1:] {
163 switch c {
164 case '{':
165 depth++
166 case '}':
167 if depth == 0 {
168 return start + 1 + i
169 }
170 depth--
171 }
172 }
173 return -1
174 }
175
176 func fieldByName(obj *codegen.Object, name string) *codegen.Field {
177 for _, field := range obj.Fields {
178 field.Name = trimArgumentFromFieldName(field.Name)
179 if field.Name == name {
180 return field
181 }
182 }
183 return nil
184 }
185
186
187
188 func trimArgumentFromFieldName(raw string) string {
189 return strings.Split(raw, "(")[0]
190 }
191
View as plain text