...

Source file src/github.com/pelletier/go-toml/v2/unmarshaler.go

Documentation: github.com/pelletier/go-toml/v2

     1  package toml
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"math"
    10  	"reflect"
    11  	"strings"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/pelletier/go-toml/v2/internal/danger"
    16  	"github.com/pelletier/go-toml/v2/internal/tracker"
    17  	"github.com/pelletier/go-toml/v2/unstable"
    18  )
    19  
    20  // Unmarshal deserializes a TOML document into a Go value.
    21  //
    22  // It is a shortcut for Decoder.Decode() with the default options.
    23  func Unmarshal(data []byte, v interface{}) error {
    24  	p := unstable.Parser{}
    25  	p.Reset(data)
    26  	d := decoder{p: &p}
    27  
    28  	return d.FromParser(v)
    29  }
    30  
    31  // Decoder reads and decode a TOML document from an input stream.
    32  type Decoder struct {
    33  	// input
    34  	r io.Reader
    35  
    36  	// global settings
    37  	strict bool
    38  
    39  	// toggles unmarshaler interface
    40  	unmarshalerInterface bool
    41  }
    42  
    43  // NewDecoder creates a new Decoder that will read from r.
    44  func NewDecoder(r io.Reader) *Decoder {
    45  	return &Decoder{r: r}
    46  }
    47  
    48  // DisallowUnknownFields causes the Decoder to return an error when the
    49  // destination is a struct and the input contains a key that does not match a
    50  // non-ignored field.
    51  //
    52  // In that case, the Decoder returns a StrictMissingError that can be used to
    53  // retrieve the individual errors as well as generate a human readable
    54  // description of the missing fields.
    55  func (d *Decoder) DisallowUnknownFields() *Decoder {
    56  	d.strict = true
    57  	return d
    58  }
    59  
    60  // EnableUnmarshalerInterface allows to enable unmarshaler interface.
    61  //
    62  // With this feature enabled, types implementing the unstable/Unmarshaler
    63  // interface can be decoded from any structure of the document. It allows types
    64  // that don't have a straightfoward TOML representation to provide their own
    65  // decoding logic.
    66  //
    67  // Currently, types can only decode from a single value. Tables and array tables
    68  // are not supported.
    69  //
    70  // *Unstable:* This method does not follow the compatibility guarantees of
    71  // semver. It can be changed or removed without a new major version being
    72  // issued.
    73  func (d *Decoder) EnableUnmarshalerInterface() *Decoder {
    74  	d.unmarshalerInterface = true
    75  	return d
    76  }
    77  
    78  // Decode the whole content of r into v.
    79  //
    80  // By default, values in the document that don't exist in the target Go value
    81  // are ignored. See Decoder.DisallowUnknownFields() to change this behavior.
    82  //
    83  // When a TOML local date, time, or date-time is decoded into a time.Time, its
    84  // value is represented in time.Local timezone. Otherwise the appropriate Local*
    85  // structure is used. For time values, precision up to the nanosecond is
    86  // supported by truncating extra digits.
    87  //
    88  // Empty tables decoded in an interface{} create an empty initialized
    89  // map[string]interface{}.
    90  //
    91  // Types implementing the encoding.TextUnmarshaler interface are decoded from a
    92  // TOML string.
    93  //
    94  // When decoding a number, go-toml will return an error if the number is out of
    95  // bounds for the target type (which includes negative numbers when decoding
    96  // into an unsigned int).
    97  //
    98  // If an error occurs while decoding the content of the document, this function
    99  // returns a toml.DecodeError, providing context about the issue. When using
   100  // strict mode and a field is missing, a `toml.StrictMissingError` is
   101  // returned. In any other case, this function returns a standard Go error.
   102  //
   103  // # Type mapping
   104  //
   105  // List of supported TOML types and their associated accepted Go types:
   106  //
   107  //	String           -> string
   108  //	Integer          -> uint*, int*, depending on size
   109  //	Float            -> float*, depending on size
   110  //	Boolean          -> bool
   111  //	Offset Date-Time -> time.Time
   112  //	Local Date-time  -> LocalDateTime, time.Time
   113  //	Local Date       -> LocalDate, time.Time
   114  //	Local Time       -> LocalTime, time.Time
   115  //	Array            -> slice and array, depending on elements types
   116  //	Table            -> map and struct
   117  //	Inline Table     -> same as Table
   118  //	Array of Tables  -> same as Array and Table
   119  func (d *Decoder) Decode(v interface{}) error {
   120  	b, err := ioutil.ReadAll(d.r)
   121  	if err != nil {
   122  		return fmt.Errorf("toml: %w", err)
   123  	}
   124  
   125  	p := unstable.Parser{}
   126  	p.Reset(b)
   127  	dec := decoder{
   128  		p: &p,
   129  		strict: strict{
   130  			Enabled: d.strict,
   131  		},
   132  		unmarshalerInterface: d.unmarshalerInterface,
   133  	}
   134  
   135  	return dec.FromParser(v)
   136  }
   137  
   138  type decoder struct {
   139  	// Which parser instance in use for this decoding session.
   140  	p *unstable.Parser
   141  
   142  	// Flag indicating that the current expression is stashed.
   143  	// If set to true, calling nextExpr will not actually pull a new expression
   144  	// but turn off the flag instead.
   145  	stashedExpr bool
   146  
   147  	// Skip expressions until a table is found. This is set to true when a
   148  	// table could not be created (missing field in map), so all KV expressions
   149  	// need to be skipped.
   150  	skipUntilTable bool
   151  
   152  	// Flag indicating that the current array/slice table should be cleared because
   153  	// it is the first encounter of an array table.
   154  	clearArrayTable bool
   155  
   156  	// Tracks position in Go arrays.
   157  	// This is used when decoding [[array tables]] into Go arrays. Given array
   158  	// tables are separate TOML expression, we need to keep track of where we
   159  	// are at in the Go array, as we can't just introspect its size.
   160  	arrayIndexes map[reflect.Value]int
   161  
   162  	// Tracks keys that have been seen, with which type.
   163  	seen tracker.SeenTracker
   164  
   165  	// Strict mode
   166  	strict strict
   167  
   168  	// Flag that enables/disables unmarshaler interface.
   169  	unmarshalerInterface bool
   170  
   171  	// Current context for the error.
   172  	errorContext *errorContext
   173  }
   174  
   175  type errorContext struct {
   176  	Struct reflect.Type
   177  	Field  []int
   178  }
   179  
   180  func (d *decoder) typeMismatchError(toml string, target reflect.Type) error {
   181  	return fmt.Errorf("toml: %s", d.typeMismatchString(toml, target))
   182  }
   183  
   184  func (d *decoder) typeMismatchString(toml string, target reflect.Type) string {
   185  	if d.errorContext != nil && d.errorContext.Struct != nil {
   186  		ctx := d.errorContext
   187  		f := ctx.Struct.FieldByIndex(ctx.Field)
   188  		return fmt.Sprintf("cannot decode TOML %s into struct field %s.%s of type %s", toml, ctx.Struct, f.Name, f.Type)
   189  	}
   190  	return fmt.Sprintf("cannot decode TOML %s into a Go value of type %s", toml, target)
   191  }
   192  
   193  func (d *decoder) expr() *unstable.Node {
   194  	return d.p.Expression()
   195  }
   196  
   197  func (d *decoder) nextExpr() bool {
   198  	if d.stashedExpr {
   199  		d.stashedExpr = false
   200  		return true
   201  	}
   202  	return d.p.NextExpression()
   203  }
   204  
   205  func (d *decoder) stashExpr() {
   206  	d.stashedExpr = true
   207  }
   208  
   209  func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int {
   210  	if d.arrayIndexes == nil {
   211  		d.arrayIndexes = make(map[reflect.Value]int, 1)
   212  	}
   213  
   214  	idx, ok := d.arrayIndexes[v]
   215  
   216  	if !ok {
   217  		d.arrayIndexes[v] = 0
   218  	} else if shouldAppend {
   219  		idx++
   220  		d.arrayIndexes[v] = idx
   221  	}
   222  
   223  	return idx
   224  }
   225  
   226  func (d *decoder) FromParser(v interface{}) error {
   227  	r := reflect.ValueOf(v)
   228  	if r.Kind() != reflect.Ptr {
   229  		return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
   230  	}
   231  
   232  	if r.IsNil() {
   233  		return fmt.Errorf("toml: decoding pointer target cannot be nil")
   234  	}
   235  
   236  	r = r.Elem()
   237  	if r.Kind() == reflect.Interface && r.IsNil() {
   238  		newMap := map[string]interface{}{}
   239  		r.Set(reflect.ValueOf(newMap))
   240  	}
   241  
   242  	err := d.fromParser(r)
   243  	if err == nil {
   244  		return d.strict.Error(d.p.Data())
   245  	}
   246  
   247  	var e *unstable.ParserError
   248  	if errors.As(err, &e) {
   249  		return wrapDecodeError(d.p.Data(), e)
   250  	}
   251  
   252  	return err
   253  }
   254  
   255  func (d *decoder) fromParser(root reflect.Value) error {
   256  	for d.nextExpr() {
   257  		err := d.handleRootExpression(d.expr(), root)
   258  		if err != nil {
   259  			return err
   260  		}
   261  	}
   262  
   263  	return d.p.Error()
   264  }
   265  
   266  /*
   267  Rules for the unmarshal code:
   268  
   269  - The stack is used to keep track of which values need to be set where.
   270  - handle* functions <=> switch on a given unstable.Kind.
   271  - unmarshalX* functions need to unmarshal a node of kind X.
   272  - An "object" is either a struct or a map.
   273  */
   274  
   275  func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
   276  	var x reflect.Value
   277  	var err error
   278  	var first bool // used for to clear array tables on first use
   279  
   280  	if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
   281  		first, err = d.seen.CheckExpression(expr)
   282  		if err != nil {
   283  			return err
   284  		}
   285  	}
   286  
   287  	switch expr.Kind {
   288  	case unstable.KeyValue:
   289  		if d.skipUntilTable {
   290  			return nil
   291  		}
   292  		x, err = d.handleKeyValue(expr, v)
   293  	case unstable.Table:
   294  		d.skipUntilTable = false
   295  		d.strict.EnterTable(expr)
   296  		x, err = d.handleTable(expr.Key(), v)
   297  	case unstable.ArrayTable:
   298  		d.skipUntilTable = false
   299  		d.strict.EnterArrayTable(expr)
   300  		d.clearArrayTable = first
   301  		x, err = d.handleArrayTable(expr.Key(), v)
   302  	default:
   303  		panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
   304  	}
   305  
   306  	if d.skipUntilTable {
   307  		if expr.Kind == unstable.Table || expr.Kind == unstable.ArrayTable {
   308  			d.strict.MissingTable(expr)
   309  		}
   310  	} else if err == nil && x.IsValid() {
   311  		v.Set(x)
   312  	}
   313  
   314  	return err
   315  }
   316  
   317  func (d *decoder) handleArrayTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   318  	if key.Next() {
   319  		return d.handleArrayTablePart(key, v)
   320  	}
   321  	return d.handleKeyValues(v)
   322  }
   323  
   324  func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   325  	switch v.Kind() {
   326  	case reflect.Interface:
   327  		elem := v.Elem()
   328  		if !elem.IsValid() {
   329  			elem = reflect.New(sliceInterfaceType).Elem()
   330  			elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   331  		} else if elem.Kind() == reflect.Slice {
   332  			if elem.Type() != sliceInterfaceType {
   333  				elem = reflect.New(sliceInterfaceType).Elem()
   334  				elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   335  			} else if !elem.CanSet() {
   336  				nelem := reflect.New(sliceInterfaceType).Elem()
   337  				nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap()))
   338  				reflect.Copy(nelem, elem)
   339  				elem = nelem
   340  			}
   341  			if d.clearArrayTable && elem.Len() > 0 {
   342  				elem.SetLen(0)
   343  				d.clearArrayTable = false
   344  			}
   345  		}
   346  		return d.handleArrayTableCollectionLast(key, elem)
   347  	case reflect.Ptr:
   348  		elem := v.Elem()
   349  		if !elem.IsValid() {
   350  			ptr := reflect.New(v.Type().Elem())
   351  			v.Set(ptr)
   352  			elem = ptr.Elem()
   353  		}
   354  
   355  		elem, err := d.handleArrayTableCollectionLast(key, elem)
   356  		if err != nil {
   357  			return reflect.Value{}, err
   358  		}
   359  		v.Elem().Set(elem)
   360  
   361  		return v, nil
   362  	case reflect.Slice:
   363  		if d.clearArrayTable && v.Len() > 0 {
   364  			v.SetLen(0)
   365  			d.clearArrayTable = false
   366  		}
   367  		elemType := v.Type().Elem()
   368  		var elem reflect.Value
   369  		if elemType.Kind() == reflect.Interface {
   370  			elem = makeMapStringInterface()
   371  		} else {
   372  			elem = reflect.New(elemType).Elem()
   373  		}
   374  		elem2, err := d.handleArrayTable(key, elem)
   375  		if err != nil {
   376  			return reflect.Value{}, err
   377  		}
   378  		if elem2.IsValid() {
   379  			elem = elem2
   380  		}
   381  		return reflect.Append(v, elem), nil
   382  	case reflect.Array:
   383  		idx := d.arrayIndex(true, v)
   384  		if idx >= v.Len() {
   385  			return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
   386  		}
   387  		elem := v.Index(idx)
   388  		_, err := d.handleArrayTable(key, elem)
   389  		return v, err
   390  	default:
   391  		return reflect.Value{}, d.typeMismatchError("array table", v.Type())
   392  	}
   393  }
   394  
   395  // When parsing an array table expression, each part of the key needs to be
   396  // evaluated like a normal key, but if it returns a collection, it also needs to
   397  // point to the last element of the collection. Unless it is the last part of
   398  // the key, then it needs to create a new element at the end.
   399  func (d *decoder) handleArrayTableCollection(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   400  	if key.IsLast() {
   401  		return d.handleArrayTableCollectionLast(key, v)
   402  	}
   403  
   404  	switch v.Kind() {
   405  	case reflect.Ptr:
   406  		elem := v.Elem()
   407  		if !elem.IsValid() {
   408  			ptr := reflect.New(v.Type().Elem())
   409  			v.Set(ptr)
   410  			elem = ptr.Elem()
   411  		}
   412  
   413  		elem, err := d.handleArrayTableCollection(key, elem)
   414  		if err != nil {
   415  			return reflect.Value{}, err
   416  		}
   417  		if elem.IsValid() {
   418  			v.Elem().Set(elem)
   419  		}
   420  
   421  		return v, nil
   422  	case reflect.Slice:
   423  		elem := v.Index(v.Len() - 1)
   424  		x, err := d.handleArrayTable(key, elem)
   425  		if err != nil || d.skipUntilTable {
   426  			return reflect.Value{}, err
   427  		}
   428  		if x.IsValid() {
   429  			elem.Set(x)
   430  		}
   431  
   432  		return v, err
   433  	case reflect.Array:
   434  		idx := d.arrayIndex(false, v)
   435  		if idx >= v.Len() {
   436  			return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
   437  		}
   438  		elem := v.Index(idx)
   439  		_, err := d.handleArrayTable(key, elem)
   440  		return v, err
   441  	}
   442  
   443  	return d.handleArrayTable(key, v)
   444  }
   445  
   446  func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) {
   447  	var rv reflect.Value
   448  
   449  	// First, dispatch over v to make sure it is a valid object.
   450  	// There is no guarantee over what it could be.
   451  	switch v.Kind() {
   452  	case reflect.Ptr:
   453  		elem := v.Elem()
   454  		if !elem.IsValid() {
   455  			v.Set(reflect.New(v.Type().Elem()))
   456  		}
   457  		elem = v.Elem()
   458  		return d.handleKeyPart(key, elem, nextFn, makeFn)
   459  	case reflect.Map:
   460  		vt := v.Type()
   461  
   462  		// Create the key for the map element. Convert to key type.
   463  		mk, err := d.keyFromData(vt.Key(), key.Node().Data)
   464  		if err != nil {
   465  			return reflect.Value{}, err
   466  		}
   467  
   468  		// If the map does not exist, create it.
   469  		if v.IsNil() {
   470  			vt := v.Type()
   471  			v = reflect.MakeMap(vt)
   472  			rv = v
   473  		}
   474  
   475  		mv := v.MapIndex(mk)
   476  		set := false
   477  		if !mv.IsValid() {
   478  			// If there is no value in the map, create a new one according to
   479  			// the map type. If the element type is interface, create either a
   480  			// map[string]interface{} or a []interface{} depending on whether
   481  			// this is the last part of the array table key.
   482  
   483  			t := vt.Elem()
   484  			if t.Kind() == reflect.Interface {
   485  				mv = makeFn()
   486  			} else {
   487  				mv = reflect.New(t).Elem()
   488  			}
   489  			set = true
   490  		} else if mv.Kind() == reflect.Interface {
   491  			mv = mv.Elem()
   492  			if !mv.IsValid() {
   493  				mv = makeFn()
   494  			}
   495  			set = true
   496  		} else if !mv.CanAddr() {
   497  			vt := v.Type()
   498  			t := vt.Elem()
   499  			oldmv := mv
   500  			mv = reflect.New(t).Elem()
   501  			mv.Set(oldmv)
   502  			set = true
   503  		}
   504  
   505  		x, err := nextFn(key, mv)
   506  		if err != nil {
   507  			return reflect.Value{}, err
   508  		}
   509  
   510  		if x.IsValid() {
   511  			mv = x
   512  			set = true
   513  		}
   514  
   515  		if set {
   516  			v.SetMapIndex(mk, mv)
   517  		}
   518  	case reflect.Struct:
   519  		path, found := structFieldPath(v, string(key.Node().Data))
   520  		if !found {
   521  			d.skipUntilTable = true
   522  			return reflect.Value{}, nil
   523  		}
   524  
   525  		if d.errorContext == nil {
   526  			d.errorContext = new(errorContext)
   527  		}
   528  		t := v.Type()
   529  		d.errorContext.Struct = t
   530  		d.errorContext.Field = path
   531  
   532  		f := fieldByIndex(v, path)
   533  		x, err := nextFn(key, f)
   534  		if err != nil || d.skipUntilTable {
   535  			return reflect.Value{}, err
   536  		}
   537  		if x.IsValid() {
   538  			f.Set(x)
   539  		}
   540  		d.errorContext.Field = nil
   541  		d.errorContext.Struct = nil
   542  	case reflect.Interface:
   543  		if v.Elem().IsValid() {
   544  			v = v.Elem()
   545  		} else {
   546  			v = makeMapStringInterface()
   547  		}
   548  
   549  		x, err := d.handleKeyPart(key, v, nextFn, makeFn)
   550  		if err != nil {
   551  			return reflect.Value{}, err
   552  		}
   553  		if x.IsValid() {
   554  			v = x
   555  		}
   556  		rv = v
   557  	default:
   558  		panic(fmt.Errorf("unhandled part: %s", v.Kind()))
   559  	}
   560  
   561  	return rv, nil
   562  }
   563  
   564  // HandleArrayTablePart navigates the Go structure v using the key v. It is
   565  // only used for the prefix (non-last) parts of an array-table. When
   566  // encountering a collection, it should go to the last element.
   567  func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   568  	var makeFn valueMakerFn
   569  	if key.IsLast() {
   570  		makeFn = makeSliceInterface
   571  	} else {
   572  		makeFn = makeMapStringInterface
   573  	}
   574  	return d.handleKeyPart(key, v, d.handleArrayTableCollection, makeFn)
   575  }
   576  
   577  // HandleTable returns a reference when it has checked the next expression but
   578  // cannot handle it.
   579  func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   580  	if v.Kind() == reflect.Slice {
   581  		if v.Len() == 0 {
   582  			return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
   583  		}
   584  		elem := v.Index(v.Len() - 1)
   585  		x, err := d.handleTable(key, elem)
   586  		if err != nil {
   587  			return reflect.Value{}, err
   588  		}
   589  		if x.IsValid() {
   590  			elem.Set(x)
   591  		}
   592  		return reflect.Value{}, nil
   593  	}
   594  	if key.Next() {
   595  		// Still scoping the key
   596  		return d.handleTablePart(key, v)
   597  	}
   598  	// Done scoping the key.
   599  	// Now handle all the key-value expressions in this table.
   600  	return d.handleKeyValues(v)
   601  }
   602  
   603  // Handle root expressions until the end of the document or the next
   604  // non-key-value.
   605  func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
   606  	var rv reflect.Value
   607  	for d.nextExpr() {
   608  		expr := d.expr()
   609  		if expr.Kind != unstable.KeyValue {
   610  			// Stash the expression so that fromParser can just loop and use
   611  			// the right handler.
   612  			// We could just recurse ourselves here, but at least this gives a
   613  			// chance to pop the stack a bit.
   614  			d.stashExpr()
   615  			break
   616  		}
   617  
   618  		_, err := d.seen.CheckExpression(expr)
   619  		if err != nil {
   620  			return reflect.Value{}, err
   621  		}
   622  
   623  		x, err := d.handleKeyValue(expr, v)
   624  		if err != nil {
   625  			return reflect.Value{}, err
   626  		}
   627  		if x.IsValid() {
   628  			v = x
   629  			rv = x
   630  		}
   631  	}
   632  	return rv, nil
   633  }
   634  
   635  type (
   636  	handlerFn    func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
   637  	valueMakerFn func() reflect.Value
   638  )
   639  
   640  func makeMapStringInterface() reflect.Value {
   641  	return reflect.MakeMap(mapStringInterfaceType)
   642  }
   643  
   644  func makeSliceInterface() reflect.Value {
   645  	return reflect.MakeSlice(sliceInterfaceType, 0, 16)
   646  }
   647  
   648  func (d *decoder) handleTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   649  	return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface)
   650  }
   651  
   652  func (d *decoder) tryTextUnmarshaler(node *unstable.Node, v reflect.Value) (bool, error) {
   653  	// Special case for time, because we allow to unmarshal to it from
   654  	// different kind of AST nodes.
   655  	if v.Type() == timeType {
   656  		return false, nil
   657  	}
   658  
   659  	if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
   660  		err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
   661  		if err != nil {
   662  			return false, unstable.NewParserError(d.p.Raw(node.Raw), "%w", err)
   663  		}
   664  
   665  		return true, nil
   666  	}
   667  
   668  	return false, nil
   669  }
   670  
   671  func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
   672  	for v.Kind() == reflect.Ptr {
   673  		v = initAndDereferencePointer(v)
   674  	}
   675  
   676  	if d.unmarshalerInterface {
   677  		if v.CanAddr() && v.Addr().CanInterface() {
   678  			if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
   679  				return outi.UnmarshalTOML(value)
   680  			}
   681  		}
   682  	}
   683  
   684  	ok, err := d.tryTextUnmarshaler(value, v)
   685  	if ok || err != nil {
   686  		return err
   687  	}
   688  
   689  	switch value.Kind {
   690  	case unstable.String:
   691  		return d.unmarshalString(value, v)
   692  	case unstable.Integer:
   693  		return d.unmarshalInteger(value, v)
   694  	case unstable.Float:
   695  		return d.unmarshalFloat(value, v)
   696  	case unstable.Bool:
   697  		return d.unmarshalBool(value, v)
   698  	case unstable.DateTime:
   699  		return d.unmarshalDateTime(value, v)
   700  	case unstable.LocalDate:
   701  		return d.unmarshalLocalDate(value, v)
   702  	case unstable.LocalTime:
   703  		return d.unmarshalLocalTime(value, v)
   704  	case unstable.LocalDateTime:
   705  		return d.unmarshalLocalDateTime(value, v)
   706  	case unstable.InlineTable:
   707  		return d.unmarshalInlineTable(value, v)
   708  	case unstable.Array:
   709  		return d.unmarshalArray(value, v)
   710  	default:
   711  		panic(fmt.Errorf("handleValue not implemented for %s", value.Kind))
   712  	}
   713  }
   714  
   715  func (d *decoder) unmarshalArray(array *unstable.Node, v reflect.Value) error {
   716  	switch v.Kind() {
   717  	case reflect.Slice:
   718  		if v.IsNil() {
   719  			v.Set(reflect.MakeSlice(v.Type(), 0, 16))
   720  		} else {
   721  			v.SetLen(0)
   722  		}
   723  	case reflect.Array:
   724  		// arrays are always initialized
   725  	case reflect.Interface:
   726  		elem := v.Elem()
   727  		if !elem.IsValid() {
   728  			elem = reflect.New(sliceInterfaceType).Elem()
   729  			elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   730  		} else if elem.Kind() == reflect.Slice {
   731  			if elem.Type() != sliceInterfaceType {
   732  				elem = reflect.New(sliceInterfaceType).Elem()
   733  				elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   734  			} else if !elem.CanSet() {
   735  				nelem := reflect.New(sliceInterfaceType).Elem()
   736  				nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap()))
   737  				reflect.Copy(nelem, elem)
   738  				elem = nelem
   739  			}
   740  		}
   741  		err := d.unmarshalArray(array, elem)
   742  		if err != nil {
   743  			return err
   744  		}
   745  		v.Set(elem)
   746  		return nil
   747  	default:
   748  		// TODO: use newDecodeError, but first the parser needs to fill
   749  		//   array.Data.
   750  		return d.typeMismatchError("array", v.Type())
   751  	}
   752  
   753  	elemType := v.Type().Elem()
   754  
   755  	it := array.Children()
   756  	idx := 0
   757  	for it.Next() {
   758  		n := it.Node()
   759  
   760  		// TODO: optimize
   761  		if v.Kind() == reflect.Slice {
   762  			elem := reflect.New(elemType).Elem()
   763  
   764  			err := d.handleValue(n, elem)
   765  			if err != nil {
   766  				return err
   767  			}
   768  
   769  			v.Set(reflect.Append(v, elem))
   770  		} else { // array
   771  			if idx >= v.Len() {
   772  				return nil
   773  			}
   774  			elem := v.Index(idx)
   775  			err := d.handleValue(n, elem)
   776  			if err != nil {
   777  				return err
   778  			}
   779  			idx++
   780  		}
   781  	}
   782  
   783  	return nil
   784  }
   785  
   786  func (d *decoder) unmarshalInlineTable(itable *unstable.Node, v reflect.Value) error {
   787  	// Make sure v is an initialized object.
   788  	switch v.Kind() {
   789  	case reflect.Map:
   790  		if v.IsNil() {
   791  			v.Set(reflect.MakeMap(v.Type()))
   792  		}
   793  	case reflect.Struct:
   794  	// structs are always initialized.
   795  	case reflect.Interface:
   796  		elem := v.Elem()
   797  		if !elem.IsValid() {
   798  			elem = makeMapStringInterface()
   799  			v.Set(elem)
   800  		}
   801  		return d.unmarshalInlineTable(itable, elem)
   802  	default:
   803  		return unstable.NewParserError(d.p.Raw(itable.Raw), "cannot store inline table in Go type %s", v.Kind())
   804  	}
   805  
   806  	it := itable.Children()
   807  	for it.Next() {
   808  		n := it.Node()
   809  
   810  		x, err := d.handleKeyValue(n, v)
   811  		if err != nil {
   812  			return err
   813  		}
   814  		if x.IsValid() {
   815  			v = x
   816  		}
   817  	}
   818  
   819  	return nil
   820  }
   821  
   822  func (d *decoder) unmarshalDateTime(value *unstable.Node, v reflect.Value) error {
   823  	dt, err := parseDateTime(value.Data)
   824  	if err != nil {
   825  		return err
   826  	}
   827  
   828  	v.Set(reflect.ValueOf(dt))
   829  	return nil
   830  }
   831  
   832  func (d *decoder) unmarshalLocalDate(value *unstable.Node, v reflect.Value) error {
   833  	ld, err := parseLocalDate(value.Data)
   834  	if err != nil {
   835  		return err
   836  	}
   837  
   838  	if v.Type() == timeType {
   839  		cast := ld.AsTime(time.Local)
   840  		v.Set(reflect.ValueOf(cast))
   841  		return nil
   842  	}
   843  
   844  	v.Set(reflect.ValueOf(ld))
   845  
   846  	return nil
   847  }
   848  
   849  func (d *decoder) unmarshalLocalTime(value *unstable.Node, v reflect.Value) error {
   850  	lt, rest, err := parseLocalTime(value.Data)
   851  	if err != nil {
   852  		return err
   853  	}
   854  
   855  	if len(rest) > 0 {
   856  		return unstable.NewParserError(rest, "extra characters at the end of a local time")
   857  	}
   858  
   859  	v.Set(reflect.ValueOf(lt))
   860  	return nil
   861  }
   862  
   863  func (d *decoder) unmarshalLocalDateTime(value *unstable.Node, v reflect.Value) error {
   864  	ldt, rest, err := parseLocalDateTime(value.Data)
   865  	if err != nil {
   866  		return err
   867  	}
   868  
   869  	if len(rest) > 0 {
   870  		return unstable.NewParserError(rest, "extra characters at the end of a local date time")
   871  	}
   872  
   873  	if v.Type() == timeType {
   874  		cast := ldt.AsTime(time.Local)
   875  
   876  		v.Set(reflect.ValueOf(cast))
   877  		return nil
   878  	}
   879  
   880  	v.Set(reflect.ValueOf(ldt))
   881  
   882  	return nil
   883  }
   884  
   885  func (d *decoder) unmarshalBool(value *unstable.Node, v reflect.Value) error {
   886  	b := value.Data[0] == 't'
   887  
   888  	switch v.Kind() {
   889  	case reflect.Bool:
   890  		v.SetBool(b)
   891  	case reflect.Interface:
   892  		v.Set(reflect.ValueOf(b))
   893  	default:
   894  		return unstable.NewParserError(value.Data, "cannot assign boolean to a %t", b)
   895  	}
   896  
   897  	return nil
   898  }
   899  
   900  func (d *decoder) unmarshalFloat(value *unstable.Node, v reflect.Value) error {
   901  	f, err := parseFloat(value.Data)
   902  	if err != nil {
   903  		return err
   904  	}
   905  
   906  	switch v.Kind() {
   907  	case reflect.Float64:
   908  		v.SetFloat(f)
   909  	case reflect.Float32:
   910  		if f > math.MaxFloat32 {
   911  			return unstable.NewParserError(value.Data, "number %f does not fit in a float32", f)
   912  		}
   913  		v.SetFloat(f)
   914  	case reflect.Interface:
   915  		v.Set(reflect.ValueOf(f))
   916  	default:
   917  		return unstable.NewParserError(value.Data, "float cannot be assigned to %s", v.Kind())
   918  	}
   919  
   920  	return nil
   921  }
   922  
   923  const (
   924  	maxInt = int64(^uint(0) >> 1)
   925  	minInt = -maxInt - 1
   926  )
   927  
   928  // Maximum value of uint for decoding. Currently the decoder parses the integer
   929  // into an int64. As a result, on architectures where uint is 64 bits, the
   930  // effective maximum uint we can decode is the maximum of int64. On
   931  // architectures where uint is 32 bits, the maximum value we can decode is
   932  // lower: the maximum of uint32. I didn't find a way to figure out this value at
   933  // compile time, so it is computed during initialization.
   934  var maxUint int64 = math.MaxInt64
   935  
   936  func init() {
   937  	m := uint64(^uint(0))
   938  	if m < uint64(maxUint) {
   939  		maxUint = int64(m)
   940  	}
   941  }
   942  
   943  func (d *decoder) unmarshalInteger(value *unstable.Node, v reflect.Value) error {
   944  	kind := v.Kind()
   945  	if kind == reflect.Float32 || kind == reflect.Float64 {
   946  		return d.unmarshalFloat(value, v)
   947  	}
   948  
   949  	i, err := parseInteger(value.Data)
   950  	if err != nil {
   951  		return err
   952  	}
   953  
   954  	var r reflect.Value
   955  
   956  	switch kind {
   957  	case reflect.Int64:
   958  		v.SetInt(i)
   959  		return nil
   960  	case reflect.Int32:
   961  		if i < math.MinInt32 || i > math.MaxInt32 {
   962  			return fmt.Errorf("toml: number %d does not fit in an int32", i)
   963  		}
   964  
   965  		r = reflect.ValueOf(int32(i))
   966  	case reflect.Int16:
   967  		if i < math.MinInt16 || i > math.MaxInt16 {
   968  			return fmt.Errorf("toml: number %d does not fit in an int16", i)
   969  		}
   970  
   971  		r = reflect.ValueOf(int16(i))
   972  	case reflect.Int8:
   973  		if i < math.MinInt8 || i > math.MaxInt8 {
   974  			return fmt.Errorf("toml: number %d does not fit in an int8", i)
   975  		}
   976  
   977  		r = reflect.ValueOf(int8(i))
   978  	case reflect.Int:
   979  		if i < minInt || i > maxInt {
   980  			return fmt.Errorf("toml: number %d does not fit in an int", i)
   981  		}
   982  
   983  		r = reflect.ValueOf(int(i))
   984  	case reflect.Uint64:
   985  		if i < 0 {
   986  			return fmt.Errorf("toml: negative number %d does not fit in an uint64", i)
   987  		}
   988  
   989  		r = reflect.ValueOf(uint64(i))
   990  	case reflect.Uint32:
   991  		if i < 0 || i > math.MaxUint32 {
   992  			return fmt.Errorf("toml: negative number %d does not fit in an uint32", i)
   993  		}
   994  
   995  		r = reflect.ValueOf(uint32(i))
   996  	case reflect.Uint16:
   997  		if i < 0 || i > math.MaxUint16 {
   998  			return fmt.Errorf("toml: negative number %d does not fit in an uint16", i)
   999  		}
  1000  
  1001  		r = reflect.ValueOf(uint16(i))
  1002  	case reflect.Uint8:
  1003  		if i < 0 || i > math.MaxUint8 {
  1004  			return fmt.Errorf("toml: negative number %d does not fit in an uint8", i)
  1005  		}
  1006  
  1007  		r = reflect.ValueOf(uint8(i))
  1008  	case reflect.Uint:
  1009  		if i < 0 || i > maxUint {
  1010  			return fmt.Errorf("toml: negative number %d does not fit in an uint", i)
  1011  		}
  1012  
  1013  		r = reflect.ValueOf(uint(i))
  1014  	case reflect.Interface:
  1015  		r = reflect.ValueOf(i)
  1016  	default:
  1017  		return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("integer", v.Type()))
  1018  	}
  1019  
  1020  	if !r.Type().AssignableTo(v.Type()) {
  1021  		r = r.Convert(v.Type())
  1022  	}
  1023  
  1024  	v.Set(r)
  1025  
  1026  	return nil
  1027  }
  1028  
  1029  func (d *decoder) unmarshalString(value *unstable.Node, v reflect.Value) error {
  1030  	switch v.Kind() {
  1031  	case reflect.String:
  1032  		v.SetString(string(value.Data))
  1033  	case reflect.Interface:
  1034  		v.Set(reflect.ValueOf(string(value.Data)))
  1035  	default:
  1036  		return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("string", v.Type()))
  1037  	}
  1038  
  1039  	return nil
  1040  }
  1041  
  1042  func (d *decoder) handleKeyValue(expr *unstable.Node, v reflect.Value) (reflect.Value, error) {
  1043  	d.strict.EnterKeyValue(expr)
  1044  
  1045  	v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v)
  1046  	if d.skipUntilTable {
  1047  		d.strict.MissingField(expr)
  1048  		d.skipUntilTable = false
  1049  	}
  1050  
  1051  	d.strict.ExitKeyValue(expr)
  1052  
  1053  	return v, err
  1054  }
  1055  
  1056  func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
  1057  	if key.Next() {
  1058  		// Still scoping the key
  1059  		return d.handleKeyValuePart(key, value, v)
  1060  	}
  1061  	// Done scoping the key.
  1062  	// v is whatever Go value we need to fill.
  1063  	return reflect.Value{}, d.handleValue(value, v)
  1064  }
  1065  
  1066  func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, error) {
  1067  	switch {
  1068  	case stringType.AssignableTo(keyType):
  1069  		return reflect.ValueOf(string(data)), nil
  1070  
  1071  	case stringType.ConvertibleTo(keyType):
  1072  		return reflect.ValueOf(string(data)).Convert(keyType), nil
  1073  
  1074  	case keyType.Implements(textUnmarshalerType):
  1075  		mk := reflect.New(keyType.Elem())
  1076  		if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
  1077  			return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
  1078  		}
  1079  		return mk, nil
  1080  
  1081  	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
  1082  		mk := reflect.New(keyType)
  1083  		if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
  1084  			return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
  1085  		}
  1086  		return mk.Elem(), nil
  1087  	}
  1088  	return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType)
  1089  }
  1090  
  1091  func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
  1092  	// contains the replacement for v
  1093  	var rv reflect.Value
  1094  
  1095  	// First, dispatch over v to make sure it is a valid object.
  1096  	// There is no guarantee over what it could be.
  1097  	switch v.Kind() {
  1098  	case reflect.Map:
  1099  		vt := v.Type()
  1100  
  1101  		mk, err := d.keyFromData(vt.Key(), key.Node().Data)
  1102  		if err != nil {
  1103  			return reflect.Value{}, err
  1104  		}
  1105  
  1106  		// If the map does not exist, create it.
  1107  		if v.IsNil() {
  1108  			v = reflect.MakeMap(vt)
  1109  			rv = v
  1110  		}
  1111  
  1112  		mv := v.MapIndex(mk)
  1113  		set := false
  1114  		if !mv.IsValid() || key.IsLast() {
  1115  			set = true
  1116  			mv = reflect.New(v.Type().Elem()).Elem()
  1117  		}
  1118  
  1119  		nv, err := d.handleKeyValueInner(key, value, mv)
  1120  		if err != nil {
  1121  			return reflect.Value{}, err
  1122  		}
  1123  		if nv.IsValid() {
  1124  			mv = nv
  1125  			set = true
  1126  		}
  1127  
  1128  		if set {
  1129  			v.SetMapIndex(mk, mv)
  1130  		}
  1131  	case reflect.Struct:
  1132  		path, found := structFieldPath(v, string(key.Node().Data))
  1133  		if !found {
  1134  			d.skipUntilTable = true
  1135  			break
  1136  		}
  1137  
  1138  		if d.errorContext == nil {
  1139  			d.errorContext = new(errorContext)
  1140  		}
  1141  		t := v.Type()
  1142  		d.errorContext.Struct = t
  1143  		d.errorContext.Field = path
  1144  
  1145  		f := fieldByIndex(v, path)
  1146  
  1147  		if !f.CanAddr() {
  1148  			// If the field is not addressable, need to take a slower path and
  1149  			// make a copy of the struct itself to a new location.
  1150  			nvp := reflect.New(v.Type())
  1151  			nvp.Elem().Set(v)
  1152  			v = nvp.Elem()
  1153  			_, err := d.handleKeyValuePart(key, value, v)
  1154  			if err != nil {
  1155  				return reflect.Value{}, err
  1156  			}
  1157  			return nvp.Elem(), nil
  1158  		}
  1159  		x, err := d.handleKeyValueInner(key, value, f)
  1160  		if err != nil {
  1161  			return reflect.Value{}, err
  1162  		}
  1163  
  1164  		if x.IsValid() {
  1165  			f.Set(x)
  1166  		}
  1167  		d.errorContext.Struct = nil
  1168  		d.errorContext.Field = nil
  1169  	case reflect.Interface:
  1170  		v = v.Elem()
  1171  
  1172  		// Following encoding/json: decoding an object into an
  1173  		// interface{}, it needs to always hold a
  1174  		// map[string]interface{}. This is for the types to be
  1175  		// consistent whether a previous value was set or not.
  1176  		if !v.IsValid() || v.Type() != mapStringInterfaceType {
  1177  			v = makeMapStringInterface()
  1178  		}
  1179  
  1180  		x, err := d.handleKeyValuePart(key, value, v)
  1181  		if err != nil {
  1182  			return reflect.Value{}, err
  1183  		}
  1184  		if x.IsValid() {
  1185  			v = x
  1186  		}
  1187  		rv = v
  1188  	case reflect.Ptr:
  1189  		elem := v.Elem()
  1190  		if !elem.IsValid() {
  1191  			ptr := reflect.New(v.Type().Elem())
  1192  			v.Set(ptr)
  1193  			rv = v
  1194  			elem = ptr.Elem()
  1195  		}
  1196  
  1197  		elem2, err := d.handleKeyValuePart(key, value, elem)
  1198  		if err != nil {
  1199  			return reflect.Value{}, err
  1200  		}
  1201  		if elem2.IsValid() {
  1202  			elem = elem2
  1203  		}
  1204  		v.Elem().Set(elem)
  1205  	default:
  1206  		return reflect.Value{}, fmt.Errorf("unhandled kv part: %s", v.Kind())
  1207  	}
  1208  
  1209  	return rv, nil
  1210  }
  1211  
  1212  func initAndDereferencePointer(v reflect.Value) reflect.Value {
  1213  	var elem reflect.Value
  1214  	if v.IsNil() {
  1215  		ptr := reflect.New(v.Type().Elem())
  1216  		v.Set(ptr)
  1217  	}
  1218  	elem = v.Elem()
  1219  	return elem
  1220  }
  1221  
  1222  // Same as reflect.Value.FieldByIndex, but creates pointers if needed.
  1223  func fieldByIndex(v reflect.Value, path []int) reflect.Value {
  1224  	for _, x := range path {
  1225  		v = v.Field(x)
  1226  
  1227  		if v.Kind() == reflect.Ptr {
  1228  			if v.IsNil() {
  1229  				v.Set(reflect.New(v.Type().Elem()))
  1230  			}
  1231  			v = v.Elem()
  1232  		}
  1233  	}
  1234  	return v
  1235  }
  1236  
  1237  type fieldPathsMap = map[string][]int
  1238  
  1239  var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap
  1240  
  1241  func structFieldPath(v reflect.Value, name string) ([]int, bool) {
  1242  	t := v.Type()
  1243  
  1244  	cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap)
  1245  	fieldPaths, ok := cache[danger.MakeTypeID(t)]
  1246  
  1247  	if !ok {
  1248  		fieldPaths = map[string][]int{}
  1249  
  1250  		forEachField(t, nil, func(name string, path []int) {
  1251  			fieldPaths[name] = path
  1252  			// extra copy for the case-insensitive match
  1253  			fieldPaths[strings.ToLower(name)] = path
  1254  		})
  1255  
  1256  		newCache := make(map[danger.TypeID]fieldPathsMap, len(cache)+1)
  1257  		newCache[danger.MakeTypeID(t)] = fieldPaths
  1258  		for k, v := range cache {
  1259  			newCache[k] = v
  1260  		}
  1261  		globalFieldPathsCache.Store(newCache)
  1262  	}
  1263  
  1264  	path, ok := fieldPaths[name]
  1265  	if !ok {
  1266  		path, ok = fieldPaths[strings.ToLower(name)]
  1267  	}
  1268  	return path, ok
  1269  }
  1270  
  1271  func forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
  1272  	n := t.NumField()
  1273  	for i := 0; i < n; i++ {
  1274  		f := t.Field(i)
  1275  
  1276  		if !f.Anonymous && f.PkgPath != "" {
  1277  			// only consider exported fields.
  1278  			continue
  1279  		}
  1280  
  1281  		fieldPath := append(path, i)
  1282  		fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]
  1283  
  1284  		name := f.Tag.Get("toml")
  1285  		if name == "-" {
  1286  			continue
  1287  		}
  1288  
  1289  		if i := strings.IndexByte(name, ','); i >= 0 {
  1290  			name = name[:i]
  1291  		}
  1292  
  1293  		if f.Anonymous && name == "" {
  1294  			t2 := f.Type
  1295  			if t2.Kind() == reflect.Ptr {
  1296  				t2 = t2.Elem()
  1297  			}
  1298  
  1299  			if t2.Kind() == reflect.Struct {
  1300  				forEachField(t2, fieldPath, do)
  1301  			}
  1302  			continue
  1303  		}
  1304  
  1305  		if name == "" {
  1306  			name = f.Name
  1307  		}
  1308  
  1309  		do(name, fieldPath)
  1310  	}
  1311  }
  1312  

View as plain text