1from enum import Enum, auto
2from io import StringIO
3from typing import Any, Callable, Mapping, Sequence, Type
4
5from yaml import (
6 MappingNode,
7 Node,
8 ScalarNode,
9 SequenceNode,
10 add_representer,
11 compose,
12 compose_all,
13 dump_all,
14)
15
16
17class ViewMode(Enum):
18 PYTHON = auto()
19 STRING = auto()
20 NODE = auto()
21
22
23class Tag(Enum):
24 SEQUENCE = compose("[]").tag
25 MAPPING = compose("{}").tag
26 STRING = compose("hello").tag
27 INT = compose("3").tag
28 FLOAT = compose("3.14159265359").tag
29 BOOL = compose("true").tag
30 NULL = compose("null").tag
31
32
33class View:
34 def __init__(self, node: Node, mode: ViewMode) -> None:
35 self.node = node
36 self.mode = mode
37
38 @property
39 def tag(self):
40 return Tag(self.node.tag)
41
42 def view(self, obj):
43 return view(obj, self.mode)
44
45 def mode_ify(self):
46 return self
47
48
49class MappingView(View, Mapping):
50 def get(self, key, default=None):
51 for k, v in self.node.value:
52 if k.value == key:
53 return self.view(v)
54 return default
55
56 def __contains__(self, key):
57 for k, v in self.node.value:
58 if k.value == key:
59 return True
60 return False
61
62 def __getitem__(self, key):
63 for k, v in self.node.value:
64 if k.value == key:
65 return self.view(v)
66 raise KeyError(key)
67
68 def __setitem__(self, key, value):
69 for idx, (k, v) in enumerate(self.node.value):
70 if k.value == key:
71 self.node.value[idx] = (node(key), node(value))
72 break
73 else:
74 self.node.value.append((node(key), node(value)))
75
76 def update(self, other):
77 for k, v in other.items():
78 self[k] = v
79
80 def merge(self, other):
81 self.node.value.extend(other.node.value)
82
83 def keys(self):
84 return set(k.value for k, v in self.node.value)
85
86 def items(self):
87 for k, v in self.node.value:
88 yield (self.view(k), self.view(v))
89
90 def __iter__(self):
91 for k, v in self.node.value:
92 yield self.view(k)
93
94 def __len__(self):
95 return len(self.node.value)
96
97 def __repr__(self):
98 return "{%s}" % ", ".join(
99 "%r: %r" % (view(k, ViewMode.PYTHON), view(v, ViewMode.PYTHON))
100 for k, v in self.node.value
101 )
102
103
104class SequenceView(View, Sequence):
105 def __getitem__(self, idx):
106 return view(self.node.value[idx], self.mode)
107
108 def __setitem__(self, idx, value):
109 self.node.value[idx] = node(value)
110
111 def append(self, value):
112 self.node.value.append(node(value))
113
114 def __len__(self):
115 return len(self.node.value)
116
117 def __iter__(self):
118 for i in self.node.value:
119 yield self.view(i)
120
121 def extend(self, items):
122 for i in items:
123 self.append(i)
124
125 def merge(self, other):
126 self.node.value.extend(other.node.value)
127
128 def __repr__(self):
129 return repr([v for v in self])
130
131
132PYJECTIONS = {
133 Tag.INT: lambda x: int(x),
134 Tag.FLOAT: lambda x: float(x),
135 Tag.STRING: lambda x: x,
136 Tag.BOOL: lambda x: x.lower() in ("y", "yes", "true", "on"),
137 Tag.NULL: lambda x: None,
138}
139
140
141class ScalarView(View):
142 def mode_ify(self):
143 if self.mode == ViewMode.PYTHON:
144 return PYJECTIONS[Tag(self.tag)](self.node.value)
145 elif self.mode == ViewMode.STRING:
146 return self.node.value
147 else:
148 return self
149
150 def __repr__(self):
151 return self.node.value
152
153
154VIEWS: Mapping[Type[Node], Type[View]] = {
155 MappingNode: MappingView,
156 SequenceNode: SequenceView,
157 ScalarNode: ScalarView,
158}
159
160
161def view(value: Any, mode: ViewMode) -> Any:
162 nd = node(value)
163 return VIEWS[type(nd)](nd, mode).mode_ify()
164
165
166COERCIONS: Mapping[Type, Callable[[Any], Node]] = {
167 MappingNode: lambda n: n,
168 SequenceNode: lambda n: n,
169 ScalarNode: lambda n: n,
170 MappingView: lambda v: v.node,
171 SequenceView: lambda v: v.node,
172 ScalarView: lambda v: v.node,
173 list: lambda l: SequenceNode(Tag.SEQUENCE.value, [node(i) for i in l]),
174 tuple: lambda t: SequenceNode(Tag.SEQUENCE.value, [node(i) for i in t]),
175 str: lambda s: ScalarNode(Tag.STRING.value, str(s)),
176 bool: lambda b: ScalarNode(Tag.BOOL.value, str(b)),
177 int: lambda i: ScalarNode(Tag.INT.value, str(i)),
178 float: lambda f: ScalarNode(Tag.FLOAT.value, str(f)),
179 dict: lambda d: MappingNode(Tag.MAPPING.value, [(node(k), node(v)) for k, v in d.items()]),
180}
181
182
183def node(value: Any) -> Node:
184 return COERCIONS[type(value)](value)
185
186
187def load(name: str, value: Any, *allowed: Tag) -> SequenceView:
188 if isinstance(value, str):
189 value = StringIO(value)
190 value.name = name
191 result = view(SequenceNode(Tag.SEQUENCE.value, list(compose_all(value))), ViewMode.PYTHON)
192 for r in view(result, ViewMode.NODE):
193 if r.tag not in allowed:
194 raise ValueError(
195 "expecting %s, got %s" % (", ".join(t.name for t in allowed), r.node.tag)
196 )
197 return result
198
199
200def dump(value: SequenceView):
201 st = dump_all(value, default_flow_style=False)
202 if not st.startswith("---"):
203 st = "---\n" + st
204 return st
205
206
207def view_representer(dumper, data):
208 return data.node
209
210
211add_representer(SequenceView, view_representer)
212add_representer(MappingView, view_representer)
213add_representer(ScalarView, view_representer)
View as plain text