1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
46 package embedcheck
47
48 import (
49 "fmt"
50 "os"
51
52 "github.com/gogo/protobuf/gogoproto"
53 "github.com/gogo/protobuf/protoc-gen-gogo/generator"
54 )
55
56 type plugin struct {
57 *generator.Generator
58 }
59
60 func NewPlugin() *plugin {
61 return &plugin{}
62 }
63
64 func (p *plugin) Name() string {
65 return "embedcheck"
66 }
67
68 func (p *plugin) Init(g *generator.Generator) {
69 p.Generator = g
70 }
71
72 var overwriters []map[string]gogoproto.EnableFunc = []map[string]gogoproto.EnableFunc{
73 {
74 "stringer": gogoproto.IsStringer,
75 },
76 {
77 "gostring": gogoproto.HasGoString,
78 },
79 {
80 "equal": gogoproto.HasEqual,
81 },
82 {
83 "verboseequal": gogoproto.HasVerboseEqual,
84 },
85 {
86 "size": gogoproto.IsSizer,
87 "protosizer": gogoproto.IsProtoSizer,
88 },
89 {
90 "unmarshaler": gogoproto.IsUnmarshaler,
91 "unsafe_unmarshaler": gogoproto.IsUnsafeUnmarshaler,
92 },
93 {
94 "marshaler": gogoproto.IsMarshaler,
95 "unsafe_marshaler": gogoproto.IsUnsafeMarshaler,
96 },
97 }
98
99 func (p *plugin) Generate(file *generator.FileDescriptor) {
100 for _, msg := range file.Messages() {
101 for _, os := range overwriters {
102 possible := true
103 for _, overwriter := range os {
104 if overwriter(file.FileDescriptorProto, msg.DescriptorProto) {
105 possible = false
106 }
107 }
108 if possible {
109 p.checkOverwrite(msg, os)
110 }
111 }
112 p.checkNameSpace(msg)
113 for _, field := range msg.GetField() {
114 if gogoproto.IsEmbed(field) && gogoproto.IsCustomName(field) {
115 fmt.Fprintf(os.Stderr, "ERROR: field %v with custom name %v cannot be embedded", *field.Name, gogoproto.GetCustomName(field))
116 os.Exit(1)
117 }
118 }
119 p.checkRepeated(msg)
120 }
121 for _, e := range file.GetExtension() {
122 if gogoproto.IsEmbed(e) {
123 fmt.Fprintf(os.Stderr, "ERROR: extended field %v cannot be embedded", generator.CamelCase(*e.Name))
124 os.Exit(1)
125 }
126 }
127 }
128
129 func (p *plugin) checkNameSpace(message *generator.Descriptor) map[string]bool {
130 ccTypeName := generator.CamelCaseSlice(message.TypeName())
131 names := make(map[string]bool)
132 for _, field := range message.Field {
133 fieldname := generator.CamelCase(*field.Name)
134 if field.IsMessage() && gogoproto.IsEmbed(field) {
135 desc := p.ObjectNamed(field.GetTypeName())
136 moreNames := p.checkNameSpace(desc.(*generator.Descriptor))
137 for another := range moreNames {
138 if names[another] {
139 fmt.Fprintf(os.Stderr, "ERROR: duplicate embedded fieldname %v in type %v\n", fieldname, ccTypeName)
140 os.Exit(1)
141 }
142 names[another] = true
143 }
144 } else {
145 if names[fieldname] {
146 fmt.Fprintf(os.Stderr, "ERROR: duplicate embedded fieldname %v in type %v\n", fieldname, ccTypeName)
147 os.Exit(1)
148 }
149 names[fieldname] = true
150 }
151 }
152 return names
153 }
154
155 func (p *plugin) checkOverwrite(message *generator.Descriptor, enablers map[string]gogoproto.EnableFunc) {
156 ccTypeName := generator.CamelCaseSlice(message.TypeName())
157 names := []string{}
158 for name := range enablers {
159 names = append(names, name)
160 }
161 for _, field := range message.Field {
162 if field.IsMessage() && gogoproto.IsEmbed(field) {
163 fieldname := generator.CamelCase(*field.Name)
164 desc := p.ObjectNamed(field.GetTypeName())
165 msg := desc.(*generator.Descriptor)
166 for errStr, enabled := range enablers {
167 if enabled(msg.File().FileDescriptorProto, msg.DescriptorProto) {
168 fmt.Fprintf(os.Stderr, "WARNING: found non-%v %v with embedded %v %v\n", names, ccTypeName, errStr, fieldname)
169 }
170 }
171 p.checkOverwrite(msg, enablers)
172 }
173 }
174 }
175
176 func (p *plugin) checkRepeated(message *generator.Descriptor) {
177 ccTypeName := generator.CamelCaseSlice(message.TypeName())
178 for _, field := range message.Field {
179 if !gogoproto.IsEmbed(field) {
180 continue
181 }
182 if field.IsBytes() {
183 fieldname := generator.CamelCase(*field.Name)
184 fmt.Fprintf(os.Stderr, "ERROR: found embedded bytes field %s in message %s\n", fieldname, ccTypeName)
185 os.Exit(1)
186 }
187 if !field.IsRepeated() {
188 continue
189 }
190 fieldname := generator.CamelCase(*field.Name)
191 fmt.Fprintf(os.Stderr, "ERROR: found repeated embedded field %s in message %s\n", fieldname, ccTypeName)
192 os.Exit(1)
193 }
194 }
195
196 func (p *plugin) GenerateImports(*generator.FileDescriptor) {}
197
198 func init() {
199 generator.RegisterPlugin(NewPlugin())
200 }
201
View as plain text