...

Text file src/github.com/emissary-ingress/emissary/v3/python/ambassador/utils.py

Documentation: github.com/emissary-ingress/emissary/v3/python/ambassador

     1#!/usr/bin/env python
     2
     3# Copyright 2018 Datawire. All rights reserved.
     4#
     5# Licensed under the Apache License, Version 2.0 (the "License");
     6# you may not use this file except in compliance with the License.
     7# You may obtain a copy of the License at
     8#
     9#     http://www.apache.org/licenses/LICENSE-2.0
    10#
    11# Unless required by applicable law or agreed to in writing, software
    12# distributed under the License is distributed on an "AS IS" BASIS,
    13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14# See the License for the specific language governing permissions and
    15# limitations under the License
    16
    17import binascii
    18import hashlib
    19import io
    20import logging
    21import os
    22import re
    23import socket
    24import tempfile
    25import threading
    26import time
    27from builtins import bytes
    28from distutils.util import strtobool
    29from typing import TYPE_CHECKING, Any, Dict, List, Optional, TextIO, Union
    30from urllib.parse import urlparse
    31
    32import orjson
    33import requests
    34import yaml
    35from prometheus_client import Gauge
    36
    37from .VERSION import Version
    38
    39if TYPE_CHECKING:
    40    from .config.acresource import ACResource  # pragma: no cover
    41    from .ir import IRResource  # pragma: no cover
    42    from .ir.irtlscontext import IRTLSContext  # pragma: no cover
    43
    44logger = logging.getLogger("utils")
    45logger.setLevel(logging.INFO)
    46
    47# XXX What a hack. There doesn't seem to be a way to convince mypy that SafeLoader
    48# and CSafeLoader share a base class, even though they do. Sigh.
    49
    50yaml_loader: Any = yaml.SafeLoader
    51yaml_dumper: Any = yaml.SafeDumper
    52
    53try:
    54    yaml_loader = yaml.CSafeLoader
    55except AttributeError:
    56    pass
    57
    58try:
    59    yaml_dumper = yaml.CSafeDumper
    60except AttributeError:
    61    pass
    62
    63yaml_logged_loader = False
    64yaml_logged_dumper = False
    65
    66
    67def parse_yaml(serialization: str) -> Any:
    68    global yaml_logged_loader
    69
    70    if not yaml_logged_loader:
    71        yaml_logged_loader = True
    72
    73        # logger.info("YAML: using %s parser" % ("Python" if (yaml_loader == yaml.SafeLoader) else "C"))
    74
    75    return list(yaml.load_all(serialization, Loader=yaml_loader))
    76
    77
    78def dump_yaml(obj: Any, **kwargs) -> str:
    79    global yaml_logged_dumper
    80
    81    if not yaml_logged_dumper:
    82        yaml_logged_dumper = True
    83
    84        # logger.info("YAML: using %s dumper" % ("Python" if (yaml_dumper == yaml.SafeDumper) else "C"))
    85
    86    return yaml.dump(obj, Dumper=yaml_dumper, **kwargs)
    87
    88
    89def parse_json(serialization: str) -> Any:
    90    return orjson.loads(serialization)
    91
    92
    93def dump_json(obj: Any, pretty=False) -> str:
    94    # There's a nicer way to do this in python, I'm sure.
    95    if pretty:
    96        return bytes.decode(
    97            orjson.dumps(
    98                obj, option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2
    99            )
   100        )
   101    else:
   102        return bytes.decode(orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS))
   103
   104
   105def _load_url_contents(
   106    logger: logging.Logger, url: str, stream1: TextIO, stream2: Optional[TextIO] = None
   107) -> bool:
   108    saved = False
   109
   110    try:
   111        with requests.get(url) as r:
   112            if r.status_code == 200:
   113
   114                # All's well, pull the config down.
   115                encoded = b""
   116
   117                try:
   118                    for chunk in r.iter_content(chunk_size=65536):
   119                        # We do this by hand instead of with 'decode_unicode=True'
   120                        # above because setting decode_unicode only decodes text,
   121                        # and WATT hands us application/json...
   122                        encoded += chunk
   123
   124                    decoded = encoded.decode("utf-8")
   125                    stream1.write(decoded)
   126
   127                    if stream2:
   128                        stream2.write(decoded)
   129
   130                    saved = True
   131                except IOError as e:
   132                    logger.error("couldn't save Kubernetes resources: %s" % e)
   133                except Exception as e:
   134                    logger.error("couldn't read Kubernetes resources: %s" % e)
   135    except requests.exceptions.RequestException as e:
   136        logger.error("could not load new snapshot: %s" % e)
   137
   138    return saved
   139
   140
   141def save_url_contents(
   142    logger: logging.Logger, url: str, path: str, stream2: Optional[TextIO] = None
   143) -> bool:
   144    with open(path, "w", encoding="utf-8") as stream:
   145        return _load_url_contents(logger, url, stream, stream2=stream2)
   146
   147
   148def load_url_contents(
   149    logger: logging.Logger, url: str, stream2: Optional[TextIO] = None
   150) -> Optional[str]:
   151    stream = io.StringIO()
   152
   153    saved = _load_url_contents(logger, url, stream, stream2=stream2)
   154
   155    if saved:
   156        return stream.getvalue()
   157    else:
   158        return None
   159
   160
   161def parse_bool(s: Optional[Union[str, bool]]) -> bool:
   162    """
   163    Parse a boolean value from a string. T, True, Y, y, 1 return True;
   164    other things return False.
   165    """
   166
   167    # If `s` is already a bool, return its value.
   168    #
   169    # This allows a caller to not know or care whether their value is already
   170    # a boolean, or if it is a string that needs to be parsed below.
   171    if isinstance(s, bool):
   172        return s
   173
   174    # If we didn't get anything at all, return False.
   175    if not s:
   176        return False
   177
   178    # OK, we got _something_, so try strtobool.
   179    try:
   180        return bool(strtobool(s))  # the linter does not like a Literal[0, 1] being returned here
   181    except ValueError:
   182        return False
   183
   184
   185class SystemInfo:
   186    MyHostName = os.environ.get("HOSTNAME", None)
   187
   188    if not MyHostName:
   189        MyHostName = "localhost"
   190
   191        try:
   192            MyHostName = socket.gethostname()
   193        except:
   194            pass
   195
   196
   197class RichStatus:
   198    def __init__(self, ok, **kwargs):
   199        self.ok = ok
   200        self.info = kwargs
   201        self.info["hostname"] = SystemInfo.MyHostName
   202        self.info["version"] = Version
   203
   204    # Remember that __getattr__ is called only as a last resort if the key
   205    # isn't a normal attr.
   206    def __getattr__(self, key):
   207        return self.info.get(key)
   208
   209    def __bool__(self):
   210        return self.ok
   211
   212    def __nonzero__(self):
   213        return bool(self)
   214
   215    def __contains__(self, key):
   216        return key in self.info
   217
   218    def __str__(self):
   219        attrs = ["%s=%s" % (key, repr(self.info[key])) for key in sorted(self.info.keys())]
   220        astr = " ".join(attrs)
   221
   222        if astr:
   223            astr = " " + astr
   224
   225        return "<RichStatus %s%s>" % ("OK" if self else "BAD", astr)
   226
   227    def as_dict(self):
   228        d = {"ok": self.ok}
   229
   230        for key in self.info.keys():
   231            d[key] = self.info[key]
   232
   233        return d
   234
   235    @classmethod
   236    def fromError(self, error, **kwargs):
   237        kwargs["error"] = error
   238        return RichStatus(False, **kwargs)
   239
   240    @classmethod
   241    def OK(self, **kwargs):
   242        return RichStatus(True, **kwargs)
   243
   244
   245class Timer:
   246    """
   247    Timer is a simple class to measure time. When a Timer is created,
   248    it is given a name, and is stopped.
   249
   250    t = Timer("test timer")
   251
   252    The simplest way to use the Timer is as a context manager:
   253
   254    with t:
   255        something_to_be_timed()
   256
   257    You can also use the start method to start the timer:
   258
   259    t.start()
   260
   261    ...and the .stop method to stop the timer and update the timer's
   262    records.
   263
   264    t.stop()
   265
   266    Timers record the accumulated time and the number of start/stop
   267    cycles (in .accumulated and .cycles respectively). They can also
   268    return the average time per cycle (.average) and minimum and
   269    maximum times per cycle (.minimum and .maximum).
   270    """
   271
   272    name: str
   273    _cycles: int
   274    _starttime: float
   275    _accumulated: float
   276    _minimum: float
   277    _maximum: float
   278    _running: bool
   279    _faketime: float
   280    _gauge: Optional[Gauge] = None
   281
   282    def __init__(self, name: str, prom_metrics_registry: Optional[Any] = None) -> None:
   283        """
   284        Create a Timer, given a name. The Timer is initially stopped.
   285        """
   286
   287        self.name = name
   288
   289        if prom_metrics_registry:
   290            metric_prefix = re.sub(r"\s+", "_", name).lower()
   291            self._gauge = Gauge(
   292                f"{metric_prefix}_time_seconds",
   293                f"Elapsed time on {name} operations",
   294                namespace="ambassador",
   295                registry=prom_metrics_registry,
   296            )
   297
   298        self.reset()
   299
   300    def reset(self) -> None:
   301        self._cycles = 0
   302        self._starttime = 0
   303        self._accumulated = 0.0
   304        self._minimum = 999999999999
   305        self._maximum = -999999999999
   306        self._running = False
   307        self._faketime = 0.0
   308
   309    def __enter__(self):
   310        self.start()
   311        return self
   312
   313    def __exit__(self, type, value, traceback):
   314        self.stop()
   315
   316    def __bool__(self) -> bool:
   317        """
   318        Timers test True in a boolean context if they have timed at least one
   319        cycle.
   320        """
   321        return self._cycles > 0
   322
   323    def start(self, when: Optional[float] = None) -> None:
   324        """
   325        Start a Timer running.
   326
   327        :param when: Optional start time. If not supplied,
   328        the current time is used.
   329        """
   330
   331        # If we're already running, this method silently discards the
   332        # currently-running cycle. Why? Because otherwise, it's a little
   333        # too easy to forget to stop a Timer, cause an Exception, and
   334        # crash the world.
   335        #
   336        # Not that I ever got bitten by this. Of course. [ :P ]
   337
   338        self._starttime = when or time.perf_counter()
   339        self._running = True
   340
   341    def stop(self, when: Optional[float] = None) -> float:
   342        """
   343        Stop a Timer, increment the cycle count, and update the
   344        accumulated time with the amount of time since the Timer
   345        was started.
   346
   347        :param when: Optional stop time. If not supplied,
   348        the current time is used.
   349        :return: The amount of time the Timer has accumulated
   350        """
   351
   352        # If we're already stopped, just return the same thing as the
   353        # previous call to stop. See comments in start() for why this
   354        # isn't an Exception...
   355
   356        if self._running:
   357            if not when:
   358                when = time.perf_counter()
   359
   360            self._running = False
   361            self._cycles += 1
   362
   363            this_cycle = (when - self._starttime) + self._faketime
   364            if self._gauge:
   365                self._gauge.set(this_cycle)
   366
   367            self._faketime = 0
   368
   369            self._accumulated += this_cycle
   370
   371            if this_cycle < self._minimum:
   372                self._minimum = this_cycle
   373
   374            if this_cycle > self._maximum:
   375                self._maximum = this_cycle
   376
   377        return self._accumulated
   378
   379    def faketime(self, faketime: float) -> None:
   380        """
   381        Add fake time to a Timer. This is intended solely for
   382        testing.
   383        """
   384
   385        if not self._running:
   386            raise Exception(f"Timer {self.name}.faketime: not running")
   387
   388        self._faketime = faketime
   389
   390    @property
   391    def cycles(self):
   392        """
   393        The number of timing cycles this Timer has recorded.
   394        """
   395        return self._cycles
   396
   397    @property
   398    def starttime(self):
   399        """
   400        The time this Timer was last started, or 0 if it has
   401        never been started.
   402        """
   403        return self._starttime
   404
   405    @property
   406    def accumulated(self):
   407        """
   408        The amount of time this Timer has accumulated.
   409        """
   410        return self._accumulated
   411
   412    @property
   413    def minimum(self):
   414        """
   415        The minimum single-cycle time this Timer has recorded.
   416        """
   417        return self._minimum
   418
   419    @property
   420    def maximum(self):
   421        """
   422        The maximum single-cycle time this Timer has recorded.
   423        """
   424        return self._maximum
   425
   426    @property
   427    def average(self):
   428        """
   429        The average cycle time for this Timer.
   430        """
   431        if self._cycles > 0:
   432            return self._accumulated / self._cycles
   433
   434        raise Exception(f"Timer {self.name}.average: no cycles to average")
   435
   436    @property
   437    def running(self):
   438        """
   439        Whether or not this Timer is running.
   440        """
   441        return self._running
   442
   443    def __str__(self) -> str:
   444        s = "Timer %s: " % self.name
   445
   446        if self._running:
   447            s += "running, "
   448
   449        s += "%.6f sec" % self._accumulated
   450
   451        return s
   452
   453    def summary(self) -> str:
   454        """
   455        Return a summary of this Timer.
   456        """
   457
   458        return "TIMER %s: %d, %.3f/%.3f/%.3f" % (
   459            self.name,
   460            self.cycles,
   461            self.minimum,
   462            self.average,
   463            self.maximum,
   464        )
   465
   466
   467class DelayTrigger(threading.Thread):
   468    def __init__(self, onfired, timeout=5, name=None):
   469        super().__init__()
   470
   471        if name:
   472            self.name = name
   473
   474        self.trigger_source, self.trigger_dest = socket.socketpair()
   475
   476        self.onfired = onfired
   477        self.timeout = timeout
   478
   479        self.setDaemon(True)
   480        self.start()
   481
   482    def trigger(self):
   483        self.trigger_source.sendall(b"X")
   484
   485    def run(self):
   486        while True:
   487            self.trigger_dest.settimeout(None)
   488            x = self.trigger_dest.recv(128)
   489
   490            self.trigger_dest.settimeout(self.timeout)
   491
   492            while True:
   493                try:
   494                    x = self.trigger_dest.recv(128)
   495                except socket.timeout:
   496                    self.onfired()
   497                    break
   498
   499
   500class PeriodicTrigger(threading.Thread):
   501    def __init__(self, onfired, period=5, name=None):
   502        super().__init__()
   503
   504        if name:
   505            self.name = name
   506
   507        self.onfired = onfired
   508        self.period = period
   509
   510        self.daemon = True
   511        self.start()
   512
   513    def trigger(self):
   514        pass
   515
   516    def run(self):
   517        while True:
   518            time.sleep(self.period)
   519            self.onfired()
   520
   521
   522class SecretInfo:
   523    """
   524    SecretInfo encapsulates a secret, including its name, its namespace, and all of its
   525    ciphertext elements. Pretty much everything in Ambassador that worries about secrets
   526    uses a SecretInfo.
   527    """
   528
   529    def __init__(
   530        self,
   531        name: str,
   532        namespace: str,
   533        secret_type: str,
   534        tls_crt: Optional[str] = None,
   535        tls_key: Optional[str] = None,
   536        user_key: Optional[str] = None,
   537        root_crt: Optional[str] = None,
   538        decode_b64=True,
   539    ) -> None:
   540        self.name = name
   541        self.namespace = namespace
   542        self.secret_type = secret_type
   543
   544        if decode_b64:
   545            if self.is_decodable(tls_crt):
   546                assert tls_crt
   547                tls_crt = self.decode(tls_crt)
   548
   549            if self.is_decodable(tls_key):
   550                assert tls_key
   551                tls_key = self.decode(tls_key)
   552
   553            if self.is_decodable(user_key):
   554                assert user_key
   555                user_key = self.decode(user_key)
   556
   557            if self.is_decodable(root_crt):
   558                assert root_crt
   559                root_crt = self.decode(root_crt)
   560
   561        self.tls_crt = tls_crt
   562        self.tls_key = tls_key
   563        self.user_key = user_key
   564        self.root_crt = root_crt
   565
   566    @staticmethod
   567    def is_decodable(b64_pem: Optional[str]) -> bool:
   568        if not b64_pem:
   569            return False
   570
   571        return not (b64_pem.startswith("-----BEGIN") or b64_pem.startswith("-sanitized-"))
   572
   573    @staticmethod
   574    def decode(b64_pem: str) -> Optional[str]:
   575        """
   576        Do base64 decoding of a cryptographic element.
   577
   578        :param b64_pem: Base64-encoded PEM element
   579        :return: Decoded PEM element
   580        """
   581        utf8_pem = None
   582        pem = None
   583
   584        try:
   585            utf8_pem = binascii.a2b_base64(b64_pem)
   586        except binascii.Error:
   587            return None
   588
   589        try:
   590            pem = utf8_pem.decode("utf-8")
   591        except UnicodeDecodeError:
   592            return None
   593
   594        return pem
   595
   596    @staticmethod
   597    def fingerprint(pem: Optional[str]) -> str:
   598        """
   599        Generate and return a cryptographic fingerprint of a PEM element.
   600
   601        The fingerprint is the uppercase hex SHA-1 signature of the element's UTF-8
   602        representation.
   603
   604        :param pem: PEM element
   605        :return: fingerprint string
   606        """
   607        if not pem:
   608            return "<none>"
   609
   610        h = hashlib.new("sha1")
   611        h.update(pem.encode("utf-8"))
   612        hd = h.hexdigest()[0:16].upper()
   613
   614        keytype = "PEM" if pem.startswith("-----BEGIN") else "RAW"
   615
   616        return f"{keytype}: {hd}"
   617
   618    def to_dict(self) -> Dict[str, Any]:
   619        """
   620        Return the dictionary representation of this SecretInfo.
   621
   622        :return: dict
   623        """
   624        return {
   625            "name": self.name,
   626            "namespace": self.namespace,
   627            "secret_type": self.secret_type,
   628            "tls_crt": self.fingerprint(self.tls_crt),
   629            "tls_key": self.fingerprint(self.tls_key),
   630            "user_key": self.fingerprint(self.user_key),
   631            "root_crt": self.fingerprint(self.root_crt),
   632        }
   633
   634    @classmethod
   635    def from_aconf_secret(cls, aconf_object: "ACResource") -> "SecretInfo":
   636        """
   637        Convert an ACResource containing a secret into a SecretInfo. This is used by the IR.save_secret_info()
   638        to convert saved secrets into SecretInfos.
   639
   640        :param aconf_object: a ACResource containing a secret
   641        :return: SecretInfo
   642        """
   643
   644        tls_crt = aconf_object.get("tls_crt", None)
   645        if not tls_crt:
   646            tls_crt = aconf_object.get("cert-chain_pem")
   647
   648        tls_key = aconf_object.get("tls_key", None)
   649        if not tls_key:
   650            tls_key = aconf_object.get("key_pem")
   651
   652        user_key = aconf_object.get("user_key", None)
   653        if not user_key:
   654            # We didn't have a 'user_key', do we have a `crl_pem` instead?
   655            user_key = aconf_object.get("crl_pem", None)
   656
   657        return SecretInfo(
   658            aconf_object.name,
   659            aconf_object.namespace,
   660            aconf_object.secret_type,
   661            tls_crt,
   662            tls_key,
   663            user_key,
   664            aconf_object.get("root-cert_pem", None),
   665        )
   666
   667    @classmethod
   668    def from_dict(
   669        cls,
   670        resource: "IRResource",
   671        secret_name: str,
   672        namespace: str,
   673        source: str,
   674        cert_data: Optional[Dict[str, Any]],
   675        secret_type="kubernetes.io/tls",
   676    ) -> Optional["SecretInfo"]:
   677        """
   678        Given a secret's name and namespace, and a dictionary of configuration elements, return
   679        a SecretInfo for the secret.
   680
   681        The "source" parameter needs some explanation. When working with secrets in most environments
   682        where Ambassador runs, secrets will be loaded from some external system (e.g. Kubernetes),
   683        and serialized to disk, and the disk serialization is the thing we can actually read the
   684        dictionary of secret data from. The "source" parameter is the thing we read to get the actual
   685        dictionary -- in our example above, "source" would be the pathname of the serialization on
   686        disk, rather than the Kubernetes resource name.
   687
   688        :param resource: owning IRResource
   689        :param secret_name: name of secret
   690        :param namespace: namespace of secret
   691        :param source: source of data
   692        :param cert_data: dictionary of secret info (public and private key, etc.)
   693        :param secret_type: Kubernetes-style secret type
   694        :return:
   695        """
   696        tls_crt = None
   697        tls_key = None
   698        user_key = None
   699
   700        if not cert_data:
   701            resource.ir.logger.error(
   702                f"{resource.kind} {resource.name}: found no certificate in {source}?"
   703            )
   704            return None
   705
   706        if secret_type == "kubernetes.io/tls":
   707            # OK, we have something to work with. Hopefully.
   708            tls_crt = cert_data.get("tls.crt", None)
   709
   710            if not tls_crt:
   711                # Having no public half is definitely an error. Having no private half given a public half
   712                # might be OK, though -- that's up to our caller to decide.
   713                resource.ir.logger.error(
   714                    f"{resource.kind} {resource.name}: found data but no certificate in {source}?"
   715                )
   716                return None
   717
   718            tls_key = cert_data.get("tls.key", None)
   719        elif secret_type == "Opaque":
   720            user_key = cert_data.get("user.key", None)
   721
   722            if not user_key:
   723                # The opaque keys we support must have user.key, but will likely have nothing else.
   724                resource.ir.logger.error(
   725                    f"{resource.kind} {resource.name}: found data but no user.key in {source}?"
   726                )
   727                return None
   728
   729            cert = None
   730        elif secret_type == "istio.io/key-and-cert":
   731            resource.ir.logger.error(
   732                f"{resource.kind} {resource.name}: found data but handler for istio key not finished yet"
   733            )
   734
   735        return SecretInfo(
   736            secret_name, namespace, secret_type, tls_crt=tls_crt, tls_key=tls_key, user_key=user_key
   737        )
   738
   739
   740class SavedSecret:
   741    """
   742    SavedSecret collects information about a secret saved locally, including its name, namespace,
   743    paths to its elements on disk, and a copy of its cert data dictionary.
   744
   745    It's legal for a SavedSecret to have paths, etc, of None, representing a secret for which we
   746    found no information. SavedSecret will evaluate True as a boolean if - and only if - it has
   747    the minimal information needed to represent a real secret.
   748    """
   749
   750    def __init__(
   751        self,
   752        secret_name: str,
   753        namespace: str,
   754        cert_path: Optional[str],
   755        key_path: Optional[str],
   756        user_path: Optional[str],
   757        root_cert_path: Optional[str],
   758        cert_data: Optional[Dict],
   759    ) -> None:
   760        self.secret_name = secret_name
   761        self.namespace = namespace
   762        self.cert_path = cert_path
   763        self.key_path = key_path
   764        self.user_path = user_path
   765        self.root_cert_path = root_cert_path
   766        self.cert_data = cert_data
   767
   768    @property
   769    def name(self) -> str:
   770        return "secret %s in namespace %s" % (self.secret_name, self.namespace)
   771
   772    def __bool__(self) -> bool:
   773        return bool((bool(self.cert_path) or bool(self.user_path)) and (self.cert_data is not None))
   774
   775    def __str__(self) -> str:
   776        return (
   777            "<SavedSecret %s.%s -- cert_path %s, key_path %s, user_path %s, root_cert_path %s, cert_data %s>"
   778            % (
   779                self.secret_name,
   780                self.namespace,
   781                self.cert_path,
   782                self.key_path,
   783                self.user_path,
   784                self.root_cert_path,
   785                "present" if self.cert_data else "absent",
   786            )
   787        )
   788
   789
   790class SecretHandler:
   791    """
   792    SecretHandler: manage secrets for Ambassador. There are two fundamental rules at work here:
   793
   794    - The Python part of Ambassador doesn’t get to talk directly to Kubernetes. Part of this is
   795      because the Python K8s client isn’t maintained all that well. Part is because, for testing,
   796      we need to be able to separate secrets from Kubernetes.
   797    - Most of the handling of secrets (e.g. saving the actual bits of the certs) need to be
   798      common code paths, so that testing them outside of Kube gives results that are valid inside
   799      Kube.
   800
   801    To work within these rules, you’re required to pass a SecretHandler when instantiating an IR.
   802    The SecretHandler mediates access to secrets outside Ambassador, and to the cache of secrets
   803    we've already loaded.
   804
   805    SecretHandler subclasses will typically only need to override load_secret: the other methods
   806    of SecretHandler generally won't need to change, and arguably should not be considered part
   807    of the public interface of SecretHandler.
   808
   809    Finally, note that SecretHandler itself is deliberately written to work correctly with
   810    secrets as they're handed over from watt, which means that it can be instantiated directly
   811    and handed to the IR when we're running "for real" in Kubernetes with watt. Other things
   812    (like mockery and the watch_hook) use subclasses to manage specific needs that they have.
   813    """
   814
   815    logger: logging.Logger
   816    source_root: str
   817    cache_dir: str
   818
   819    def __init__(
   820        self, logger: logging.Logger, source_root: str, cache_dir: str, version: str
   821    ) -> None:
   822        self.logger = logger
   823        self.source_root = source_root
   824        self.cache_dir = cache_dir
   825        self.version = version
   826
   827    def load_secret(
   828        self, resource: "IRResource", secret_name: str, namespace: str
   829    ) -> Optional[SecretInfo]:
   830        """
   831        load_secret: given a secret’s name and namespace, pull it from wherever it really lives,
   832        write it to disk, and return a SecretInfo telling the rest of Ambassador where it got written.
   833
   834        This is the fallback load_secret implementation, which doesn't do anything: it is written
   835        assuming that ir.save_secret_info has already filled ir.saved_secrets with any secrets handed in
   836        from watt, so that load_secrets will never be called for those secrets. Therefore, if load_secrets
   837        gets called at all, it's for a secret that wasn't found, and it should just return None.
   838
   839        :param resource: referencing resource (so that we can correctly default the namespace)
   840        :param secret_name: name of the secret
   841        :param namespace: namespace, if any specific namespace was given
   842        :return: Optional[SecretInfo]
   843        """
   844
   845        self.logger.debug(
   846            "SecretHandler (%s %s): load secret %s in namespace %s"
   847            % (resource.kind, resource.name, secret_name, namespace)
   848        )
   849
   850        return None
   851
   852    def still_needed(self, resource: "IRResource", secret_name: str, namespace: str) -> None:
   853        """
   854        still_needed: remember that a given secret is still needed, so that we can tell watt to
   855        keep paying attention to it.
   856
   857        The default implementation doesn't do much of anything, because it assumes that we're
   858        not running in the watch_hook, so watt has already been told everything it needs to be
   859        told. This should be OK for everything that's not the watch_hook.
   860
   861        :param resource: referencing resource
   862        :param secret_name: name of the secret
   863        :param namespace: namespace of the secret
   864        :return: None
   865        """
   866
   867        self.logger.debug(
   868            "SecretHandler (%s %s): secret %s in namespace %s is still needed"
   869            % (resource.kind, resource.name, secret_name, namespace)
   870        )
   871
   872    def cache_secret(self, resource: "IRResource", secret_info: SecretInfo) -> SavedSecret:
   873        """
   874        cache_secret: stash the SecretInfo from load_secret into Ambassador’s internal cache,
   875        so that we don’t have to call load_secret again if we need it again.
   876
   877        The default implementation should be usable by everything that's not the watch_hook.
   878
   879        :param resource: referencing resource
   880        :param secret_info: SecretInfo returned from load_secret
   881        :return: SavedSecret
   882        """
   883
   884        name = secret_info.name
   885        namespace = secret_info.namespace
   886        tls_crt = secret_info.tls_crt
   887        tls_key = secret_info.tls_key
   888        user_key = secret_info.user_key
   889        root_crt = secret_info.root_crt
   890
   891        return self.cache_internal(name, namespace, tls_crt, tls_key, user_key, root_crt)
   892
   893    def cache_internal(
   894        self,
   895        name: str,
   896        namespace: str,
   897        tls_crt: Optional[str],
   898        tls_key: Optional[str],
   899        user_key: Optional[str],
   900        root_crt: Optional[str],
   901    ) -> SavedSecret:
   902        h = hashlib.new("sha1")
   903
   904        tls_crt_path = None
   905        tls_key_path = None
   906        user_key_path = None
   907        root_crt_path = None
   908        cert_data = None
   909
   910        # Don't save if it has neither a tls_crt or a user_key or the root_crt
   911        if tls_crt or user_key or root_crt:
   912            for el in [tls_crt, tls_key, user_key]:
   913                if el:
   914                    h.update(el.encode("utf-8"))
   915
   916            hd = h.hexdigest().upper()
   917
   918            secret_dir = os.path.join(self.cache_dir, namespace, "secrets-decoded", name)
   919
   920            try:
   921                os.makedirs(secret_dir)
   922            except FileExistsError:
   923                pass
   924
   925            if tls_crt:
   926                tls_crt_path = os.path.join(secret_dir, f"{hd}.crt")
   927                open(tls_crt_path, "w").write(tls_crt)
   928
   929            if tls_key:
   930                tls_key_path = os.path.join(secret_dir, f"{hd}.key")
   931                open(tls_key_path, "w").write(tls_key)
   932
   933            if user_key:
   934                user_key_path = os.path.join(secret_dir, f"{hd}.user")
   935                open(user_key_path, "w").write(user_key)
   936
   937            if root_crt:
   938                root_crt_path = os.path.join(secret_dir, f"{hd}.root.crt")
   939                open(root_crt_path, "w").write(root_crt)
   940
   941            cert_data = {
   942                "tls_crt": tls_crt,
   943                "tls_key": tls_key,
   944                "user_key": user_key,
   945                "root_crt": root_crt,
   946            }
   947
   948            self.logger.debug(
   949                f"saved secret {name}.{namespace}: {tls_crt_path}, {tls_key_path}, {root_crt_path}"
   950            )
   951
   952        return SavedSecret(
   953            name, namespace, tls_crt_path, tls_key_path, user_key_path, root_crt_path, cert_data
   954        )
   955
   956    def secret_info_from_k8s(
   957        self,
   958        resource: "IRResource",
   959        secret_name: str,
   960        namespace: str,
   961        source: str,
   962        serialization: Optional[str],
   963    ) -> Optional[SecretInfo]:
   964        """
   965        secret_info_from_k8s is NO LONGER USED.
   966        """
   967
   968        objects: Optional[List[Any]] = None
   969
   970        self.logger.debug(f"getting secret info for secret {secret_name} from k8s")
   971
   972        # If serialization is None or empty, we'll just return None.
   973
   974        if serialization:
   975            try:
   976                objects = parse_yaml(serialization)
   977            except yaml.error.YAMLError as e:
   978                self.logger.error(f"{resource.kind} {resource.name}: could not parse {source}: {e}")
   979
   980        if not objects:
   981            # Nothing in the serialization, we're done.
   982            return None
   983
   984        secret_type = None
   985        cert_data = None
   986        ocount = 0
   987        errors = 0
   988
   989        for obj in objects:
   990            ocount += 1
   991            kind = obj.get("kind", None)
   992
   993            if kind != "Secret":
   994                self.logger.error(
   995                    "%s %s: found K8s %s at %s.%d?"
   996                    % (resource.kind, resource.name, kind, source, ocount)
   997                )
   998                errors += 1
   999                continue
  1000
  1001            metadata = obj.get("metadata", None)
  1002
  1003            if not metadata:
  1004                self.logger.error(
  1005                    "%s %s: found K8s Secret with no metadata at %s.%d?"
  1006                    % (resource.kind, resource.name, source, ocount)
  1007                )
  1008                errors += 1
  1009                continue
  1010
  1011            secret_type = metadata.get("type", "kubernetes.io/tls")
  1012
  1013            if "data" in obj:
  1014                if cert_data:
  1015                    self.logger.error(
  1016                        "%s %s: found multiple Secrets in %s?"
  1017                        % (resource.kind, resource.name, source)
  1018                    )
  1019                    errors += 1
  1020                    continue
  1021
  1022                cert_data = obj["data"]
  1023
  1024        if errors:
  1025            # Bzzt.
  1026            return None
  1027
  1028        return SecretInfo.from_dict(
  1029            resource, secret_name, namespace, source, cert_data=cert_data, secret_type=secret_type
  1030        )
  1031
  1032
  1033class NullSecretHandler(SecretHandler):
  1034    def __init__(
  1035        self,
  1036        logger: logging.Logger,
  1037        source_root: Optional[str],
  1038        cache_dir: Optional[str],
  1039        version: str,
  1040    ) -> None:
  1041        """
  1042        Returns a valid SecretInfo (with fake keys) for any requested secret. Also, you can pass
  1043        None for source_root and cache_dir to use random temporary directories for them.
  1044        """
  1045
  1046        if not source_root:
  1047            self.tempdir_source = tempfile.TemporaryDirectory(
  1048                prefix="null-secret-", suffix="-source"
  1049            )
  1050            source_root = self.tempdir_source.name
  1051
  1052        if not cache_dir:
  1053            self.tempdir_cache = tempfile.TemporaryDirectory(prefix="null-secret-", suffix="-cache")
  1054            cache_dir = self.tempdir_cache.name
  1055
  1056        logger.info(f"NullSecretHandler using source_root {source_root}, cache_dir {cache_dir}")
  1057
  1058        super().__init__(logger, source_root, cache_dir, version)
  1059
  1060    def load_secret(
  1061        self, resource: "IRResource", secret_name: str, namespace: str
  1062    ) -> Optional[SecretInfo]:
  1063        # In the Real World, the secret loader should, y'know, load secrets..
  1064        # Here we're just gonna fake it.
  1065        self.logger.debug(
  1066            "NullSecretHandler (%s %s): load secret %s in namespace %s"
  1067            % (resource.kind, resource.name, secret_name, namespace)
  1068        )
  1069
  1070        return SecretInfo(
  1071            secret_name,
  1072            namespace,
  1073            "fake-secret",
  1074            "fake-tls-crt",
  1075            "fake-tls-key",
  1076            "fake-user-key",
  1077            decode_b64=False,
  1078        )
  1079
  1080
  1081class EmptySecretHandler(SecretHandler):
  1082    def __init__(
  1083        self,
  1084        logger: logging.Logger,
  1085        source_root: Optional[str],
  1086        cache_dir: Optional[str],
  1087        version: str,
  1088    ) -> None:
  1089        """
  1090        Returns a None to simulate no provided secrets
  1091        """
  1092        super().__init__(logger, "", "", version)
  1093
  1094    def load_secret(
  1095        self, resource: "IRResource", secret_name: str, namespace: str
  1096    ) -> Optional[SecretInfo]:
  1097        return None
  1098
  1099
  1100class FSSecretHandler(SecretHandler):
  1101    # XXX NO LONGER USED
  1102    def load_secret(
  1103        self, resource: "IRResource", secret_name: str, namespace: str
  1104    ) -> Optional[SecretInfo]:
  1105        self.logger.debug(
  1106            "FSSecretHandler (%s %s): load secret %s in namespace %s"
  1107            % (resource.kind, resource.name, secret_name, namespace)
  1108        )
  1109
  1110        source = os.path.join(self.source_root, namespace, "secrets", "%s.yaml" % secret_name)
  1111
  1112        serialization = None
  1113
  1114        try:
  1115            serialization = open(source, "r").read()
  1116        except IOError as e:
  1117            self.logger.error(
  1118                "%s %s: FSSecretHandler could not open %s" % (resource.kind, resource.name, source)
  1119            )
  1120
  1121        # Yes, this duplicates part of self.secret_info_from_k8s, but whatever.
  1122        objects: Optional[List[Any]] = None
  1123
  1124        # If serialization is None or empty, we'll just return None.
  1125        if serialization:
  1126            try:
  1127                objects = parse_yaml(serialization)
  1128            except yaml.error.YAMLError as e:
  1129                self.logger.error(
  1130                    "%s %s: could not parse %s: %s" % (resource.kind, resource.name, source, e)
  1131                )
  1132
  1133        if not objects:
  1134            # Nothing in the serialization, we're done.
  1135            return None
  1136
  1137        if len(objects) != 1:
  1138            self.logger.error(
  1139                "%s %s: found %d objects in %s instead of exactly 1"
  1140                % (resource.kind, resource.name, len(objects), source)
  1141            )
  1142            return None
  1143
  1144        obj = objects[0]
  1145
  1146        version = obj.get("apiVersion", None)
  1147        kind = obj.get("kind", None)
  1148
  1149        if (kind == "Secret") and (
  1150            version.startswith("ambassador") or version.startswith("getambassador.io")
  1151        ):
  1152            # It's an Ambassador Secret. It should have a public key and maybe a private key.
  1153            secret_type = obj.get("type", "kubernetes.io/tls")
  1154            return SecretInfo.from_dict(
  1155                resource, secret_name, namespace, source, cert_data=obj, secret_type=secret_type
  1156            )
  1157
  1158        # Didn't look like an Ambassador object. Try K8s.
  1159        return self.secret_info_from_k8s(resource, secret_name, namespace, source, serialization)
  1160
  1161
  1162class KubewatchSecretHandler(SecretHandler):
  1163    # XXX NO LONGER USED
  1164    def load_secret(
  1165        self, resource: "IRResource", secret_name: str, namespace: str
  1166    ) -> Optional[SecretInfo]:
  1167        self.logger.debug(
  1168            "FSSecretHandler (%s %s): load secret %s in namespace %s"
  1169            % (resource.kind, resource.name, secret_name, namespace)
  1170        )
  1171
  1172        source = "%s/secrets/%s/%s" % (self.source_root, namespace, secret_name)
  1173        serialization = load_url_contents(self.logger, source)
  1174
  1175        if not serialization:
  1176            self.logger.error(
  1177                "%s %s: SCC.url_reader could not load %s" % (resource.kind, resource.name, source)
  1178            )
  1179
  1180        return self.secret_info_from_k8s(resource, secret_name, namespace, source, serialization)
  1181
  1182
  1183# TODO(gsagula): This duplicates code from ircluster.py.
  1184class ParsedService:
  1185    def __init__(self, logger, service: str, allow_scheme=True, ctx_name: str = None) -> None:
  1186        original_service = service
  1187
  1188        originate_tls = False
  1189
  1190        self.scheme = "http"
  1191        self.errors: List[str] = []
  1192        self.name_fields: List[str] = []
  1193        self.ctx_name = ctx_name
  1194
  1195        if allow_scheme and service.lower().startswith("https://"):
  1196            service = service[len("https://") :]
  1197
  1198            originate_tls = True
  1199            self.name_fields.append("otls")
  1200
  1201        elif allow_scheme and service.lower().startswith("http://"):
  1202            service = service[len("http://") :]
  1203
  1204            if ctx_name:
  1205                self.errors.append(
  1206                    f"Originate-TLS context {ctx_name} being used even though service {service} lists HTTP"
  1207                )
  1208                originate_tls = True
  1209                self.name_fields.append("otls")
  1210            else:
  1211                originate_tls = False
  1212
  1213        elif ctx_name:
  1214            # No scheme (or schemes are ignored), but we have a context.
  1215            originate_tls = True
  1216            self.name_fields.append("otls")
  1217            self.name_fields.append(ctx_name)
  1218
  1219        if "://" in service:
  1220            idx = service.index("://")
  1221            scheme = service[0:idx]
  1222
  1223            if allow_scheme:
  1224                self.errors.append(
  1225                    f"service {service} has unknown scheme {scheme}, assuming {self.scheme}"
  1226                )
  1227            else:
  1228                self.errors.append(
  1229                    f"ignoring scheme {scheme} for service {service}, since it is being used for a non-HTTP mapping"
  1230                )
  1231
  1232            service = service[idx + 3 :]
  1233
  1234        # # XXX Should this be checking originate_tls? Why does it do that?
  1235        # if originate_tls and host_rewrite:
  1236        #     name_fields.append("hr-%s" % host_rewrite)
  1237
  1238        # Parse the service as a URL. Note that we have to supply a scheme to urllib's
  1239        # parser, because it's kind of stupid.
  1240
  1241        logger.debug(
  1242            f"Service: {original_service} otls {originate_tls} ctx {ctx_name} -> {self.scheme}, {service}"
  1243        )
  1244        p = urlparse("random://" + service)
  1245
  1246        # Is there any junk after the host?
  1247
  1248        if p.path or p.params or p.query or p.fragment:
  1249            self.errors.append(
  1250                f"service {service} has extra URL components; ignoring everything but the host and port"
  1251            )
  1252
  1253        # p is read-only, so break stuff out.
  1254
  1255        self.hostname = p.hostname
  1256        try:
  1257            self.port = p.port
  1258        except ValueError as e:
  1259            self.errors.append(
  1260                "found invalid port for service {}. Please specify a valid port between 0 and 65535 - {}. Service {} cluster will be ignored, please re-configure".format(
  1261                    service, e, service
  1262                )
  1263            )
  1264            self.port = 0
  1265
  1266        # If the port is unset, fix it up.
  1267        if not self.port:
  1268            self.port = 443 if originate_tls else 80
  1269
  1270        self.hostname_port = f"{self.hostname}:{self.port}"

View as plain text