...

Source file src/github.com/sassoftware/relic/server/server.go

Documentation: github.com/sassoftware/relic/server

     1  //
     2  // Copyright (c) SAS Institute Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package server
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/sha256"
    23  	"crypto/x509"
    24  	"encoding/hex"
    25  	"fmt"
    26  	"io"
    27  	"log"
    28  	"net/http"
    29  	"os"
    30  	"runtime"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	"github.com/sassoftware/relic/config"
    36  	"github.com/sassoftware/relic/lib/compresshttp"
    37  	"github.com/sassoftware/relic/lib/isologger"
    38  	"github.com/sassoftware/relic/lib/x509tools"
    39  	"github.com/sassoftware/relic/token/worker"
    40  )
    41  
    42  type Server struct {
    43  	Config   *config.Config
    44  	ErrorLog *log.Logger
    45  	closeLog io.Closer
    46  	logMu    sync.Mutex
    47  	Closed   <-chan bool
    48  	closeCh  chan<- bool
    49  	tokens   map[string]*worker.WorkerToken
    50  }
    51  
    52  func (s *Server) callHandler(request *http.Request, lw *loggingWriter) (response Response, err error) {
    53  	defer func() {
    54  		if caught := recover(); caught != nil {
    55  			const size = 64 << 10
    56  			buf := make([]byte, size)
    57  			buf = buf[:runtime.Stack(buf, false)]
    58  			response = s.LogError(request, caught, buf)
    59  			err = nil
    60  		}
    61  	}()
    62  	ctx := request.Context()
    63  	ctx, errResponse := s.getUserRoles(ctx, request)
    64  	if errResponse != nil {
    65  		return errResponse, nil
    66  	}
    67  	request = request.WithContext(ctx)
    68  	lw.r = request
    69  	if err := compresshttp.DecompressRequest(request); err == compresshttp.ErrUnacceptableEncoding {
    70  		return StringResponse(http.StatusNotAcceptable, err.Error()), nil
    71  	} else if err != nil {
    72  		return nil, err
    73  	}
    74  	if request.URL.Path == "/health" {
    75  		// this view is the only one allowed without a client cert
    76  		return s.serveHealth(request)
    77  	} else if GetClientName(request) == "" {
    78  		return AccessDeniedResponse, nil
    79  	}
    80  	if strings.HasPrefix(request.URL.Path, "/keys/") {
    81  		return s.serveGetKey(request)
    82  	}
    83  	switch request.URL.Path {
    84  	case "/":
    85  		return s.serveHome(request)
    86  	case "/list_keys":
    87  		return s.serveListKeys(request)
    88  	case "/sign":
    89  		return s.serveSign(request, lw)
    90  	case "/directory":
    91  		return s.serveDirectory(request)
    92  	default:
    93  		return ErrorResponse(http.StatusNotFound), nil
    94  	}
    95  }
    96  
    97  func formatSubject(cert *x509.Certificate) string {
    98  	return x509tools.FormatPkixName(cert.RawSubject, x509tools.NameStyleOpenSsl)
    99  }
   100  
   101  func (s *Server) getUserRoles(ctx context.Context, request *http.Request) (context.Context, Response) {
   102  	if request.TLS != nil && len(request.TLS.PeerCertificates) != 0 {
   103  		cert := request.TLS.PeerCertificates[0]
   104  		digest := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
   105  		encoded := hex.EncodeToString(digest[:])
   106  		var useDN bool
   107  		client := s.Config.Clients[encoded]
   108  		if client == nil {
   109  			var saved error
   110  			for _, c2 := range s.Config.Clients {
   111  				match, err := c2.Match(request.TLS.PeerCertificates)
   112  				if match {
   113  					client = c2
   114  					useDN = true
   115  					break
   116  				} else if err != nil {
   117  					// preserve any potentially interesting validation errors
   118  					saved = err
   119  				}
   120  			}
   121  			if client == nil && saved != nil {
   122  				s.Logr(request, "client cert verification failed: %s\n", saved)
   123  			}
   124  		}
   125  		if client == nil {
   126  			s.Logr(request, "access denied: unknown fingerprint %s on certificate: %s\n", encoded, formatSubject(cert))
   127  			return nil, AccessDeniedResponse
   128  		}
   129  		name := client.Nickname
   130  		if name == "" {
   131  			name = encoded[:12]
   132  		}
   133  		ctx = context.WithValue(ctx, ctxClientName, name)
   134  		ctx = context.WithValue(ctx, ctxRoles, client.Roles)
   135  		if useDN {
   136  			ctx = context.WithValue(ctx, ctxClientDN, formatSubject(cert))
   137  		}
   138  	}
   139  	return ctx, nil
   140  }
   141  
   142  func (s *Server) CheckKeyAccess(request *http.Request, keyName string) *config.KeyConfig {
   143  	keyConf, err := s.Config.GetKey(keyName)
   144  	if err != nil {
   145  		return nil
   146  	}
   147  	clientRoles := GetClientRoles(request)
   148  	for _, keyRole := range keyConf.Roles {
   149  		for _, clientRole := range clientRoles {
   150  			if keyRole == clientRole {
   151  				return keyConf
   152  			}
   153  		}
   154  	}
   155  	return nil
   156  }
   157  
   158  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
   159  	writer.Header().Set("Accept-Encoding", compresshttp.AcceptedEncodings)
   160  	lw := &loggingWriter{ResponseWriter: writer, s: s, r: request, st: time.Now()}
   161  	defer lw.Close()
   162  	response, err := s.callHandler(request, lw)
   163  	if err != nil {
   164  		if request.Context().Err() != nil {
   165  			s.Logr(request, "client disconnected")
   166  			response = StringResponse(http.StatusBadRequest, "client disconnected")
   167  		} else {
   168  			response = s.LogError(lw.r, err, nil)
   169  		}
   170  	}
   171  	if response != nil {
   172  		for k, v := range response.Headers() {
   173  			lw.Header().Set(k, v)
   174  		}
   175  		ae := request.Header.Get("Accept-Encoding")
   176  		r := bytes.NewReader(response.Bytes())
   177  		if response.Status() >= 300 {
   178  			// don't compress errors
   179  			ae = ""
   180  		}
   181  		if err := compresshttp.CompressResponse(r, ae, lw, response.Status()); err != nil {
   182  			response = s.LogError(lw.r, err, nil)
   183  			writeResponse(lw, response)
   184  		}
   185  	}
   186  }
   187  
   188  func (s *Server) Close() error {
   189  	if s.closeCh != nil {
   190  		close(s.closeCh)
   191  		s.closeCh = nil
   192  	}
   193  	for _, t := range s.tokens {
   194  		t.Close()
   195  	}
   196  	return nil
   197  }
   198  
   199  func (s *Server) ReopenLogger() error {
   200  	if s.Config.Server.LogFile == "" {
   201  		return nil
   202  	}
   203  	f, err := os.OpenFile(s.Config.Server.LogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	s.logMu.Lock()
   208  	defer s.logMu.Unlock()
   209  	isologger.SetOutput(s.ErrorLog, f, isologger.RFC3339Milli)
   210  	if s.closeLog != nil {
   211  		s.closeLog.Close()
   212  	}
   213  	s.closeLog = f
   214  	return nil
   215  }
   216  
   217  func New(config *config.Config) (*Server, error) {
   218  	closed := make(chan bool)
   219  	s := &Server{
   220  		Config:   config,
   221  		Closed:   closed,
   222  		closeCh:  closed,
   223  		ErrorLog: log.New(os.Stderr, "", 0),
   224  		tokens:   make(map[string]*worker.WorkerToken),
   225  	}
   226  	if err := s.ReopenLogger(); err != nil {
   227  		return nil, fmt.Errorf("failed to open logfile: %s", err)
   228  	}
   229  	for _, name := range config.ListServedTokens() {
   230  		tok, err := worker.New(config, name)
   231  		if err != nil {
   232  			for _, t := range s.tokens {
   233  				t.Close()
   234  			}
   235  			return nil, err
   236  		}
   237  		s.tokens[name] = tok
   238  	}
   239  	if err := s.startHealthCheck(); err != nil {
   240  		return nil, err
   241  	}
   242  	return s, nil
   243  }
   244  

View as plain text