1
18
19
20
21
22
23
24 package alts
25
26 import (
27 "context"
28 "errors"
29 "fmt"
30 "net"
31 "sync"
32 "time"
33
34 "google.golang.org/grpc/credentials"
35 core "google.golang.org/grpc/credentials/alts/internal"
36 "google.golang.org/grpc/credentials/alts/internal/handshaker"
37 "google.golang.org/grpc/credentials/alts/internal/handshaker/service"
38 altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
39 "google.golang.org/grpc/grpclog"
40 "google.golang.org/grpc/internal/googlecloud"
41 )
42
43 const (
44
45
46 hypervisorHandshakerServiceAddress = "dns:///metadata.google.internal.:8080"
47
48 defaultTimeout = 30.0 * time.Second
49
50
51 protocolVersionMaxMajor = 2
52 protocolVersionMaxMinor = 1
53 protocolVersionMinMajor = 2
54 protocolVersionMinMinor = 1
55 )
56
57 var (
58 vmOnGCP bool
59 once sync.Once
60 maxRPCVersion = &altspb.RpcProtocolVersions_Version{
61 Major: protocolVersionMaxMajor,
62 Minor: protocolVersionMaxMinor,
63 }
64 minRPCVersion = &altspb.RpcProtocolVersions_Version{
65 Major: protocolVersionMinMajor,
66 Minor: protocolVersionMinMinor,
67 }
68
69
70
71 ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
72 logger = grpclog.Component("alts")
73 )
74
75
76
77
78
79
80 type AuthInfo interface {
81
82
83 ApplicationProtocol() string
84
85
86 RecordProtocol() string
87
88
89 SecurityLevel() altspb.SecurityLevel
90
91 PeerServiceAccount() string
92
93 LocalServiceAccount() string
94
95 PeerRPCVersions() *altspb.RpcProtocolVersions
96 }
97
98
99
100 type ClientOptions struct {
101
102
103 TargetServiceAccounts []string
104
105
106 HandshakerServiceAddress string
107 }
108
109
110
111 func DefaultClientOptions() *ClientOptions {
112 return &ClientOptions{
113 HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
114 }
115 }
116
117
118
119 type ServerOptions struct {
120
121
122 HandshakerServiceAddress string
123 }
124
125
126
127 func DefaultServerOptions() *ServerOptions {
128 return &ServerOptions{
129 HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
130 }
131 }
132
133
134
135 type altsTC struct {
136 info *credentials.ProtocolInfo
137 side core.Side
138 accounts []string
139 hsAddress string
140 }
141
142
143 func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
144 return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
145 }
146
147
148 func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
149 return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
150 }
151
152 func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
153 once.Do(func() {
154 vmOnGCP = googlecloud.OnGCE()
155 })
156 if hsAddress == "" {
157 hsAddress = hypervisorHandshakerServiceAddress
158 }
159 return &altsTC{
160 info: &credentials.ProtocolInfo{
161 SecurityProtocol: "alts",
162 SecurityVersion: "1.0",
163 },
164 side: side,
165 accounts: accounts,
166 hsAddress: hsAddress,
167 }
168 }
169
170
171 func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
172 if !vmOnGCP {
173 return nil, nil, ErrUntrustedPlatform
174 }
175
176
177 hsConn, err := service.Dial(g.hsAddress)
178 if err != nil {
179 return nil, nil, err
180 }
181
182
183
184
185
186 var cancel context.CancelFunc
187 ctx, cancel = context.WithCancel(ctx)
188 defer func() {
189 if err != nil {
190 cancel()
191 }
192 }()
193
194 opts := handshaker.DefaultClientHandshakerOptions()
195 opts.TargetName = addr
196 opts.TargetServiceAccounts = g.accounts
197 opts.RPCVersions = &altspb.RpcProtocolVersions{
198 MaxRpcVersion: maxRPCVersion,
199 MinRpcVersion: minRPCVersion,
200 }
201 chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
202 if err != nil {
203 return nil, nil, err
204 }
205 defer func() {
206 if err != nil {
207 chs.Close()
208 }
209 }()
210 secConn, authInfo, err := chs.ClientHandshake(ctx)
211 if err != nil {
212 return nil, nil, err
213 }
214 altsAuthInfo, ok := authInfo.(AuthInfo)
215 if !ok {
216 return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
217 }
218 match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
219 if !match {
220 return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
221 }
222 return secConn, authInfo, nil
223 }
224
225
226 func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
227 if !vmOnGCP {
228 return nil, nil, ErrUntrustedPlatform
229 }
230
231 hsConn, err := service.Dial(g.hsAddress)
232 if err != nil {
233 return nil, nil, err
234 }
235
236
237 ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
238 defer cancel()
239 opts := handshaker.DefaultServerHandshakerOptions()
240 opts.RPCVersions = &altspb.RpcProtocolVersions{
241 MaxRpcVersion: maxRPCVersion,
242 MinRpcVersion: minRPCVersion,
243 }
244 shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
245 if err != nil {
246 return nil, nil, err
247 }
248 defer func() {
249 if err != nil {
250 shs.Close()
251 }
252 }()
253 secConn, authInfo, err := shs.ServerHandshake(ctx)
254 if err != nil {
255 return nil, nil, err
256 }
257 altsAuthInfo, ok := authInfo.(AuthInfo)
258 if !ok {
259 return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
260 }
261 match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
262 if !match {
263 return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
264 }
265 return secConn, authInfo, nil
266 }
267
268 func (g *altsTC) Info() credentials.ProtocolInfo {
269 return *g.info
270 }
271
272 func (g *altsTC) Clone() credentials.TransportCredentials {
273 info := *g.info
274 var accounts []string
275 if g.accounts != nil {
276 accounts = make([]string, len(g.accounts))
277 copy(accounts, g.accounts)
278 }
279 return &altsTC{
280 info: &info,
281 side: g.side,
282 hsAddress: g.hsAddress,
283 accounts: accounts,
284 }
285 }
286
287 func (g *altsTC) OverrideServerName(serverNameOverride string) error {
288 g.info.ServerName = serverNameOverride
289 return nil
290 }
291
292
293 func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
294 switch {
295 case v1.GetMajor() > v2.GetMajor(),
296 v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
297 return 1
298 case v1.GetMajor() < v2.GetMajor(),
299 v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
300 return -1
301 }
302 return 0
303 }
304
305
306
307
308
309
310 func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
311 if local == nil || peer == nil {
312 logger.Error("invalid checkRPCVersions argument, either local or peer is nil.")
313 return false, nil
314 }
315
316
317 maxCommonVersion := local.GetMaxRpcVersion()
318 if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
319 maxCommonVersion = peer.GetMaxRpcVersion()
320 }
321
322
323 minCommonVersion := peer.GetMinRpcVersion()
324 if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
325 minCommonVersion = local.GetMinRpcVersion()
326 }
327
328 if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
329 return false, nil
330 }
331 return true, maxCommonVersion
332 }
333
View as plain text