1
2
3
4
5 package grpc
6
7 import (
8 "bytes"
9 "context"
10 "log"
11 "strings"
12 "testing"
13
14 "cloud.google.com/go/compute/metadata"
15 "github.com/google/go-cmp/cmp"
16 "golang.org/x/oauth2/google"
17 "google.golang.org/api/internal"
18 "google.golang.org/grpc"
19 )
20
21 func TestDial(t *testing.T) {
22 oldDialContext := dialContext
23
24 dialContext = func(ctxGot context.Context, target string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
25 if len(opts) != 4 {
26 t.Fatalf("got: %d, want: 4", len(opts))
27 }
28 return nil, nil
29 }
30 defer func() {
31 dialContext = oldDialContext
32 }()
33
34 var o internal.DialSettings
35 dial(context.Background(), false, &o)
36 }
37
38 func TestCheckDirectPathEndPoint(t *testing.T) {
39 for _, testcase := range []struct {
40 name string
41 endpoint string
42 want bool
43 }{
44 {
45 name: "empty endpoint are disallowed",
46 endpoint: "",
47 want: false,
48 },
49 {
50 name: "dns schemes are allowed",
51 endpoint: "dns:///foo",
52 want: true,
53 },
54 {
55 name: "host without no prefix are allowed",
56 endpoint: "foo",
57 want: true,
58 },
59 {
60 name: "host with port are allowed",
61 endpoint: "foo:1234",
62 want: true,
63 },
64 {
65 name: "non-dns schemes are disallowed",
66 endpoint: "https://foo",
67 want: false,
68 },
69 } {
70 t.Run(testcase.name, func(t *testing.T) {
71 if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
72 t.Fatalf("got %v, want %v", got, testcase.want)
73 }
74 })
75 }
76 }
77
78 func TestLogDirectPathMisconfigAttrempDirectPathNotSet(t *testing.T) {
79 o := &internal.DialSettings{}
80 o.EnableDirectPathXds = true
81
82 endpoint := "abc.googleapis.com"
83
84 creds, err := internal.Creds(context.Context(context.Background()), o)
85 if err != nil {
86 t.Fatalf("failed to create creds")
87 }
88
89 buf := bytes.Buffer{}
90 log.SetOutput(&buf)
91
92 logDirectPathMisconfig(endpoint, creds.TokenSource, o)
93
94 wantedLog := "WARNING: DirectPath is misconfigured. Please set the EnableDirectPath option along with the EnableDirectPathXds option."
95 if !strings.Contains(buf.String(), wantedLog) {
96 t.Fatalf("got: %v, want: %v", buf.String(), wantedLog)
97 }
98
99 }
100
101 func TestLogDirectPathMisconfigWrongCredential(t *testing.T) {
102 o := &internal.DialSettings{}
103 o.EnableDirectPath = true
104 o.EnableDirectPathXds = true
105
106 endpoint := "abc.googleapis.com"
107
108 creds := &google.Credentials{}
109
110 buf := bytes.Buffer{}
111 log.SetOutput(&buf)
112
113 logDirectPathMisconfig(endpoint, creds.TokenSource, o)
114
115 wantedLog := "WARNING: DirectPath is misconfigured. Please make sure the token source is fetched from GCE metadata server and the default service account is used."
116 if !strings.Contains(buf.String(), wantedLog) {
117 t.Fatalf("got: %v, want: %v", buf.String(), wantedLog)
118 }
119
120 }
121
122 func TestLogDirectPathMisconfigNotOnGCE(t *testing.T) {
123 o := &internal.DialSettings{}
124 o.EnableDirectPath = true
125 o.EnableDirectPathXds = true
126
127 endpoint := "abc.googleapis.com"
128
129 creds, err := internal.Creds(context.Context(context.Background()), o)
130 if err != nil {
131 t.Fatalf("failed to create creds")
132 }
133
134 buf := bytes.Buffer{}
135 log.SetOutput(&buf)
136
137 logDirectPathMisconfig(endpoint, creds.TokenSource, o)
138
139 if !metadata.OnGCE() {
140 wantedLog := "WARNING: DirectPath is misconfigured. DirectPath is only available in a GCE environment."
141 if !strings.Contains(buf.String(), wantedLog) {
142 t.Fatalf("got: %v, want: %v", buf.String(), wantedLog)
143 }
144 }
145
146 }
147
148 func TestGRPCAPIKey_GetRequestMetadata(t *testing.T) {
149 for _, test := range []struct {
150 apiKey string
151 reason string
152 }{
153 {
154 apiKey: "MY_API_KEY",
155 reason: "MY_REQUEST_REASON",
156 },
157 } {
158 ts := grpcAPIKey{
159 apiKey: test.apiKey,
160 requestReason: test.reason,
161 }
162 got, err := ts.GetRequestMetadata(context.Background())
163 if err != nil {
164 t.Fatal(err)
165 }
166 want := map[string]string{
167 "X-goog-api-key": ts.apiKey,
168 "X-goog-request-reason": ts.requestReason,
169 }
170 if diff := cmp.Diff(want, got); diff != "" {
171 t.Errorf("mismatch (-want +got):\n%s", diff)
172 }
173 }
174 }
175
View as plain text