1 package runtime
2
3 import (
4 "errors"
5 "fmt"
6 "net/url"
7 "regexp"
8 "strconv"
9 "strings"
10 "time"
11
12 "github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
13 "google.golang.org/grpc/grpclog"
14 "google.golang.org/protobuf/encoding/protojson"
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/reflect/protoreflect"
17 "google.golang.org/protobuf/reflect/protoregistry"
18 "google.golang.org/protobuf/types/known/durationpb"
19 field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
20 "google.golang.org/protobuf/types/known/structpb"
21 "google.golang.org/protobuf/types/known/timestamppb"
22 "google.golang.org/protobuf/types/known/wrapperspb"
23 )
24
25 var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
26
27 var currentQueryParser QueryParameterParser = &DefaultQueryParser{}
28
29
30 type QueryParameterParser interface {
31 Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
32 }
33
34
35
36 func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
37 return currentQueryParser.Parse(msg, values, filter)
38 }
39
40
41
42
43
44 type DefaultQueryParser struct{}
45
46
47
48 func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
49 for key, values := range values {
50 if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 {
51 key = match[1]
52 values = append([]string{match[2]}, values...)
53 }
54
55 msgValue := msg.ProtoReflect()
56 fieldPath := normalizeFieldPath(msgValue, strings.Split(key, "."))
57 if filter.HasCommonPrefix(fieldPath) {
58 continue
59 }
60 if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil {
61 return err
62 }
63 }
64 return nil
65 }
66
67
68 func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
69 fieldPath := strings.Split(fieldPathString, ".")
70 return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
71 }
72
73 func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string {
74 newFieldPath := make([]string, 0, len(fieldPath))
75 for i, fieldName := range fieldPath {
76 fields := msgValue.Descriptor().Fields()
77 fieldDesc := fields.ByTextName(fieldName)
78 if fieldDesc == nil {
79 fieldDesc = fields.ByJSONName(fieldName)
80 }
81 if fieldDesc == nil {
82
83 return fieldPath
84 }
85
86 newFieldPath = append(newFieldPath, string(fieldDesc.Name()))
87
88
89 if i == len(fieldPath)-1 {
90 break
91 }
92
93
94 if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated {
95 return fieldPath
96 }
97
98
99 msgValue = msgValue.Get(fieldDesc).Message()
100 }
101
102 return newFieldPath
103 }
104
105 func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
106 if len(fieldPath) < 1 {
107 return errors.New("no field path")
108 }
109 if len(values) < 1 {
110 return errors.New("no value provided")
111 }
112
113 var fieldDescriptor protoreflect.FieldDescriptor
114 for i, fieldName := range fieldPath {
115 fields := msgValue.Descriptor().Fields()
116
117
118 fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
119 if fieldDescriptor == nil {
120 fieldDescriptor = fields.ByJSONName(fieldName)
121 if fieldDescriptor == nil {
122
123
124 grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
125 return nil
126 }
127 }
128
129
130 if i == len(fieldPath)-1 {
131 break
132 }
133
134
135 if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
136 return fmt.Errorf("invalid path: %q is not a message", fieldName)
137 }
138
139
140 msgValue = msgValue.Mutable(fieldDescriptor).Message()
141 }
142
143
144 if of := fieldDescriptor.ContainingOneof(); of != nil {
145 if f := msgValue.WhichOneof(of); f != nil {
146 return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
147 }
148 }
149
150 switch {
151 case fieldDescriptor.IsList():
152 return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
153 case fieldDescriptor.IsMap():
154 return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
155 }
156
157 if len(values) > 1 {
158 return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
159 }
160
161 return populateField(fieldDescriptor, msgValue, values[0])
162 }
163
164 func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
165 v, err := parseField(fieldDescriptor, value)
166 if err != nil {
167 return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
168 }
169
170 msgValue.Set(fieldDescriptor, v)
171 return nil
172 }
173
174 func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
175 for _, value := range values {
176 v, err := parseField(fieldDescriptor, value)
177 if err != nil {
178 return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
179 }
180 list.Append(v)
181 }
182
183 return nil
184 }
185
186 func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
187 if len(values) != 2 {
188 return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
189 }
190
191 key, err := parseField(fieldDescriptor.MapKey(), values[0])
192 if err != nil {
193 return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
194 }
195
196 value, err := parseField(fieldDescriptor.MapValue(), values[1])
197 if err != nil {
198 return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
199 }
200
201 mp.Set(key.MapKey(), value)
202
203 return nil
204 }
205
206 func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
207 switch fieldDescriptor.Kind() {
208 case protoreflect.BoolKind:
209 v, err := strconv.ParseBool(value)
210 if err != nil {
211 return protoreflect.Value{}, err
212 }
213 return protoreflect.ValueOfBool(v), nil
214 case protoreflect.EnumKind:
215 enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
216 if err != nil {
217 if errors.Is(err, protoregistry.NotFound) {
218 return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
219 }
220 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
221 }
222
223 v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
224 if v == nil {
225 i, err := strconv.Atoi(value)
226 if err != nil {
227 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
228 }
229
230 if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil {
231 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
232 }
233 }
234 return protoreflect.ValueOfEnum(v.Number()), nil
235 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
236 v, err := strconv.ParseInt(value, 10, 32)
237 if err != nil {
238 return protoreflect.Value{}, err
239 }
240 return protoreflect.ValueOfInt32(int32(v)), nil
241 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
242 v, err := strconv.ParseInt(value, 10, 64)
243 if err != nil {
244 return protoreflect.Value{}, err
245 }
246 return protoreflect.ValueOfInt64(v), nil
247 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
248 v, err := strconv.ParseUint(value, 10, 32)
249 if err != nil {
250 return protoreflect.Value{}, err
251 }
252 return protoreflect.ValueOfUint32(uint32(v)), nil
253 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
254 v, err := strconv.ParseUint(value, 10, 64)
255 if err != nil {
256 return protoreflect.Value{}, err
257 }
258 return protoreflect.ValueOfUint64(v), nil
259 case protoreflect.FloatKind:
260 v, err := strconv.ParseFloat(value, 32)
261 if err != nil {
262 return protoreflect.Value{}, err
263 }
264 return protoreflect.ValueOfFloat32(float32(v)), nil
265 case protoreflect.DoubleKind:
266 v, err := strconv.ParseFloat(value, 64)
267 if err != nil {
268 return protoreflect.Value{}, err
269 }
270 return protoreflect.ValueOfFloat64(v), nil
271 case protoreflect.StringKind:
272 return protoreflect.ValueOfString(value), nil
273 case protoreflect.BytesKind:
274 v, err := Bytes(value)
275 if err != nil {
276 return protoreflect.Value{}, err
277 }
278 return protoreflect.ValueOfBytes(v), nil
279 case protoreflect.MessageKind, protoreflect.GroupKind:
280 return parseMessage(fieldDescriptor.Message(), value)
281 default:
282 panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
283 }
284 }
285
286 func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
287 var msg proto.Message
288 switch msgDescriptor.FullName() {
289 case "google.protobuf.Timestamp":
290 t, err := time.Parse(time.RFC3339Nano, value)
291 if err != nil {
292 return protoreflect.Value{}, err
293 }
294 msg = timestamppb.New(t)
295 case "google.protobuf.Duration":
296 d, err := time.ParseDuration(value)
297 if err != nil {
298 return protoreflect.Value{}, err
299 }
300 msg = durationpb.New(d)
301 case "google.protobuf.DoubleValue":
302 v, err := strconv.ParseFloat(value, 64)
303 if err != nil {
304 return protoreflect.Value{}, err
305 }
306 msg = wrapperspb.Double(v)
307 case "google.protobuf.FloatValue":
308 v, err := strconv.ParseFloat(value, 32)
309 if err != nil {
310 return protoreflect.Value{}, err
311 }
312 msg = wrapperspb.Float(float32(v))
313 case "google.protobuf.Int64Value":
314 v, err := strconv.ParseInt(value, 10, 64)
315 if err != nil {
316 return protoreflect.Value{}, err
317 }
318 msg = wrapperspb.Int64(v)
319 case "google.protobuf.Int32Value":
320 v, err := strconv.ParseInt(value, 10, 32)
321 if err != nil {
322 return protoreflect.Value{}, err
323 }
324 msg = wrapperspb.Int32(int32(v))
325 case "google.protobuf.UInt64Value":
326 v, err := strconv.ParseUint(value, 10, 64)
327 if err != nil {
328 return protoreflect.Value{}, err
329 }
330 msg = wrapperspb.UInt64(v)
331 case "google.protobuf.UInt32Value":
332 v, err := strconv.ParseUint(value, 10, 32)
333 if err != nil {
334 return protoreflect.Value{}, err
335 }
336 msg = wrapperspb.UInt32(uint32(v))
337 case "google.protobuf.BoolValue":
338 v, err := strconv.ParseBool(value)
339 if err != nil {
340 return protoreflect.Value{}, err
341 }
342 msg = wrapperspb.Bool(v)
343 case "google.protobuf.StringValue":
344 msg = wrapperspb.String(value)
345 case "google.protobuf.BytesValue":
346 v, err := Bytes(value)
347 if err != nil {
348 return protoreflect.Value{}, err
349 }
350 msg = wrapperspb.Bytes(v)
351 case "google.protobuf.FieldMask":
352 fm := &field_mask.FieldMask{}
353 fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
354 msg = fm
355 case "google.protobuf.Value":
356 var v structpb.Value
357 if err := protojson.Unmarshal([]byte(value), &v); err != nil {
358 return protoreflect.Value{}, err
359 }
360 msg = &v
361 case "google.protobuf.Struct":
362 var v structpb.Struct
363 if err := protojson.Unmarshal([]byte(value), &v); err != nil {
364 return protoreflect.Value{}, err
365 }
366 msg = &v
367 default:
368 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
369 }
370
371 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
372 }
373
View as plain text