1
18
19
20 package main
21
22 import (
23 "context"
24 "flag"
25 "log"
26 "net"
27 "os"
28 "os/exec"
29 "syscall"
30 "time"
31
32 "golang.org/x/sys/unix"
33 "google.golang.org/grpc"
34 _ "google.golang.org/grpc/balancer/grpclb"
35 "google.golang.org/grpc/credentials"
36 "google.golang.org/grpc/credentials/alts"
37 "google.golang.org/grpc/credentials/google"
38 _ "google.golang.org/grpc/xds/googledirectpath"
39
40 testgrpc "google.golang.org/grpc/interop/grpc_testing"
41 testpb "google.golang.org/grpc/interop/grpc_testing"
42 )
43
44 var (
45 customCredentialsType = flag.String("custom_credentials_type", "", "Client creds to use")
46 serverURI = flag.String("server_uri", "dns:///staging-grpc-directpath-fallback-test.googleapis.com:443", "The server host name")
47 induceFallbackCmd = flag.String("induce_fallback_cmd", "", "Command to induce fallback e.g. by making certain addresses unroutable")
48 fallbackDeadlineSeconds = flag.Int("fallback_deadline_seconds", 1, "How long to wait for fallback to happen after induce_fallback_cmd")
49 testCase = flag.String("test_case", "",
50 `Configure different test cases. Valid options are:
51 fallback_before_startup : LB/backend connections fail before RPC's have been made;
52 fallback_after_startup : LB/backend connections fail after RPC's have been made;`)
53 infoLog = log.New(os.Stderr, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile)
54 errorLog = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
55 )
56
57 func doRPCAndGetPath(client testgrpc.TestServiceClient, timeout time.Duration) testpb.GrpclbRouteType {
58 infoLog.Printf("doRPCAndGetPath timeout:%v\n", timeout)
59 ctx, cancel := context.WithTimeout(context.Background(), timeout)
60 defer cancel()
61 req := &testpb.SimpleRequest{
62 FillGrpclbRouteType: true,
63 }
64 reply, err := client.UnaryCall(ctx, req)
65 if err != nil {
66 infoLog.Printf("doRPCAndGetPath error:%v\n", err)
67 return testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_UNKNOWN
68 }
69 g := reply.GetGrpclbRouteType()
70 infoLog.Printf("doRPCAndGetPath got grpclb route type: %v\n", g)
71 if g != testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_FALLBACK && g != testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_BACKEND {
72 errorLog.Fatalf("Expected grpclb route type to be either backend or fallback; got: %d", g)
73 }
74 return g
75 }
76
77 func dialTCPUserTimeout(ctx context.Context, addr string) (net.Conn, error) {
78 control := func(network, address string, c syscall.RawConn) error {
79 var syscallErr error
80 controlErr := c.Control(func(fd uintptr) {
81 syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, 20000)
82 })
83 if syscallErr != nil {
84 errorLog.Fatalf("syscall error setting sockopt TCP_USER_TIMEOUT: %v", syscallErr)
85 }
86 if controlErr != nil {
87 errorLog.Fatalf("control error setting sockopt TCP_USER_TIMEOUT: %v", syscallErr)
88 }
89 return nil
90 }
91 d := &net.Dialer{
92 Control: control,
93 }
94 return d.DialContext(ctx, "tcp", addr)
95 }
96
97 func createTestConn() *grpc.ClientConn {
98 opts := []grpc.DialOption{
99 grpc.WithContextDialer(dialTCPUserTimeout),
100 }
101 switch *customCredentialsType {
102 case "tls":
103 creds := credentials.NewClientTLSFromCert(nil, "")
104 opts = append(opts, grpc.WithTransportCredentials(creds))
105 case "alts":
106 creds := alts.NewClientCreds(alts.DefaultClientOptions())
107 opts = append(opts, grpc.WithTransportCredentials(creds))
108 case "google_default_credentials":
109 opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
110 case "compute_engine_channel_creds":
111 opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
112 default:
113 errorLog.Fatalf("Invalid --custom_credentials_type:%v", *customCredentialsType)
114 }
115 conn, err := grpc.Dial(*serverURI, opts...)
116 if err != nil {
117 errorLog.Fatalf("Fail to dial: %v", err)
118 }
119 return conn
120 }
121
122 func runCmd(command string) {
123 infoLog.Printf("Running cmd:|%v|\n", command)
124 if err := exec.Command("bash", "-c", command).Run(); err != nil {
125 errorLog.Fatalf("error running cmd:|%v| : %v", command, err)
126 }
127 }
128
129 func waitForFallbackAndDoRPCs(client testgrpc.TestServiceClient, fallbackDeadline time.Time) {
130 fallbackRetryCount := 0
131 fellBack := false
132 for time.Now().Before(fallbackDeadline) {
133 g := doRPCAndGetPath(client, 20*time.Second)
134 if g == testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_FALLBACK {
135 infoLog.Println("Made one successul RPC to a fallback. Now expect the same for the rest.")
136 fellBack = true
137 break
138 } else if g == testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_BACKEND {
139 errorLog.Fatalf("Got RPC type backend. This suggests an error in test implementation")
140 } else {
141 infoLog.Println("Retryable RPC failure on iteration:", fallbackRetryCount)
142 }
143 fallbackRetryCount++
144 }
145 if !fellBack {
146 infoLog.Fatalf("Didn't fall back before deadline: %v\n", fallbackDeadline)
147 }
148 for i := 0; i < 30; i++ {
149 if g := doRPCAndGetPath(client, 20*time.Second); g != testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_FALLBACK {
150 errorLog.Fatalf("Expected RPC to take grpclb route type FALLBACK. Got: %v", g)
151 }
152 time.Sleep(time.Second)
153 }
154 }
155
156 func doFallbackBeforeStartup() {
157 runCmd(*induceFallbackCmd)
158 fallbackDeadline := time.Now().Add(time.Duration(*fallbackDeadlineSeconds) * time.Second)
159 conn := createTestConn()
160 defer conn.Close()
161 client := testgrpc.NewTestServiceClient(conn)
162 waitForFallbackAndDoRPCs(client, fallbackDeadline)
163 }
164
165 func doFallbackAfterStartup() {
166 conn := createTestConn()
167 defer conn.Close()
168 client := testgrpc.NewTestServiceClient(conn)
169 if g := doRPCAndGetPath(client, 20*time.Second); g != testpb.GrpclbRouteType_GRPCLB_ROUTE_TYPE_BACKEND {
170 errorLog.Fatalf("Expected RPC to take grpclb route type BACKEND. Got: %v", g)
171 }
172 runCmd(*induceFallbackCmd)
173 fallbackDeadline := time.Now().Add(time.Duration(*fallbackDeadlineSeconds) * time.Second)
174 waitForFallbackAndDoRPCs(client, fallbackDeadline)
175 }
176
177 func main() {
178 flag.Parse()
179 if len(*induceFallbackCmd) == 0 {
180 errorLog.Fatalf("--induce_fallback_cmd unset")
181 }
182 switch *testCase {
183 case "fallback_before_startup":
184 doFallbackBeforeStartup()
185 log.Printf("FallbackBeforeStartup done!\n")
186 case "fallback_after_startup":
187 doFallbackAfterStartup()
188 log.Printf("FallbackAfterStartup done!\n")
189 default:
190 errorLog.Fatalf("Unsupported test case: %v", *testCase)
191 }
192 }
193
View as plain text