1 package genswagger
2
3 import (
4 "bytes"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "path/filepath"
9 "reflect"
10 "strings"
11
12 "github.com/golang/glog"
13 pbdescriptor "github.com/golang/protobuf/descriptor"
14 "github.com/golang/protobuf/proto"
15 protocdescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
16 plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
17 "github.com/golang/protobuf/ptypes/any"
18 "github.com/grpc-ecosystem/grpc-gateway/internal"
19 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
20 gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
21 swagger_options "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/options"
22 )
23
24 var (
25 errNoTargetService = errors.New("no target service defined in the file")
26 )
27
28 type generator struct {
29 reg *descriptor.Registry
30 }
31
32 type wrapper struct {
33 fileName string
34 swagger *swaggerObject
35 }
36
37
38 func New(reg *descriptor.Registry) gen.Generator {
39 return &generator{reg: reg}
40 }
41
42
43 func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
44 var mergedTarget *wrapper
45 for _, f := range targets {
46 if mergedTarget == nil {
47 mergedTarget = &wrapper{
48 fileName: mergeFileName,
49 swagger: f.swagger,
50 }
51 } else {
52 for k, v := range f.swagger.Definitions {
53 mergedTarget.swagger.Definitions[k] = v
54 }
55 for k, v := range f.swagger.Paths {
56 mergedTarget.swagger.Paths[k] = v
57 }
58 for k, v := range f.swagger.SecurityDefinitions {
59 mergedTarget.swagger.SecurityDefinitions[k] = v
60 }
61 mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
62 }
63 }
64 return mergedTarget
65 }
66
67
68
69
70
71
72
73
74
75
76 func (so swaggerObject) MarshalJSON() ([]byte, error) {
77 type alias swaggerObject
78 return extensionMarshalJSON(alias(so), so.extensions)
79 }
80
81 func (so swaggerInfoObject) MarshalJSON() ([]byte, error) {
82 type alias swaggerInfoObject
83 return extensionMarshalJSON(alias(so), so.extensions)
84 }
85
86 func (so swaggerSecuritySchemeObject) MarshalJSON() ([]byte, error) {
87 type alias swaggerSecuritySchemeObject
88 return extensionMarshalJSON(alias(so), so.extensions)
89 }
90
91 func (so swaggerOperationObject) MarshalJSON() ([]byte, error) {
92 type alias swaggerOperationObject
93 return extensionMarshalJSON(alias(so), so.extensions)
94 }
95
96 func (so swaggerResponseObject) MarshalJSON() ([]byte, error) {
97 type alias swaggerResponseObject
98 return extensionMarshalJSON(alias(so), so.extensions)
99 }
100
101 func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) {
102
103
104
105
106
107
108
109
110
111
112
113
114 fields := []reflect.StructField{
115 reflect.StructField{
116 Name: "Embedded",
117 Type: reflect.TypeOf(so),
118 Anonymous: true,
119 },
120 }
121 for _, ext := range extensions {
122 fields = append(fields, reflect.StructField{
123 Name: fieldName(ext.key),
124 Type: reflect.TypeOf(ext.value),
125 Tag: reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)),
126 })
127 }
128
129 t := reflect.StructOf(fields)
130 s := reflect.New(t).Elem()
131 s.Field(0).Set(reflect.ValueOf(so))
132 for _, ext := range extensions {
133 s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value))
134 }
135 return json.Marshal(s.Interface())
136 }
137
138
139 func encodeSwagger(file *wrapper) (*plugin.CodeGeneratorResponse_File, error) {
140 var formatted bytes.Buffer
141 enc := json.NewEncoder(&formatted)
142 enc.SetIndent("", " ")
143 if err := enc.Encode(*file.swagger); err != nil {
144 return nil, err
145 }
146 name := file.fileName
147 ext := filepath.Ext(name)
148 base := strings.TrimSuffix(name, ext)
149 output := fmt.Sprintf("%s.swagger.json", base)
150 return &plugin.CodeGeneratorResponse_File{
151 Name: proto.String(output),
152 Content: proto.String(formatted.String()),
153 }, nil
154 }
155
156 func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
157 var files []*plugin.CodeGeneratorResponse_File
158 if g.reg.IsAllowMerge() {
159 var mergedTarget *descriptor.File
160
161 for _, f := range targets {
162 if proto.HasExtension(f.Options, swagger_options.E_Openapiv2Swagger) {
163 mergedTarget = f
164 break
165 }
166 }
167
168 for _, f := range targets {
169 if mergedTarget == nil {
170 mergedTarget = f
171 } else if mergedTarget != f {
172 mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...)
173 mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...)
174 mergedTarget.Services = append(mergedTarget.Services, f.Services...)
175 }
176 }
177
178 targets = nil
179 targets = append(targets, mergedTarget)
180 }
181
182 var swaggers []*wrapper
183 for _, file := range targets {
184 glog.V(1).Infof("Processing %s", file.GetName())
185 swagger, err := applyTemplate(param{File: file, reg: g.reg})
186 if err == errNoTargetService {
187 glog.V(1).Infof("%s: %v", file.GetName(), err)
188 continue
189 }
190 if err != nil {
191 return nil, err
192 }
193 swaggers = append(swaggers, &wrapper{
194 fileName: file.GetName(),
195 swagger: swagger,
196 })
197 }
198
199 if g.reg.IsAllowMerge() {
200 targetSwagger := mergeTargetFile(swaggers, g.reg.GetMergeFileName())
201 f, err := encodeSwagger(targetSwagger)
202 if err != nil {
203 return nil, fmt.Errorf("failed to encode swagger for %s: %s", g.reg.GetMergeFileName(), err)
204 }
205 files = append(files, f)
206 glog.V(1).Infof("New swagger file will emit")
207 } else {
208 for _, file := range swaggers {
209 f, err := encodeSwagger(file)
210 if err != nil {
211 return nil, fmt.Errorf("failed to encode swagger for %s: %s", file.fileName, err)
212 }
213 files = append(files, f)
214 glog.V(1).Infof("New swagger file will emit")
215 }
216 }
217 return files, nil
218 }
219
220
221 func AddStreamError(reg *descriptor.Registry) error {
222
223 any := fileDescriptorProtoForMessage(&any.Any{})
224 streamError := fileDescriptorProtoForMessage(&internal.StreamError{})
225 if err := reg.Load(&plugin.CodeGeneratorRequest{
226 ProtoFile: []*protocdescriptor.FileDescriptorProto{
227 any,
228 streamError,
229 },
230 }); err != nil {
231 return err
232 }
233 return nil
234 }
235
236 func fileDescriptorProtoForMessage(msg pbdescriptor.Message) *protocdescriptor.FileDescriptorProto {
237 fdp, _ := pbdescriptor.ForMessage(msg)
238 fdp.SourceCodeInfo = &protocdescriptor.SourceCodeInfo{}
239 return fdp
240 }
241
View as plain text