...

Source file src/github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/genswagger/generator.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/genswagger

     1  package genswagger
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"path/filepath"
     9  	"reflect"
    10  	"strings"
    11  
    12  	"github.com/golang/glog"
    13  	pbdescriptor "github.com/golang/protobuf/descriptor"
    14  	"github.com/golang/protobuf/proto"
    15  	protocdescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
    16  	plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
    17  	"github.com/golang/protobuf/ptypes/any"
    18  	"github.com/grpc-ecosystem/grpc-gateway/internal"
    19  	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
    20  	gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
    21  	swagger_options "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-swagger/options"
    22  )
    23  
    24  var (
    25  	errNoTargetService = errors.New("no target service defined in the file")
    26  )
    27  
    28  type generator struct {
    29  	reg *descriptor.Registry
    30  }
    31  
    32  type wrapper struct {
    33  	fileName string
    34  	swagger  *swaggerObject
    35  }
    36  
    37  // New returns a new generator which generates grpc gateway files.
    38  func New(reg *descriptor.Registry) gen.Generator {
    39  	return &generator{reg: reg}
    40  }
    41  
    42  // Merge a lot of swagger file (wrapper) to single one swagger file
    43  func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
    44  	var mergedTarget *wrapper
    45  	for _, f := range targets {
    46  		if mergedTarget == nil {
    47  			mergedTarget = &wrapper{
    48  				fileName: mergeFileName,
    49  				swagger:  f.swagger,
    50  			}
    51  		} else {
    52  			for k, v := range f.swagger.Definitions {
    53  				mergedTarget.swagger.Definitions[k] = v
    54  			}
    55  			for k, v := range f.swagger.Paths {
    56  				mergedTarget.swagger.Paths[k] = v
    57  			}
    58  			for k, v := range f.swagger.SecurityDefinitions {
    59  				mergedTarget.swagger.SecurityDefinitions[k] = v
    60  			}
    61  			mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
    62  		}
    63  	}
    64  	return mergedTarget
    65  }
    66  
    67  // Q: What's up with the alias types here?
    68  // A: We don't want to completely override how these structs are marshaled into
    69  //    JSON, we only want to add fields (see below, extensionMarshalJSON).
    70  //    An infinite recursion would happen if we'd call json.Marshal on the struct
    71  //    that has swaggerObject as an embedded field. To avoid that, we'll create
    72  //    type aliases, and those don't have the custom MarshalJSON methods defined
    73  //    on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever
    74  //    goes away, use
    75  //    https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/.
    76  func (so swaggerObject) MarshalJSON() ([]byte, error) {
    77  	type alias swaggerObject
    78  	return extensionMarshalJSON(alias(so), so.extensions)
    79  }
    80  
    81  func (so swaggerInfoObject) MarshalJSON() ([]byte, error) {
    82  	type alias swaggerInfoObject
    83  	return extensionMarshalJSON(alias(so), so.extensions)
    84  }
    85  
    86  func (so swaggerSecuritySchemeObject) MarshalJSON() ([]byte, error) {
    87  	type alias swaggerSecuritySchemeObject
    88  	return extensionMarshalJSON(alias(so), so.extensions)
    89  }
    90  
    91  func (so swaggerOperationObject) MarshalJSON() ([]byte, error) {
    92  	type alias swaggerOperationObject
    93  	return extensionMarshalJSON(alias(so), so.extensions)
    94  }
    95  
    96  func (so swaggerResponseObject) MarshalJSON() ([]byte, error) {
    97  	type alias swaggerResponseObject
    98  	return extensionMarshalJSON(alias(so), so.extensions)
    99  }
   100  
   101  func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) {
   102  	// To append arbitrary keys to the struct we'll render into json,
   103  	// we're creating another struct that embeds the original one, and
   104  	// its extra fields:
   105  	//
   106  	// The struct will look like
   107  	// struct {
   108  	//   *swaggerCore
   109  	//   XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"`
   110  	//   XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"`
   111  	// }
   112  	// and thus render into what we want -- the JSON of swaggerCore with the
   113  	// extensions appended.
   114  	fields := []reflect.StructField{
   115  		reflect.StructField{ // embedded
   116  			Name:      "Embedded",
   117  			Type:      reflect.TypeOf(so),
   118  			Anonymous: true,
   119  		},
   120  	}
   121  	for _, ext := range extensions {
   122  		fields = append(fields, reflect.StructField{
   123  			Name: fieldName(ext.key),
   124  			Type: reflect.TypeOf(ext.value),
   125  			Tag:  reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)),
   126  		})
   127  	}
   128  
   129  	t := reflect.StructOf(fields)
   130  	s := reflect.New(t).Elem()
   131  	s.Field(0).Set(reflect.ValueOf(so))
   132  	for _, ext := range extensions {
   133  		s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value))
   134  	}
   135  	return json.Marshal(s.Interface())
   136  }
   137  
   138  // encodeSwagger converts swagger file obj to plugin.CodeGeneratorResponse_File
   139  func encodeSwagger(file *wrapper) (*plugin.CodeGeneratorResponse_File, error) {
   140  	var formatted bytes.Buffer
   141  	enc := json.NewEncoder(&formatted)
   142  	enc.SetIndent("", "  ")
   143  	if err := enc.Encode(*file.swagger); err != nil {
   144  		return nil, err
   145  	}
   146  	name := file.fileName
   147  	ext := filepath.Ext(name)
   148  	base := strings.TrimSuffix(name, ext)
   149  	output := fmt.Sprintf("%s.swagger.json", base)
   150  	return &plugin.CodeGeneratorResponse_File{
   151  		Name:    proto.String(output),
   152  		Content: proto.String(formatted.String()),
   153  	}, nil
   154  }
   155  
   156  func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
   157  	var files []*plugin.CodeGeneratorResponse_File
   158  	if g.reg.IsAllowMerge() {
   159  		var mergedTarget *descriptor.File
   160  		// try to find proto leader
   161  		for _, f := range targets {
   162  			if proto.HasExtension(f.Options, swagger_options.E_Openapiv2Swagger) {
   163  				mergedTarget = f
   164  				break
   165  			}
   166  		}
   167  		// merge protos to leader
   168  		for _, f := range targets {
   169  			if mergedTarget == nil {
   170  				mergedTarget = f
   171  			} else if mergedTarget != f {
   172  				mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...)
   173  				mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...)
   174  				mergedTarget.Services = append(mergedTarget.Services, f.Services...)
   175  			}
   176  		}
   177  
   178  		targets = nil
   179  		targets = append(targets, mergedTarget)
   180  	}
   181  
   182  	var swaggers []*wrapper
   183  	for _, file := range targets {
   184  		glog.V(1).Infof("Processing %s", file.GetName())
   185  		swagger, err := applyTemplate(param{File: file, reg: g.reg})
   186  		if err == errNoTargetService {
   187  			glog.V(1).Infof("%s: %v", file.GetName(), err)
   188  			continue
   189  		}
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  		swaggers = append(swaggers, &wrapper{
   194  			fileName: file.GetName(),
   195  			swagger:  swagger,
   196  		})
   197  	}
   198  
   199  	if g.reg.IsAllowMerge() {
   200  		targetSwagger := mergeTargetFile(swaggers, g.reg.GetMergeFileName())
   201  		f, err := encodeSwagger(targetSwagger)
   202  		if err != nil {
   203  			return nil, fmt.Errorf("failed to encode swagger for %s: %s", g.reg.GetMergeFileName(), err)
   204  		}
   205  		files = append(files, f)
   206  		glog.V(1).Infof("New swagger file will emit")
   207  	} else {
   208  		for _, file := range swaggers {
   209  			f, err := encodeSwagger(file)
   210  			if err != nil {
   211  				return nil, fmt.Errorf("failed to encode swagger for %s: %s", file.fileName, err)
   212  			}
   213  			files = append(files, f)
   214  			glog.V(1).Infof("New swagger file will emit")
   215  		}
   216  	}
   217  	return files, nil
   218  }
   219  
   220  //AddStreamError Adds grpc.gateway.runtime.StreamError and google.protobuf.Any to registry for stream responses
   221  func AddStreamError(reg *descriptor.Registry) error {
   222  	//load internal protos
   223  	any := fileDescriptorProtoForMessage(&any.Any{})
   224  	streamError := fileDescriptorProtoForMessage(&internal.StreamError{})
   225  	if err := reg.Load(&plugin.CodeGeneratorRequest{
   226  		ProtoFile: []*protocdescriptor.FileDescriptorProto{
   227  			any,
   228  			streamError,
   229  		},
   230  	}); err != nil {
   231  		return err
   232  	}
   233  	return nil
   234  }
   235  
   236  func fileDescriptorProtoForMessage(msg pbdescriptor.Message) *protocdescriptor.FileDescriptorProto {
   237  	fdp, _ := pbdescriptor.ForMessage(msg)
   238  	fdp.SourceCodeInfo = &protocdescriptor.SourceCodeInfo{}
   239  	return fdp
   240  }
   241  

View as plain text