1
18
19 package test
20
21 import (
22 "context"
23 "fmt"
24 "net"
25 "strings"
26 "testing"
27 "time"
28
29 "google.golang.org/grpc"
30 "google.golang.org/grpc/codes"
31 "google.golang.org/grpc/credentials"
32 "google.golang.org/grpc/credentials/insecure"
33 "google.golang.org/grpc/credentials/local"
34 "google.golang.org/grpc/internal/stubserver"
35 "google.golang.org/grpc/peer"
36 "google.golang.org/grpc/status"
37
38 testgrpc "google.golang.org/grpc/interop/grpc_testing"
39 testpb "google.golang.org/grpc/interop/grpc_testing"
40 )
41
42 func testLocalCredsE2ESucceed(network, address string) error {
43 ss := &stubserver.StubServer{
44 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
45 pr, ok := peer.FromContext(ctx)
46 if !ok {
47 return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
48 }
49 type internalInfo interface {
50 GetCommonAuthInfo() credentials.CommonAuthInfo
51 }
52 var secLevel credentials.SecurityLevel
53 if info, ok := (pr.AuthInfo).(internalInfo); ok {
54 secLevel = info.GetCommonAuthInfo().SecurityLevel
55 } else {
56 return nil, status.Errorf(codes.Unauthenticated, "peer.AuthInfo does not implement GetCommonAuthInfo()")
57 }
58
59 switch network {
60 case "unix":
61 if secLevel != credentials.PrivacyAndIntegrity {
62 return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.PrivacyAndIntegrity)
63 }
64 case "tcp":
65 if secLevel != credentials.NoSecurity {
66 return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity)
67 }
68 }
69 return &testpb.Empty{}, nil
70 },
71 }
72
73 sopts := []grpc.ServerOption{grpc.Creds(local.NewCredentials())}
74 s := grpc.NewServer(sopts...)
75 defer s.Stop()
76
77 testgrpc.RegisterTestServiceServer(s, ss)
78
79 lis, err := net.Listen(network, address)
80 if err != nil {
81 return fmt.Errorf("Failed to create listener: %v", err)
82 }
83
84 go s.Serve(lis)
85
86 var cc *grpc.ClientConn
87 lisAddr := lis.Addr().String()
88
89 switch network {
90 case "unix":
91 cc, err = grpc.Dial(lisAddr, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer(
92 func(ctx context.Context, addr string) (net.Conn, error) {
93 return net.Dial("unix", addr)
94 }))
95 case "tcp":
96 cc, err = grpc.NewClient(lisAddr, grpc.WithTransportCredentials(local.NewCredentials()))
97 default:
98 return fmt.Errorf("unsupported network %q", network)
99 }
100 if err != nil {
101 return fmt.Errorf("Failed to dial server: %v, %v", err, lisAddr)
102 }
103 defer cc.Close()
104
105 c := testgrpc.NewTestServiceClient(cc)
106 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
107 defer cancel()
108
109 if _, err = c.EmptyCall(ctx, &testpb.Empty{}); err != nil {
110 return fmt.Errorf("EmptyCall(_, _) = _, %v; want _, <nil>", err)
111 }
112 return nil
113 }
114
115 func (s) TestLocalCredsLocalhost(t *testing.T) {
116 if err := testLocalCredsE2ESucceed("tcp", "localhost:0"); err != nil {
117 t.Fatalf("Failed e2e test for localhost: %v", err)
118 }
119 }
120
121 func (s) TestLocalCredsUDS(t *testing.T) {
122 addr := fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano())
123 if err := testLocalCredsE2ESucceed("unix", addr); err != nil {
124 t.Fatalf("Failed e2e test for UDS: %v", err)
125 }
126 }
127
128 type connWrapper struct {
129 net.Conn
130 remote net.Addr
131 }
132
133 func (c connWrapper) RemoteAddr() net.Addr {
134 return c.remote
135 }
136
137 type lisWrapper struct {
138 net.Listener
139 remote net.Addr
140 }
141
142 func spoofListener(l net.Listener, remote net.Addr) net.Listener {
143 return &lisWrapper{l, remote}
144 }
145
146 func (l *lisWrapper) Accept() (net.Conn, error) {
147 c, err := l.Listener.Accept()
148 if err != nil {
149 return nil, err
150 }
151 return connWrapper{c, l.remote}, nil
152 }
153
154 func spoofDialer(addr net.Addr) func(target string, t time.Duration) (net.Conn, error) {
155 return func(t string, d time.Duration) (net.Conn, error) {
156 c, err := net.DialTimeout("tcp", t, d)
157 if err != nil {
158 return nil, err
159 }
160 return connWrapper{c, addr}, nil
161 }
162 }
163
164 func testLocalCredsE2EFail(dopts []grpc.DialOption) error {
165 ss := &stubserver.StubServer{
166 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
167 return &testpb.Empty{}, nil
168 },
169 }
170
171 sopts := []grpc.ServerOption{grpc.Creds(local.NewCredentials())}
172 s := grpc.NewServer(sopts...)
173 defer s.Stop()
174
175 testgrpc.RegisterTestServiceServer(s, ss)
176
177 lis, err := net.Listen("tcp", "localhost:0")
178 if err != nil {
179 return fmt.Errorf("Failed to create listener: %v", err)
180 }
181
182 var fakeClientAddr, fakeServerAddr net.Addr
183 fakeClientAddr = &net.IPAddr{
184 IP: net.ParseIP("10.8.9.10"),
185 Zone: "",
186 }
187 fakeServerAddr = &net.IPAddr{
188 IP: net.ParseIP("10.8.9.11"),
189 Zone: "",
190 }
191
192 go s.Serve(spoofListener(lis, fakeClientAddr))
193
194 cc, err := grpc.NewClient(lis.Addr().String(), append(dopts, grpc.WithDialer(spoofDialer(fakeServerAddr)))...)
195 if err != nil {
196 return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String())
197 }
198 defer cc.Close()
199
200 c := testgrpc.NewTestServiceClient(cc)
201 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
202 defer cancel()
203
204 _, err = c.EmptyCall(ctx, &testpb.Empty{})
205 return err
206 }
207
208 func isExpected(got, want error) bool {
209 return status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message())
210 }
211
212 func (s) TestLocalCredsClientFail(t *testing.T) {
213
214 opts := []grpc.DialOption{grpc.WithTransportCredentials(local.NewCredentials())}
215 want := status.Error(codes.Unavailable, "transport: authentication handshake failed: local credentials rejected connection to non-local address")
216 if err := testLocalCredsE2EFail(opts); !isExpected(err, want) {
217 t.Fatalf("testLocalCredsE2EFail() = %v; want %v", err, want)
218 }
219 }
220
221 func (s) TestLocalCredsServerFail(t *testing.T) {
222
223 opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
224 if err := testLocalCredsE2EFail(opts); status.Code(err) != codes.Unavailable {
225 t.Fatalf("testLocalCredsE2EFail() = %v; want %v", err, codes.Unavailable)
226 }
227 }
228
View as plain text