1
16
17 package validation
18
19 import (
20 "fmt"
21 "net/netip"
22 "regexp"
23
24 "k8s.io/apimachinery/pkg/util/sets"
25 "k8s.io/apimachinery/pkg/util/validation/field"
26
27 gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"
28 )
29
30 var (
31
32 protocolsHostnameInvalid = map[gatewayv1.ProtocolType]struct{}{
33 gatewayv1.TCPProtocolType: {},
34 gatewayv1.UDPProtocolType: {},
35 }
36
37 protocolsTLSInvalid = map[gatewayv1.ProtocolType]struct{}{
38 gatewayv1.HTTPProtocolType: {},
39 gatewayv1.UDPProtocolType: {},
40 gatewayv1.TCPProtocolType: {},
41 }
42
43 protocolsTLSRequired = map[gatewayv1.ProtocolType]struct{}{
44 gatewayv1.HTTPSProtocolType: {},
45 gatewayv1.TLSProtocolType: {},
46 }
47
48 validHostnameAddress = `^(\*\.)?[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
49 validHostnameRegexp = regexp.MustCompile(validHostnameAddress)
50 )
51
52
53
54
55
56
57
58
59 func ValidateGateway(gw *gatewayv1.Gateway) field.ErrorList {
60 return ValidateGatewaySpec(&gw.Spec, field.NewPath("spec"))
61 }
62
63
64
65 func ValidateGatewaySpec(spec *gatewayv1.GatewaySpec, path *field.Path) field.ErrorList {
66 var errs field.ErrorList
67 errs = append(errs, validateGatewayListeners(spec.Listeners, path.Child("listeners"))...)
68 errs = append(errs, validateGatewayAddresses(spec.Addresses, path.Child("addresses"))...)
69 return errs
70 }
71
72
73
74 func validateGatewayListeners(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
75 var errs field.ErrorList
76 errs = append(errs, ValidateListenerTLSConfig(listeners, path)...)
77 errs = append(errs, validateListenerHostname(listeners, path)...)
78 errs = append(errs, ValidateTLSCertificateRefs(listeners, path)...)
79 errs = append(errs, ValidateListenerNames(listeners, path)...)
80 errs = append(errs, validateHostnameProtocolPort(listeners, path)...)
81 return errs
82 }
83
84
85
86 func ValidateListenerTLSConfig(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
87 var errs field.ErrorList
88 for i, l := range listeners {
89 if isProtocolInSubset(l.Protocol, protocolsTLSRequired) && l.TLS == nil {
90 errs = append(errs, field.Forbidden(path.Index(i).Child("tls"), fmt.Sprintf("must be set for protocol %v", l.Protocol)))
91 }
92 if isProtocolInSubset(l.Protocol, protocolsTLSInvalid) && l.TLS != nil {
93 errs = append(errs, field.Forbidden(path.Index(i).Child("tls"), fmt.Sprintf("should be empty for protocol %v", l.Protocol)))
94 }
95 }
96 return errs
97 }
98
99 func isProtocolInSubset(protocol gatewayv1.ProtocolType, set map[gatewayv1.ProtocolType]struct{}) bool {
100 _, ok := set[protocol]
101 return ok
102 }
103
104
105
106 func validateListenerHostname(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
107 var errs field.ErrorList
108 for i, h := range listeners {
109 if isProtocolInSubset(h.Protocol, protocolsHostnameInvalid) && h.Hostname != nil {
110 errs = append(errs, field.Forbidden(path.Index(i).Child("hostname"), fmt.Sprintf("should be empty for protocol %v", h.Protocol)))
111 }
112 }
113 return errs
114 }
115
116
117
118
119 func ValidateTLSCertificateRefs(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
120 var errs field.ErrorList
121 for i, c := range listeners {
122 if isProtocolInSubset(c.Protocol, protocolsTLSRequired) && c.TLS != nil {
123 if *c.TLS.Mode == gatewayv1.TLSModeTerminate && len(c.TLS.CertificateRefs) == 0 {
124 errs = append(errs, field.Forbidden(path.Index(i).Child("tls").Child("certificateRefs"), "should be set and not empty when TLSModeType is Terminate"))
125 }
126 }
127 }
128 return errs
129 }
130
131
132
133 func ValidateListenerNames(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
134 var errs field.ErrorList
135 nameMap := make(map[gatewayv1.SectionName]struct{}, len(listeners))
136 for i, c := range listeners {
137 if _, found := nameMap[c.Name]; found {
138 errs = append(errs, field.Duplicate(path.Index(i).Child("name"), "must be unique within the Gateway"))
139 }
140 nameMap[c.Name] = struct{}{}
141 }
142 return errs
143 }
144
145
146
147 func validateHostnameProtocolPort(listeners []gatewayv1.Listener, path *field.Path) field.ErrorList {
148 var errs field.ErrorList
149 hostnameProtocolPortSets := sets.Set[string]{}
150 for i, listener := range listeners {
151 hostname := new(gatewayv1.Hostname)
152 if listener.Hostname != nil {
153 hostname = listener.Hostname
154 }
155 protocol := listener.Protocol
156 port := listener.Port
157 hostnameProtocolPort := fmt.Sprintf("%s:%s:%d", *hostname, protocol, port)
158 if hostnameProtocolPortSets.Has(hostnameProtocolPort) {
159 errs = append(errs, field.Duplicate(path.Index(i), "combination of port, protocol, and hostname must be unique for each listener"))
160 } else {
161 hostnameProtocolPortSets.Insert(hostnameProtocolPort)
162 }
163 }
164 return errs
165 }
166
167
168
169 func validateGatewayAddresses(addresses []gatewayv1.GatewayAddress, path *field.Path) field.ErrorList {
170 var errs field.ErrorList
171 ipAddrSet, hostnameAddrSet := sets.Set[string]{}, sets.Set[string]{}
172 for i, address := range addresses {
173 if address.Type != nil {
174 if *address.Type == gatewayv1.IPAddressType {
175 if _, err := netip.ParseAddr(address.Value); err != nil {
176 errs = append(errs, field.Invalid(path.Index(i), address.Value, "invalid ip address"))
177 }
178 if ipAddrSet.Has(address.Value) {
179 errs = append(errs, field.Duplicate(path.Index(i), address.Value))
180 } else {
181 ipAddrSet.Insert(address.Value)
182 }
183 } else if *address.Type == gatewayv1.HostnameAddressType {
184 if !validHostnameRegexp.MatchString(address.Value) {
185 errs = append(errs, field.Invalid(path.Index(i), address.Value, fmt.Sprintf("must only contain valid characters (matching %s)", validHostnameAddress)))
186 }
187 if hostnameAddrSet.Has(address.Value) {
188 errs = append(errs, field.Duplicate(path.Index(i), address.Value))
189 } else {
190 hostnameAddrSet.Insert(address.Value)
191 }
192 }
193 }
194 }
195 return errs
196 }
197
View as plain text