1 package filtering
2
3 import (
4 "fmt"
5 "time"
6
7 expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
8 "google.golang.org/protobuf/proto"
9 )
10
11 type Checker struct {
12 declarations *Declarations
13 expr *expr.Expr
14 sourceInfo *expr.SourceInfo
15 typeMap map[int64]*expr.Type
16 }
17
18 func (c *Checker) Init(exp *expr.Expr, sourceInfo *expr.SourceInfo, declarations *Declarations) {
19 *c = Checker{
20 expr: exp,
21 declarations: declarations,
22 sourceInfo: sourceInfo,
23 typeMap: make(map[int64]*expr.Type, len(sourceInfo.GetPositions())),
24 }
25 }
26
27 func (c *Checker) Check() (*expr.CheckedExpr, error) {
28 if err := c.checkExpr(c.expr); err != nil {
29 return nil, err
30 }
31 resultType, ok := c.getType(c.expr)
32 if !ok {
33 return nil, c.errorf(c.expr, "unknown result type")
34 }
35 if !proto.Equal(resultType, TypeBool) {
36 return nil, c.errorf(c.expr, "non-bool result type")
37 }
38 return &expr.CheckedExpr{
39 TypeMap: c.typeMap,
40 SourceInfo: c.sourceInfo,
41 Expr: c.expr,
42 }, nil
43 }
44
45 func (c *Checker) checkExpr(e *expr.Expr) error {
46 if e == nil {
47 return nil
48 }
49 switch e.GetExprKind().(type) {
50 case *expr.Expr_ConstExpr:
51 switch e.GetConstExpr().GetConstantKind().(type) {
52 case *expr.Constant_BoolValue:
53 return c.checkBoolLiteral(e)
54 case *expr.Constant_DoubleValue:
55 return c.checkDoubleLiteral(e)
56 case *expr.Constant_Int64Value:
57 return c.checkInt64Literal(e)
58 case *expr.Constant_StringValue:
59 return c.checkStringLiteral(e)
60 default:
61 return c.errorf(e, "unsupported constant kind")
62 }
63 case *expr.Expr_IdentExpr:
64 return c.checkIdentExpr(e)
65 case *expr.Expr_SelectExpr:
66 return c.checkSelectExpr(e)
67 case *expr.Expr_CallExpr:
68 return c.checkCallExpr(e)
69 default:
70 return c.errorf(e, "unsupported expr kind")
71 }
72 }
73
74 func (c *Checker) checkIdentExpr(e *expr.Expr) error {
75 identExpr := e.GetIdentExpr()
76 ident, ok := c.declarations.LookupIdent(identExpr.GetName())
77 if !ok {
78 return c.errorf(e, "undeclared identifier '%s'", identExpr.GetName())
79 }
80 if err := c.setType(e, ident.GetIdent().GetType()); err != nil {
81 return c.wrapf(err, e, "identifier '%s'", identExpr.GetName())
82 }
83 return nil
84 }
85
86 func (c *Checker) checkSelectExpr(e *expr.Expr) (err error) {
87 defer func() {
88 if err != nil {
89 err = c.wrapf(err, e, "check select expr")
90 }
91 }()
92 if qualifiedName, ok := toQualifiedName(e); ok {
93 if ident, ok := c.declarations.LookupIdent(qualifiedName); ok {
94 return c.setType(e, ident.GetIdent().GetType())
95 }
96 }
97 selectExpr := e.GetSelectExpr()
98 if selectExpr.GetOperand() == nil {
99 return c.errorf(e, "missing operand")
100 }
101 if err := c.checkExpr(selectExpr.GetOperand()); err != nil {
102 return err
103 }
104 operandType, ok := c.getType(selectExpr.GetOperand())
105 if !ok {
106 return c.errorf(e, "failed to get operand type")
107 }
108 switch operandType.GetTypeKind().(type) {
109 case *expr.Type_MapType_:
110 return c.setType(e, operandType.GetMapType().GetValueType())
111 default:
112 return c.errorf(e, "unsupported operand type")
113 }
114 }
115
116 func (c *Checker) checkCallExpr(e *expr.Expr) (err error) {
117 defer func() {
118 if err != nil {
119 err = c.wrapf(err, e, "check call expr")
120 }
121 }()
122 callExpr := e.GetCallExpr()
123 for _, arg := range callExpr.GetArgs() {
124 if err := c.checkExpr(arg); err != nil {
125 return err
126 }
127 }
128 functionDeclaration, ok := c.declarations.LookupFunction(callExpr.GetFunction())
129 if !ok {
130 return c.errorf(e, "undeclared function '%s'", callExpr.GetFunction())
131 }
132 functionOverload, err := c.resolveCallExprFunctionOverload(e, functionDeclaration)
133 if err != nil {
134 return err
135 }
136 if err := c.checkCallExprBuiltinFunctionOverloads(e, functionOverload); err != nil {
137 return err
138 }
139 return c.setType(e, functionOverload.GetResultType())
140 }
141
142 func (c *Checker) resolveCallExprFunctionOverload(
143 e *expr.Expr,
144 functionDeclaration *expr.Decl,
145 ) (*expr.Decl_FunctionDecl_Overload, error) {
146 callExpr := e.GetCallExpr()
147 for _, overload := range functionDeclaration.GetFunction().GetOverloads() {
148 if len(callExpr.GetArgs()) != len(overload.GetParams()) {
149 continue
150 }
151 if len(overload.GetTypeParams()) == 0 {
152 allTypesMatch := true
153 for i, param := range overload.GetParams() {
154 argType, ok := c.getType(callExpr.GetArgs()[i])
155 if !ok {
156 return nil, c.errorf(callExpr.GetArgs()[i], "unknown type")
157 }
158 if !proto.Equal(argType, param) {
159 allTypesMatch = false
160 break
161 }
162 }
163 if allTypesMatch {
164 return overload, nil
165 }
166 }
167
168 }
169 var argTypes []string
170 for _, arg := range callExpr.GetArgs() {
171 t, ok := c.getType(arg)
172 if !ok {
173 argTypes = append(argTypes, "UNKNOWN")
174 } else {
175 argTypes = append(argTypes, t.String())
176 }
177 }
178 return nil, c.errorf(e, "no matching overload found for calling '%s' with %s", callExpr.GetFunction(), argTypes)
179 }
180
181 func (c *Checker) checkCallExprBuiltinFunctionOverloads(
182 e *expr.Expr,
183 functionOverload *expr.Decl_FunctionDecl_Overload,
184 ) error {
185 callExpr := e.GetCallExpr()
186 switch functionOverload.GetOverloadId() {
187 case FunctionOverloadTimestampString:
188 if constExpr := callExpr.GetArgs()[0].GetConstExpr(); constExpr != nil {
189 if _, err := time.Parse(time.RFC3339, constExpr.GetStringValue()); err != nil {
190 return c.errorf(callExpr.GetArgs()[0], "invalid timestamp. Should be in RFC3339 format")
191 }
192 }
193 case FunctionOverloadDurationString:
194 if constExpr := callExpr.GetArgs()[0].GetConstExpr(); constExpr != nil {
195 if _, err := time.ParseDuration(constExpr.GetStringValue()); err != nil {
196 return c.errorf(callExpr.GetArgs()[0], "invalid duration")
197 }
198 }
199 case FunctionOverloadLessThanTimestampString,
200 FunctionOverloadGreaterThanTimestampString,
201 FunctionOverloadLessEqualsTimestampString,
202 FunctionOverloadGreaterEqualsTimestampString,
203 FunctionOverloadEqualsTimestampString,
204 FunctionOverloadNotEqualsTimestampString:
205 if constExpr := callExpr.GetArgs()[1].GetConstExpr(); constExpr != nil {
206 if _, err := time.Parse(time.RFC3339, constExpr.GetStringValue()); err != nil {
207 return c.errorf(callExpr.GetArgs()[0], "invalid timestamp. Should be in RFC3339 format")
208 }
209 }
210 }
211 return nil
212 }
213
214 func (c *Checker) checkInt64Literal(e *expr.Expr) error {
215 return c.setType(e, TypeInt)
216 }
217
218 func (c *Checker) checkStringLiteral(e *expr.Expr) error {
219 return c.setType(e, TypeString)
220 }
221
222 func (c *Checker) checkDoubleLiteral(e *expr.Expr) error {
223 return c.setType(e, TypeFloat)
224 }
225
226 func (c *Checker) checkBoolLiteral(e *expr.Expr) error {
227 return c.setType(e, TypeBool)
228 }
229
230 func (c *Checker) errorf(_ *expr.Expr, format string, args ...interface{}) error {
231
232 return &typeError{
233 message: fmt.Sprintf(format, args...),
234 }
235 }
236
237 func (c *Checker) wrapf(err error, _ *expr.Expr, format string, args ...interface{}) error {
238
239 return &typeError{
240 message: fmt.Sprintf(format, args...),
241 err: err,
242 }
243 }
244
245 func (c *Checker) setType(e *expr.Expr, t *expr.Type) error {
246 if existingT, ok := c.typeMap[e.GetId()]; ok && !proto.Equal(t, existingT) {
247 return c.errorf(e, "type conflict between %s and %s", t, existingT)
248 }
249 c.typeMap[e.GetId()] = t
250 return nil
251 }
252
253 func (c *Checker) getType(e *expr.Expr) (*expr.Type, bool) {
254 t, ok := c.typeMap[e.GetId()]
255 if !ok {
256 return nil, false
257 }
258 return t, true
259 }
260
261 func toQualifiedName(e *expr.Expr) (string, bool) {
262 switch kind := e.GetExprKind().(type) {
263 case *expr.Expr_IdentExpr:
264 return kind.IdentExpr.GetName(), true
265 case *expr.Expr_SelectExpr:
266 if kind.SelectExpr.GetTestOnly() {
267 return "", false
268 }
269 parent, ok := toQualifiedName(kind.SelectExpr.GetOperand())
270 if !ok {
271 return "", false
272 }
273 return parent + "." + kind.SelectExpr.GetField(), true
274 default:
275 return "", false
276 }
277 }
278
View as plain text