1 package graphql
2
3 import (
4 "context"
5 "testing"
6 "time"
7
8 "github.com/stretchr/testify/require"
9 "github.com/vektah/gqlparser/v2/ast"
10 )
11
12
13 type testGraphRequestContext struct {
14 opContext *OperationContext
15 }
16
17 func (t *testGraphRequestContext) Deadline() (deadline time.Time, ok bool) {
18 return time.Time{}, false
19 }
20
21 func (t *testGraphRequestContext) Done() <-chan struct{} {
22 return nil
23 }
24
25 func (t *testGraphRequestContext) Err() error {
26 return nil
27 }
28
29 func (t *testGraphRequestContext) Value(key interface{}) interface{} {
30 return t.opContext
31 }
32
33 func TestGetOperationContext(t *testing.T) {
34 rc := &OperationContext{}
35
36 t.Run("with operation context", func(t *testing.T) {
37 ctx := WithOperationContext(context.Background(), rc)
38
39 require.True(t, HasOperationContext(ctx))
40 require.Equal(t, rc, GetOperationContext(ctx))
41 })
42
43 t.Run("without operation context", func(t *testing.T) {
44 ctx := context.Background()
45
46 require.False(t, HasOperationContext(ctx))
47 require.Panics(t, func() {
48 GetOperationContext(ctx)
49 })
50 })
51
52 t.Run("with nil operation context", func(t *testing.T) {
53 ctx := &testGraphRequestContext{opContext: nil}
54
55 require.False(t, HasOperationContext(ctx))
56 require.Panics(t, func() {
57 GetOperationContext(ctx)
58 })
59 })
60 }
61
62 func TestCollectAllFields(t *testing.T) {
63 t.Run("collect fields", func(t *testing.T) {
64 ctx := testContext(ast.SelectionSet{
65 &ast.Field{
66 Name: "field",
67 },
68 })
69 s := CollectAllFields(ctx)
70 require.Equal(t, []string{"field"}, s)
71 })
72
73 t.Run("unique field names", func(t *testing.T) {
74 ctx := testContext(ast.SelectionSet{
75 &ast.Field{
76 Name: "field",
77 },
78 &ast.Field{
79 Name: "field",
80 Alias: "field alias",
81 },
82 })
83 s := CollectAllFields(ctx)
84 require.Equal(t, []string{"field"}, s)
85 })
86
87 t.Run("collect fragments", func(t *testing.T) {
88 ctx := testContext(ast.SelectionSet{
89 &ast.Field{
90 Name: "fieldA",
91 },
92 &ast.InlineFragment{
93 TypeCondition: "ExampleTypeA",
94 SelectionSet: ast.SelectionSet{
95 &ast.Field{
96 Name: "fieldA",
97 },
98 },
99 },
100 &ast.InlineFragment{
101 TypeCondition: "ExampleTypeB",
102 SelectionSet: ast.SelectionSet{
103 &ast.Field{
104 Name: "fieldB",
105 },
106 },
107 },
108 })
109 s := CollectAllFields(ctx)
110 require.Equal(t, []string{"fieldA", "fieldB"}, s)
111 })
112
113 t.Run("collect fragments with same field name on different types", func(t *testing.T) {
114 ctx := testContext(ast.SelectionSet{
115 &ast.InlineFragment{
116 TypeCondition: "ExampleTypeA",
117 SelectionSet: ast.SelectionSet{
118 &ast.Field{
119 Name: "fieldA",
120 ObjectDefinition: &ast.Definition{Name: "ExampleTypeA"},
121 },
122 },
123 },
124 &ast.InlineFragment{
125 TypeCondition: "ExampleTypeB",
126 SelectionSet: ast.SelectionSet{
127 &ast.Field{
128 Name: "fieldA",
129 ObjectDefinition: &ast.Definition{Name: "ExampleTypeB"},
130 },
131 },
132 },
133 })
134 resCtx := GetFieldContext(ctx)
135 collected := CollectFields(GetOperationContext(ctx), resCtx.Field.Selections, nil)
136 require.Len(t, collected, 2)
137 require.NotEqual(t, collected[0], collected[1])
138 require.Equal(t, collected[0].Name, collected[1].Name)
139 })
140
141 t.Run("collect fragments with same field name and different alias", func(t *testing.T) {
142 ctx := testContext(ast.SelectionSet{
143 &ast.InlineFragment{
144 TypeCondition: "ExampleTypeA",
145 SelectionSet: ast.SelectionSet{
146 &ast.Field{
147 Name: "fieldA",
148 Alias: "fieldA",
149 ObjectDefinition: &ast.Definition{Name: "ExampleTypeA"},
150 },
151 &ast.Field{
152 Name: "fieldA",
153 Alias: "fieldA Alias",
154 ObjectDefinition: &ast.Definition{Name: "ExampleTypeA"},
155 },
156 },
157 ObjectDefinition: &ast.Definition{Name: "ExampleType", Kind: ast.Interface},
158 },
159 })
160 resCtx := GetFieldContext(ctx)
161 collected := CollectFields(GetOperationContext(ctx), resCtx.Field.Selections, nil)
162 require.Len(t, collected, 2)
163 require.NotEqual(t, collected[0], collected[1])
164 require.Equal(t, collected[0].Name, collected[1].Name)
165 require.NotEqual(t, collected[0].Alias, collected[1].Alias)
166 })
167 }
168
View as plain text