1 package creds
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "errors"
8 "fmt"
9 "net"
10
11 "google.golang.org/grpc/credentials"
12 )
13
14 var (
15 ErrClientHandshakeNop = errors.New(
16 "boulder/grpc/creds: Client-side handshakes are not implemented with " +
17 "serverTransportCredentials")
18 ErrServerHandshakeNop = errors.New(
19 "boulder/grpc/creds: Server-side handshakes are not implemented with " +
20 "clientTransportCredentials")
21 ErrOverrideServerNameNop = errors.New(
22 "boulder/grpc/creds: OverrideServerName() is not implemented")
23 ErrNilServerConfig = errors.New(
24 "boulder/grpc/creds: `serverConfig` must not be nil")
25 ErrEmptyPeerCerts = errors.New(
26 "boulder/grpc/creds: validateClient given state with empty PeerCertificates")
27 )
28
29 type ErrSANNotAccepted struct {
30 got, expected []string
31 }
32
33 func (e ErrSANNotAccepted) Error() string {
34 return fmt.Sprintf("boulder/grpc/creds: client certificate SAN was invalid. "+
35 "Got %q, expected one of %q.", e.got, e.expected)
36 }
37
38
39
40 type clientTransportCredentials struct {
41 roots *x509.CertPool
42 clients []tls.Certificate
43
44
45 hostOverride string
46 }
47
48
49 func NewClientCredentials(rootCAs *x509.CertPool, clientCerts []tls.Certificate, hostOverride string) credentials.TransportCredentials {
50 return &clientTransportCredentials{rootCAs, clientCerts, hostOverride}
51 }
52
53
54
55
56
57 func (tc *clientTransportCredentials) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
58 var err error
59 host := tc.hostOverride
60 if host == "" {
61
62
63 host, _, err = net.SplitHostPort(addr)
64 if err != nil {
65 return nil, nil, err
66 }
67 }
68 conn := tls.Client(rawConn, &tls.Config{
69 ServerName: host,
70 RootCAs: tc.roots,
71 Certificates: tc.clients,
72 })
73 err = conn.HandshakeContext(ctx)
74 if err != nil {
75 _ = rawConn.Close()
76 return nil, nil, err
77 }
78 return conn, nil, nil
79 }
80
81
82
83 func (tc *clientTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
84 return nil, nil, ErrServerHandshakeNop
85 }
86
87
88 func (tc *clientTransportCredentials) Info() credentials.ProtocolInfo {
89 return credentials.ProtocolInfo{
90 SecurityProtocol: "tls",
91 SecurityVersion: "1.2",
92 }
93 }
94
95
96 func (tc *clientTransportCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
97 return nil, nil
98 }
99
100
101 func (tc *clientTransportCredentials) RequireTransportSecurity() bool {
102 return true
103 }
104
105
106 func (tc *clientTransportCredentials) Clone() credentials.TransportCredentials {
107 return NewClientCredentials(tc.roots, tc.clients, tc.hostOverride)
108 }
109
110
111 func (tc *clientTransportCredentials) OverrideServerName(serverNameOverride string) error {
112 return ErrOverrideServerNameNop
113 }
114
115
116
117 type serverTransportCredentials struct {
118 serverConfig *tls.Config
119 acceptedSANs map[string]struct{}
120 }
121
122
123 func NewServerCredentials(serverConfig *tls.Config, acceptedSANs map[string]struct{}) (credentials.TransportCredentials, error) {
124 if serverConfig == nil {
125 return nil, ErrNilServerConfig
126 }
127
128 return &serverTransportCredentials{serverConfig, acceptedSANs}, nil
129 }
130
131
132
133
134
135
136
137
138
139
140
141 func (tc *serverTransportCredentials) validateClient(peerState tls.ConnectionState) error {
142
149 if len(tc.acceptedSANs) == 0 {
150 return nil
151 }
152
153
154
155
156 if len(peerState.PeerCertificates) < 1 {
157 return ErrEmptyPeerCerts
158 }
159
160
161
162
163
164
165 leaf := peerState.PeerCertificates[0]
166
167
168
169 var receivedSANs []string
170 receivedSANs = append(receivedSANs, leaf.DNSNames...)
171 for _, ip := range leaf.IPAddresses {
172 receivedSANs = append(receivedSANs, ip.String())
173 }
174
175 for _, name := range receivedSANs {
176 if _, ok := tc.acceptedSANs[name]; ok {
177 return nil
178 }
179 }
180
181
182
183 var acceptableSANs []string
184 for k := range tc.acceptedSANs {
185 acceptableSANs = append(acceptableSANs, k)
186 }
187 return ErrSANNotAccepted{receivedSANs, acceptableSANs}
188 }
189
190
191
192
193 func (tc *serverTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
194
195
196 conn := tls.Server(rawConn, tc.serverConfig)
197 err := conn.Handshake()
198 if err != nil {
199 return nil, nil, err
200 }
201
202
203
204 err = tc.validateClient(conn.ConnectionState())
205 if err != nil {
206 return nil, nil, err
207 }
208
209 return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
210 }
211
212
213
214 func (tc *serverTransportCredentials) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
215 return nil, nil, ErrClientHandshakeNop
216 }
217
218
219 func (tc *serverTransportCredentials) Info() credentials.ProtocolInfo {
220 return credentials.ProtocolInfo{
221 SecurityProtocol: "tls",
222 SecurityVersion: "1.2",
223 }
224 }
225
226
227 func (tc *serverTransportCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
228 return nil, nil
229 }
230
231
232 func (tc *serverTransportCredentials) RequireTransportSecurity() bool {
233 return true
234 }
235
236
237 func (tc *serverTransportCredentials) Clone() credentials.TransportCredentials {
238 clone, _ := NewServerCredentials(tc.serverConfig, tc.acceptedSANs)
239 return clone
240 }
241
242
243 func (tc *serverTransportCredentials) OverrideServerName(serverNameOverride string) error {
244 return ErrOverrideServerNameNop
245 }
246
View as plain text