...
1 package extension
2
3 import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/v2/gqlerror"
8
9 "github.com/99designs/gqlgen/complexity"
10 "github.com/99designs/gqlgen/graphql"
11 "github.com/99designs/gqlgen/graphql/errcode"
12 )
13
14 const errComplexityLimit = "COMPLEXITY_LIMIT_EXCEEDED"
15
16
17
18
19 type ComplexityLimit struct {
20 Func func(ctx context.Context, rc *graphql.OperationContext) int
21
22 es graphql.ExecutableSchema
23 }
24
25 var _ interface {
26 graphql.OperationContextMutator
27 graphql.HandlerExtension
28 } = &ComplexityLimit{}
29
30 const complexityExtension = "ComplexityLimit"
31
32 type ComplexityStats struct {
33
34 Complexity int
35
36
37 ComplexityLimit int
38 }
39
40
41 func FixedComplexityLimit(limit int) *ComplexityLimit {
42 return &ComplexityLimit{
43 Func: func(ctx context.Context, rc *graphql.OperationContext) int {
44 return limit
45 },
46 }
47 }
48
49 func (c ComplexityLimit) ExtensionName() string {
50 return complexityExtension
51 }
52
53 func (c *ComplexityLimit) Validate(schema graphql.ExecutableSchema) error {
54 if c.Func == nil {
55 return fmt.Errorf("ComplexityLimit func can not be nil")
56 }
57 c.es = schema
58 return nil
59 }
60
61 func (c ComplexityLimit) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
62 op := rc.Doc.Operations.ForName(rc.OperationName)
63 complexityCalcs := complexity.Calculate(c.es, op, rc.Variables)
64
65 limit := c.Func(ctx, rc)
66
67 rc.Stats.SetExtension(complexityExtension, &ComplexityStats{
68 Complexity: complexityCalcs,
69 ComplexityLimit: limit,
70 })
71
72 if complexityCalcs > limit {
73 err := gqlerror.Errorf("operation has complexity %d, which exceeds the limit of %d", complexityCalcs, limit)
74 errcode.Set(err, errComplexityLimit)
75 return err
76 }
77
78 return nil
79 }
80
81 func GetComplexityStats(ctx context.Context) *ComplexityStats {
82 rc := graphql.GetOperationContext(ctx)
83 if rc == nil {
84 return nil
85 }
86
87 s, _ := rc.Stats.GetExtension(complexityExtension).(*ComplexityStats)
88 return s
89 }
90
View as plain text