...

Source file src/go.starlark.net/starlark/interp.go

Documentation: go.starlark.net/starlark

     1  package starlark
     2  
     3  // This file defines the bytecode interpreter.
     4  
     5  import (
     6  	"fmt"
     7  	"os"
     8  	"sync/atomic"
     9  	"unsafe"
    10  
    11  	"go.starlark.net/internal/compile"
    12  	"go.starlark.net/internal/spell"
    13  	"go.starlark.net/resolve"
    14  	"go.starlark.net/syntax"
    15  )
    16  
    17  const vmdebug = false // TODO(adonovan): use a bitfield of specific kinds of error.
    18  
    19  // TODO(adonovan):
    20  // - optimize position table.
    21  // - opt: record MaxIterStack during compilation and preallocate the stack.
    22  
    23  func (fn *Function) CallInternal(thread *Thread, args Tuple, kwargs []Tuple) (Value, error) {
    24  	// Postcondition: args is not mutated. This is stricter than required by Callable,
    25  	// but allows CALL to avoid a copy.
    26  
    27  	if !resolve.AllowRecursion {
    28  		// detect recursion
    29  		for _, fr := range thread.stack[:len(thread.stack)-1] {
    30  			// We look for the same function code,
    31  			// not function value, otherwise the user could
    32  			// defeat the check by writing the Y combinator.
    33  			if frfn, ok := fr.Callable().(*Function); ok && frfn.funcode == fn.funcode {
    34  				return nil, fmt.Errorf("function %s called recursively", fn.Name())
    35  			}
    36  		}
    37  	}
    38  
    39  	f := fn.funcode
    40  	fr := thread.frameAt(0)
    41  
    42  	// Allocate space for stack and locals.
    43  	// Logically these do not escape from this frame
    44  	// (See https://github.com/golang/go/issues/20533.)
    45  	//
    46  	// This heap allocation looks expensive, but I was unable to get
    47  	// more than 1% real time improvement in a large alloc-heavy
    48  	// benchmark (in which this alloc was 8% of alloc-bytes)
    49  	// by allocating space for 8 Values in each frame, or
    50  	// by allocating stack by slicing an array held by the Thread
    51  	// that is expanded in chunks of min(k, nspace), for k=256 or 1024.
    52  	nlocals := len(f.Locals)
    53  	nspace := nlocals + f.MaxStack
    54  	space := make([]Value, nspace)
    55  	locals := space[:nlocals:nlocals] // local variables, starting with parameters
    56  	stack := space[nlocals:]          // operand stack
    57  
    58  	// Digest arguments and set parameters.
    59  	err := setArgs(locals, fn, args, kwargs)
    60  	if err != nil {
    61  		return nil, thread.evalError(err)
    62  	}
    63  
    64  	fr.locals = locals
    65  
    66  	if vmdebug {
    67  		fmt.Printf("Entering %s @ %s\n", f.Name, f.Position(0))
    68  		fmt.Printf("%d stack, %d locals\n", len(stack), len(locals))
    69  		defer fmt.Println("Leaving ", f.Name)
    70  	}
    71  
    72  	// Spill indicated locals to cells.
    73  	// Each cell is a separate alloc to avoid spurious liveness.
    74  	for _, index := range f.Cells {
    75  		locals[index] = &cell{locals[index]}
    76  	}
    77  
    78  	// TODO(adonovan): add static check that beneath this point
    79  	// - there is exactly one return statement
    80  	// - there is no redefinition of 'err'.
    81  
    82  	var iterstack []Iterator // stack of active iterators
    83  
    84  	// Use defer so that application panics can pass through
    85  	// interpreter without leaving thread in a bad state.
    86  	defer func() {
    87  		// ITERPOP the rest of the iterator stack.
    88  		for _, iter := range iterstack {
    89  			iter.Done()
    90  		}
    91  
    92  		fr.locals = nil
    93  	}()
    94  
    95  	sp := 0
    96  	var pc uint32
    97  	var result Value
    98  	code := f.Code
    99  loop:
   100  	for {
   101  		thread.Steps++
   102  		if thread.Steps >= thread.maxSteps {
   103  			if thread.OnMaxSteps != nil {
   104  				thread.OnMaxSteps(thread)
   105  			} else {
   106  				thread.Cancel("too many steps")
   107  			}
   108  		}
   109  		if reason := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&thread.cancelReason))); reason != nil {
   110  			err = fmt.Errorf("Starlark computation cancelled: %s", *(*string)(reason))
   111  			break loop
   112  		}
   113  
   114  		fr.pc = pc
   115  
   116  		op := compile.Opcode(code[pc])
   117  		pc++
   118  		var arg uint32
   119  		if op >= compile.OpcodeArgMin {
   120  			// TODO(adonovan): opt: profile this.
   121  			// Perhaps compiling big endian would be less work to decode?
   122  			for s := uint(0); ; s += 7 {
   123  				b := code[pc]
   124  				pc++
   125  				arg |= uint32(b&0x7f) << s
   126  				if b < 0x80 {
   127  					break
   128  				}
   129  			}
   130  		}
   131  		if vmdebug {
   132  			fmt.Fprintln(os.Stderr, stack[:sp]) // very verbose!
   133  			compile.PrintOp(f, fr.pc, op, arg)
   134  		}
   135  
   136  		switch op {
   137  		case compile.NOP:
   138  			// nop
   139  
   140  		case compile.DUP:
   141  			stack[sp] = stack[sp-1]
   142  			sp++
   143  
   144  		case compile.DUP2:
   145  			stack[sp] = stack[sp-2]
   146  			stack[sp+1] = stack[sp-1]
   147  			sp += 2
   148  
   149  		case compile.POP:
   150  			sp--
   151  
   152  		case compile.EXCH:
   153  			stack[sp-2], stack[sp-1] = stack[sp-1], stack[sp-2]
   154  
   155  		case compile.EQL, compile.NEQ, compile.GT, compile.LT, compile.LE, compile.GE:
   156  			op := syntax.Token(op-compile.EQL) + syntax.EQL
   157  			y := stack[sp-1]
   158  			x := stack[sp-2]
   159  			sp -= 2
   160  			ok, err2 := Compare(op, x, y)
   161  			if err2 != nil {
   162  				err = err2
   163  				break loop
   164  			}
   165  			stack[sp] = Bool(ok)
   166  			sp++
   167  
   168  		case compile.PLUS,
   169  			compile.MINUS,
   170  			compile.STAR,
   171  			compile.SLASH,
   172  			compile.SLASHSLASH,
   173  			compile.PERCENT,
   174  			compile.AMP,
   175  			compile.PIPE,
   176  			compile.CIRCUMFLEX,
   177  			compile.LTLT,
   178  			compile.GTGT,
   179  			compile.IN:
   180  			binop := syntax.Token(op-compile.PLUS) + syntax.PLUS
   181  			if op == compile.IN {
   182  				binop = syntax.IN // IN token is out of order
   183  			}
   184  			y := stack[sp-1]
   185  			x := stack[sp-2]
   186  			sp -= 2
   187  			z, err2 := Binary(binop, x, y)
   188  			if err2 != nil {
   189  				err = err2
   190  				break loop
   191  			}
   192  			stack[sp] = z
   193  			sp++
   194  
   195  		case compile.UPLUS, compile.UMINUS, compile.TILDE:
   196  			var unop syntax.Token
   197  			if op == compile.TILDE {
   198  				unop = syntax.TILDE
   199  			} else {
   200  				unop = syntax.Token(op-compile.UPLUS) + syntax.PLUS
   201  			}
   202  			x := stack[sp-1]
   203  			y, err2 := Unary(unop, x)
   204  			if err2 != nil {
   205  				err = err2
   206  				break loop
   207  			}
   208  			stack[sp-1] = y
   209  
   210  		case compile.INPLACE_ADD:
   211  			y := stack[sp-1]
   212  			x := stack[sp-2]
   213  			sp -= 2
   214  
   215  			// It's possible that y is not Iterable but
   216  			// nonetheless defines x+y, in which case we
   217  			// should fall back to the general case.
   218  			var z Value
   219  			if xlist, ok := x.(*List); ok {
   220  				if yiter, ok := y.(Iterable); ok {
   221  					if err = xlist.checkMutable("apply += to"); err != nil {
   222  						break loop
   223  					}
   224  					listExtend(xlist, yiter)
   225  					z = xlist
   226  				}
   227  			}
   228  			if z == nil {
   229  				z, err = Binary(syntax.PLUS, x, y)
   230  				if err != nil {
   231  					break loop
   232  				}
   233  			}
   234  
   235  			stack[sp] = z
   236  			sp++
   237  
   238  		case compile.INPLACE_PIPE:
   239  			y := stack[sp-1]
   240  			x := stack[sp-2]
   241  			sp -= 2
   242  
   243  			// It's possible that y is not Dict but
   244  			// nonetheless defines x|y, in which case we
   245  			// should fall back to the general case.
   246  			var z Value
   247  			if xdict, ok := x.(*Dict); ok {
   248  				if ydict, ok := y.(*Dict); ok {
   249  					if err = xdict.ht.checkMutable("apply |= to"); err != nil {
   250  						break loop
   251  					}
   252  					xdict.ht.addAll(&ydict.ht) // can't fail
   253  					z = xdict
   254  				}
   255  			}
   256  			if z == nil {
   257  				z, err = Binary(syntax.PIPE, x, y)
   258  				if err != nil {
   259  					break loop
   260  				}
   261  			}
   262  
   263  			stack[sp] = z
   264  			sp++
   265  
   266  		case compile.NONE:
   267  			stack[sp] = None
   268  			sp++
   269  
   270  		case compile.TRUE:
   271  			stack[sp] = True
   272  			sp++
   273  
   274  		case compile.FALSE:
   275  			stack[sp] = False
   276  			sp++
   277  
   278  		case compile.MANDATORY:
   279  			stack[sp] = mandatory{}
   280  			sp++
   281  
   282  		case compile.JMP:
   283  			pc = arg
   284  
   285  		case compile.CALL, compile.CALL_VAR, compile.CALL_KW, compile.CALL_VAR_KW:
   286  			var kwargs Value
   287  			if op == compile.CALL_KW || op == compile.CALL_VAR_KW {
   288  				kwargs = stack[sp-1]
   289  				sp--
   290  			}
   291  
   292  			var args Value
   293  			if op == compile.CALL_VAR || op == compile.CALL_VAR_KW {
   294  				args = stack[sp-1]
   295  				sp--
   296  			}
   297  
   298  			// named args (pairs)
   299  			var kvpairs []Tuple
   300  			if nkvpairs := int(arg & 0xff); nkvpairs > 0 {
   301  				kvpairs = make([]Tuple, 0, nkvpairs)
   302  				kvpairsAlloc := make(Tuple, 2*nkvpairs) // allocate a single backing array
   303  				sp -= 2 * nkvpairs
   304  				for i := 0; i < nkvpairs; i++ {
   305  					pair := kvpairsAlloc[:2:2]
   306  					kvpairsAlloc = kvpairsAlloc[2:]
   307  					pair[0] = stack[sp+2*i]   // name
   308  					pair[1] = stack[sp+2*i+1] // value
   309  					kvpairs = append(kvpairs, pair)
   310  				}
   311  			}
   312  			if kwargs != nil {
   313  				// Add key/value items from **kwargs dictionary.
   314  				dict, ok := kwargs.(IterableMapping)
   315  				if !ok {
   316  					err = fmt.Errorf("argument after ** must be a mapping, not %s", kwargs.Type())
   317  					break loop
   318  				}
   319  				items := dict.Items()
   320  				for _, item := range items {
   321  					if _, ok := item[0].(String); !ok {
   322  						err = fmt.Errorf("keywords must be strings, not %s", item[0].Type())
   323  						break loop
   324  					}
   325  				}
   326  				if len(kvpairs) == 0 {
   327  					kvpairs = items
   328  				} else {
   329  					kvpairs = append(kvpairs, items...)
   330  				}
   331  			}
   332  
   333  			// positional args
   334  			var positional Tuple
   335  			if npos := int(arg >> 8); npos > 0 {
   336  				positional = stack[sp-npos : sp]
   337  				sp -= npos
   338  
   339  				// Copy positional arguments into a new array,
   340  				// unless the callee is another Starlark function,
   341  				// in which case it can be trusted not to mutate them.
   342  				if _, ok := stack[sp-1].(*Function); !ok || args != nil {
   343  					positional = append(Tuple(nil), positional...)
   344  				}
   345  			}
   346  			if args != nil {
   347  				// Add elements from *args sequence.
   348  				iter := Iterate(args)
   349  				if iter == nil {
   350  					err = fmt.Errorf("argument after * must be iterable, not %s", args.Type())
   351  					break loop
   352  				}
   353  				var elem Value
   354  				for iter.Next(&elem) {
   355  					positional = append(positional, elem)
   356  				}
   357  				iter.Done()
   358  			}
   359  
   360  			function := stack[sp-1]
   361  
   362  			if vmdebug {
   363  				fmt.Printf("VM call %s args=%s kwargs=%s @%s\n",
   364  					function, positional, kvpairs, f.Position(fr.pc))
   365  			}
   366  
   367  			thread.endProfSpan()
   368  			z, err2 := Call(thread, function, positional, kvpairs)
   369  			thread.beginProfSpan()
   370  			if err2 != nil {
   371  				err = err2
   372  				break loop
   373  			}
   374  			if vmdebug {
   375  				fmt.Printf("Resuming %s @ %s\n", f.Name, f.Position(0))
   376  			}
   377  			stack[sp-1] = z
   378  
   379  		case compile.ITERPUSH:
   380  			x := stack[sp-1]
   381  			sp--
   382  			iter := Iterate(x)
   383  			if iter == nil {
   384  				err = fmt.Errorf("%s value is not iterable", x.Type())
   385  				break loop
   386  			}
   387  			iterstack = append(iterstack, iter)
   388  
   389  		case compile.ITERJMP:
   390  			iter := iterstack[len(iterstack)-1]
   391  			if iter.Next(&stack[sp]) {
   392  				sp++
   393  			} else {
   394  				pc = arg
   395  			}
   396  
   397  		case compile.ITERPOP:
   398  			n := len(iterstack) - 1
   399  			iterstack[n].Done()
   400  			iterstack = iterstack[:n]
   401  
   402  		case compile.NOT:
   403  			stack[sp-1] = !stack[sp-1].Truth()
   404  
   405  		case compile.RETURN:
   406  			result = stack[sp-1]
   407  			break loop
   408  
   409  		case compile.SETINDEX:
   410  			z := stack[sp-1]
   411  			y := stack[sp-2]
   412  			x := stack[sp-3]
   413  			sp -= 3
   414  			err = setIndex(x, y, z)
   415  			if err != nil {
   416  				break loop
   417  			}
   418  
   419  		case compile.INDEX:
   420  			y := stack[sp-1]
   421  			x := stack[sp-2]
   422  			sp -= 2
   423  			z, err2 := getIndex(x, y)
   424  			if err2 != nil {
   425  				err = err2
   426  				break loop
   427  			}
   428  			stack[sp] = z
   429  			sp++
   430  
   431  		case compile.ATTR:
   432  			x := stack[sp-1]
   433  			name := f.Prog.Names[arg]
   434  			y, err2 := getAttr(x, name)
   435  			if err2 != nil {
   436  				err = err2
   437  				break loop
   438  			}
   439  			stack[sp-1] = y
   440  
   441  		case compile.SETFIELD:
   442  			y := stack[sp-1]
   443  			x := stack[sp-2]
   444  			sp -= 2
   445  			name := f.Prog.Names[arg]
   446  			if err2 := setField(x, name, y); err2 != nil {
   447  				err = err2
   448  				break loop
   449  			}
   450  
   451  		case compile.MAKEDICT:
   452  			stack[sp] = new(Dict)
   453  			sp++
   454  
   455  		case compile.SETDICT, compile.SETDICTUNIQ:
   456  			dict := stack[sp-3].(*Dict)
   457  			k := stack[sp-2]
   458  			v := stack[sp-1]
   459  			sp -= 3
   460  			oldlen := dict.Len()
   461  			if err2 := dict.SetKey(k, v); err2 != nil {
   462  				err = err2
   463  				break loop
   464  			}
   465  			if op == compile.SETDICTUNIQ && dict.Len() == oldlen {
   466  				err = fmt.Errorf("duplicate key: %v", k)
   467  				break loop
   468  			}
   469  
   470  		case compile.APPEND:
   471  			elem := stack[sp-1]
   472  			list := stack[sp-2].(*List)
   473  			sp -= 2
   474  			list.elems = append(list.elems, elem)
   475  
   476  		case compile.SLICE:
   477  			x := stack[sp-4]
   478  			lo := stack[sp-3]
   479  			hi := stack[sp-2]
   480  			step := stack[sp-1]
   481  			sp -= 4
   482  			res, err2 := slice(x, lo, hi, step)
   483  			if err2 != nil {
   484  				err = err2
   485  				break loop
   486  			}
   487  			stack[sp] = res
   488  			sp++
   489  
   490  		case compile.UNPACK:
   491  			n := int(arg)
   492  			iterable := stack[sp-1]
   493  			sp--
   494  			iter := Iterate(iterable)
   495  			if iter == nil {
   496  				err = fmt.Errorf("got %s in sequence assignment", iterable.Type())
   497  				break loop
   498  			}
   499  			i := 0
   500  			sp += n
   501  			for i < n && iter.Next(&stack[sp-1-i]) {
   502  				i++
   503  			}
   504  			var dummy Value
   505  			if iter.Next(&dummy) {
   506  				// NB: Len may return -1 here in obscure cases.
   507  				err = fmt.Errorf("too many values to unpack (got %d, want %d)", Len(iterable), n)
   508  				break loop
   509  			}
   510  			iter.Done()
   511  			if i < n {
   512  				err = fmt.Errorf("too few values to unpack (got %d, want %d)", i, n)
   513  				break loop
   514  			}
   515  
   516  		case compile.CJMP:
   517  			if stack[sp-1].Truth() {
   518  				pc = arg
   519  			}
   520  			sp--
   521  
   522  		case compile.CONSTANT:
   523  			stack[sp] = fn.module.constants[arg]
   524  			sp++
   525  
   526  		case compile.MAKETUPLE:
   527  			n := int(arg)
   528  			tuple := make(Tuple, n)
   529  			sp -= n
   530  			copy(tuple, stack[sp:])
   531  			stack[sp] = tuple
   532  			sp++
   533  
   534  		case compile.MAKELIST:
   535  			n := int(arg)
   536  			elems := make([]Value, n)
   537  			sp -= n
   538  			copy(elems, stack[sp:])
   539  			stack[sp] = NewList(elems)
   540  			sp++
   541  
   542  		case compile.MAKEFUNC:
   543  			funcode := f.Prog.Functions[arg]
   544  			tuple := stack[sp-1].(Tuple)
   545  			n := len(tuple) - len(funcode.Freevars)
   546  			defaults := tuple[:n:n]
   547  			freevars := tuple[n:]
   548  			stack[sp-1] = &Function{
   549  				funcode:  funcode,
   550  				module:   fn.module,
   551  				defaults: defaults,
   552  				freevars: freevars,
   553  			}
   554  
   555  		case compile.LOAD:
   556  			n := int(arg)
   557  			module := string(stack[sp-1].(String))
   558  			sp--
   559  
   560  			if thread.Load == nil {
   561  				err = fmt.Errorf("load not implemented by this application")
   562  				break loop
   563  			}
   564  
   565  			thread.endProfSpan()
   566  			dict, err2 := thread.Load(thread, module)
   567  			thread.beginProfSpan()
   568  			if err2 != nil {
   569  				err = wrappedError{
   570  					msg:   fmt.Sprintf("cannot load %s: %v", module, err2),
   571  					cause: err2,
   572  				}
   573  				break loop
   574  			}
   575  
   576  			for i := 0; i < n; i++ {
   577  				from := string(stack[sp-1-i].(String))
   578  				v, ok := dict[from]
   579  				if !ok {
   580  					err = fmt.Errorf("load: name %s not found in module %s", from, module)
   581  					if n := spell.Nearest(from, dict.Keys()); n != "" {
   582  						err = fmt.Errorf("%s (did you mean %s?)", err, n)
   583  					}
   584  					break loop
   585  				}
   586  				stack[sp-1-i] = v
   587  			}
   588  
   589  		case compile.SETLOCAL:
   590  			locals[arg] = stack[sp-1]
   591  			sp--
   592  
   593  		case compile.SETLOCALCELL:
   594  			locals[arg].(*cell).v = stack[sp-1]
   595  			sp--
   596  
   597  		case compile.SETGLOBAL:
   598  			fn.module.globals[arg] = stack[sp-1]
   599  			sp--
   600  
   601  		case compile.LOCAL:
   602  			x := locals[arg]
   603  			if x == nil {
   604  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   605  				break loop
   606  			}
   607  			stack[sp] = x
   608  			sp++
   609  
   610  		case compile.FREE:
   611  			stack[sp] = fn.freevars[arg]
   612  			sp++
   613  
   614  		case compile.LOCALCELL:
   615  			v := locals[arg].(*cell).v
   616  			if v == nil {
   617  				err = fmt.Errorf("local variable %s referenced before assignment", f.Locals[arg].Name)
   618  				break loop
   619  			}
   620  			stack[sp] = v
   621  			sp++
   622  
   623  		case compile.FREECELL:
   624  			v := fn.freevars[arg].(*cell).v
   625  			if v == nil {
   626  				err = fmt.Errorf("local variable %s referenced before assignment", f.Freevars[arg].Name)
   627  				break loop
   628  			}
   629  			stack[sp] = v
   630  			sp++
   631  
   632  		case compile.GLOBAL:
   633  			x := fn.module.globals[arg]
   634  			if x == nil {
   635  				err = fmt.Errorf("global variable %s referenced before assignment", f.Prog.Globals[arg].Name)
   636  				break loop
   637  			}
   638  			stack[sp] = x
   639  			sp++
   640  
   641  		case compile.PREDECLARED:
   642  			name := f.Prog.Names[arg]
   643  			x := fn.module.predeclared[name]
   644  			if x == nil {
   645  				err = fmt.Errorf("internal error: predeclared variable %s is uninitialized", name)
   646  				break loop
   647  			}
   648  			stack[sp] = x
   649  			sp++
   650  
   651  		case compile.UNIVERSAL:
   652  			stack[sp] = Universe[f.Prog.Names[arg]]
   653  			sp++
   654  
   655  		default:
   656  			err = fmt.Errorf("unimplemented: %s", op)
   657  			break loop
   658  		}
   659  	}
   660  	// (deferred cleanup runs here)
   661  	return result, err
   662  }
   663  
   664  type wrappedError struct {
   665  	msg   string
   666  	cause error
   667  }
   668  
   669  func (e wrappedError) Error() string {
   670  	return e.msg
   671  }
   672  
   673  // Implements the xerrors.Wrapper interface
   674  // https://godoc.org/golang.org/x/xerrors#Wrapper
   675  func (e wrappedError) Unwrap() error {
   676  	return e.cause
   677  }
   678  
   679  // mandatory is a sentinel value used in a function's defaults tuple
   680  // to indicate that a (keyword-only) parameter is mandatory.
   681  type mandatory struct{}
   682  
   683  func (mandatory) String() string        { return "mandatory" }
   684  func (mandatory) Type() string          { return "mandatory" }
   685  func (mandatory) Freeze()               {} // immutable
   686  func (mandatory) Truth() Bool           { return False }
   687  func (mandatory) Hash() (uint32, error) { return 0, nil }
   688  
   689  // A cell is a box containing a Value.
   690  // Local variables marked as cells hold their value indirectly
   691  // so that they may be shared by outer and inner nested functions.
   692  // Cells are always accessed using indirect {FREE,LOCAL,SETLOCAL}CELL instructions.
   693  // The FreeVars tuple contains only cells.
   694  // The FREE instruction always yields a cell.
   695  type cell struct{ v Value }
   696  
   697  func (c *cell) String() string { return "cell" }
   698  func (c *cell) Type() string   { return "cell" }
   699  func (c *cell) Freeze() {
   700  	if c.v != nil {
   701  		c.v.Freeze()
   702  	}
   703  }
   704  func (c *cell) Truth() Bool           { panic("unreachable") }
   705  func (c *cell) Hash() (uint32, error) { panic("unreachable") }
   706  

View as plain text