1 package api
2
3 import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "net"
11 "net/http"
12 "sync/atomic"
13 "time"
14
15 "github.com/julienschmidt/httprouter"
16 "github.com/linkerd/linkerd2/controller/k8s"
17 pkgk8s "github.com/linkerd/linkerd2/pkg/k8s"
18 "github.com/linkerd/linkerd2/pkg/prometheus"
19 pkgTls "github.com/linkerd/linkerd2/pkg/tls"
20 pb "github.com/linkerd/linkerd2/viz/tap/gen/tap"
21 log "github.com/sirupsen/logrus"
22 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
23 )
24
25
26 type Server struct {
27 *http.Server
28 listener net.Listener
29 router *httprouter.Router
30 allowedNames []string
31 certValue *atomic.Value
32 log *log.Entry
33 }
34
35
36 func NewServer(
37 ctx context.Context,
38 addr string,
39 k8sAPI *k8s.API,
40 grpcTapServer pb.TapServer,
41 disableCommonNames bool,
42 ) (*Server, error) {
43 updateEvent := make(chan struct{})
44 errEvent := make(chan error)
45 watcher := pkgTls.NewFsCredsWatcher(pkgk8s.MountPathTLSBase, updateEvent, errEvent).
46 WithFilePaths(pkgk8s.MountPathTLSCrtPEM, pkgk8s.MountPathTLSKeyPEM)
47 go func() {
48 if err := watcher.StartWatching(ctx); err != nil {
49 log.Fatalf("Failed to start creds watcher: %s", err)
50 }
51 }()
52
53 clientCAPem, allowedNames, usernameHeader, groupHeader, err := serverAuth(ctx, k8sAPI)
54 if err != nil {
55 return nil, err
56 }
57
58
59 if disableCommonNames {
60 allowedNames = []string{}
61 }
62
63 log := log.WithFields(log.Fields{
64 "component": "tap",
65 "addr": addr,
66 })
67
68 clientCertPool := x509.NewCertPool()
69 clientCertPool.AppendCertsFromPEM([]byte(clientCAPem))
70
71 httpServer := &http.Server{
72 Addr: addr,
73 ReadHeaderTimeout: 15 * time.Second,
74 TLSConfig: &tls.Config{
75 ClientAuth: tls.VerifyClientCertIfGiven,
76 ClientCAs: clientCertPool,
77 MinVersion: tls.VersionTLS12,
78 },
79 }
80
81 var emptyCert atomic.Value
82 h := &handler{
83 k8sAPI: k8sAPI,
84 usernameHeader: usernameHeader,
85 groupHeader: groupHeader,
86 grpcTapServer: grpcTapServer,
87 log: log,
88 }
89
90 lis, err := net.Listen("tcp", addr)
91 if err != nil {
92 return nil, fmt.Errorf("net.Listen failed with: %w", err)
93 }
94
95 s := &Server{
96 Server: httpServer,
97 listener: lis,
98 router: initRouter(h),
99 allowedNames: allowedNames,
100 certValue: &emptyCert,
101 log: log,
102 }
103 s.Handler = prometheus.WithTelemetry(s)
104 httpServer.TLSConfig.GetCertificate = s.getCertificate
105
106 if err := watcher.UpdateCert(s.certValue); err != nil {
107 return nil, fmt.Errorf("failed to initialized certificate: %w", err)
108 }
109
110 go watcher.ProcessEvents(log, s.certValue, updateEvent, errEvent)
111
112 return s, nil
113 }
114
115
116 func (a *Server) Start(ctx context.Context) {
117 a.log.Infof("starting tap API server on %s", a.Server.Addr)
118 if err := a.ServeTLS(a.listener, "", ""); err != nil {
119 if errors.Is(err, http.ErrServerClosed) {
120 return
121 }
122 a.log.Fatal(err)
123 }
124 }
125
126 func (a *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
127 return a.certValue.Load().(*tls.Certificate), nil
128 }
129
130
131 func (a *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
132 a.log.Debugf("ServeHTTP(): %+v", req)
133 if err := a.validate(req); err != nil {
134 a.log.Debug(err)
135 renderJSONError(w, err, http.StatusBadRequest)
136 } else {
137 a.router.ServeHTTP(w, req)
138 }
139 }
140
141
142 func (a *Server) validate(req *http.Request) error {
143
144 if len(a.allowedNames) > 0 {
145 for _, cn := range a.allowedNames {
146 for _, clientCert := range req.TLS.PeerCertificates {
147
148 if cn == clientCert.Subject.CommonName || isSubjectAlternateName(clientCert, cn) {
149 return nil
150 }
151 }
152 }
153
154 clientNames := []string{}
155 for _, clientCert := range req.TLS.PeerCertificates {
156 clientNames = append(clientNames, clientCert.Subject.CommonName)
157 }
158 return fmt.Errorf("no valid CN found. allowed names: %s, client names: %s", a.allowedNames, clientNames)
159 }
160 return nil
161 }
162
163
164
165
166
167 func serverAuth(ctx context.Context, k8sAPI *k8s.API) (string, []string, string, string, error) {
168
169 cm, err := k8sAPI.Client.CoreV1().
170 ConfigMaps(metav1.NamespaceSystem).
171 Get(ctx, pkgk8s.ExtensionAPIServerAuthenticationConfigMapName, metav1.GetOptions{})
172 if err != nil {
173 return "", nil, "", "", fmt.Errorf("failed to load [%s] config: %w", pkgk8s.ExtensionAPIServerAuthenticationConfigMapName, err)
174 }
175
176 clientCAPem, ok := cm.Data[pkgk8s.ExtensionAPIServerAuthenticationRequestHeaderClientCAFileKey]
177 if !ok {
178 return "", nil, "", "", fmt.Errorf("no client CA cert available for apiextension-server")
179 }
180
181 allowedNames, err := deserializeStrings(cm.Data["requestheader-allowed-names"])
182 if err != nil {
183 return "", nil, "", "", err
184 }
185
186 usernameHeaders, err := deserializeStrings(cm.Data["requestheader-username-headers"])
187 if err != nil {
188 return "", nil, "", "", err
189 }
190 usernameHeader := ""
191 if len(usernameHeaders) > 0 {
192 usernameHeader = usernameHeaders[0]
193 }
194
195 groupHeaders, err := deserializeStrings(cm.Data["requestheader-group-headers"])
196 if err != nil {
197 return "", nil, "", "", err
198 }
199 groupHeader := ""
200 if len(groupHeaders) > 0 {
201 groupHeader = groupHeaders[0]
202 }
203
204 return clientCAPem, allowedNames, usernameHeader, groupHeader, nil
205 }
206
207
208 func deserializeStrings(in string) ([]string, error) {
209 if in == "" {
210 return nil, nil
211 }
212 var ret []string
213 if err := json.Unmarshal([]byte(in), &ret); err != nil {
214 return nil, err
215 }
216 return ret, nil
217 }
218
219
220
221 func isSubjectAlternateName(cert *x509.Certificate, name string) bool {
222 for _, dnsName := range cert.DNSNames {
223 if dnsName == name {
224 return true
225 }
226 }
227 for _, emailAddress := range cert.EmailAddresses {
228 if emailAddress == name {
229 return true
230 }
231 }
232 for _, ip := range cert.IPAddresses {
233 if ip.String() == name {
234 return true
235 }
236 }
237 for _, url := range cert.URIs {
238 if url.String() == name {
239 return true
240 }
241 }
242 return false
243 }
244
View as plain text