1
18
19 package google
20
21 import (
22 "context"
23 "net"
24 "testing"
25
26 "google.golang.org/grpc/credentials"
27 icredentials "google.golang.org/grpc/internal/credentials"
28 "google.golang.org/grpc/internal/grpctest"
29 "google.golang.org/grpc/internal/xds"
30 "google.golang.org/grpc/resolver"
31 )
32
33 type s struct {
34 grpctest.Tester
35 }
36
37 func Test(t *testing.T) {
38 grpctest.RunSubTests(t, s{})
39 }
40
41 type testCreds struct {
42 credentials.TransportCredentials
43 typ string
44 }
45
46 func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
47 return nil, &testAuthInfo{typ: c.typ}, nil
48 }
49
50 func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
51 return nil, &testAuthInfo{typ: c.typ}, nil
52 }
53
54 type testAuthInfo struct {
55 typ string
56 }
57
58 func (t *testAuthInfo) AuthType() string {
59 return t.typ
60 }
61
62 var (
63 testTLS = &testCreds{typ: "tls"}
64 testALTS = &testCreds{typ: "alts"}
65 )
66
67 func overrideNewCredsFuncs() func() {
68 origNewTLS := newTLS
69 newTLS = func() credentials.TransportCredentials {
70 return testTLS
71 }
72 origNewALTS := newALTS
73 newALTS = func() credentials.TransportCredentials {
74 return testALTS
75 }
76 origNewADC := newADC
77 newADC = func(context.Context) (credentials.PerRPCCredentials, error) {
78
79 return nil, nil
80 }
81
82 return func() {
83 newTLS = origNewTLS
84 newALTS = origNewALTS
85 newADC = origNewADC
86 }
87 }
88
89
90
91
92 func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
93 defer overrideNewCredsFuncs()()
94 for bundleTyp, tc := range map[string]credentials.Bundle{
95 "defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
96 "defaultCreds": NewDefaultCredentials(),
97 "computeCreds": NewComputeEngineCredentials(),
98 } {
99 tests := []struct {
100 name string
101 ctx context.Context
102 wantTyp string
103 }{
104 {
105 name: "no cluster name",
106 ctx: context.Background(),
107 wantTyp: "tls",
108 },
109 {
110 name: "with non-CFE cluster name",
111 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
112 Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
113 }),
114
115 wantTyp: "alts",
116 },
117 {
118 name: "with CFE cluster name",
119 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
120 Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
121 }),
122
123 wantTyp: "tls",
124 },
125 {
126 name: "with xdstp CFE cluster name",
127 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
128 Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
129 }),
130
131 wantTyp: "tls",
132 },
133 {
134 name: "with xdstp non-CFE cluster name",
135 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
136 Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
137 }),
138
139 wantTyp: "alts",
140 },
141 }
142 for _, tt := range tests {
143 t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
144 _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
145 if err != nil {
146 t.Fatalf("ClientHandshake failed: %v", err)
147 }
148 if gotType := info.AuthType(); gotType != tt.wantTyp {
149 t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
150 }
151
152 _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
153 if err != nil {
154 t.Fatalf("ClientHandshake failed: %v", err)
155 }
156
157 if gotType := infoServer.AuthType(); gotType != "tls" {
158 t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
159 }
160 })
161 }
162 }
163 }
164
View as plain text