1import cProfile
2import difflib
3import logging
4import pstats
5from typing import Optional, Tuple, Union
6
7from ambassador import IR, Cache, Config, EnvoyConfig
8from ambassador.fetch import ResourceFetcher
9from ambassador.ir.ir import IRFileChecker
10from ambassador.utils import NullSecretHandler, SecretHandler, Timer
11
12# Types
13OptionalStats = Optional[pstats.Stats]
14
15
16class Profiler:
17 def __init__(self):
18 self.pr = cProfile.Profile()
19
20 def __enter__(self) -> None:
21 self.pr.enable()
22
23 def __exit__(self, *args) -> None:
24 self.pr.disable()
25
26 def stats(self) -> OptionalStats:
27 return pstats.Stats(self.pr).sort_stats("tottime")
28
29
30class NullProfiler(Profiler):
31 def __init__(self):
32 pass
33
34 def __enter__(self) -> None:
35 pass
36
37 def __exit__(self, *args) -> None:
38 pass
39
40 def stats(self) -> OptionalStats:
41 return None
42
43
44class Madness:
45 def __init__(
46 self,
47 watt_path: Optional[str] = None,
48 yaml_path: Optional[str] = None,
49 logger: Optional[logging.Logger] = None,
50 secret_handler: Optional[SecretHandler] = None,
51 file_checker: Optional[IRFileChecker] = None,
52 ) -> None:
53 if not logger:
54 logging.basicConfig(
55 level=logging.INFO,
56 format="%(asctime)s madness %(levelname)s: %(message)s",
57 datefmt="%Y-%m-%d %H:%M:%S",
58 )
59
60 logger = logging.getLogger("mockery")
61
62 self.logger = logger
63
64 if not secret_handler:
65 secret_handler = NullSecretHandler(logger, None, None, "0")
66
67 if not file_checker:
68 file_checker = lambda f: True
69
70 self.secret_handler = secret_handler
71 self.file_checker = file_checker
72
73 self.reset_cache()
74
75 self.aconf_timer = Timer("aconf")
76 self.fetcher_timer = Timer("fetcher")
77 self.ir_timer = Timer("ir")
78 self.econf_timer = Timer("econf")
79
80 self.aconf = Config()
81
82 with self.fetcher_timer:
83 self.fetcher = ResourceFetcher(self.logger, self.aconf)
84
85 if watt_path:
86 self.fetcher.parse_watt(open(watt_path, "r").read())
87 elif yaml_path:
88 self.fetcher.parse_yaml(open(yaml_path, "r").read(), k8s=True)
89 else:
90 raise RuntimeError("either watt_path or yaml_path must be provided")
91
92 with self.aconf_timer:
93 self.aconf.load_all(self.fetcher.sorted())
94
95 def reset_cache(self) -> None:
96 self.cache = Cache(self.logger)
97
98 def summarize(self) -> None:
99 for timer in [
100 self.fetcher_timer,
101 self.aconf_timer,
102 self.ir_timer,
103 self.econf_timer,
104 ]:
105 if timer:
106 self.logger.info(timer.summary())
107
108 def build_ir(self, cache=True, profile=False, summarize=True) -> Tuple[IR, OptionalStats]:
109 self.ir_timer.reset()
110
111 _cache = self.cache if cache else None
112 _pr = Profiler() if profile else NullProfiler()
113
114 with self.ir_timer:
115 with _pr:
116 ir = IR(self.aconf, cache=_cache, secret_handler=self.secret_handler)
117
118 if summarize:
119 self.summarize()
120
121 return (ir, _pr.stats())
122
123 def build_econf(
124 self, ir: Union[IR, Tuple[IR, OptionalStats]], cache=True, profile=False, summarize=True
125 ) -> Tuple[EnvoyConfig, OptionalStats]:
126 self.econf_timer.reset()
127
128 _cache = self.cache if cache else None
129 _pr = Profiler() if profile else NullProfiler()
130
131 _ir: Optional[IR] = None
132
133 if isinstance(ir, tuple):
134 _ir = ir[0]
135 else:
136 _ir = ir
137
138 assert ir is not None
139
140 with self.econf_timer:
141 with _pr:
142 econf = EnvoyConfig.generate(_ir, "V2", cache=_cache)
143
144 if summarize:
145 self.summarize()
146
147 return (econf, _pr.stats())
148
149 def build(self, cache=True, profile=False) -> Tuple[IR, EnvoyConfig, OptionalStats]:
150 _cache = self.cache if cache else None
151
152 _pr = Profiler() if profile else NullProfiler()
153
154 with _pr:
155 ir, _ = self.build_ir(cache=_cache, profile=False, summarize=False)
156 econf, _ = self.build_econf(ir, cache=_cache, profile=False, summarize=False)
157
158 self.summarize()
159
160 return (ir, econf, _pr.stats())
161
162 def diff(self, *rsrcs) -> None:
163 jsons = [rsrc.as_json() for rsrc in rsrcs]
164
165 if len(set(jsons)) == 1:
166 return
167
168 for i in range(len(rsrcs) - 1):
169 if jsons[i] != jsons[i + 1]:
170 l1 = jsons[i].split("\n")
171 l2 = jsons[i + 1].split("\n")
172
173 n1 = f"rsrcs[{i}]"
174 n2 = f"rsrcs[{i+1}]"
175
176 print("\n--------")
177
178 for line in difflib.context_diff(l1, l2, fromfile=n1, tofile=n2):
179 line = line.rstrip()
180 print(line)
View as plain text