1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package server
18
19 import (
20 "bytes"
21 "context"
22 "crypto/sha256"
23 "crypto/x509"
24 "encoding/hex"
25 "fmt"
26 "io"
27 "log"
28 "net/http"
29 "os"
30 "runtime"
31 "strings"
32 "sync"
33 "time"
34
35 "github.com/sassoftware/relic/config"
36 "github.com/sassoftware/relic/lib/compresshttp"
37 "github.com/sassoftware/relic/lib/isologger"
38 "github.com/sassoftware/relic/lib/x509tools"
39 "github.com/sassoftware/relic/token/worker"
40 )
41
42 type Server struct {
43 Config *config.Config
44 ErrorLog *log.Logger
45 closeLog io.Closer
46 logMu sync.Mutex
47 Closed <-chan bool
48 closeCh chan<- bool
49 tokens map[string]*worker.WorkerToken
50 }
51
52 func (s *Server) callHandler(request *http.Request, lw *loggingWriter) (response Response, err error) {
53 defer func() {
54 if caught := recover(); caught != nil {
55 const size = 64 << 10
56 buf := make([]byte, size)
57 buf = buf[:runtime.Stack(buf, false)]
58 response = s.LogError(request, caught, buf)
59 err = nil
60 }
61 }()
62 ctx := request.Context()
63 ctx, errResponse := s.getUserRoles(ctx, request)
64 if errResponse != nil {
65 return errResponse, nil
66 }
67 request = request.WithContext(ctx)
68 lw.r = request
69 if err := compresshttp.DecompressRequest(request); err == compresshttp.ErrUnacceptableEncoding {
70 return StringResponse(http.StatusNotAcceptable, err.Error()), nil
71 } else if err != nil {
72 return nil, err
73 }
74 if request.URL.Path == "/health" {
75
76 return s.serveHealth(request)
77 } else if GetClientName(request) == "" {
78 return AccessDeniedResponse, nil
79 }
80 if strings.HasPrefix(request.URL.Path, "/keys/") {
81 return s.serveGetKey(request)
82 }
83 switch request.URL.Path {
84 case "/":
85 return s.serveHome(request)
86 case "/list_keys":
87 return s.serveListKeys(request)
88 case "/sign":
89 return s.serveSign(request, lw)
90 case "/directory":
91 return s.serveDirectory(request)
92 default:
93 return ErrorResponse(http.StatusNotFound), nil
94 }
95 }
96
97 func formatSubject(cert *x509.Certificate) string {
98 return x509tools.FormatPkixName(cert.RawSubject, x509tools.NameStyleOpenSsl)
99 }
100
101 func (s *Server) getUserRoles(ctx context.Context, request *http.Request) (context.Context, Response) {
102 if request.TLS != nil && len(request.TLS.PeerCertificates) != 0 {
103 cert := request.TLS.PeerCertificates[0]
104 digest := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
105 encoded := hex.EncodeToString(digest[:])
106 var useDN bool
107 client := s.Config.Clients[encoded]
108 if client == nil {
109 var saved error
110 for _, c2 := range s.Config.Clients {
111 match, err := c2.Match(request.TLS.PeerCertificates)
112 if match {
113 client = c2
114 useDN = true
115 break
116 } else if err != nil {
117
118 saved = err
119 }
120 }
121 if client == nil && saved != nil {
122 s.Logr(request, "client cert verification failed: %s\n", saved)
123 }
124 }
125 if client == nil {
126 s.Logr(request, "access denied: unknown fingerprint %s on certificate: %s\n", encoded, formatSubject(cert))
127 return nil, AccessDeniedResponse
128 }
129 name := client.Nickname
130 if name == "" {
131 name = encoded[:12]
132 }
133 ctx = context.WithValue(ctx, ctxClientName, name)
134 ctx = context.WithValue(ctx, ctxRoles, client.Roles)
135 if useDN {
136 ctx = context.WithValue(ctx, ctxClientDN, formatSubject(cert))
137 }
138 }
139 return ctx, nil
140 }
141
142 func (s *Server) CheckKeyAccess(request *http.Request, keyName string) *config.KeyConfig {
143 keyConf, err := s.Config.GetKey(keyName)
144 if err != nil {
145 return nil
146 }
147 clientRoles := GetClientRoles(request)
148 for _, keyRole := range keyConf.Roles {
149 for _, clientRole := range clientRoles {
150 if keyRole == clientRole {
151 return keyConf
152 }
153 }
154 }
155 return nil
156 }
157
158 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
159 writer.Header().Set("Accept-Encoding", compresshttp.AcceptedEncodings)
160 lw := &loggingWriter{ResponseWriter: writer, s: s, r: request, st: time.Now()}
161 defer lw.Close()
162 response, err := s.callHandler(request, lw)
163 if err != nil {
164 if request.Context().Err() != nil {
165 s.Logr(request, "client disconnected")
166 response = StringResponse(http.StatusBadRequest, "client disconnected")
167 } else {
168 response = s.LogError(lw.r, err, nil)
169 }
170 }
171 if response != nil {
172 for k, v := range response.Headers() {
173 lw.Header().Set(k, v)
174 }
175 ae := request.Header.Get("Accept-Encoding")
176 r := bytes.NewReader(response.Bytes())
177 if response.Status() >= 300 {
178
179 ae = ""
180 }
181 if err := compresshttp.CompressResponse(r, ae, lw, response.Status()); err != nil {
182 response = s.LogError(lw.r, err, nil)
183 writeResponse(lw, response)
184 }
185 }
186 }
187
188 func (s *Server) Close() error {
189 if s.closeCh != nil {
190 close(s.closeCh)
191 s.closeCh = nil
192 }
193 for _, t := range s.tokens {
194 t.Close()
195 }
196 return nil
197 }
198
199 func (s *Server) ReopenLogger() error {
200 if s.Config.Server.LogFile == "" {
201 return nil
202 }
203 f, err := os.OpenFile(s.Config.Server.LogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
204 if err != nil {
205 return err
206 }
207 s.logMu.Lock()
208 defer s.logMu.Unlock()
209 isologger.SetOutput(s.ErrorLog, f, isologger.RFC3339Milli)
210 if s.closeLog != nil {
211 s.closeLog.Close()
212 }
213 s.closeLog = f
214 return nil
215 }
216
217 func New(config *config.Config) (*Server, error) {
218 closed := make(chan bool)
219 s := &Server{
220 Config: config,
221 Closed: closed,
222 closeCh: closed,
223 ErrorLog: log.New(os.Stderr, "", 0),
224 tokens: make(map[string]*worker.WorkerToken),
225 }
226 if err := s.ReopenLogger(); err != nil {
227 return nil, fmt.Errorf("failed to open logfile: %s", err)
228 }
229 for _, name := range config.ListServedTokens() {
230 tok, err := worker.New(config, name)
231 if err != nil {
232 for _, t := range s.tokens {
233 t.Close()
234 }
235 return nil, err
236 }
237 s.tokens[name] = tok
238 }
239 if err := s.startHealthCheck(); err != nil {
240 return nil, err
241 }
242 return s, nil
243 }
244
View as plain text