...
1 package fieldmask
2
3 import (
4 "fmt"
5 "strings"
6
7 "google.golang.org/protobuf/proto"
8 "google.golang.org/protobuf/reflect/protoreflect"
9 "google.golang.org/protobuf/types/known/fieldmaskpb"
10 )
11
12
13
14
15
16
17
18
19
20
21
22 func Update(mask *fieldmaskpb.FieldMask, dst, src proto.Message) {
23 dstReflect := dst.ProtoReflect()
24 srcReflect := src.ProtoReflect()
25 if dstReflect.Descriptor() != srcReflect.Descriptor() {
26 panic(fmt.Sprintf(
27 "dst (%s) and src (%s) messages have different types",
28 dstReflect.Descriptor().FullName(),
29 srcReflect.Descriptor().FullName(),
30 ))
31 }
32 switch {
33
34
35 case len(mask.GetPaths()) == 0:
36 updateWireSetFields(dstReflect, srcReflect)
37
38
39 case IsFullReplacement(mask):
40 proto.Reset(dst)
41 proto.Merge(dst, src)
42 default:
43 for _, path := range mask.GetPaths() {
44 segments := strings.Split(path, ".")
45 updateNamedField(dstReflect, srcReflect, segments)
46 }
47 }
48 }
49
50 func updateWireSetFields(dst, src protoreflect.Message) {
51 src.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
52 switch {
53 case field.IsList():
54 dst.Set(field, value)
55 case field.IsMap():
56 dst.Set(field, value)
57 case field.Message() != nil && !dst.Has(field):
58 dst.Set(field, value)
59 case field.Message() != nil:
60 updateWireSetFields(dst.Get(field).Message(), value.Message())
61 default:
62 dst.Set(field, value)
63 }
64 return true
65 })
66 }
67
68 func updateNamedField(dst, src protoreflect.Message, segments []string) {
69 if len(segments) == 0 {
70 return
71 }
72 field := src.Descriptor().Fields().ByName(protoreflect.Name(segments[0]))
73 if field == nil {
74
75 return
76 }
77
78 if len(segments) == 1 {
79 if !src.Has(field) {
80 dst.Clear(field)
81 } else {
82 dst.Set(field, src.Get(field))
83 }
84 return
85 }
86
87
88 switch {
89 case field.IsList(), field.IsMap():
90
91 return
92 case field.Message() != nil:
93
94 if !dst.Has(field) {
95 dst.Set(field, dst.NewField(field))
96 }
97 if !src.Has(field) {
98 src.Set(field, src.NewField(field))
99 }
100 updateNamedField(dst.Get(field).Message(), src.Get(field).Message(), segments[1:])
101 default:
102 return
103 }
104 }
105
View as plain text