1
2
3
4
19
20 package gce
21
22 import (
23 "context"
24 "fmt"
25 "net/http"
26 "strconv"
27 "strings"
28 "time"
29
30 compute "google.golang.org/api/compute/v1"
31
32 v1 "k8s.io/api/core/v1"
33 "k8s.io/apimachinery/pkg/util/sets"
34 "k8s.io/apimachinery/pkg/util/wait"
35 cloudprovider "k8s.io/cloud-provider"
36 "k8s.io/kubernetes/test/e2e/framework"
37 gcecloud "k8s.io/legacy-cloud-providers/gce"
38 )
39
40
41
42 func MakeFirewallNameForLBService(name string) string {
43 return fmt.Sprintf("k8s-fw-%s", name)
44 }
45
46
47 func ConstructFirewallForLBService(svc *v1.Service, nodeTag string) *compute.Firewall {
48 if svc.Spec.Type != v1.ServiceTypeLoadBalancer {
49 framework.Failf("can not construct firewall rule for non-loadbalancer type service")
50 }
51 fw := compute.Firewall{}
52 fw.Name = MakeFirewallNameForLBService(cloudprovider.DefaultLoadBalancerName(svc))
53 fw.TargetTags = []string{nodeTag}
54 if svc.Spec.LoadBalancerSourceRanges == nil {
55 fw.SourceRanges = []string{"0.0.0.0/0"}
56 } else {
57 fw.SourceRanges = svc.Spec.LoadBalancerSourceRanges
58 }
59 for _, sp := range svc.Spec.Ports {
60 fw.Allowed = append(fw.Allowed, &compute.FirewallAllowed{
61 IPProtocol: strings.ToLower(string(sp.Protocol)),
62 Ports: []string{strconv.Itoa(int(sp.Port))},
63 })
64 }
65 return &fw
66 }
67
68
69
70 func MakeHealthCheckFirewallNameForLBService(clusterID, name string, isNodesHealthCheck bool) string {
71 return gcecloud.MakeHealthCheckFirewallName(clusterID, name, isNodesHealthCheck)
72 }
73
74
75 func ConstructHealthCheckFirewallForLBService(clusterID string, svc *v1.Service, nodeTag string, isNodesHealthCheck bool) *compute.Firewall {
76 if svc.Spec.Type != v1.ServiceTypeLoadBalancer {
77 framework.Failf("can not construct firewall rule for non-loadbalancer type service")
78 }
79 fw := compute.Firewall{}
80 fw.Name = MakeHealthCheckFirewallNameForLBService(clusterID, cloudprovider.DefaultLoadBalancerName(svc), isNodesHealthCheck)
81 fw.TargetTags = []string{nodeTag}
82 fw.SourceRanges = gcecloud.L4LoadBalancerSrcRanges()
83 healthCheckPort := gcecloud.GetNodesHealthCheckPort()
84 if !isNodesHealthCheck {
85 healthCheckPort = svc.Spec.HealthCheckNodePort
86 }
87 fw.Allowed = []*compute.FirewallAllowed{
88 {
89 IPProtocol: "tcp",
90 Ports: []string{fmt.Sprintf("%d", healthCheckPort)},
91 },
92 }
93 return &fw
94 }
95
96
97 func PackProtocolsPortsFromFirewall(alloweds []*compute.FirewallAllowed) []string {
98 protocolPorts := []string{}
99 for _, allowed := range alloweds {
100 for _, port := range allowed.Ports {
101 protocolPorts = append(protocolPorts, strings.ToLower(allowed.IPProtocol+"/"+port))
102 }
103 }
104 return protocolPorts
105 }
106
107 type portRange struct {
108 protocol string
109 min, max int
110 }
111
112 func toPortRange(s string) (pr portRange, err error) {
113 protoPorts := strings.Split(s, "/")
114
115 pr.protocol = strings.ToUpper(protoPorts[0])
116
117 if len(protoPorts) != 2 {
118 return pr, fmt.Errorf("expected a single '/' in %q", s)
119 }
120
121 ports := strings.Split(protoPorts[1], "-")
122 switch len(ports) {
123 case 1:
124 v, err := strconv.Atoi(ports[0])
125 if err != nil {
126 return pr, err
127 }
128 pr.min, pr.max = v, v
129 case 2:
130 start, err := strconv.Atoi(ports[0])
131 if err != nil {
132 return pr, err
133 }
134 end, err := strconv.Atoi(ports[1])
135 if err != nil {
136 return pr, err
137 }
138 pr.min, pr.max = start, end
139 default:
140 return pr, fmt.Errorf("unexpected range value %q", protoPorts[1])
141 }
142
143 return pr, nil
144 }
145
146
147
148
149
150 func isPortsSubset(requiredPorts, coverage []string) error {
151 for _, reqPort := range requiredPorts {
152 rRange, err := toPortRange(reqPort)
153 if err != nil {
154 return err
155 }
156 if rRange.min != rRange.max {
157 return fmt.Errorf("requiring a range is not supported: %q", reqPort)
158 }
159
160 var covered bool
161 for _, c := range coverage {
162 cRange, err := toPortRange(c)
163 if err != nil {
164 return err
165 }
166
167 if rRange.protocol != cRange.protocol {
168 continue
169 }
170
171 if rRange.min >= cRange.min && rRange.min <= cRange.max {
172 covered = true
173 break
174 }
175 }
176
177 if !covered {
178 return fmt.Errorf("%q is not covered by %v", reqPort, coverage)
179 }
180 }
181 return nil
182 }
183
184
185
186
187 func SameStringArray(result, expected []string, include bool) error {
188 res := sets.NewString(result...)
189 exp := sets.NewString(expected...)
190 if !include {
191 diff := res.Difference(exp)
192 if len(diff) != 0 {
193 return fmt.Errorf("found differences: %v", diff)
194 }
195 } else {
196 if !res.IsSuperset(exp) {
197 return fmt.Errorf("some elements are missing: expected %v, got %v", expected, result)
198 }
199 }
200 return nil
201 }
202
203
204
205 func VerifyFirewallRule(res, exp *compute.Firewall, network string, portsSubset bool) error {
206 if res == nil || exp == nil {
207 return fmt.Errorf("res and exp must not be nil")
208 }
209 if res.Name != exp.Name {
210 return fmt.Errorf("incorrect name: %v, expected %v", res.Name, exp.Name)
211 }
212
213 actualPorts := PackProtocolsPortsFromFirewall(res.Allowed)
214 expPorts := PackProtocolsPortsFromFirewall(exp.Allowed)
215 if portsSubset {
216 if err := isPortsSubset(expPorts, actualPorts); err != nil {
217 return fmt.Errorf("incorrect allowed protocol ports: %w", err)
218 }
219 } else {
220 if err := SameStringArray(actualPorts, expPorts, false); err != nil {
221 return fmt.Errorf("incorrect allowed protocols ports: %w", err)
222 }
223 }
224
225 if err := SameStringArray(res.SourceRanges, exp.SourceRanges, false); err != nil {
226 return fmt.Errorf("incorrect source ranges %v, expected %v: %w", res.SourceRanges, exp.SourceRanges, err)
227 }
228 if err := SameStringArray(res.SourceTags, exp.SourceTags, false); err != nil {
229 return fmt.Errorf("incorrect source tags %v, expected %v: %w", res.SourceTags, exp.SourceTags, err)
230 }
231 if err := SameStringArray(res.TargetTags, exp.TargetTags, false); err != nil {
232 return fmt.Errorf("incorrect target tags %v, expected %v: %w", res.TargetTags, exp.TargetTags, err)
233 }
234 return nil
235 }
236
237
238 func WaitForFirewallRule(ctx context.Context, gceCloud *gcecloud.Cloud, fwName string, exist bool, timeout time.Duration) (*compute.Firewall, error) {
239 framework.Logf("Waiting up to %v for firewall %v exist=%v", timeout, fwName, exist)
240 var fw *compute.Firewall
241 var err error
242
243 condition := func(ctx context.Context) (bool, error) {
244 fw, err = gceCloud.GetFirewall(fwName)
245 if err != nil && exist ||
246 err == nil && !exist ||
247 err != nil && !exist && !IsGoogleAPIHTTPErrorCode(err, http.StatusNotFound) {
248 return false, nil
249 }
250 return true, nil
251 }
252
253 if err := wait.PollUntilContextTimeout(ctx, 5*time.Second, timeout, true, condition); err != nil {
254 return nil, fmt.Errorf("error waiting for firewall %v exist=%v", fwName, exist)
255 }
256 return fw, nil
257 }
258
View as plain text