1
18
19 package local
20
21 import (
22 "context"
23 "fmt"
24 "net"
25 "runtime"
26 "strings"
27 "testing"
28 "time"
29
30 "google.golang.org/grpc/credentials"
31 "google.golang.org/grpc/internal/grpctest"
32 )
33
34 const defaultTestTimeout = 10 * time.Second
35
36 type s struct {
37 grpctest.Tester
38 }
39
40 func Test(t *testing.T) {
41 grpctest.RunSubTests(t, s{})
42 }
43
44 func (s) TestGetSecurityLevel(t *testing.T) {
45 testCases := []struct {
46 testNetwork string
47 testAddr string
48 want credentials.SecurityLevel
49 }{
50 {
51 testNetwork: "tcp",
52 testAddr: "127.0.0.1:10000",
53 want: credentials.NoSecurity,
54 },
55 {
56 testNetwork: "tcp",
57 testAddr: "[::1]:10000",
58 want: credentials.NoSecurity,
59 },
60 {
61 testNetwork: "unix",
62 testAddr: "/tmp/grpc_fullstack_test",
63 want: credentials.PrivacyAndIntegrity,
64 },
65 {
66 testNetwork: "tcp",
67 testAddr: "192.168.0.1:10000",
68 want: credentials.InvalidSecurityLevel,
69 },
70 }
71 for _, tc := range testCases {
72 got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr)
73 if got != tc.want {
74 t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String())
75 }
76 }
77 }
78
79 type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
80
81 func getSecurityLevelFromAuthInfo(ai credentials.AuthInfo) credentials.SecurityLevel {
82 if c, ok := ai.(interface {
83 GetCommonAuthInfo() credentials.CommonAuthInfo
84 }); ok {
85 return c.GetCommonAuthInfo().SecurityLevel
86 }
87 return credentials.InvalidSecurityLevel
88 }
89
90
91 func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
92 cred := NewCredentials()
93 _, authInfo, err := cred.ServerHandshake(conn)
94 if err != nil {
95 return nil, err
96 }
97 return authInfo, nil
98 }
99
100
101 func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
102 cred := NewCredentials()
103 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
104 defer cancel()
105
106 _, authInfo, err := cred.ClientHandshake(ctx, lisAddr, conn)
107 if err != nil {
108 return nil, err
109 }
110 return authInfo, nil
111 }
112
113
114 func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) {
115 conn, _ := net.Dial(network, lisAddr)
116 defer conn.Close()
117 clientAuthInfo, err := hs(conn, lisAddr)
118 if err != nil {
119 return nil, fmt.Errorf("Error on client while handshake")
120 }
121 return clientAuthInfo, nil
122 }
123
124 type testServerHandleResult struct {
125 authInfo credentials.AuthInfo
126 err error
127 }
128
129
130 func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) {
131 serverRawConn, err := lis.Accept()
132 if err != nil {
133 done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)}
134 return
135 }
136 serverAuthInfo, err := hs(serverRawConn)
137 if err != nil {
138 serverRawConn.Close()
139 done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)}
140 return
141 }
142 done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil}
143 }
144
145 func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) {
146 done := make(chan testServerHandleResult, 1)
147 const timeout = 5 * time.Second
148 timer := time.NewTimer(timeout)
149 defer timer.Stop()
150 go serverHandle(serverLocalHandshake, done, lis)
151 defer lis.Close()
152 clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
153 if err != nil {
154 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: %v", err)
155 }
156 select {
157 case <-timer.C:
158 return credentials.InvalidSecurityLevel, fmt.Errorf("Test didn't finish in time")
159 case serverHandleResult := <-done:
160 if serverHandleResult.err != nil {
161 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
162 }
163 clientSecLevel := getSecurityLevelFromAuthInfo(clientAuthInfo)
164 serverSecLevel := getSecurityLevelFromAuthInfo(serverHandleResult.authInfo)
165
166 if clientSecLevel == credentials.InvalidSecurityLevel {
167 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at client-side: client's AuthInfo does not implement GetCommonAuthInfo()")
168 }
169 if serverSecLevel == credentials.InvalidSecurityLevel {
170 return credentials.InvalidSecurityLevel, fmt.Errorf("Error at server-side: server's AuthInfo does not implement GetCommonAuthInfo()")
171 }
172 if clientSecLevel != serverSecLevel {
173 return credentials.InvalidSecurityLevel, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
174 }
175 return clientSecLevel, nil
176 }
177 }
178
179 func (s) TestServerAndClientHandshake(t *testing.T) {
180 testCases := []struct {
181 testNetwork string
182 testAddr string
183 want credentials.SecurityLevel
184 }{
185 {
186 testNetwork: "tcp",
187 testAddr: "127.0.0.1:0",
188 want: credentials.NoSecurity,
189 },
190 {
191 testNetwork: "tcp",
192 testAddr: "[::1]:0",
193 want: credentials.NoSecurity,
194 },
195 {
196 testNetwork: "tcp",
197 testAddr: "localhost:0",
198 want: credentials.NoSecurity,
199 },
200 {
201 testNetwork: "unix",
202 testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()),
203 want: credentials.PrivacyAndIntegrity,
204 },
205 }
206 for _, tc := range testCases {
207 if runtime.GOOS == "windows" && tc.testNetwork == "unix" {
208 t.Skip("skipping tests for unix connections on Windows")
209 }
210 t.Run("serverAndClientHandshakeResult", func(t *testing.T) {
211 lis, err := net.Listen(tc.testNetwork, tc.testAddr)
212 if err != nil {
213 if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
214 strings.Contains(err.Error(), "socket: address family not supported by protocol") {
215 t.Skipf("no support for address %v", tc.testAddr)
216 }
217 t.Fatalf("Failed to listen: %v", err)
218 }
219 got, err := serverAndClientHandshake(lis)
220 if got != tc.want {
221 t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want)
222 }
223 })
224 }
225 }
226
View as plain text