...

Source file src/github.com/sassoftware/relic/cmdline/workercmd/handler.go

Documentation: github.com/sassoftware/relic/cmdline/workercmd

     1  //
     2  // Copyright (c) SAS Institute Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package workercmd
    18  
    19  import (
    20  	"crypto"
    21  	"crypto/hmac"
    22  	"crypto/rand"
    23  	"crypto/rsa"
    24  	"crypto/x509"
    25  	"encoding/json"
    26  	"errors"
    27  	"fmt"
    28  	"io/ioutil"
    29  	"log"
    30  	"net/http"
    31  	"os"
    32  	"sync"
    33  	"time"
    34  
    35  	"github.com/miekg/pkcs11"
    36  	"github.com/sassoftware/relic/cmdline/shared"
    37  	"github.com/sassoftware/relic/internal/workerrpc"
    38  	"github.com/sassoftware/relic/token"
    39  )
    40  
    41  // an arbitarily-chosen set of error codes that indicate that the token session
    42  // is busted and that the worker should exit and start over
    43  var fatalErrors = map[pkcs11Error]bool{
    44  	pkcs11.CKR_CRYPTOKI_NOT_INITIALIZED: true,
    45  	pkcs11.CKR_DEVICE_REMOVED:           true,
    46  	pkcs11.CKR_GENERAL_ERROR:            true,
    47  	pkcs11.CKR_HOST_MEMORY:              true,
    48  	pkcs11.CKR_LIBRARY_LOAD_FAILED:      true,
    49  	pkcs11.CKR_SESSION_CLOSED:           true,
    50  	pkcs11.CKR_SESSION_HANDLE_INVALID:   true,
    51  	pkcs11.CKR_TOKEN_NOT_PRESENT:        true,
    52  	pkcs11.CKR_TOKEN_NOT_RECOGNIZED:     true,
    53  	pkcs11.CKR_USER_NOT_LOGGED_IN:       true,
    54  }
    55  
    56  func (h *handler) healthCheck() {
    57  	interval := time.Duration(shared.CurrentConfig.Server.TokenCheckInterval) * time.Second
    58  	timeout := time.Duration(shared.CurrentConfig.Server.TokenCheckTimeout) * time.Second
    59  	ppid := os.Getppid()
    60  	tick := time.NewTicker(interval)
    61  	tmt := time.NewTimer(timeout)
    62  	errch := make(chan error)
    63  	for {
    64  		// check if parent process went away
    65  		if os.Getppid() != ppid {
    66  			log.Println("error: parent process disappeared, worker stopping", ppid, os.Getppid())
    67  			h.shutdown()
    68  			return
    69  		}
    70  		// check if token is alive
    71  		go func() {
    72  			errch <- h.token.Ping()
    73  		}()
    74  		var err error
    75  		select {
    76  		case err = <-errch:
    77  			// ping completed
    78  		case <-tmt.C:
    79  			// timed out
    80  			err = fmt.Errorf("timed out after %s", timeout)
    81  		}
    82  		if err != nil {
    83  			// stop the worker on error
    84  			log.Printf("error: health check of token \"%s\" failed: %s", h.token.Config().Name(), err)
    85  			h.shutdown()
    86  			return
    87  		}
    88  		// wait for next tick
    89  		<-tick.C
    90  		// reset timeout
    91  		if !tmt.Stop() {
    92  			<-tmt.C
    93  		}
    94  		tmt.Reset(timeout)
    95  	}
    96  }
    97  
    98  type handler struct {
    99  	token    token.Token
   100  	cookie   []byte
   101  	keyCache sync.Map
   102  	shutdown func()
   103  }
   104  
   105  func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   106  	// validate auth cookie
   107  	cookie := req.Header.Get("Auth-Cookie")
   108  	if !hmac.Equal([]byte(cookie), []byte(h.cookie)) {
   109  		rw.WriteHeader(http.StatusForbidden)
   110  		return
   111  	}
   112  	// dispatch
   113  	resp, err := h.handle(rw, req)
   114  	if err != nil {
   115  		resp.Retryable = true
   116  		resp.Err = err.Error()
   117  		if e, ok := err.(pkcs11Error); ok {
   118  			if fatalErrors[e] {
   119  				log.Printf("error: terminating worker for token \"%s\" due to error: %s", h.token.Config().Name(), err)
   120  				go h.shutdown()
   121  				// errors that cause the worker to restart are also retryable
   122  				resp.Retryable = true
   123  			} else {
   124  				// pkcs11 errors not in fatalErrors are probably user error, so don't retry
   125  				resp.Retryable = false
   126  			}
   127  		}
   128  	}
   129  	// marshal response
   130  	blob, err := json.Marshal(resp)
   131  	if err != nil {
   132  		log.Printf("error: worker for token \"%s\": %s", h.token.Config().Name(), err)
   133  		return
   134  	}
   135  	rw.Write(blob)
   136  }
   137  
   138  func (h *handler) handle(rw http.ResponseWriter, req *http.Request) (resp workerrpc.Response, err error) {
   139  	blob, err := ioutil.ReadAll(req.Body)
   140  	if err != nil {
   141  		return resp, err
   142  	}
   143  	var rr workerrpc.Request
   144  	if err := json.Unmarshal(blob, &rr); err != nil {
   145  		return resp, err
   146  	}
   147  	switch req.URL.Path {
   148  	case workerrpc.Ping:
   149  		return resp, h.token.Ping()
   150  	case workerrpc.GetKey:
   151  		key, err := h.getKey(rr.KeyName)
   152  		if err != nil {
   153  			return resp, err
   154  		}
   155  		resp.ID = key.GetID()
   156  		resp.Value, err = x509.MarshalPKIXPublicKey(key.Public())
   157  		return resp, err
   158  	case workerrpc.Sign:
   159  		hash := crypto.Hash(rr.Hash)
   160  		opts := crypto.SignerOpts(hash)
   161  		if rr.SaltLength != nil {
   162  			opts = &rsa.PSSOptions{SaltLength: *rr.SaltLength, Hash: hash}
   163  		}
   164  		key, err := h.getKey(rr.KeyName)
   165  		if err != nil {
   166  			return resp, err
   167  		}
   168  		resp.Value, err = key.Sign(rand.Reader, rr.Digest, opts)
   169  		return resp, err
   170  	default:
   171  		return resp, errors.New("invalid method: " + req.URL.Path)
   172  	}
   173  }
   174  
   175  // cache key handles
   176  func (h *handler) getKey(keyName string) (token.Key, error) {
   177  	key, _ := h.keyCache.Load(keyName)
   178  	if key == nil {
   179  		var err error
   180  		key, err = h.token.GetKey(keyName)
   181  		if err != nil {
   182  			return nil, err
   183  		}
   184  		h.keyCache.Store(keyName, key)
   185  	}
   186  	return key.(token.Key), nil
   187  }
   188  

View as plain text