...
1 package extension
2
3 import (
4 "context"
5 "crypto/sha256"
6 "encoding/hex"
7 "fmt"
8
9 "github.com/mitchellh/mapstructure"
10 "github.com/vektah/gqlparser/v2/gqlerror"
11
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/99designs/gqlgen/graphql/errcode"
14 )
15
16 const (
17 errPersistedQueryNotFound = "PersistedQueryNotFound"
18 errPersistedQueryNotFoundCode = "PERSISTED_QUERY_NOT_FOUND"
19 )
20
21
22
23
24
25 type AutomaticPersistedQuery struct {
26 Cache graphql.Cache
27 }
28
29 type ApqStats struct {
30
31 Hash string
32
33
34 SentQuery bool
35 }
36
37 const apqExtension = "APQ"
38
39 var _ interface {
40 graphql.OperationParameterMutator
41 graphql.HandlerExtension
42 } = AutomaticPersistedQuery{}
43
44 func (a AutomaticPersistedQuery) ExtensionName() string {
45 return "AutomaticPersistedQuery"
46 }
47
48 func (a AutomaticPersistedQuery) Validate(schema graphql.ExecutableSchema) error {
49 if a.Cache == nil {
50 return fmt.Errorf("AutomaticPersistedQuery.Cache can not be nil")
51 }
52 return nil
53 }
54
55 func (a AutomaticPersistedQuery) MutateOperationParameters(ctx context.Context, rawParams *graphql.RawParams) *gqlerror.Error {
56 if rawParams.Extensions["persistedQuery"] == nil {
57 return nil
58 }
59
60 var extension struct {
61 Sha256 string `mapstructure:"sha256Hash"`
62 Version int64 `mapstructure:"version"`
63 }
64
65 if err := mapstructure.Decode(rawParams.Extensions["persistedQuery"], &extension); err != nil {
66 return gqlerror.Errorf("invalid APQ extension data")
67 }
68
69 if extension.Version != 1 {
70 return gqlerror.Errorf("unsupported APQ version")
71 }
72
73 fullQuery := false
74 if rawParams.Query == "" {
75
76 query, ok := a.Cache.Get(ctx, extension.Sha256)
77 if !ok {
78 err := gqlerror.Errorf(errPersistedQueryNotFound)
79 errcode.Set(err, errPersistedQueryNotFoundCode)
80 return err
81 }
82 rawParams.Query = query.(string)
83 } else {
84
85 if computeQueryHash(rawParams.Query) != extension.Sha256 {
86 return gqlerror.Errorf("provided APQ hash does not match query")
87 }
88 a.Cache.Add(ctx, extension.Sha256, rawParams.Query)
89 fullQuery = true
90 }
91
92 graphql.GetOperationContext(ctx).Stats.SetExtension(apqExtension, &ApqStats{
93 Hash: extension.Sha256,
94 SentQuery: fullQuery,
95 })
96
97 return nil
98 }
99
100 func GetApqStats(ctx context.Context) *ApqStats {
101 rc := graphql.GetOperationContext(ctx)
102 if rc == nil {
103 return nil
104 }
105
106 s, _ := rc.Stats.GetExtension(apqExtension).(*ApqStats)
107 return s
108 }
109
110 func computeQueryHash(query string) string {
111 b := sha256.Sum256([]byte(query))
112 return hex.EncodeToString(b[:])
113 }
114
View as plain text