...

Source file src/github.com/ory/x/configx/provider.go

Documentation: github.com/ory/x/configx

     1  package configx
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/url"
    10  	"os"
    11  	"reflect"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/sirupsen/logrus"
    17  
    18  	"github.com/ory/x/logrusx"
    19  
    20  	"github.com/ory/x/jsonschemax"
    21  
    22  	"github.com/opentracing/opentracing-go"
    23  	"github.com/opentracing/opentracing-go/log"
    24  
    25  	"github.com/ory/jsonschema/v3"
    26  	"github.com/ory/x/watcherx"
    27  
    28  	"github.com/inhies/go-bytesize"
    29  	"github.com/knadh/koanf/providers/posflag"
    30  	"github.com/spf13/pflag"
    31  
    32  	"github.com/ory/x/stringsx"
    33  	"github.com/ory/x/tracing"
    34  
    35  	"github.com/knadh/koanf"
    36  	"github.com/knadh/koanf/parsers/json"
    37  	"github.com/pkg/errors"
    38  	"github.com/rs/cors"
    39  )
    40  
    41  type tuple struct {
    42  	Key   string
    43  	Value interface{}
    44  }
    45  
    46  type Provider struct {
    47  	l sync.Mutex
    48  	*koanf.Koanf
    49  	immutables []string
    50  
    51  	originalContext context.Context
    52  	cancelFork      context.CancelFunc
    53  
    54  	schema                   []byte
    55  	flags                    *pflag.FlagSet
    56  	validator                *jsonschema.Schema
    57  	onChanges                []func(watcherx.Event, error)
    58  	onValidationError        func(k *koanf.Koanf, err error)
    59  	excludeFieldsFromTracing []string
    60  	tracer                   *tracing.Tracer
    61  	forcedValues             []tuple
    62  	baseValues               []tuple
    63  	files                    []string
    64  	skipValidation           bool
    65  	logger                   *logrusx.Logger
    66  }
    67  
    68  const (
    69  	FlagConfig = "config"
    70  	Delimiter  = "."
    71  )
    72  
    73  // RegisterConfigFlag registers the "--config" flag on pflag.FlagSet.
    74  func RegisterConfigFlag(flags *pflag.FlagSet, fallback []string) {
    75  	flags.StringSliceP(FlagConfig, "c", fallback, "Config files to load, overwriting in the order specified.")
    76  }
    77  
    78  // New creates a new provider instance or errors.
    79  // Configuration values are loaded in the following order:
    80  //
    81  // 1. Defaults from the JSON Schema
    82  // 2. Config files (yaml, yml, toml, json)
    83  // 3. Command line flags
    84  // 4. Environment variables
    85  func New(schema []byte, modifiers ...OptionModifier) (*Provider, error) {
    86  	schemaID, comp, err := newCompiler(schema)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	validator, err := comp.Compile(schemaID)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	l := logrus.New()
    97  	l.Out = ioutil.Discard
    98  
    99  	p := &Provider{
   100  		originalContext:          context.Background(),
   101  		schema:                   schema,
   102  		validator:                validator,
   103  		onValidationError:        func(k *koanf.Koanf, err error) {},
   104  		excludeFieldsFromTracing: []string{"dsn", "secret", "password", "key"},
   105  		logger:                   logrusx.New("discarding config logger", "", logrusx.UseLogger(l)),
   106  	}
   107  
   108  	for _, m := range modifiers {
   109  		m(p)
   110  	}
   111  
   112  	k, _, cancelFork, err := p.forkKoanf()
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	p.replaceKoanf(k, cancelFork)
   118  	return p, nil
   119  }
   120  
   121  func (p *Provider) replaceKoanf(k *koanf.Koanf, cancelFork context.CancelFunc) {
   122  	p.l.Lock()
   123  	defer p.l.Unlock()
   124  	if p.cancelFork != nil {
   125  		p.cancelFork()
   126  	}
   127  	p.Koanf = k
   128  	p.cancelFork = cancelFork
   129  }
   130  
   131  func (p *Provider) validate(k *koanf.Koanf) error {
   132  	if p.skipValidation {
   133  		return nil
   134  	}
   135  
   136  	out, err := k.Marshal(json.Parser())
   137  	if err != nil {
   138  		return errors.WithStack(err)
   139  	}
   140  	if err := p.validator.Validate(bytes.NewReader(out)); err != nil {
   141  		p.onValidationError(k, err)
   142  		return err
   143  	}
   144  
   145  	return nil
   146  }
   147  
   148  func (p *Provider) forkKoanf() (*koanf.Koanf, context.Context, context.CancelFunc, error) {
   149  	fork, cancel := context.WithCancel(p.originalContext)
   150  	span, fork := p.startSpan(fork, LoadSpanOpName)
   151  	defer span.Finish()
   152  
   153  	k := koanf.New(Delimiter)
   154  	dp, err := NewKoanfSchemaDefaults(p.schema)
   155  	if err != nil {
   156  		cancel()
   157  		return nil, nil, nil, err
   158  	}
   159  
   160  	ep, err := NewKoanfEnv("", p.schema)
   161  	if err != nil {
   162  		cancel()
   163  		return nil, nil, nil, err
   164  	}
   165  
   166  	// Load defaults
   167  	if err := k.Load(dp, nil); err != nil {
   168  		cancel()
   169  		return nil, nil, nil, err
   170  	}
   171  
   172  	for _, t := range p.baseValues {
   173  		if err := k.Load(NewKoanfConfmap([]tuple{t}), nil); err != nil {
   174  			cancel()
   175  			return nil, nil, nil, err
   176  		}
   177  	}
   178  
   179  	var paths []string
   180  	if p.flags != nil {
   181  		p, _ := p.flags.GetStringSlice(FlagConfig)
   182  		paths = append(paths, p...)
   183  	}
   184  
   185  	if err := p.addAndWatchConfigFiles(fork, append(p.files, paths...), k); err != nil {
   186  		cancel()
   187  		return nil, nil, nil, err
   188  	}
   189  
   190  	if p.flags != nil {
   191  		if err := k.Load(posflag.Provider(p.flags, ".", k), nil); err != nil {
   192  			cancel()
   193  			return nil, nil, nil, err
   194  		}
   195  	}
   196  
   197  	if err := k.Load(ep, nil); err != nil {
   198  		cancel()
   199  		return nil, nil, nil, err
   200  	}
   201  
   202  	// Workaround for https://github.com/knadh/koanf/pull/47
   203  	for _, t := range p.forcedValues {
   204  		if err := k.Load(NewKoanfConfmap([]tuple{t}), nil); err != nil {
   205  			cancel()
   206  			return nil, nil, nil, err
   207  		}
   208  	}
   209  
   210  	if err := p.validate(k); err != nil {
   211  		cancel()
   212  		return nil, nil, nil, err
   213  	}
   214  
   215  	p.traceConfig(fork, k, LoadSpanOpName)
   216  	return k, fork, cancel, nil
   217  }
   218  
   219  // TraceSnapshot will send the configuration to the tracer.
   220  func (p *Provider) SetTracer(ctx context.Context, t *tracing.Tracer) {
   221  	p.tracer = t
   222  	p.traceConfig(ctx, p.Koanf, SnapshotSpanOpName)
   223  }
   224  
   225  func (p *Provider) startSpan(ctx context.Context, opName string) (opentracing.Span, context.Context) {
   226  	tracer := opentracing.GlobalTracer()
   227  	if p.tracer != nil && p.tracer.Tracer() != nil {
   228  		tracer = p.tracer.Tracer()
   229  	}
   230  	return opentracing.StartSpanFromContextWithTracer(ctx, tracer, opName)
   231  }
   232  
   233  func (p *Provider) traceConfig(ctx context.Context, k *koanf.Koanf, opName string) {
   234  	span, ctx := p.startSpan(ctx, opName)
   235  	defer span.Finish()
   236  
   237  	span.SetTag("component", "github.com/ory/x/configx")
   238  
   239  	fields := make([]log.Field, 0, len(k.Keys()))
   240  	for _, key := range k.Keys() {
   241  		var redact bool
   242  		for _, e := range p.excludeFieldsFromTracing {
   243  			if strings.Contains(key, e) {
   244  				redact = true
   245  			}
   246  		}
   247  
   248  		if redact {
   249  			fields = append(fields, log.Object(key, "[redacted]"))
   250  		} else {
   251  			fields = append(fields, log.Object(key, k.Get(key)))
   252  		}
   253  	}
   254  
   255  	span.LogFields(fields...)
   256  }
   257  
   258  func (p *Provider) runOnChanges(e watcherx.Event, err error) {
   259  	for k := range p.onChanges {
   260  		p.onChanges[k](e, err)
   261  	}
   262  }
   263  
   264  func (p *Provider) addAndWatchConfigFiles(ctx context.Context, paths []string, k *koanf.Koanf) error {
   265  	p.logger.WithField("files", paths).Debug("Adding config files.")
   266  
   267  	watchForFileChanges := func(c watcherx.EventChannel) {
   268  		// Channel is closed automatically on ctx.Done() because of fp.WatchChannel()
   269  		for e := range c {
   270  			switch et := e.(type) {
   271  			case *watcherx.ErrorEvent:
   272  				p.runOnChanges(e, et)
   273  				continue
   274  			default:
   275  				nk, _, cancel, err := p.forkKoanf()
   276  				if err != nil {
   277  					p.runOnChanges(e, err)
   278  					continue
   279  				}
   280  
   281  				var cancelReload bool
   282  				for _, key := range p.immutables {
   283  					if !reflect.DeepEqual(k.Get(key), nk.Get(key)) {
   284  						cancel()
   285  						cancelReload = true
   286  						p.runOnChanges(e, NewImmutableError(key, fmt.Sprintf("%v", k.Get(key)), fmt.Sprintf("%v", nk.Get(key))))
   287  						break
   288  					}
   289  				}
   290  
   291  				if cancelReload {
   292  					continue
   293  				}
   294  
   295  				p.replaceKoanf(nk, cancel)
   296  				p.runOnChanges(e, nil)
   297  			}
   298  		}
   299  	}
   300  
   301  	for _, path := range paths {
   302  		fp, err := NewKoanfFile(ctx, path)
   303  		if err != nil {
   304  			return err
   305  		}
   306  
   307  		if err := k.Load(fp, nil); err != nil {
   308  			return err
   309  		}
   310  
   311  		c := make(watcherx.EventChannel)
   312  		if _, err := fp.WatchChannel(c); err != nil {
   313  			return err
   314  		}
   315  
   316  		go watchForFileChanges(c)
   317  	}
   318  
   319  	return nil
   320  }
   321  
   322  func (p *Provider) Set(key string, value interface{}) error {
   323  	p.forcedValues = append(p.forcedValues, tuple{Key: key, Value: value})
   324  
   325  	k, _, cancel, err := p.forkKoanf()
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	p.replaceKoanf(k, cancel)
   331  	return nil
   332  }
   333  
   334  func (p *Provider) BoolF(key string, fallback bool) bool {
   335  	if !p.Koanf.Exists(key) {
   336  		return fallback
   337  	}
   338  
   339  	return p.Bool(key)
   340  }
   341  
   342  func (p *Provider) StringF(key string, fallback string) string {
   343  	if !p.Koanf.Exists(key) {
   344  		return fallback
   345  	}
   346  
   347  	return p.String(key)
   348  }
   349  
   350  func (p *Provider) StringsF(key string, fallback []string) (val []string) {
   351  	if !p.Koanf.Exists(key) {
   352  		return fallback
   353  	}
   354  
   355  	return p.Strings(key)
   356  }
   357  
   358  func (p *Provider) IntF(key string, fallback int) (val int) {
   359  	if !p.Koanf.Exists(key) {
   360  		return fallback
   361  	}
   362  
   363  	return p.Int(key)
   364  }
   365  
   366  func (p *Provider) Float64F(key string, fallback float64) (val float64) {
   367  	if !p.Koanf.Exists(key) {
   368  		return fallback
   369  	}
   370  
   371  	return p.Float64(key)
   372  }
   373  
   374  func (p *Provider) DurationF(key string, fallback time.Duration) (val time.Duration) {
   375  	if !p.Koanf.Exists(key) {
   376  		return fallback
   377  	}
   378  
   379  	return p.Duration(key)
   380  }
   381  
   382  func (p *Provider) ByteSizeF(key string, fallback bytesize.ByteSize) bytesize.ByteSize {
   383  	if !p.Koanf.Exists(key) {
   384  		return fallback
   385  	}
   386  
   387  	switch v := p.Koanf.Get(key).(type) {
   388  	case string:
   389  		// this type usually comes from user input
   390  		dec, err := bytesize.Parse(v)
   391  		if err != nil {
   392  			p.logger.WithField("key", key).WithField("raw_value", v).WithError(err).Warnf("error parsing byte size value, using fallback of %s", fallback)
   393  			return fallback
   394  		}
   395  		return dec
   396  	case float64:
   397  		// this type comes from json.Unmarshal
   398  		return bytesize.ByteSize(v)
   399  	case bytesize.ByteSize:
   400  		return v
   401  	default:
   402  		p.logger.WithField("key", key).WithField("raw_type", fmt.Sprintf("%T", v)).WithField("raw_value", fmt.Sprintf("%+v", v)).Errorf("error converting byte size value because of unknown type, using fallback of %s", fallback)
   403  		return fallback
   404  	}
   405  }
   406  
   407  func (p *Provider) GetF(key string, fallback interface{}) (val interface{}) {
   408  	if !p.Exists(key) {
   409  		return fallback
   410  	}
   411  
   412  	return p.Get(key)
   413  }
   414  
   415  func (p *Provider) CORS(prefix string, defaults cors.Options) (cors.Options, bool) {
   416  	if len(prefix) > 0 {
   417  		prefix = strings.TrimRight(prefix, ".") + "."
   418  	}
   419  
   420  	return cors.Options{
   421  		AllowedOrigins:     p.StringsF(prefix+"cors.allowed_origins", defaults.AllowedOrigins),
   422  		AllowedMethods:     p.StringsF(prefix+"cors.allowed_methods", defaults.AllowedMethods),
   423  		AllowedHeaders:     p.StringsF(prefix+"cors.allowed_headers", defaults.AllowedHeaders),
   424  		ExposedHeaders:     p.StringsF(prefix+"cors.exposed_headers", defaults.ExposedHeaders),
   425  		AllowCredentials:   p.BoolF(prefix+"cors.allow_credentials", defaults.AllowCredentials),
   426  		OptionsPassthrough: p.BoolF(prefix+"cors.options_passthrough", defaults.OptionsPassthrough),
   427  		MaxAge:             p.IntF(prefix+"cors.max_age", defaults.MaxAge),
   428  		Debug:              p.BoolF(prefix+"cors.debug", defaults.Debug),
   429  	}, p.Bool(prefix + "cors.enabled")
   430  }
   431  
   432  func (p *Provider) TracingConfig(serviceName string) *tracing.Config {
   433  	return &tracing.Config{
   434  		ServiceName: p.StringF("tracing.service_name", serviceName),
   435  		Provider:    p.String("tracing.provider"),
   436  		Jaeger: &tracing.JaegerConfig{
   437  			LocalAgentHostPort: p.String("tracing.providers.jaeger.local_agent_address"),
   438  			SamplerType:        p.StringF("tracing.providers.jaeger.sampling.type", "const"),
   439  			SamplerValue:       p.Float64F("tracing.providers.jaeger.sampling.value", float64(1)),
   440  			SamplerServerURL:   p.String("tracing.providers.jaeger.sampling.server_url"),
   441  			Propagation: stringsx.Coalesce(
   442  				os.Getenv("JAEGER_PROPAGATION"),
   443  				p.String("tracing.providers.jaeger.propagation"),
   444  			),
   445  		},
   446  		Zipkin: &tracing.ZipkinConfig{
   447  			ServerURL: p.String("tracing.providers.zipkin.server_url"),
   448  		},
   449  	}
   450  }
   451  
   452  func (p *Provider) RequestURIF(path string, fallback *url.URL) *url.URL {
   453  	switch t := p.Get(path).(type) {
   454  	case *url.URL:
   455  		return t
   456  	case url.URL:
   457  		return &t
   458  	case string:
   459  		if parsed, err := url.ParseRequestURI(t); err == nil {
   460  			return parsed
   461  		}
   462  	}
   463  
   464  	return fallback
   465  }
   466  
   467  func (p *Provider) URIF(path string, fallback *url.URL) *url.URL {
   468  	switch t := p.Get(path).(type) {
   469  	case *url.URL:
   470  		return t
   471  	case url.URL:
   472  		return &t
   473  	case string:
   474  		if parsed, err := url.Parse(t); err == nil {
   475  			return parsed
   476  		}
   477  	}
   478  
   479  	return fallback
   480  }
   481  
   482  // PrintHumanReadableValidationErrors prints human readable validation errors. Duh.
   483  func (p *Provider) PrintHumanReadableValidationErrors(w io.Writer, err error) {
   484  	p.printHumanReadableValidationErrors(p.Koanf, w, err)
   485  }
   486  
   487  func (p *Provider) printHumanReadableValidationErrors(k *koanf.Koanf, w io.Writer, err error) {
   488  	if err == nil {
   489  		return
   490  	}
   491  
   492  	_, _ = fmt.Fprintln(os.Stderr, "")
   493  	conf, innerErr := k.Marshal(json.Parser())
   494  	if innerErr != nil {
   495  		_, _ = fmt.Fprintf(w, "Unable to unmarshal configuration: %+v", innerErr)
   496  	}
   497  
   498  	jsonschemax.FormatValidationErrorForCLI(w, conf, err)
   499  }
   500  

View as plain text