1 package configx
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "io/ioutil"
9 "net/url"
10 "os"
11 "reflect"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/sirupsen/logrus"
17
18 "github.com/ory/x/logrusx"
19
20 "github.com/ory/x/jsonschemax"
21
22 "github.com/opentracing/opentracing-go"
23 "github.com/opentracing/opentracing-go/log"
24
25 "github.com/ory/jsonschema/v3"
26 "github.com/ory/x/watcherx"
27
28 "github.com/inhies/go-bytesize"
29 "github.com/knadh/koanf/providers/posflag"
30 "github.com/spf13/pflag"
31
32 "github.com/ory/x/stringsx"
33 "github.com/ory/x/tracing"
34
35 "github.com/knadh/koanf"
36 "github.com/knadh/koanf/parsers/json"
37 "github.com/pkg/errors"
38 "github.com/rs/cors"
39 )
40
41 type tuple struct {
42 Key string
43 Value interface{}
44 }
45
46 type Provider struct {
47 l sync.Mutex
48 *koanf.Koanf
49 immutables []string
50
51 originalContext context.Context
52 cancelFork context.CancelFunc
53
54 schema []byte
55 flags *pflag.FlagSet
56 validator *jsonschema.Schema
57 onChanges []func(watcherx.Event, error)
58 onValidationError func(k *koanf.Koanf, err error)
59 excludeFieldsFromTracing []string
60 tracer *tracing.Tracer
61 forcedValues []tuple
62 baseValues []tuple
63 files []string
64 skipValidation bool
65 logger *logrusx.Logger
66 }
67
68 const (
69 FlagConfig = "config"
70 Delimiter = "."
71 )
72
73
74 func RegisterConfigFlag(flags *pflag.FlagSet, fallback []string) {
75 flags.StringSliceP(FlagConfig, "c", fallback, "Config files to load, overwriting in the order specified.")
76 }
77
78
79
80
81
82
83
84
85 func New(schema []byte, modifiers ...OptionModifier) (*Provider, error) {
86 schemaID, comp, err := newCompiler(schema)
87 if err != nil {
88 return nil, err
89 }
90
91 validator, err := comp.Compile(schemaID)
92 if err != nil {
93 return nil, err
94 }
95
96 l := logrus.New()
97 l.Out = ioutil.Discard
98
99 p := &Provider{
100 originalContext: context.Background(),
101 schema: schema,
102 validator: validator,
103 onValidationError: func(k *koanf.Koanf, err error) {},
104 excludeFieldsFromTracing: []string{"dsn", "secret", "password", "key"},
105 logger: logrusx.New("discarding config logger", "", logrusx.UseLogger(l)),
106 }
107
108 for _, m := range modifiers {
109 m(p)
110 }
111
112 k, _, cancelFork, err := p.forkKoanf()
113 if err != nil {
114 return nil, err
115 }
116
117 p.replaceKoanf(k, cancelFork)
118 return p, nil
119 }
120
121 func (p *Provider) replaceKoanf(k *koanf.Koanf, cancelFork context.CancelFunc) {
122 p.l.Lock()
123 defer p.l.Unlock()
124 if p.cancelFork != nil {
125 p.cancelFork()
126 }
127 p.Koanf = k
128 p.cancelFork = cancelFork
129 }
130
131 func (p *Provider) validate(k *koanf.Koanf) error {
132 if p.skipValidation {
133 return nil
134 }
135
136 out, err := k.Marshal(json.Parser())
137 if err != nil {
138 return errors.WithStack(err)
139 }
140 if err := p.validator.Validate(bytes.NewReader(out)); err != nil {
141 p.onValidationError(k, err)
142 return err
143 }
144
145 return nil
146 }
147
148 func (p *Provider) forkKoanf() (*koanf.Koanf, context.Context, context.CancelFunc, error) {
149 fork, cancel := context.WithCancel(p.originalContext)
150 span, fork := p.startSpan(fork, LoadSpanOpName)
151 defer span.Finish()
152
153 k := koanf.New(Delimiter)
154 dp, err := NewKoanfSchemaDefaults(p.schema)
155 if err != nil {
156 cancel()
157 return nil, nil, nil, err
158 }
159
160 ep, err := NewKoanfEnv("", p.schema)
161 if err != nil {
162 cancel()
163 return nil, nil, nil, err
164 }
165
166
167 if err := k.Load(dp, nil); err != nil {
168 cancel()
169 return nil, nil, nil, err
170 }
171
172 for _, t := range p.baseValues {
173 if err := k.Load(NewKoanfConfmap([]tuple{t}), nil); err != nil {
174 cancel()
175 return nil, nil, nil, err
176 }
177 }
178
179 var paths []string
180 if p.flags != nil {
181 p, _ := p.flags.GetStringSlice(FlagConfig)
182 paths = append(paths, p...)
183 }
184
185 if err := p.addAndWatchConfigFiles(fork, append(p.files, paths...), k); err != nil {
186 cancel()
187 return nil, nil, nil, err
188 }
189
190 if p.flags != nil {
191 if err := k.Load(posflag.Provider(p.flags, ".", k), nil); err != nil {
192 cancel()
193 return nil, nil, nil, err
194 }
195 }
196
197 if err := k.Load(ep, nil); err != nil {
198 cancel()
199 return nil, nil, nil, err
200 }
201
202
203 for _, t := range p.forcedValues {
204 if err := k.Load(NewKoanfConfmap([]tuple{t}), nil); err != nil {
205 cancel()
206 return nil, nil, nil, err
207 }
208 }
209
210 if err := p.validate(k); err != nil {
211 cancel()
212 return nil, nil, nil, err
213 }
214
215 p.traceConfig(fork, k, LoadSpanOpName)
216 return k, fork, cancel, nil
217 }
218
219
220 func (p *Provider) SetTracer(ctx context.Context, t *tracing.Tracer) {
221 p.tracer = t
222 p.traceConfig(ctx, p.Koanf, SnapshotSpanOpName)
223 }
224
225 func (p *Provider) startSpan(ctx context.Context, opName string) (opentracing.Span, context.Context) {
226 tracer := opentracing.GlobalTracer()
227 if p.tracer != nil && p.tracer.Tracer() != nil {
228 tracer = p.tracer.Tracer()
229 }
230 return opentracing.StartSpanFromContextWithTracer(ctx, tracer, opName)
231 }
232
233 func (p *Provider) traceConfig(ctx context.Context, k *koanf.Koanf, opName string) {
234 span, ctx := p.startSpan(ctx, opName)
235 defer span.Finish()
236
237 span.SetTag("component", "github.com/ory/x/configx")
238
239 fields := make([]log.Field, 0, len(k.Keys()))
240 for _, key := range k.Keys() {
241 var redact bool
242 for _, e := range p.excludeFieldsFromTracing {
243 if strings.Contains(key, e) {
244 redact = true
245 }
246 }
247
248 if redact {
249 fields = append(fields, log.Object(key, "[redacted]"))
250 } else {
251 fields = append(fields, log.Object(key, k.Get(key)))
252 }
253 }
254
255 span.LogFields(fields...)
256 }
257
258 func (p *Provider) runOnChanges(e watcherx.Event, err error) {
259 for k := range p.onChanges {
260 p.onChanges[k](e, err)
261 }
262 }
263
264 func (p *Provider) addAndWatchConfigFiles(ctx context.Context, paths []string, k *koanf.Koanf) error {
265 p.logger.WithField("files", paths).Debug("Adding config files.")
266
267 watchForFileChanges := func(c watcherx.EventChannel) {
268
269 for e := range c {
270 switch et := e.(type) {
271 case *watcherx.ErrorEvent:
272 p.runOnChanges(e, et)
273 continue
274 default:
275 nk, _, cancel, err := p.forkKoanf()
276 if err != nil {
277 p.runOnChanges(e, err)
278 continue
279 }
280
281 var cancelReload bool
282 for _, key := range p.immutables {
283 if !reflect.DeepEqual(k.Get(key), nk.Get(key)) {
284 cancel()
285 cancelReload = true
286 p.runOnChanges(e, NewImmutableError(key, fmt.Sprintf("%v", k.Get(key)), fmt.Sprintf("%v", nk.Get(key))))
287 break
288 }
289 }
290
291 if cancelReload {
292 continue
293 }
294
295 p.replaceKoanf(nk, cancel)
296 p.runOnChanges(e, nil)
297 }
298 }
299 }
300
301 for _, path := range paths {
302 fp, err := NewKoanfFile(ctx, path)
303 if err != nil {
304 return err
305 }
306
307 if err := k.Load(fp, nil); err != nil {
308 return err
309 }
310
311 c := make(watcherx.EventChannel)
312 if _, err := fp.WatchChannel(c); err != nil {
313 return err
314 }
315
316 go watchForFileChanges(c)
317 }
318
319 return nil
320 }
321
322 func (p *Provider) Set(key string, value interface{}) error {
323 p.forcedValues = append(p.forcedValues, tuple{Key: key, Value: value})
324
325 k, _, cancel, err := p.forkKoanf()
326 if err != nil {
327 return err
328 }
329
330 p.replaceKoanf(k, cancel)
331 return nil
332 }
333
334 func (p *Provider) BoolF(key string, fallback bool) bool {
335 if !p.Koanf.Exists(key) {
336 return fallback
337 }
338
339 return p.Bool(key)
340 }
341
342 func (p *Provider) StringF(key string, fallback string) string {
343 if !p.Koanf.Exists(key) {
344 return fallback
345 }
346
347 return p.String(key)
348 }
349
350 func (p *Provider) StringsF(key string, fallback []string) (val []string) {
351 if !p.Koanf.Exists(key) {
352 return fallback
353 }
354
355 return p.Strings(key)
356 }
357
358 func (p *Provider) IntF(key string, fallback int) (val int) {
359 if !p.Koanf.Exists(key) {
360 return fallback
361 }
362
363 return p.Int(key)
364 }
365
366 func (p *Provider) Float64F(key string, fallback float64) (val float64) {
367 if !p.Koanf.Exists(key) {
368 return fallback
369 }
370
371 return p.Float64(key)
372 }
373
374 func (p *Provider) DurationF(key string, fallback time.Duration) (val time.Duration) {
375 if !p.Koanf.Exists(key) {
376 return fallback
377 }
378
379 return p.Duration(key)
380 }
381
382 func (p *Provider) ByteSizeF(key string, fallback bytesize.ByteSize) bytesize.ByteSize {
383 if !p.Koanf.Exists(key) {
384 return fallback
385 }
386
387 switch v := p.Koanf.Get(key).(type) {
388 case string:
389
390 dec, err := bytesize.Parse(v)
391 if err != nil {
392 p.logger.WithField("key", key).WithField("raw_value", v).WithError(err).Warnf("error parsing byte size value, using fallback of %s", fallback)
393 return fallback
394 }
395 return dec
396 case float64:
397
398 return bytesize.ByteSize(v)
399 case bytesize.ByteSize:
400 return v
401 default:
402 p.logger.WithField("key", key).WithField("raw_type", fmt.Sprintf("%T", v)).WithField("raw_value", fmt.Sprintf("%+v", v)).Errorf("error converting byte size value because of unknown type, using fallback of %s", fallback)
403 return fallback
404 }
405 }
406
407 func (p *Provider) GetF(key string, fallback interface{}) (val interface{}) {
408 if !p.Exists(key) {
409 return fallback
410 }
411
412 return p.Get(key)
413 }
414
415 func (p *Provider) CORS(prefix string, defaults cors.Options) (cors.Options, bool) {
416 if len(prefix) > 0 {
417 prefix = strings.TrimRight(prefix, ".") + "."
418 }
419
420 return cors.Options{
421 AllowedOrigins: p.StringsF(prefix+"cors.allowed_origins", defaults.AllowedOrigins),
422 AllowedMethods: p.StringsF(prefix+"cors.allowed_methods", defaults.AllowedMethods),
423 AllowedHeaders: p.StringsF(prefix+"cors.allowed_headers", defaults.AllowedHeaders),
424 ExposedHeaders: p.StringsF(prefix+"cors.exposed_headers", defaults.ExposedHeaders),
425 AllowCredentials: p.BoolF(prefix+"cors.allow_credentials", defaults.AllowCredentials),
426 OptionsPassthrough: p.BoolF(prefix+"cors.options_passthrough", defaults.OptionsPassthrough),
427 MaxAge: p.IntF(prefix+"cors.max_age", defaults.MaxAge),
428 Debug: p.BoolF(prefix+"cors.debug", defaults.Debug),
429 }, p.Bool(prefix + "cors.enabled")
430 }
431
432 func (p *Provider) TracingConfig(serviceName string) *tracing.Config {
433 return &tracing.Config{
434 ServiceName: p.StringF("tracing.service_name", serviceName),
435 Provider: p.String("tracing.provider"),
436 Jaeger: &tracing.JaegerConfig{
437 LocalAgentHostPort: p.String("tracing.providers.jaeger.local_agent_address"),
438 SamplerType: p.StringF("tracing.providers.jaeger.sampling.type", "const"),
439 SamplerValue: p.Float64F("tracing.providers.jaeger.sampling.value", float64(1)),
440 SamplerServerURL: p.String("tracing.providers.jaeger.sampling.server_url"),
441 Propagation: stringsx.Coalesce(
442 os.Getenv("JAEGER_PROPAGATION"),
443 p.String("tracing.providers.jaeger.propagation"),
444 ),
445 },
446 Zipkin: &tracing.ZipkinConfig{
447 ServerURL: p.String("tracing.providers.zipkin.server_url"),
448 },
449 }
450 }
451
452 func (p *Provider) RequestURIF(path string, fallback *url.URL) *url.URL {
453 switch t := p.Get(path).(type) {
454 case *url.URL:
455 return t
456 case url.URL:
457 return &t
458 case string:
459 if parsed, err := url.ParseRequestURI(t); err == nil {
460 return parsed
461 }
462 }
463
464 return fallback
465 }
466
467 func (p *Provider) URIF(path string, fallback *url.URL) *url.URL {
468 switch t := p.Get(path).(type) {
469 case *url.URL:
470 return t
471 case url.URL:
472 return &t
473 case string:
474 if parsed, err := url.Parse(t); err == nil {
475 return parsed
476 }
477 }
478
479 return fallback
480 }
481
482
483 func (p *Provider) PrintHumanReadableValidationErrors(w io.Writer, err error) {
484 p.printHumanReadableValidationErrors(p.Koanf, w, err)
485 }
486
487 func (p *Provider) printHumanReadableValidationErrors(k *koanf.Koanf, w io.Writer, err error) {
488 if err == nil {
489 return
490 }
491
492 _, _ = fmt.Fprintln(os.Stderr, "")
493 conf, innerErr := k.Marshal(json.Parser())
494 if innerErr != nil {
495 _, _ = fmt.Fprintf(w, "Unable to unmarshal configuration: %+v", innerErr)
496 }
497
498 jsonschemax.FormatValidationErrorForCLI(w, conf, err)
499 }
500
View as plain text