1from typing import TYPE_CHECKING, Literal, Optional
2from typing import cast as typecast
3
4from ..config import Config
5from ..resource import Resource
6from ..utils import RichStatus
7from .ircluster import IRCluster
8from .irfilter import IRFilter
9
10if TYPE_CHECKING:
11 from .ir import IR # pragma: no cover
12
13
14class IRAuth(IRFilter):
15 cluster: Optional[IRCluster]
16 protocol_version: Literal["v2", "v3"]
17
18 def __init__(
19 self,
20 ir: "IR",
21 aconf: Config,
22 rkey: str = "ir.auth",
23 kind: str = "IRAuth",
24 name: str = "extauth",
25 namespace: Optional[str] = None,
26 type: Optional[str] = "decoder",
27 **kwargs,
28 ) -> None:
29
30 super().__init__(
31 ir=ir,
32 aconf=aconf,
33 rkey=rkey,
34 kind=kind,
35 name=name,
36 namespace=namespace,
37 cluster=None,
38 timeout_ms=None,
39 connect_timeout_ms=3000,
40 path_prefix=None,
41 api_version=None,
42 allowed_headers=[],
43 allowed_request_headers=[],
44 allowed_authorization_headers=[],
45 hosts={},
46 type=type,
47 **kwargs,
48 )
49
50 def setup(self, ir: "IR", aconf: Config) -> bool:
51 module_info = aconf.get_module("authentication")
52
53 if module_info:
54 self._load_auth(module_info, ir)
55
56 config_info = aconf.get_config("auth_configs")
57
58 if config_info:
59 for config in config_info.values():
60 self._load_auth(config, ir)
61
62 if not self.hosts:
63 self.logger.debug("IRAuth: no AuthServices, going inactive")
64 return False
65
66 self.logger.debug("IRAuth: going active")
67
68 return True
69
70 def add_mappings(self, ir: "IR", aconf: Config):
71 cluster_hosts = self.get("hosts", {"127.0.0.1:5000": (100, None, "-internal-")})
72
73 self.cluster = None
74 cluster_good = False
75
76 for service, params in cluster_hosts.items():
77 weight, grpc, ctx_name, location = params
78
79 self.logger.debug(
80 "IRAuth: svc %s, weight %s, grpc %s, ctx_name %s, location %s"
81 % (service, weight, grpc, ctx_name, location)
82 )
83
84 cluster = IRCluster(
85 ir=ir,
86 aconf=aconf,
87 parent_ir_resource=self,
88 location=location,
89 service=service,
90 host_rewrite=self.get("host_rewrite", False),
91 ctx_name=ctx_name,
92 grpc=grpc,
93 marker="extauth",
94 stats_name=self.get("stats_name", None),
95 circuit_breakers=self.get("circuit_breakers", None),
96 )
97
98 cluster.referenced_by(self)
99
100 cluster_good = True
101
102 if self.cluster:
103 if not self.cluster.merge(cluster):
104 self.post_error(
105 RichStatus.fromError(
106 "auth canary %s can only change service!" % cluster.name
107 )
108 )
109 cluster_good = False
110 else:
111 self.cluster = cluster
112
113 if cluster_good:
114 ir.add_cluster(typecast(IRCluster, self.cluster))
115 self.referenced_by(typecast(IRCluster, self.cluster))
116
117 def _load_auth(self, module: Resource, ir: "IR"):
118 self.namespace = module.get("namespace", self.namespace)
119 if self.location == "--internal--":
120 self.sourced_by(module)
121
122 for key in ["path_prefix", "timeout_ms", "cluster", "allow_request_body", "proto"]:
123 value = module.get(key, None)
124
125 if value:
126 previous = self.get(key, None)
127
128 if previous and (previous != value):
129 # Don't use self.post_error() here, since we need to explicitly override the
130 # resource. And don't use self.ir.post_error, since our module isn't an IRResource.
131 self.ir.aconf.post_error(
132 "AuthService cannot support multiple %s values; using %s" % (key, previous),
133 resource=module,
134 )
135 else:
136 self[key] = value
137
138 self.referenced_by(module)
139
140 if module.get("add_linkerd_headers"):
141 self["add_linkerd_headers"] = module.get("add_linkerd_headers")
142 else:
143 add_linkerd_headers = module.get("add_linkerd_headers", None)
144 if add_linkerd_headers is None:
145 self["add_linkerd_headers"] = ir.ambassador_module.get("add_linkerd_headers", False)
146
147 if module.get("circuit_breakers", None):
148 self["circuit_breakers"] = module.get("circuit_breakers")
149 else:
150 cb = ir.ambassador_module.get("circuit_breakers")
151
152 if cb:
153 self["circuit_breakers"] = cb
154
155 self["allow_request_body"] = module.get("allow_request_body", False)
156 self["include_body"] = module.get("include_body", None)
157 self["api_version"] = module.get("apiVersion", None)
158 self["proto"] = module.get("proto", "http")
159 self["timeout_ms"] = module.get("timeout_ms", 5000)
160 self["connect_timeout_ms"] = module.get("connect_timeout_ms", 3000)
161 self["cluster_idle_timeout_ms"] = module.get("cluster_idle_timeout_ms", None)
162 self["cluster_max_connection_lifetime_ms"] = module.get(
163 "cluster_max_connection_lifetime_ms", None
164 )
165 self["add_auth_headers"] = module.get("add_auth_headers", {})
166 self.__to_header_list("allowed_headers", module)
167 self.__to_header_list("allowed_request_headers", module)
168 self.__to_header_list("allowed_authorization_headers", module)
169
170 if self["proto"] not in ["grpc", "http"]:
171 self.post_error(
172 f'AuthService: proto_version {self["proto"]} is unsupported, proto must be "grpc" or "http"'
173 )
174
175 self.protocol_version = module.get("protocol_version", "v2")
176 if self["proto"] == "grpc" and self.protocol_version not in ["v3"]:
177 self.post_error(
178 f'AuthService: protocol_version {self.protocol_version} is unsupported, protocol_version must be "v3"'
179 )
180
181 self["stats_name"] = module.get("stats_name", None)
182
183 status_on_error = module.get("status_on_error", None)
184 if status_on_error:
185 self["status_on_error"] = status_on_error
186
187 failure_mode_allow = module.get("failure_mode_allow", None)
188 if failure_mode_allow:
189 self["failure_mode_allow"] = failure_mode_allow
190
191 # Required fields check.
192 if self["api_version"] == None:
193 self.post_error(RichStatus.fromError("AuthService config requires apiVersion field"))
194
195 if self["proto"] == None:
196 self.post_error(RichStatus.fromError("AuthService requires proto field."))
197
198 if self.get("include_body") and self.get("allow_request_body"):
199 self.post_error("AuthService ignoring allow_request_body since include_body is present")
200 del self["allow_request_body"]
201
202 auth_service = module.get("auth_service", None)
203 weight = 100 # Can't support arbitrary weights right now.
204
205 if auth_service:
206 is_grpc = True if self["proto"] == "grpc" else False
207 self.hosts[auth_service] = (weight, is_grpc, module.get("tls", None), module.location)
208
209 def __to_header_list(self, list_name, module):
210 headers = module.get(list_name, None)
211
212 if headers:
213 allowed_headers = self.get(list_name, [])
214
215 for hdr in sorted(headers):
216 if hdr.lower() not in allowed_headers:
217 allowed_headers.append(hdr.lower())
218
219 self[list_name] = allowed_headers
View as plain text