1
2
3
4
5 package proto
6
7 import (
8 "errors"
9 "fmt"
10
11 "google.golang.org/protobuf/encoding/prototext"
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/runtime/protoimpl"
14 )
15
16 const (
17 WireVarint = 0
18 WireFixed32 = 5
19 WireFixed64 = 1
20 WireBytes = 2
21 WireStartGroup = 3
22 WireEndGroup = 4
23 )
24
25
26 func EncodeVarint(v uint64) []byte {
27 return protowire.AppendVarint(nil, v)
28 }
29
30
31
32 func SizeVarint(v uint64) int {
33 return protowire.SizeVarint(v)
34 }
35
36
37
38
39 func DecodeVarint(b []byte) (uint64, int) {
40 v, n := protowire.ConsumeVarint(b)
41 if n < 0 {
42 return 0, 0
43 }
44 return v, n
45 }
46
47
48
49 type Buffer struct {
50 buf []byte
51 idx int
52 deterministic bool
53 }
54
55
56
57 func NewBuffer(buf []byte) *Buffer {
58 return &Buffer{buf: buf}
59 }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 func (b *Buffer) SetDeterministic(deterministic bool) {
82 b.deterministic = deterministic
83 }
84
85
86
87 func (b *Buffer) SetBuf(buf []byte) {
88 b.buf = buf
89 b.idx = 0
90 }
91
92
93 func (b *Buffer) Reset() {
94 b.buf = b.buf[:0]
95 b.idx = 0
96 }
97
98
99 func (b *Buffer) Bytes() []byte {
100 return b.buf
101 }
102
103
104 func (b *Buffer) Unread() []byte {
105 return b.buf[b.idx:]
106 }
107
108
109 func (b *Buffer) Marshal(m Message) error {
110 var err error
111 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
112 return err
113 }
114
115
116
117
118 func (b *Buffer) Unmarshal(m Message) error {
119 err := UnmarshalMerge(b.Unread(), m)
120 b.idx = len(b.buf)
121 return err
122 }
123
124 type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields }
125
126 func (m *unknownFields) String() string { panic("not implemented") }
127 func (m *unknownFields) Reset() { panic("not implemented") }
128 func (m *unknownFields) ProtoMessage() { panic("not implemented") }
129
130
131
132 func (*Buffer) DebugPrint(s string, b []byte) {
133 m := MessageReflect(new(unknownFields))
134 m.SetUnknown(b)
135 b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface())
136 fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s)
137 }
138
139
140 func (b *Buffer) EncodeVarint(v uint64) error {
141 b.buf = protowire.AppendVarint(b.buf, v)
142 return nil
143 }
144
145
146 func (b *Buffer) EncodeZigzag32(v uint64) error {
147 return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31))))
148 }
149
150
151 func (b *Buffer) EncodeZigzag64(v uint64) error {
152 return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63))))
153 }
154
155
156 func (b *Buffer) EncodeFixed32(v uint64) error {
157 b.buf = protowire.AppendFixed32(b.buf, uint32(v))
158 return nil
159 }
160
161
162 func (b *Buffer) EncodeFixed64(v uint64) error {
163 b.buf = protowire.AppendFixed64(b.buf, uint64(v))
164 return nil
165 }
166
167
168 func (b *Buffer) EncodeRawBytes(v []byte) error {
169 b.buf = protowire.AppendBytes(b.buf, v)
170 return nil
171 }
172
173
174
175 func (b *Buffer) EncodeStringBytes(v string) error {
176 b.buf = protowire.AppendString(b.buf, v)
177 return nil
178 }
179
180
181 func (b *Buffer) EncodeMessage(m Message) error {
182 var err error
183 b.buf = protowire.AppendVarint(b.buf, uint64(Size(m)))
184 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
185 return err
186 }
187
188
189 func (b *Buffer) DecodeVarint() (uint64, error) {
190 v, n := protowire.ConsumeVarint(b.buf[b.idx:])
191 if n < 0 {
192 return 0, protowire.ParseError(n)
193 }
194 b.idx += n
195 return uint64(v), nil
196 }
197
198
199 func (b *Buffer) DecodeZigzag32() (uint64, error) {
200 v, err := b.DecodeVarint()
201 if err != nil {
202 return 0, err
203 }
204 return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil
205 }
206
207
208 func (b *Buffer) DecodeZigzag64() (uint64, error) {
209 v, err := b.DecodeVarint()
210 if err != nil {
211 return 0, err
212 }
213 return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil
214 }
215
216
217 func (b *Buffer) DecodeFixed32() (uint64, error) {
218 v, n := protowire.ConsumeFixed32(b.buf[b.idx:])
219 if n < 0 {
220 return 0, protowire.ParseError(n)
221 }
222 b.idx += n
223 return uint64(v), nil
224 }
225
226
227 func (b *Buffer) DecodeFixed64() (uint64, error) {
228 v, n := protowire.ConsumeFixed64(b.buf[b.idx:])
229 if n < 0 {
230 return 0, protowire.ParseError(n)
231 }
232 b.idx += n
233 return uint64(v), nil
234 }
235
236
237
238
239 func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) {
240 v, n := protowire.ConsumeBytes(b.buf[b.idx:])
241 if n < 0 {
242 return nil, protowire.ParseError(n)
243 }
244 b.idx += n
245 if alloc {
246 v = append([]byte(nil), v...)
247 }
248 return v, nil
249 }
250
251
252
253 func (b *Buffer) DecodeStringBytes() (string, error) {
254 v, n := protowire.ConsumeString(b.buf[b.idx:])
255 if n < 0 {
256 return "", protowire.ParseError(n)
257 }
258 b.idx += n
259 return v, nil
260 }
261
262
263
264 func (b *Buffer) DecodeMessage(m Message) error {
265 v, err := b.DecodeRawBytes(false)
266 if err != nil {
267 return err
268 }
269 return UnmarshalMerge(v, m)
270 }
271
272
273
274
275
276 func (b *Buffer) DecodeGroup(m Message) error {
277 v, n, err := consumeGroup(b.buf[b.idx:])
278 if err != nil {
279 return err
280 }
281 b.idx += n
282 return UnmarshalMerge(v, m)
283 }
284
285
286
287
288 func consumeGroup(b []byte) ([]byte, int, error) {
289 b0 := b
290 depth := 1
291 for {
292 _, wtyp, tagLen := protowire.ConsumeTag(b)
293 if tagLen < 0 {
294 return nil, 0, protowire.ParseError(tagLen)
295 }
296 b = b[tagLen:]
297
298 var valLen int
299 switch wtyp {
300 case protowire.VarintType:
301 _, valLen = protowire.ConsumeVarint(b)
302 case protowire.Fixed32Type:
303 _, valLen = protowire.ConsumeFixed32(b)
304 case protowire.Fixed64Type:
305 _, valLen = protowire.ConsumeFixed64(b)
306 case protowire.BytesType:
307 _, valLen = protowire.ConsumeBytes(b)
308 case protowire.StartGroupType:
309 depth++
310 case protowire.EndGroupType:
311 depth--
312 default:
313 return nil, 0, errors.New("proto: cannot parse reserved wire type")
314 }
315 if valLen < 0 {
316 return nil, 0, protowire.ParseError(valLen)
317 }
318 b = b[valLen:]
319
320 if depth == 0 {
321 return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil
322 }
323 }
324 }
325
View as plain text