1 package codegen
2
3 import (
4 "fmt"
5 "strconv"
6 "strings"
7
8 "github.com/vektah/gqlparser/v2/ast"
9
10 "github.com/99designs/gqlgen/codegen/templates"
11 )
12
13 type DirectiveList map[string]*Directive
14
15
16 func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
17 return locationDirectives(dl, ast.DirectiveLocation(location))
18 }
19
20 type Directive struct {
21 *ast.DirectiveDefinition
22 Name string
23 Args []*FieldArgument
24 Builtin bool
25 }
26
27
28 func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
29 for _, l := range d.Locations {
30 for _, a := range location {
31 if l == a {
32 return true
33 }
34 }
35 }
36
37 return false
38 }
39
40 func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
41 mDirectives := make(map[string]*Directive)
42 for name, d := range directives {
43 if d.IsLocation(location...) {
44 mDirectives[name] = d
45 }
46 }
47 return mDirectives
48 }
49
50 func (b *builder) buildDirectives() (map[string]*Directive, error) {
51 directives := make(map[string]*Directive, len(b.Schema.Directives))
52
53 for name, dir := range b.Schema.Directives {
54 if _, ok := directives[name]; ok {
55 return nil, fmt.Errorf("directive with name %s already exists", name)
56 }
57
58 var args []*FieldArgument
59 for _, arg := range dir.Arguments {
60 tr, err := b.Binder.TypeReference(arg.Type, nil)
61 if err != nil {
62 return nil, err
63 }
64
65 newArg := &FieldArgument{
66 ArgumentDefinition: arg,
67 TypeReference: tr,
68 VarName: templates.ToGoPrivate(arg.Name),
69 }
70
71 if arg.DefaultValue != nil {
72 var err error
73 newArg.Default, err = arg.DefaultValue.Value(nil)
74 if err != nil {
75 return nil, fmt.Errorf("default value for directive argument %s(%s) is not valid: %w", dir.Name, arg.Name, err)
76 }
77 }
78 args = append(args, newArg)
79 }
80
81 directives[name] = &Directive{
82 DirectiveDefinition: dir,
83 Name: name,
84 Args: args,
85 Builtin: b.Config.Directives[name].SkipRuntime,
86 }
87 }
88
89 return directives, nil
90 }
91
92 func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
93 dirs := make([]*Directive, len(list))
94 for i, d := range list {
95 argValues := make(map[string]interface{}, len(d.Arguments))
96 for _, da := range d.Arguments {
97 val, err := da.Value.Value(nil)
98 if err != nil {
99 return nil, err
100 }
101 argValues[da.Name] = val
102 }
103 def, ok := b.Directives[d.Name]
104 if !ok {
105 return nil, fmt.Errorf("directive %s not found", d.Name)
106 }
107
108 var args []*FieldArgument
109 for _, a := range def.Args {
110 value := a.Default
111 if argValue, ok := argValues[a.Name]; ok {
112 value = argValue
113 }
114 args = append(args, &FieldArgument{
115 ArgumentDefinition: a.ArgumentDefinition,
116 Value: value,
117 VarName: a.VarName,
118 TypeReference: a.TypeReference,
119 })
120 }
121 dirs[i] = &Directive{
122 Name: d.Name,
123 Args: args,
124 DirectiveDefinition: list[i].Definition,
125 Builtin: b.Config.Directives[d.Name].SkipRuntime,
126 }
127
128 }
129
130 return dirs, nil
131 }
132
133 func (d *Directive) ArgsFunc() string {
134 if len(d.Args) == 0 {
135 return ""
136 }
137
138 return "dir_" + d.Name + "_args"
139 }
140
141 func (d *Directive) CallArgs() string {
142 args := []string{"ctx", "obj", "n"}
143
144 for _, arg := range d.Args {
145 args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
146 }
147
148 return strings.Join(args, ", ")
149 }
150
151 func (d *Directive) ResolveArgs(obj string, next int) string {
152 args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
153
154 for _, arg := range d.Args {
155 dArg := arg.VarName
156 if arg.Value == nil && arg.Default == nil {
157 dArg = "nil"
158 }
159
160 args = append(args, dArg)
161 }
162
163 return strings.Join(args, ", ")
164 }
165
166 func (d *Directive) Declaration() string {
167 res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
168
169 for _, arg := range d.Args {
170 res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO))
171 }
172
173 res += ") (res interface{}, err error)"
174 return res
175 }
176
View as plain text