...
1 package apollofederatedtracingv1
2
3 import (
4 "context"
5 "encoding/base64"
6 "fmt"
7
8 "google.golang.org/protobuf/proto"
9
10 "github.com/99designs/gqlgen/graphql"
11 )
12
13 type (
14 Tracer struct {
15 ClientName string
16 Version string
17 Hostname string
18 }
19
20 treeBuilderKey string
21 )
22
23 const (
24 key = treeBuilderKey("treeBuilder")
25 )
26
27 var _ interface {
28 graphql.HandlerExtension
29 graphql.ResponseInterceptor
30 graphql.FieldInterceptor
31 graphql.OperationInterceptor
32 } = &Tracer{}
33
34
35 func (Tracer) ExtensionName() string {
36 return "ApolloFederatedTracingV1"
37 }
38
39
40 func (Tracer) Validate(graphql.ExecutableSchema) error {
41 return nil
42 }
43
44 func (t *Tracer) shouldTrace(ctx context.Context) bool {
45 return graphql.HasOperationContext(ctx) &&
46 graphql.GetOperationContext(ctx).Headers.Get("apollo-federation-include-trace") == "ftv1"
47 }
48
49 func (t *Tracer) getTreeBuilder(ctx context.Context) *TreeBuilder {
50 val := ctx.Value(key)
51 if val == nil {
52 return nil
53 }
54 if tb, ok := val.(*TreeBuilder); ok {
55 return tb
56 }
57 return nil
58 }
59
60
61 func (t *Tracer) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
62 if !t.shouldTrace(ctx) {
63 return next(ctx)
64 }
65 return next(context.WithValue(ctx, key, NewTreeBuilder()))
66 }
67
68
69
70 func (t *Tracer) InterceptField(ctx context.Context, next graphql.Resolver) (interface{}, error) {
71 if !t.shouldTrace(ctx) {
72 return next(ctx)
73 }
74 if tb := t.getTreeBuilder(ctx); tb != nil {
75 tb.WillResolveField(ctx)
76 }
77
78 return next(ctx)
79 }
80
81
82
83 func (t *Tracer) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
84 if !t.shouldTrace(ctx) {
85 return next(ctx)
86 }
87 tb := t.getTreeBuilder(ctx)
88 if tb == nil {
89 return next(ctx)
90 }
91
92 tb.StartTimer(ctx)
93
94 val := new(string)
95 graphql.RegisterExtension(ctx, "ftv1", val)
96
97
98 defer func(val *string) {
99 tb.StopTimer(ctx)
100
101
102 p, err := proto.Marshal(tb.Trace)
103 if err != nil {
104 fmt.Print(err)
105 }
106
107
108 *val = base64.StdEncoding.EncodeToString(p)
109 }(val)
110 resp := next(ctx)
111 return resp
112 }
113
View as plain text