...

Source file src/github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/internal/gengateway/generator.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/internal/gengateway

     1  package gengateway
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/format"
     7  	"path"
     8  	"path/filepath"
     9  	"strings"
    10  
    11  	"github.com/golang/glog"
    12  	"github.com/golang/protobuf/proto"
    13  	plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
    14  	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
    15  	gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
    16  )
    17  
    18  var (
    19  	errNoTargetService = errors.New("no target service defined in the file")
    20  )
    21  
    22  type pathType int
    23  
    24  const (
    25  	pathTypeImport pathType = iota
    26  	pathTypeSourceRelative
    27  )
    28  
    29  type generator struct {
    30  	reg                *descriptor.Registry
    31  	baseImports        []descriptor.GoPackage
    32  	useRequestContext  bool
    33  	registerFuncSuffix string
    34  	pathType           pathType
    35  	modulePath         string
    36  	allowPatchFeature  bool
    37  }
    38  
    39  // New returns a new generator which generates grpc gateway files.
    40  func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString, modulePathString string, allowPatchFeature bool) gen.Generator {
    41  	var imports []descriptor.GoPackage
    42  	for _, pkgpath := range []string{
    43  		"context",
    44  		"io",
    45  		"net/http",
    46  		"github.com/grpc-ecosystem/grpc-gateway/runtime",
    47  		"github.com/grpc-ecosystem/grpc-gateway/utilities",
    48  		"github.com/golang/protobuf/descriptor",
    49  		"github.com/golang/protobuf/proto",
    50  		"google.golang.org/grpc",
    51  		"google.golang.org/grpc/codes",
    52  		"google.golang.org/grpc/grpclog",
    53  		"google.golang.org/grpc/metadata",
    54  		"google.golang.org/grpc/status",
    55  	} {
    56  		pkg := descriptor.GoPackage{
    57  			Path: pkgpath,
    58  			Name: path.Base(pkgpath),
    59  		}
    60  		if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
    61  			for i := 0; ; i++ {
    62  				alias := fmt.Sprintf("%s_%d", pkg.Name, i)
    63  				if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
    64  					continue
    65  				}
    66  				pkg.Alias = alias
    67  				break
    68  			}
    69  		}
    70  		imports = append(imports, pkg)
    71  	}
    72  
    73  	var pathType pathType
    74  	switch pathTypeString {
    75  	case "", "import":
    76  		// paths=import is default
    77  	case "source_relative":
    78  		pathType = pathTypeSourceRelative
    79  	default:
    80  		glog.Fatalf(`Unknown path type %q: want "import" or "source_relative".`, pathTypeString)
    81  	}
    82  
    83  	return &generator{
    84  		reg:                reg,
    85  		baseImports:        imports,
    86  		useRequestContext:  useRequestContext,
    87  		registerFuncSuffix: registerFuncSuffix,
    88  		pathType:           pathType,
    89  		modulePath:         modulePathString,
    90  		allowPatchFeature:  allowPatchFeature,
    91  	}
    92  }
    93  
    94  func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
    95  	var files []*plugin.CodeGeneratorResponse_File
    96  	for _, file := range targets {
    97  		glog.V(1).Infof("Processing %s", file.GetName())
    98  		code, err := g.generate(file)
    99  		if err == errNoTargetService {
   100  			glog.V(1).Infof("%s: %v", file.GetName(), err)
   101  			continue
   102  		}
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		formatted, err := format.Source([]byte(code))
   107  		if err != nil {
   108  			glog.Errorf("%v: %s", err, code)
   109  			return nil, err
   110  		}
   111  		name, err := g.getFilePath(file)
   112  		if err != nil {
   113  			glog.Errorf("%v: %s", err, code)
   114  			return nil, err
   115  		}
   116  		ext := filepath.Ext(name)
   117  		base := strings.TrimSuffix(name, ext)
   118  		output := fmt.Sprintf("%s.pb.gw.go", base)
   119  		files = append(files, &plugin.CodeGeneratorResponse_File{
   120  			Name:    proto.String(output),
   121  			Content: proto.String(string(formatted)),
   122  		})
   123  		glog.V(1).Infof("Will emit %s", output)
   124  	}
   125  	return files, nil
   126  }
   127  
   128  func (g *generator) getFilePath(file *descriptor.File) (string, error) {
   129  	name := file.GetName()
   130  	switch {
   131  	case g.modulePath != "" && g.pathType != pathTypeImport:
   132  		return "", errors.New("cannot use module= with paths=")
   133  
   134  	case g.modulePath != "":
   135  		trimPath, pkgPath := g.modulePath+"/", file.GoPkg.Path+"/"
   136  		if !strings.HasPrefix(pkgPath, trimPath) {
   137  			return "", fmt.Errorf("%v: file go path does not match module prefix: %v", file.GoPkg.Path, trimPath)
   138  		}
   139  		return filepath.Join(strings.TrimPrefix(pkgPath, trimPath), filepath.Base(name)), nil
   140  
   141  	case g.pathType == pathTypeImport && file.GoPkg.Path != "":
   142  		return fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)), nil
   143  
   144  	default:
   145  		return name, nil
   146  	}
   147  }
   148  
   149  func (g *generator) generate(file *descriptor.File) (string, error) {
   150  	pkgSeen := make(map[string]bool)
   151  	var imports []descriptor.GoPackage
   152  	for _, pkg := range g.baseImports {
   153  		pkgSeen[pkg.Path] = true
   154  		imports = append(imports, pkg)
   155  	}
   156  	for _, svc := range file.Services {
   157  		for _, m := range svc.Methods {
   158  			imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...)
   159  			pkg := m.RequestType.File.GoPkg
   160  			if len(m.Bindings) == 0 ||
   161  				pkg == file.GoPkg || pkgSeen[pkg.Path] {
   162  				continue
   163  			}
   164  			pkgSeen[pkg.Path] = true
   165  			imports = append(imports, pkg)
   166  		}
   167  	}
   168  	params := param{
   169  		File:               file,
   170  		Imports:            imports,
   171  		UseRequestContext:  g.useRequestContext,
   172  		RegisterFuncSuffix: g.registerFuncSuffix,
   173  		AllowPatchFeature:  g.allowPatchFeature,
   174  	}
   175  	if g.reg != nil {
   176  		params.OmitPackageDoc = g.reg.GetOmitPackageDoc()
   177  	}
   178  	return applyTemplate(params, g.reg)
   179  }
   180  
   181  // addEnumPathParamImports handles adding import of enum path parameter go packages
   182  func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage {
   183  	var imports []descriptor.GoPackage
   184  	for _, b := range m.Bindings {
   185  		for _, p := range b.PathParams {
   186  			e, err := g.reg.LookupEnum("", p.Target.GetTypeName())
   187  			if err != nil {
   188  				continue
   189  			}
   190  			pkg := e.File.GoPkg
   191  			if pkg == file.GoPkg || pkgSeen[pkg.Path] {
   192  				continue
   193  			}
   194  			pkgSeen[pkg.Path] = true
   195  			imports = append(imports, pkg)
   196  		}
   197  	}
   198  	return imports
   199  }
   200  

View as plain text