...
1
2
3
4
5 package dynamicpb
6
7 import (
8 "fmt"
9 "strings"
10 "sync"
11 "sync/atomic"
12
13 "google.golang.org/protobuf/internal/errors"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 )
17
18 type extField struct {
19 name protoreflect.FullName
20 number protoreflect.FieldNumber
21 }
22
23
24
25
26
27
28 type Types struct {
29
30
31
32
33
34
35 atomicExtFiles uint64
36 extMu sync.Mutex
37
38 files *protoregistry.Files
39
40 extensionsByMessage map[extField]protoreflect.ExtensionDescriptor
41 }
42
43
44
45
46 func NewTypes(f *protoregistry.Files) *Types {
47 return &Types{
48 files: f,
49 }
50 }
51
52
53
54
55
56 func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) {
57 d, err := t.files.FindDescriptorByName(name)
58 if err != nil {
59 return nil, err
60 }
61 ed, ok := d.(protoreflect.EnumDescriptor)
62 if !ok {
63 return nil, errors.New("found wrong type: got %v, want enum", descName(d))
64 }
65 return NewEnumType(ed), nil
66 }
67
68
69
70
71
72
73
74 func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) {
75 d, err := t.files.FindDescriptorByName(name)
76 if err != nil {
77 return nil, err
78 }
79 xd, ok := d.(protoreflect.ExtensionDescriptor)
80 if !ok {
81 return nil, errors.New("found wrong type: got %v, want extension", descName(d))
82 }
83 return NewExtensionType(xd), nil
84 }
85
86
87
88
89
90 func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
91
92
93 if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) {
94 t.updateExtensions()
95 }
96 xd := t.extensionsByMessage[extField{message, field}]
97 if xd == nil {
98 return nil, protoregistry.NotFound
99 }
100 return NewExtensionType(xd), nil
101 }
102
103
104
105
106
107 func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) {
108 d, err := t.files.FindDescriptorByName(name)
109 if err != nil {
110 return nil, err
111 }
112 md, ok := d.(protoreflect.MessageDescriptor)
113 if !ok {
114 return nil, errors.New("found wrong type: got %v, want message", descName(d))
115 }
116 return NewMessageType(md), nil
117 }
118
119
120
121
122
123 func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
124
125
126 message := protoreflect.FullName(url)
127 if i := strings.LastIndexByte(url, '/'); i >= 0 {
128 message = message[i+len("/"):]
129 }
130 return t.FindMessageByName(message)
131 }
132
133 func (t *Types) updateExtensions() {
134 t.extMu.Lock()
135 defer t.extMu.Unlock()
136 if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) {
137 return
138 }
139 defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles()))
140 t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
141 t.registerExtensions(fd.Extensions())
142 t.registerExtensionsInMessages(fd.Messages())
143 return true
144 })
145 }
146
147 func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) {
148 count := mds.Len()
149 for i := 0; i < count; i++ {
150 md := mds.Get(i)
151 t.registerExtensions(md.Extensions())
152 t.registerExtensionsInMessages(md.Messages())
153 }
154 }
155
156 func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) {
157 count := xds.Len()
158 for i := 0; i < count; i++ {
159 xd := xds.Get(i)
160 field := xd.Number()
161 message := xd.ContainingMessage().FullName()
162 if t.extensionsByMessage == nil {
163 t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor)
164 }
165 t.extensionsByMessage[extField{message, field}] = xd
166 }
167 }
168
169 func descName(d protoreflect.Descriptor) string {
170 switch d.(type) {
171 case protoreflect.EnumDescriptor:
172 return "enum"
173 case protoreflect.EnumValueDescriptor:
174 return "enum value"
175 case protoreflect.MessageDescriptor:
176 return "message"
177 case protoreflect.ExtensionDescriptor:
178 return "extension"
179 case protoreflect.ServiceDescriptor:
180 return "service"
181 default:
182 return fmt.Sprintf("%T", d)
183 }
184 }
185
View as plain text