...
1
18
19
20
21
22
23
24
25
26 package pemfile
27
28 import (
29 "bytes"
30 "context"
31 "crypto/tls"
32 "crypto/x509"
33 "errors"
34 "fmt"
35 "os"
36 "path/filepath"
37 "time"
38
39 "google.golang.org/grpc/credentials/tls/certprovider"
40 "google.golang.org/grpc/grpclog"
41 )
42
43 const defaultCertRefreshDuration = 1 * time.Hour
44
45 var (
46
47 newDistributor = func() distributor { return certprovider.NewDistributor() }
48
49 logger = grpclog.Component("pemfile")
50 )
51
52
53
54 type Options struct {
55
56
57 CertFile string
58
59
60 KeyFile string
61
62
63 RootFile string
64
65
66
67 RefreshDuration time.Duration
68 }
69
70 func (o Options) canonical() []byte {
71 return []byte(fmt.Sprintf("%s:%s:%s:%s", o.CertFile, o.KeyFile, o.RootFile, o.RefreshDuration))
72 }
73
74 func (o Options) validate() error {
75 if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" {
76 return fmt.Errorf("pemfile: at least one credential file needs to be specified")
77 }
78 if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
79 return fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified")
80 }
81
82
83
84
85
86 if certDir, keyDir := filepath.Dir(o.CertFile), filepath.Dir(o.KeyFile); certDir != keyDir {
87 return errors.New("pemfile: certificate and key file must be in the same directory")
88 }
89 return nil
90 }
91
92
93
94 func NewProvider(o Options) (certprovider.Provider, error) {
95 if err := o.validate(); err != nil {
96 return nil, err
97 }
98 return newProvider(o), nil
99 }
100
101
102
103 func newProvider(o Options) certprovider.Provider {
104 if o.RefreshDuration == 0 {
105 o.RefreshDuration = defaultCertRefreshDuration
106 }
107
108 provider := &watcher{opts: o}
109 if o.CertFile != "" && o.KeyFile != "" {
110 provider.identityDistributor = newDistributor()
111 }
112 if o.RootFile != "" {
113 provider.rootDistributor = newDistributor()
114 }
115
116 ctx, cancel := context.WithCancel(context.Background())
117 provider.cancel = cancel
118 go provider.run(ctx)
119 return provider
120 }
121
122
123
124
125
126 type watcher struct {
127 identityDistributor distributor
128 rootDistributor distributor
129 opts Options
130 certFileContents []byte
131 keyFileContents []byte
132 rootFileContents []byte
133 cancel context.CancelFunc
134 }
135
136
137
138
139 type distributor interface {
140 KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error)
141 Set(km *certprovider.KeyMaterial, err error)
142 Stop()
143 }
144
145
146
147
148
149
150
151
152 func (w *watcher) updateIdentityDistributor() {
153 if w.identityDistributor == nil {
154 return
155 }
156
157 certFileContents, err := os.ReadFile(w.opts.CertFile)
158 if err != nil {
159 logger.Warningf("certFile (%s) read failed: %v", w.opts.CertFile, err)
160 return
161 }
162 keyFileContents, err := os.ReadFile(w.opts.KeyFile)
163 if err != nil {
164 logger.Warningf("keyFile (%s) read failed: %v", w.opts.KeyFile, err)
165 return
166 }
167
168 if bytes.Equal(w.certFileContents, certFileContents) && bytes.Equal(w.keyFileContents, keyFileContents) {
169 return
170 }
171
172 cert, err := tls.X509KeyPair(certFileContents, keyFileContents)
173 if err != nil {
174 logger.Warningf("tls.X509KeyPair(%q, %q) failed: %v", certFileContents, keyFileContents, err)
175 return
176 }
177 w.certFileContents = certFileContents
178 w.keyFileContents = keyFileContents
179 w.identityDistributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}, nil)
180 }
181
182
183
184
185
186
187
188
189 func (w *watcher) updateRootDistributor() {
190 if w.rootDistributor == nil {
191 return
192 }
193
194 rootFileContents, err := os.ReadFile(w.opts.RootFile)
195 if err != nil {
196 logger.Warningf("rootFile (%s) read failed: %v", w.opts.RootFile, err)
197 return
198 }
199 trustPool := x509.NewCertPool()
200 if !trustPool.AppendCertsFromPEM(rootFileContents) {
201 logger.Warning("failed to parse root certificate")
202 return
203 }
204
205 if bytes.Equal(w.rootFileContents, rootFileContents) {
206 return
207 }
208
209 w.rootFileContents = rootFileContents
210 w.rootDistributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil)
211 }
212
213
214
215
216 func (w *watcher) run(ctx context.Context) {
217 ticker := time.NewTicker(w.opts.RefreshDuration)
218 for {
219 w.updateIdentityDistributor()
220 w.updateRootDistributor()
221 select {
222 case <-ctx.Done():
223 ticker.Stop()
224 if w.identityDistributor != nil {
225 w.identityDistributor.Stop()
226 }
227 if w.rootDistributor != nil {
228 w.rootDistributor.Stop()
229 }
230 return
231 case <-ticker.C:
232 }
233 }
234 }
235
236
237
238 func (w *watcher) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
239 km := &certprovider.KeyMaterial{}
240 if w.identityDistributor != nil {
241 identityKM, err := w.identityDistributor.KeyMaterial(ctx)
242 if err != nil {
243 return nil, err
244 }
245 km.Certs = identityKM.Certs
246 }
247 if w.rootDistributor != nil {
248 rootKM, err := w.rootDistributor.KeyMaterial(ctx)
249 if err != nil {
250 return nil, err
251 }
252 km.Roots = rootKM.Roots
253 }
254 return km, nil
255 }
256
257
258 func (w *watcher) Close() {
259 w.cancel()
260 }
261
View as plain text