1
18
19
20 package fallback
21
22 import (
23 "context"
24 "crypto/tls"
25 "fmt"
26 "net"
27
28 "google.golang.org/grpc/credentials"
29 "google.golang.org/grpc/grpclog"
30 )
31
32 const (
33 alpnProtoStrH2 = "h2"
34 alpnProtoStrHTTP = "http/1.1"
35 defaultHTTPSPort = "443"
36 )
37
38
39
40 var FallbackTLSConfigGRPC = tls.Config{
41 MinVersion: tls.VersionTLS13,
42 ClientSessionCache: nil,
43 NextProtos: []string{alpnProtoStrH2},
44 }
45
46
47
48 var FallbackTLSConfigHTTP = tls.Config{
49 MinVersion: tls.VersionTLS13,
50 ClientSessionCache: nil,
51 NextProtos: []string{alpnProtoStrH2, alpnProtoStrHTTP},
52 }
53
54
55
56
57
58
59
60
61 type ClientHandshake func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 func DefaultFallbackClientHandshakeFunc(fallbackAddr string) (ClientHandshake, error) {
80 var fallbackDialer = tls.Dialer{Config: &FallbackTLSConfigGRPC}
81 return defaultFallbackClientHandshakeFuncInternal(fallbackAddr, fallbackDialer.DialContext)
82 }
83
84 func defaultFallbackClientHandshakeFuncInternal(fallbackAddr string, dialContextFunc func(context.Context, string, string) (net.Conn, error)) (ClientHandshake, error) {
85 fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
86 if err != nil {
87 if grpclog.V(1) {
88 grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
89 }
90 return nil, err
91 }
92 return func(ctx context.Context, targetServer string, conn net.Conn, s2aErr error) (net.Conn, credentials.AuthInfo, error) {
93 fbConn, fbErr := dialContextFunc(ctx, "tcp", fallbackServerAddr)
94 if fbErr != nil {
95 grpclog.Infof("dialing to fallback server %s failed: %v", fallbackServerAddr, fbErr)
96 return nil, nil, fmt.Errorf("dialing to fallback server %s failed: %v; S2A client handshake with %s error: %w", fallbackServerAddr, fbErr, targetServer, s2aErr)
97 }
98
99 tc, success := fbConn.(*tls.Conn)
100 if !success {
101 grpclog.Infof("the connection with fallback server is expected to be tls but isn't")
102 return nil, nil, fmt.Errorf("the connection with fallback server is expected to be tls but isn't; S2A client handshake with %s error: %w", targetServer, s2aErr)
103 }
104
105 tlsInfo := credentials.TLSInfo{
106 State: tc.ConnectionState(),
107 CommonAuthInfo: credentials.CommonAuthInfo{
108 SecurityLevel: credentials.PrivacyAndIntegrity,
109 },
110 }
111 if grpclog.V(1) {
112 grpclog.Infof("ConnectionState.NegotiatedProtocol: %v", tc.ConnectionState().NegotiatedProtocol)
113 grpclog.Infof("ConnectionState.HandshakeComplete: %v", tc.ConnectionState().HandshakeComplete)
114 grpclog.Infof("ConnectionState.ServerName: %v", tc.ConnectionState().ServerName)
115 }
116 conn.Close()
117 return fbConn, tlsInfo, nil
118 }, nil
119 }
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140 func DefaultFallbackDialerAndAddress(fallbackAddr string) (*tls.Dialer, string, error) {
141 fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
142 if err != nil {
143 if grpclog.V(1) {
144 grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
145 }
146 return nil, "", err
147 }
148 return &tls.Dialer{Config: &FallbackTLSConfigHTTP}, fallbackServerAddr, nil
149 }
150
151 func processFallbackAddr(fallbackAddr string) (string, error) {
152 var fallbackServerAddr string
153 var err error
154
155 if fallbackAddr == "" {
156 return "", fmt.Errorf("empty fallback address")
157 }
158 _, _, err = net.SplitHostPort(fallbackAddr)
159 if err != nil {
160
161 fallbackServerAddr = net.JoinHostPort(fallbackAddr, defaultHTTPSPort)
162 } else {
163
164 fallbackServerAddr = fallbackAddr
165 }
166 return fallbackServerAddr, nil
167 }
168
View as plain text