...
1# Copyright 2018 Datawire. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License
14
15# This is based on the blog post described here, but with a few
16# differences noted below:
17#
18# https://adambard.com/blog/implementing-multimethods-in-python/
19#
20# The differences are:
21#
22# 1. The "default default" action is to raise a TypeError rather than
23# to be a noop. This avoids many classes of silent failure.
24#
25# 2. The naming is a bit different:
26# a) the initial decorator is the same:
27#
28# @multi
29# def foo(...): ...
30#
31# b) rather than adding actions with @method(foo, case), this uses @foo.when(case)
32# c) rather than specifying the default action with @method(foo), this uses @foo.default
33#
34# 3. You can specify multiple cases at once, e.g.:
35#
36# @foo.when(a, b, c)
37# def foo(...): ...
38#
39# 4. If foo is a generator then the dispatch logic will check each key
40# yielded in turn until a result is found, this allows for more
41# flexible dispatch logic like this:
42#
43# @multi
44# def fib(x):
45# yield x # first dispatch on the value of x itself
46# yield type(x) # if there are no matches, then dispatch on the type of x
47#
48# @fib.when(0, 1)
49# def fib(x):
50# return x
51#
52# @fib.when(int)
53# def fib(x):
54# return fib(x-1) + fib(x-2)
55
56import functools
57import inspect
58
59
60def _error(multifun, keys, args, kwargs):
61 sargs = [repr(a) for a in args] + ["%s=%r" for k, v in kwargs.items()]
62 raise TypeError(
63 "no match found for multi function %s(%s): known keys %r, searched keys %r"
64 % (multifun.__name__, ", ".join(sargs), tuple(multifun.__multi__.keys()), tuple(keys))
65 )
66
67
68def multi(dispatch_fn):
69 gen = inspect.isgeneratorfunction(dispatch_fn)
70
71 if gen:
72
73 def multifun(*args, **kwargs):
74 for key in dispatch_fn(*args, **kwargs):
75 try:
76 action = multifun.__multi__[key]
77 break
78 except KeyError:
79 continue
80 else:
81 action = multifun.__multi_default__
82 return action(*args, **kwargs)
83
84 else:
85
86 def multifun(*args, **kwargs):
87 key = dispatch_fn(*args, **kwargs)
88 action = multifun.__multi__.get(key, multifun.__multi_default__)
89 return action(*args, **kwargs)
90
91 multifun.when = lambda *keys: _when(multifun, keys)
92 multifun.default = _default(multifun)
93 multifun.__multi__ = {}
94 # Default default
95 multifun.__multi_default__ = lambda *args, **kwargs: _error(
96 multifun,
97 dispatch_fn(*args, **kwargs) if gen else [dispatch_fn(*args, **kwargs)],
98 args,
99 kwargs,
100 )
101
102 functools.update_wrapper(multifun, dispatch_fn)
103 return multifun
104
105
106def _when(multifun, keys):
107 def apply_decorator(action):
108 for k in keys:
109 multifun.__multi__[k] = action
110 return multifun
111
112 return apply_decorator
113
114
115def _default(multifun):
116 def apply_decorator(action):
117 multifun.__multi_default__ = action
118 return multifun
119
120 return apply_decorator
View as plain text