1
18
19 package credentials
20
21 import (
22 "context"
23 "crypto/tls"
24 "net"
25 "strings"
26 "testing"
27 "time"
28
29 "google.golang.org/grpc/internal/grpctest"
30 "google.golang.org/grpc/testdata"
31 )
32
33 const defaultTestTimeout = 10 * time.Second
34
35 type s struct {
36 grpctest.Tester
37 }
38
39 func Test(t *testing.T) {
40 grpctest.RunSubTests(t, s{})
41 }
42
43
44 type testAuthInfoNoGetCommonAuthInfoMethod struct{}
45
46 func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string {
47 return "testAuthInfoNoGetCommonAuthInfoMethod"
48 }
49
50
51 type testAuthInfo struct {
52 CommonAuthInfo
53 }
54
55 func (ta testAuthInfo) AuthType() string {
56 return "testAuthInfo"
57 }
58
59 func (s) TestCheckSecurityLevel(t *testing.T) {
60 testCases := []struct {
61 authLevel SecurityLevel
62 testLevel SecurityLevel
63 want bool
64 }{
65 {
66 authLevel: PrivacyAndIntegrity,
67 testLevel: PrivacyAndIntegrity,
68 want: true,
69 },
70 {
71 authLevel: IntegrityOnly,
72 testLevel: PrivacyAndIntegrity,
73 want: false,
74 },
75 {
76 authLevel: IntegrityOnly,
77 testLevel: NoSecurity,
78 want: true,
79 },
80 {
81 authLevel: InvalidSecurityLevel,
82 testLevel: IntegrityOnly,
83 want: true,
84 },
85 {
86 authLevel: InvalidSecurityLevel,
87 testLevel: PrivacyAndIntegrity,
88 want: true,
89 },
90 }
91 for _, tc := range testCases {
92 err := CheckSecurityLevel(testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: tc.authLevel}}, tc.testLevel)
93 if tc.want && (err != nil) {
94 t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String())
95 } else if !tc.want && (err == nil) {
96 t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String())
97
98 }
99 }
100 }
101
102 func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) {
103 if err := CheckSecurityLevel(testAuthInfoNoGetCommonAuthInfoMethod{}, PrivacyAndIntegrity); err != nil {
104 t.Fatalf("CheckSeurityLevel() returned failure but want success")
105 }
106 }
107
108 func (s) TestTLSOverrideServerName(t *testing.T) {
109 expectedServerName := "server.name"
110 c := NewTLS(nil)
111 c.OverrideServerName(expectedServerName)
112 if c.Info().ServerName != expectedServerName {
113 t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
114 }
115 }
116
117 func (s) TestTLSClone(t *testing.T) {
118 expectedServerName := "server.name"
119 c := NewTLS(nil)
120 c.OverrideServerName(expectedServerName)
121 cc := c.Clone()
122 if cc.Info().ServerName != expectedServerName {
123 t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
124 }
125 cc.OverrideServerName("")
126 if c.Info().ServerName != expectedServerName {
127 t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
128 }
129
130 }
131
132 type serverHandshake func(net.Conn) (AuthInfo, error)
133
134 func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {
135 tcs := []struct {
136 name string
137 address string
138 }{
139 {
140 name: "localhost",
141 address: "localhost:0",
142 },
143 {
144 name: "ipv4",
145 address: "127.0.0.1:0",
146 },
147 {
148 name: "ipv6",
149 address: "[::1]:0",
150 },
151 }
152
153 for _, tc := range tcs {
154 t.Run(tc.name, func(t *testing.T) {
155 done := make(chan AuthInfo, 1)
156 lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
157 defer lis.Close()
158 lisAddr := lis.Addr().String()
159 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
160
161 serverAuthInfo, ok := <-done
162 if !ok {
163 t.Fatalf("Error at server-side")
164 }
165 if !compare(clientAuthInfo, serverAuthInfo) {
166 t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
167 }
168 })
169 }
170 }
171
172 func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {
173 done := make(chan AuthInfo, 1)
174 lis := launchServer(t, gRPCServerHandshake, done)
175 defer lis.Close()
176 clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
177
178 serverAuthInfo, ok := <-done
179 if !ok {
180 t.Fatalf("Error at server-side")
181 }
182 if !compare(clientAuthInfo, serverAuthInfo) {
183 t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
184 }
185 }
186
187 func (s) TestServerAndClientHandshake(t *testing.T) {
188 done := make(chan AuthInfo, 1)
189 lis := launchServer(t, gRPCServerHandshake, done)
190 defer lis.Close()
191 clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
192
193 serverAuthInfo, ok := <-done
194 if !ok {
195 t.Fatalf("Error at server-side")
196 }
197 if !compare(clientAuthInfo, serverAuthInfo) {
198 t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
199 }
200 }
201
202 func compare(a1, a2 AuthInfo) bool {
203 if a1.AuthType() != a2.AuthType() {
204 return false
205 }
206 switch a1.AuthType() {
207 case "tls":
208 state1 := a1.(TLSInfo).State
209 state2 := a2.(TLSInfo).State
210 if state1.Version == state2.Version &&
211 state1.HandshakeComplete == state2.HandshakeComplete &&
212 state1.CipherSuite == state2.CipherSuite &&
213 state1.NegotiatedProtocol == state2.NegotiatedProtocol {
214 return true
215 }
216 return false
217 default:
218 return false
219 }
220 }
221
222 func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
223 return launchServerOnListenAddress(t, hs, done, "localhost:0")
224 }
225
226 func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener {
227 lis, err := net.Listen("tcp", address)
228 if err != nil {
229 if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
230 strings.Contains(err.Error(), "socket: address family not supported by protocol") {
231 t.Skipf("no support for address %v", address)
232 }
233 t.Fatalf("Failed to listen: %v", err)
234 }
235 go serverHandle(t, hs, done, lis)
236 return lis
237 }
238
239
240 func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
241 serverRawConn, err := lis.Accept()
242 if err != nil {
243 t.Errorf("Server failed to accept connection: %v", err)
244 close(done)
245 return
246 }
247 serverAuthInfo, err := hs(serverRawConn)
248 if err != nil {
249 t.Errorf("Server failed while handshake. Error: %v", err)
250 serverRawConn.Close()
251 close(done)
252 return
253 }
254 done <- serverAuthInfo
255 }
256
257 func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
258 conn, err := net.Dial("tcp", lisAddr)
259 if err != nil {
260 t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
261 }
262 defer conn.Close()
263 clientAuthInfo, err := hs(conn, lisAddr)
264 if err != nil {
265 t.Fatalf("Error on client while handshake. Error: %v", err)
266 }
267 return clientAuthInfo
268 }
269
270
271 func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
272 serverTLS, err := NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
273 if err != nil {
274 return nil, err
275 }
276 _, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
277 if err != nil {
278 return nil, err
279 }
280 return serverAuthInfo, nil
281 }
282
283
284 func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
285 clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
286 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
287 defer cancel()
288 _, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)
289 if err != nil {
290 return nil, err
291 }
292 return authInfo, nil
293 }
294
295 func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
296 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
297 if err != nil {
298 return nil, err
299 }
300 serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
301 serverConn := tls.Server(conn, serverTLSConfig)
302 err = serverConn.Handshake()
303 if err != nil {
304 return nil, err
305 }
306 return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
307 }
308
309 func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
310 clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
311 clientConn := tls.Client(conn, clientTLSConfig)
312 if err := clientConn.Handshake(); err != nil {
313 return nil, err
314 }
315 return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil
316 }
317
View as plain text