...
1 package filtering
2
3 import (
4 "testing"
5
6 expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
7 "google.golang.org/protobuf/testing/protocmp"
8 "gotest.tools/v3/assert"
9 )
10
11 func TestApplyMacros(t *testing.T) {
12 t.Parallel()
13 for _, tt := range []struct {
14 name string
15 filter string
16 declarations []DeclarationOption
17 macros []Macro
18 macroDeclarations []DeclarationOption
19 expected *expr.Expr
20 errorContains string
21 }{
22 {
23 filter: `annotations.schedule = "test"`,
24 declarations: []DeclarationOption{
25 DeclareStandardFunctions(),
26 DeclareIdent("annotations", TypeMap(TypeString, TypeString)),
27 },
28 macros: []Macro{
29 func(cursor *Cursor) {
30 callExpr := cursor.Expr().GetCallExpr()
31 if callExpr == nil {
32 return
33 }
34 if callExpr.GetFunction() != FunctionEquals {
35 return
36 }
37 if len(callExpr.GetArgs()) != 2 {
38 return
39 }
40 arg0Select := callExpr.GetArgs()[0].GetSelectExpr()
41 if arg0Select == nil || arg0Select.GetOperand().GetIdentExpr().GetName() != "annotations" {
42 return
43 }
44 arg1String := callExpr.GetArgs()[1].GetConstExpr().GetStringValue()
45 if arg1String == "" {
46 return
47 }
48 cursor.Replace(Has(arg0Select.GetOperand(), String(arg0Select.GetField()+"="+arg1String)))
49 },
50 },
51 macroDeclarations: []DeclarationOption{
52 DeclareStandardFunctions(),
53 DeclareIdent("annotations", TypeList(TypeString)),
54 },
55 expected: Has(Text("annotations"), String("schedule=test")),
56 },
57 } {
58 tt := tt
59 t.Run(tt.name, func(t *testing.T) {
60 t.Parallel()
61 declarations, err := NewDeclarations(tt.declarations...)
62 assert.NilError(t, err)
63 filter, err := ParseFilter(&mockRequest{filter: tt.filter}, declarations)
64 if err != nil && tt.errorContains != "" {
65 assert.ErrorContains(t, err, tt.errorContains)
66 return
67 }
68 assert.NilError(t, err)
69 macroDeclarations, err := NewDeclarations(tt.macroDeclarations...)
70 assert.NilError(t, err)
71 actual, err := ApplyMacros(filter, macroDeclarations, tt.macros...)
72 if err != nil && tt.errorContains != "" {
73 assert.ErrorContains(t, err, tt.errorContains)
74 return
75 }
76 assert.NilError(t, err)
77 assert.DeepEqual(
78 t,
79 tt.expected,
80 actual.CheckedExpr.GetExpr(),
81 protocmp.Transform(),
82 protocmp.IgnoreFields(&expr.Expr{}, "id"),
83 )
84 })
85 }
86 }
87
88 type mockRequest struct {
89 filter string
90 }
91
92 func (m *mockRequest) GetFilter() string {
93 return m.filter
94 }
95
View as plain text