...

Source file src/go.einride.tech/aip/filtering/checker.go

Documentation: go.einride.tech/aip/filtering

     1  package filtering
     2  
     3  import (
     4  	"fmt"
     5  	"time"
     6  
     7  	expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
     8  	"google.golang.org/protobuf/proto"
     9  )
    10  
    11  type Checker struct {
    12  	declarations *Declarations
    13  	expr         *expr.Expr
    14  	sourceInfo   *expr.SourceInfo
    15  	typeMap      map[int64]*expr.Type
    16  }
    17  
    18  func (c *Checker) Init(exp *expr.Expr, sourceInfo *expr.SourceInfo, declarations *Declarations) {
    19  	*c = Checker{
    20  		expr:         exp,
    21  		declarations: declarations,
    22  		sourceInfo:   sourceInfo,
    23  		typeMap:      make(map[int64]*expr.Type, len(sourceInfo.GetPositions())),
    24  	}
    25  }
    26  
    27  func (c *Checker) Check() (*expr.CheckedExpr, error) {
    28  	if err := c.checkExpr(c.expr); err != nil {
    29  		return nil, err
    30  	}
    31  	resultType, ok := c.getType(c.expr)
    32  	if !ok {
    33  		return nil, c.errorf(c.expr, "unknown result type")
    34  	}
    35  	if !proto.Equal(resultType, TypeBool) {
    36  		return nil, c.errorf(c.expr, "non-bool result type")
    37  	}
    38  	return &expr.CheckedExpr{
    39  		TypeMap:    c.typeMap,
    40  		SourceInfo: c.sourceInfo,
    41  		Expr:       c.expr,
    42  	}, nil
    43  }
    44  
    45  func (c *Checker) checkExpr(e *expr.Expr) error {
    46  	if e == nil {
    47  		return nil
    48  	}
    49  	switch e.GetExprKind().(type) {
    50  	case *expr.Expr_ConstExpr:
    51  		switch e.GetConstExpr().GetConstantKind().(type) {
    52  		case *expr.Constant_BoolValue:
    53  			return c.checkBoolLiteral(e)
    54  		case *expr.Constant_DoubleValue:
    55  			return c.checkDoubleLiteral(e)
    56  		case *expr.Constant_Int64Value:
    57  			return c.checkInt64Literal(e)
    58  		case *expr.Constant_StringValue:
    59  			return c.checkStringLiteral(e)
    60  		default:
    61  			return c.errorf(e, "unsupported constant kind")
    62  		}
    63  	case *expr.Expr_IdentExpr:
    64  		return c.checkIdentExpr(e)
    65  	case *expr.Expr_SelectExpr:
    66  		return c.checkSelectExpr(e)
    67  	case *expr.Expr_CallExpr:
    68  		return c.checkCallExpr(e)
    69  	default:
    70  		return c.errorf(e, "unsupported expr kind")
    71  	}
    72  }
    73  
    74  func (c *Checker) checkIdentExpr(e *expr.Expr) error {
    75  	identExpr := e.GetIdentExpr()
    76  	ident, ok := c.declarations.LookupIdent(identExpr.GetName())
    77  	if !ok {
    78  		return c.errorf(e, "undeclared identifier '%s'", identExpr.GetName())
    79  	}
    80  	if err := c.setType(e, ident.GetIdent().GetType()); err != nil {
    81  		return c.wrapf(err, e, "identifier '%s'", identExpr.GetName())
    82  	}
    83  	return nil
    84  }
    85  
    86  func (c *Checker) checkSelectExpr(e *expr.Expr) (err error) {
    87  	defer func() {
    88  		if err != nil {
    89  			err = c.wrapf(err, e, "check select expr")
    90  		}
    91  	}()
    92  	if qualifiedName, ok := toQualifiedName(e); ok {
    93  		if ident, ok := c.declarations.LookupIdent(qualifiedName); ok {
    94  			return c.setType(e, ident.GetIdent().GetType())
    95  		}
    96  	}
    97  	selectExpr := e.GetSelectExpr()
    98  	if selectExpr.GetOperand() == nil {
    99  		return c.errorf(e, "missing operand")
   100  	}
   101  	if err := c.checkExpr(selectExpr.GetOperand()); err != nil {
   102  		return err
   103  	}
   104  	operandType, ok := c.getType(selectExpr.GetOperand())
   105  	if !ok {
   106  		return c.errorf(e, "failed to get operand type")
   107  	}
   108  	switch operandType.GetTypeKind().(type) {
   109  	case *expr.Type_MapType_:
   110  		return c.setType(e, operandType.GetMapType().GetValueType())
   111  	default:
   112  		return c.errorf(e, "unsupported operand type")
   113  	}
   114  }
   115  
   116  func (c *Checker) checkCallExpr(e *expr.Expr) (err error) {
   117  	defer func() {
   118  		if err != nil {
   119  			err = c.wrapf(err, e, "check call expr")
   120  		}
   121  	}()
   122  	callExpr := e.GetCallExpr()
   123  	for _, arg := range callExpr.GetArgs() {
   124  		if err := c.checkExpr(arg); err != nil {
   125  			return err
   126  		}
   127  	}
   128  	functionDeclaration, ok := c.declarations.LookupFunction(callExpr.GetFunction())
   129  	if !ok {
   130  		return c.errorf(e, "undeclared function '%s'", callExpr.GetFunction())
   131  	}
   132  	functionOverload, err := c.resolveCallExprFunctionOverload(e, functionDeclaration)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	if err := c.checkCallExprBuiltinFunctionOverloads(e, functionOverload); err != nil {
   137  		return err
   138  	}
   139  	return c.setType(e, functionOverload.GetResultType())
   140  }
   141  
   142  func (c *Checker) resolveCallExprFunctionOverload(
   143  	e *expr.Expr,
   144  	functionDeclaration *expr.Decl,
   145  ) (*expr.Decl_FunctionDecl_Overload, error) {
   146  	callExpr := e.GetCallExpr()
   147  	for _, overload := range functionDeclaration.GetFunction().GetOverloads() {
   148  		if len(callExpr.GetArgs()) != len(overload.GetParams()) {
   149  			continue
   150  		}
   151  		if len(overload.GetTypeParams()) == 0 {
   152  			allTypesMatch := true
   153  			for i, param := range overload.GetParams() {
   154  				argType, ok := c.getType(callExpr.GetArgs()[i])
   155  				if !ok {
   156  					return nil, c.errorf(callExpr.GetArgs()[i], "unknown type")
   157  				}
   158  				if !proto.Equal(argType, param) {
   159  					allTypesMatch = false
   160  					break
   161  				}
   162  			}
   163  			if allTypesMatch {
   164  				return overload, nil
   165  			}
   166  		}
   167  		// TODO: Add support for type parameters.
   168  	}
   169  	var argTypes []string
   170  	for _, arg := range callExpr.GetArgs() {
   171  		t, ok := c.getType(arg)
   172  		if !ok {
   173  			argTypes = append(argTypes, "UNKNOWN")
   174  		} else {
   175  			argTypes = append(argTypes, t.String())
   176  		}
   177  	}
   178  	return nil, c.errorf(e, "no matching overload found for calling '%s' with %s", callExpr.GetFunction(), argTypes)
   179  }
   180  
   181  func (c *Checker) checkCallExprBuiltinFunctionOverloads(
   182  	e *expr.Expr,
   183  	functionOverload *expr.Decl_FunctionDecl_Overload,
   184  ) error {
   185  	callExpr := e.GetCallExpr()
   186  	switch functionOverload.GetOverloadId() {
   187  	case FunctionOverloadTimestampString:
   188  		if constExpr := callExpr.GetArgs()[0].GetConstExpr(); constExpr != nil {
   189  			if _, err := time.Parse(time.RFC3339, constExpr.GetStringValue()); err != nil {
   190  				return c.errorf(callExpr.GetArgs()[0], "invalid timestamp. Should be in RFC3339 format")
   191  			}
   192  		}
   193  	case FunctionOverloadDurationString:
   194  		if constExpr := callExpr.GetArgs()[0].GetConstExpr(); constExpr != nil {
   195  			if _, err := time.ParseDuration(constExpr.GetStringValue()); err != nil {
   196  				return c.errorf(callExpr.GetArgs()[0], "invalid duration")
   197  			}
   198  		}
   199  	case FunctionOverloadLessThanTimestampString,
   200  		FunctionOverloadGreaterThanTimestampString,
   201  		FunctionOverloadLessEqualsTimestampString,
   202  		FunctionOverloadGreaterEqualsTimestampString,
   203  		FunctionOverloadEqualsTimestampString,
   204  		FunctionOverloadNotEqualsTimestampString:
   205  		if constExpr := callExpr.GetArgs()[1].GetConstExpr(); constExpr != nil {
   206  			if _, err := time.Parse(time.RFC3339, constExpr.GetStringValue()); err != nil {
   207  				return c.errorf(callExpr.GetArgs()[0], "invalid timestamp. Should be in RFC3339 format")
   208  			}
   209  		}
   210  	}
   211  	return nil
   212  }
   213  
   214  func (c *Checker) checkInt64Literal(e *expr.Expr) error {
   215  	return c.setType(e, TypeInt)
   216  }
   217  
   218  func (c *Checker) checkStringLiteral(e *expr.Expr) error {
   219  	return c.setType(e, TypeString)
   220  }
   221  
   222  func (c *Checker) checkDoubleLiteral(e *expr.Expr) error {
   223  	return c.setType(e, TypeFloat)
   224  }
   225  
   226  func (c *Checker) checkBoolLiteral(e *expr.Expr) error {
   227  	return c.setType(e, TypeBool)
   228  }
   229  
   230  func (c *Checker) errorf(_ *expr.Expr, format string, args ...interface{}) error {
   231  	// TODO: Include the provided expr.
   232  	return &typeError{
   233  		message: fmt.Sprintf(format, args...),
   234  	}
   235  }
   236  
   237  func (c *Checker) wrapf(err error, _ *expr.Expr, format string, args ...interface{}) error {
   238  	// TODO: Include the provided expr.
   239  	return &typeError{
   240  		message: fmt.Sprintf(format, args...),
   241  		err:     err,
   242  	}
   243  }
   244  
   245  func (c *Checker) setType(e *expr.Expr, t *expr.Type) error {
   246  	if existingT, ok := c.typeMap[e.GetId()]; ok && !proto.Equal(t, existingT) {
   247  		return c.errorf(e, "type conflict between %s and %s", t, existingT)
   248  	}
   249  	c.typeMap[e.GetId()] = t
   250  	return nil
   251  }
   252  
   253  func (c *Checker) getType(e *expr.Expr) (*expr.Type, bool) {
   254  	t, ok := c.typeMap[e.GetId()]
   255  	if !ok {
   256  		return nil, false
   257  	}
   258  	return t, true
   259  }
   260  
   261  func toQualifiedName(e *expr.Expr) (string, bool) {
   262  	switch kind := e.GetExprKind().(type) {
   263  	case *expr.Expr_IdentExpr:
   264  		return kind.IdentExpr.GetName(), true
   265  	case *expr.Expr_SelectExpr:
   266  		if kind.SelectExpr.GetTestOnly() {
   267  			return "", false
   268  		}
   269  		parent, ok := toQualifiedName(kind.SelectExpr.GetOperand())
   270  		if !ok {
   271  			return "", false
   272  		}
   273  		return parent + "." + kind.SelectExpr.GetField(), true
   274  	default:
   275  		return "", false
   276  	}
   277  }
   278  

View as plain text