1
2
3
4
5 package proto
6
7 import (
8 "errors"
9 "fmt"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/encoding/messageset"
13 "google.golang.org/protobuf/internal/order"
14 "google.golang.org/protobuf/internal/pragma"
15 "google.golang.org/protobuf/reflect/protoreflect"
16 "google.golang.org/protobuf/runtime/protoiface"
17
18 protoerrors "google.golang.org/protobuf/internal/errors"
19 )
20
21
22
23
24
25
26 type MarshalOptions struct {
27 pragma.NoUnkeyedLiterals
28
29
30
31
32 AllowPartial bool
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55 Deterministic bool
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 UseCachedSize bool
76 }
77
78
79
80
81
82 func (o MarshalOptions) flags() protoiface.MarshalInputFlags {
83 var flags protoiface.MarshalInputFlags
84
85
86
87
88 if o.Deterministic {
89 flags |= protoiface.MarshalDeterministic
90 }
91
92 if o.UseCachedSize {
93 flags |= protoiface.MarshalUseCachedSize
94 }
95
96 return flags
97 }
98
99
100
101
102
103
104 func Marshal(m Message) ([]byte, error) {
105
106 if m == nil {
107 return nil, nil
108 }
109
110 out, err := MarshalOptions{}.marshal(nil, m.ProtoReflect())
111 if len(out.Buf) == 0 && err == nil {
112 out.Buf = emptyBytesForMessage(m)
113 }
114 return out.Buf, err
115 }
116
117
118 func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
119
120 if m == nil {
121 return nil, nil
122 }
123
124 out, err := o.marshal(nil, m.ProtoReflect())
125 if len(out.Buf) == 0 && err == nil {
126 out.Buf = emptyBytesForMessage(m)
127 }
128 return out.Buf, err
129 }
130
131
132
133
134
135
136
137
138
139
140 func emptyBytesForMessage(m Message) []byte {
141 if m == nil || !m.ProtoReflect().IsValid() {
142 return nil
143 }
144 return emptyBuf[:]
145 }
146
147
148
149
150
151
152 func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
153
154 if m == nil {
155 return b, nil
156 }
157
158 out, err := o.marshal(b, m.ProtoReflect())
159 return out.Buf, err
160 }
161
162
163
164
165
166 func (o MarshalOptions) MarshalState(in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
167 return o.marshal(in.Buf, in.Message)
168 }
169
170
171
172
173 func (o MarshalOptions) marshal(b []byte, m protoreflect.Message) (out protoiface.MarshalOutput, err error) {
174 allowPartial := o.AllowPartial
175 o.AllowPartial = true
176 if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
177 !(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
178 in := protoiface.MarshalInput{
179 Message: m,
180 Buf: b,
181 Flags: o.flags(),
182 }
183 if methods.Size != nil {
184 sout := methods.Size(protoiface.SizeInput{
185 Message: m,
186 Flags: in.Flags,
187 })
188 if cap(b) < len(b)+sout.Size {
189 in.Buf = make([]byte, len(b), growcap(cap(b), len(b)+sout.Size))
190 copy(in.Buf, b)
191 }
192 in.Flags |= protoiface.MarshalUseCachedSize
193 }
194 out, err = methods.Marshal(in)
195 } else {
196 out.Buf, err = o.marshalMessageSlow(b, m)
197 }
198 if err != nil {
199 var mismatch *protoerrors.SizeMismatchError
200 if errors.As(err, &mismatch) {
201 return out, fmt.Errorf("marshaling %s: %v", string(m.Descriptor().FullName()), err)
202 }
203 return out, err
204 }
205 if allowPartial {
206 return out, nil
207 }
208 return out, checkInitialized(m)
209 }
210
211 func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
212 out, err := o.marshal(b, m)
213 return out.Buf, err
214 }
215
216
217
218
219
220
221
222 func growcap(oldcap, wantcap int) (newcap int) {
223 if wantcap > oldcap*2 {
224 newcap = wantcap
225 } else if oldcap < 1024 {
226
227
228
229 newcap = oldcap * 2
230 } else {
231 newcap = oldcap
232 for 0 < newcap && newcap < wantcap {
233 newcap += newcap / 4
234 }
235 if newcap <= 0 {
236 newcap = wantcap
237 }
238 }
239 return newcap
240 }
241
242 func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) {
243 if messageset.IsMessageSet(m.Descriptor()) {
244 return o.marshalMessageSet(b, m)
245 }
246 fieldOrder := order.AnyFieldOrder
247 if o.Deterministic {
248
249
250
251 fieldOrder = order.LegacyFieldOrder
252 }
253 var err error
254 order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
255 b, err = o.marshalField(b, fd, v)
256 return err == nil
257 })
258 if err != nil {
259 return b, err
260 }
261 b = append(b, m.GetUnknown()...)
262 return b, nil
263 }
264
265 func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
266 switch {
267 case fd.IsList():
268 return o.marshalList(b, fd, value.List())
269 case fd.IsMap():
270 return o.marshalMap(b, fd, value.Map())
271 default:
272 b = protowire.AppendTag(b, fd.Number(), wireTypes[fd.Kind()])
273 return o.marshalSingular(b, fd, value)
274 }
275 }
276
277 func (o MarshalOptions) marshalList(b []byte, fd protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
278 if fd.IsPacked() && list.Len() > 0 {
279 b = protowire.AppendTag(b, fd.Number(), protowire.BytesType)
280 b, pos := appendSpeculativeLength(b)
281 for i, llen := 0, list.Len(); i < llen; i++ {
282 var err error
283 b, err = o.marshalSingular(b, fd, list.Get(i))
284 if err != nil {
285 return b, err
286 }
287 }
288 b = finishSpeculativeLength(b, pos)
289 return b, nil
290 }
291
292 kind := fd.Kind()
293 for i, llen := 0, list.Len(); i < llen; i++ {
294 var err error
295 b = protowire.AppendTag(b, fd.Number(), wireTypes[kind])
296 b, err = o.marshalSingular(b, fd, list.Get(i))
297 if err != nil {
298 return b, err
299 }
300 }
301 return b, nil
302 }
303
304 func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
305 keyf := fd.MapKey()
306 valf := fd.MapValue()
307 keyOrder := order.AnyKeyOrder
308 if o.Deterministic {
309 keyOrder = order.GenericKeyOrder
310 }
311 var err error
312 order.RangeEntries(mapv, keyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool {
313 b = protowire.AppendTag(b, fd.Number(), protowire.BytesType)
314 var pos int
315 b, pos = appendSpeculativeLength(b)
316
317 b, err = o.marshalField(b, keyf, key.Value())
318 if err != nil {
319 return false
320 }
321 b, err = o.marshalField(b, valf, value)
322 if err != nil {
323 return false
324 }
325 b = finishSpeculativeLength(b, pos)
326 return true
327 })
328 return b, err
329 }
330
331
332
333
334 const speculativeLength = 1
335
336 func appendSpeculativeLength(b []byte) ([]byte, int) {
337 pos := len(b)
338 b = append(b, "\x00\x00\x00\x00"[:speculativeLength]...)
339 return b, pos
340 }
341
342 func finishSpeculativeLength(b []byte, pos int) []byte {
343 mlen := len(b) - pos - speculativeLength
344 msiz := protowire.SizeVarint(uint64(mlen))
345 if msiz != speculativeLength {
346 for i := 0; i < msiz-speculativeLength; i++ {
347 b = append(b, 0)
348 }
349 copy(b[pos+msiz:], b[pos+speculativeLength:])
350 b = b[:pos+msiz+mlen]
351 }
352 protowire.AppendVarint(b[:pos], uint64(mlen))
353 return b
354 }
355
View as plain text