1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package workercmd
18
19 import (
20 "crypto"
21 "crypto/hmac"
22 "crypto/rand"
23 "crypto/rsa"
24 "crypto/x509"
25 "encoding/json"
26 "errors"
27 "fmt"
28 "io/ioutil"
29 "log"
30 "net/http"
31 "os"
32 "sync"
33 "time"
34
35 "github.com/miekg/pkcs11"
36 "github.com/sassoftware/relic/cmdline/shared"
37 "github.com/sassoftware/relic/internal/workerrpc"
38 "github.com/sassoftware/relic/token"
39 )
40
41
42
43 var fatalErrors = map[pkcs11Error]bool{
44 pkcs11.CKR_CRYPTOKI_NOT_INITIALIZED: true,
45 pkcs11.CKR_DEVICE_REMOVED: true,
46 pkcs11.CKR_GENERAL_ERROR: true,
47 pkcs11.CKR_HOST_MEMORY: true,
48 pkcs11.CKR_LIBRARY_LOAD_FAILED: true,
49 pkcs11.CKR_SESSION_CLOSED: true,
50 pkcs11.CKR_SESSION_HANDLE_INVALID: true,
51 pkcs11.CKR_TOKEN_NOT_PRESENT: true,
52 pkcs11.CKR_TOKEN_NOT_RECOGNIZED: true,
53 pkcs11.CKR_USER_NOT_LOGGED_IN: true,
54 }
55
56 func (h *handler) healthCheck() {
57 interval := time.Duration(shared.CurrentConfig.Server.TokenCheckInterval) * time.Second
58 timeout := time.Duration(shared.CurrentConfig.Server.TokenCheckTimeout) * time.Second
59 ppid := os.Getppid()
60 tick := time.NewTicker(interval)
61 tmt := time.NewTimer(timeout)
62 errch := make(chan error)
63 for {
64
65 if os.Getppid() != ppid {
66 log.Println("error: parent process disappeared, worker stopping", ppid, os.Getppid())
67 h.shutdown()
68 return
69 }
70
71 go func() {
72 errch <- h.token.Ping()
73 }()
74 var err error
75 select {
76 case err = <-errch:
77
78 case <-tmt.C:
79
80 err = fmt.Errorf("timed out after %s", timeout)
81 }
82 if err != nil {
83
84 log.Printf("error: health check of token \"%s\" failed: %s", h.token.Config().Name(), err)
85 h.shutdown()
86 return
87 }
88
89 <-tick.C
90
91 if !tmt.Stop() {
92 <-tmt.C
93 }
94 tmt.Reset(timeout)
95 }
96 }
97
98 type handler struct {
99 token token.Token
100 cookie []byte
101 keyCache sync.Map
102 shutdown func()
103 }
104
105 func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
106
107 cookie := req.Header.Get("Auth-Cookie")
108 if !hmac.Equal([]byte(cookie), []byte(h.cookie)) {
109 rw.WriteHeader(http.StatusForbidden)
110 return
111 }
112
113 resp, err := h.handle(rw, req)
114 if err != nil {
115 resp.Retryable = true
116 resp.Err = err.Error()
117 if e, ok := err.(pkcs11Error); ok {
118 if fatalErrors[e] {
119 log.Printf("error: terminating worker for token \"%s\" due to error: %s", h.token.Config().Name(), err)
120 go h.shutdown()
121
122 resp.Retryable = true
123 } else {
124
125 resp.Retryable = false
126 }
127 }
128 }
129
130 blob, err := json.Marshal(resp)
131 if err != nil {
132 log.Printf("error: worker for token \"%s\": %s", h.token.Config().Name(), err)
133 return
134 }
135 rw.Write(blob)
136 }
137
138 func (h *handler) handle(rw http.ResponseWriter, req *http.Request) (resp workerrpc.Response, err error) {
139 blob, err := ioutil.ReadAll(req.Body)
140 if err != nil {
141 return resp, err
142 }
143 var rr workerrpc.Request
144 if err := json.Unmarshal(blob, &rr); err != nil {
145 return resp, err
146 }
147 switch req.URL.Path {
148 case workerrpc.Ping:
149 return resp, h.token.Ping()
150 case workerrpc.GetKey:
151 key, err := h.getKey(rr.KeyName)
152 if err != nil {
153 return resp, err
154 }
155 resp.ID = key.GetID()
156 resp.Value, err = x509.MarshalPKIXPublicKey(key.Public())
157 return resp, err
158 case workerrpc.Sign:
159 hash := crypto.Hash(rr.Hash)
160 opts := crypto.SignerOpts(hash)
161 if rr.SaltLength != nil {
162 opts = &rsa.PSSOptions{SaltLength: *rr.SaltLength, Hash: hash}
163 }
164 key, err := h.getKey(rr.KeyName)
165 if err != nil {
166 return resp, err
167 }
168 resp.Value, err = key.Sign(rand.Reader, rr.Digest, opts)
169 return resp, err
170 default:
171 return resp, errors.New("invalid method: " + req.URL.Path)
172 }
173 }
174
175
176 func (h *handler) getKey(keyName string) (token.Key, error) {
177 key, _ := h.keyCache.Load(keyName)
178 if key == nil {
179 var err error
180 key, err = h.token.GetKey(keyName)
181 if err != nil {
182 return nil, err
183 }
184 h.keyCache.Store(keyName, key)
185 }
186 return key.(token.Key), nil
187 }
188
View as plain text