1
18
19
20 package xds
21
22 import (
23 "context"
24 "crypto/tls"
25 "crypto/x509"
26 "errors"
27 "fmt"
28 "strings"
29 "unsafe"
30
31 "google.golang.org/grpc/attributes"
32 "google.golang.org/grpc/credentials/tls/certprovider"
33 "google.golang.org/grpc/internal"
34 "google.golang.org/grpc/internal/xds/matcher"
35 "google.golang.org/grpc/resolver"
36 )
37
38 func init() {
39 internal.GetXDSHandshakeInfoForTesting = GetHandshakeInfo
40 }
41
42
43
44 type handshakeAttrKey struct{}
45
46
47 func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool {
48 if hi == nil && other == nil {
49 return true
50 }
51 if hi == nil || other == nil {
52 return false
53 }
54 if hi.rootProvider != other.rootProvider ||
55 hi.identityProvider != other.identityProvider ||
56 hi.requireClientCert != other.requireClientCert ||
57 len(hi.sanMatchers) != len(other.sanMatchers) {
58 return false
59 }
60 for i := range hi.sanMatchers {
61 if !hi.sanMatchers[i].Equal(other.sanMatchers[i]) {
62 return false
63 }
64 }
65 return true
66 }
67
68
69
70 func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address {
71 addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr)
72 return addr
73 }
74
75
76 func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer {
77 v := attr.Value(handshakeAttrKey{})
78 hi, _ := v.(*unsafe.Pointer)
79 return hi
80 }
81
82
83
84
85 type HandshakeInfo struct {
86
87
88 rootProvider certprovider.Provider
89 identityProvider certprovider.Provider
90 sanMatchers []matcher.StringMatcher
91 requireClientCert bool
92 }
93
94
95
96 func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo {
97 return &HandshakeInfo{
98 rootProvider: rootProvider,
99 identityProvider: identityProvider,
100 sanMatchers: sanMatchers,
101 requireClientCert: requireClientCert,
102 }
103 }
104
105
106
107 func (hi *HandshakeInfo) UseFallbackCreds() bool {
108 if hi == nil {
109 return true
110 }
111 return hi.identityProvider == nil && hi.rootProvider == nil
112 }
113
114
115
116 func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher {
117 return append([]matcher.StringMatcher{}, hi.sanMatchers...)
118 }
119
120
121
122 func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
123
124
125 if hi.rootProvider == nil {
126 return nil, errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server")
127 }
128
129
130 rootProv, idProv := hi.rootProvider, hi.identityProvider
131
132
133
134
135
136
137 cfg := &tls.Config{
138 InsecureSkipVerify: true,
139 NextProtos: []string{"h2"},
140 }
141
142 km, err := rootProv.KeyMaterial(ctx)
143 if err != nil {
144 return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
145 }
146 cfg.RootCAs = km.Roots
147
148 if idProv != nil {
149 km, err := idProv.KeyMaterial(ctx)
150 if err != nil {
151 return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err)
152 }
153 cfg.Certificates = km.Certs
154 }
155 return cfg, nil
156 }
157
158
159
160 func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config, error) {
161 cfg := &tls.Config{
162 ClientAuth: tls.NoClientCert,
163 NextProtos: []string{"h2"},
164 }
165
166
167 if hi.identityProvider == nil {
168 return nil, errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server")
169 }
170
171
172 rootProv, idProv := hi.rootProvider, hi.identityProvider
173 if hi.requireClientCert {
174 cfg.ClientAuth = tls.RequireAndVerifyClientCert
175 }
176
177
178 km, err := idProv.KeyMaterial(ctx)
179 if err != nil {
180 return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err)
181 }
182 cfg.Certificates = km.Certs
183
184 if rootProv != nil {
185 km, err := rootProv.KeyMaterial(ctx)
186 if err != nil {
187 return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
188 }
189 cfg.ClientCAs = km.Roots
190 }
191 return cfg, nil
192 }
193
194
195
196
197
198
199 func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
200 if len(hi.sanMatchers) == 0 {
201 return true
202 }
203
204
205 for _, san := range cert.DNSNames {
206 if hi.matchSAN(san, true) {
207 return true
208 }
209 }
210 for _, san := range cert.EmailAddresses {
211 if hi.matchSAN(san, false) {
212 return true
213 }
214 }
215 for _, san := range cert.IPAddresses {
216 if hi.matchSAN(san.String(), false) {
217 return true
218 }
219 }
220 for _, san := range cert.URIs {
221 if hi.matchSAN(san.String(), false) {
222 return true
223 }
224 }
225 return false
226 }
227
228
229 func (hi *HandshakeInfo) matchSAN(san string, isDNS bool) bool {
230 for _, matcher := range hi.sanMatchers {
231 if em := matcher.ExactMatch(); em != "" && isDNS {
232
233
234
235
236 if dnsMatch(em, san) {
237 return true
238 }
239 continue
240 }
241 if matcher.Match(san) {
242 return true
243 }
244 }
245 return false
246 }
247
248
249
250
251
252
253
254 func dnsMatch(host, san string) bool {
255
256 if !strings.HasSuffix(host, ".") {
257 host += "."
258 }
259 if !strings.HasSuffix(san, ".") {
260 san += "."
261 }
262
263 host = strings.ToLower(host)
264 san = strings.ToLower(san)
265
266
267 if !strings.Contains(san, "*") {
268 return host == san
269 }
270
271
272
273
274
275
276
277
278
279 if san == "*." || !strings.HasPrefix(san, "*.") || strings.Contains(san[1:], "*") {
280 return false
281 }
282
283
284
285 if len(host) < len(san) {
286 return false
287 }
288
289 if !strings.HasSuffix(host, san[1:]) {
290 return false
291 }
292
293
294
295 hostPrefix := strings.TrimSuffix(host, san[1:])
296 return !strings.Contains(hostPrefix, ".")
297 }
298
View as plain text