...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway/generator.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway

     1  package gengateway
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/format"
     7  	"path"
     8  
     9  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
    10  	gen "github.com/grpc-ecosystem/grpc-gateway/v2/internal/generator"
    11  	"google.golang.org/grpc/grpclog"
    12  	"google.golang.org/protobuf/proto"
    13  	"google.golang.org/protobuf/types/pluginpb"
    14  )
    15  
    16  var errNoTargetService = errors.New("no target service defined in the file")
    17  
    18  type generator struct {
    19  	reg                *descriptor.Registry
    20  	baseImports        []descriptor.GoPackage
    21  	useRequestContext  bool
    22  	registerFuncSuffix string
    23  	allowPatchFeature  bool
    24  	standalone         bool
    25  }
    26  
    27  // New returns a new generator which generates grpc gateway files.
    28  func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix string,
    29  	allowPatchFeature, standalone bool) gen.Generator {
    30  	var imports []descriptor.GoPackage
    31  	for _, pkgpath := range []string{
    32  		"context",
    33  		"io",
    34  		"net/http",
    35  		"github.com/grpc-ecosystem/grpc-gateway/v2/runtime",
    36  		"github.com/grpc-ecosystem/grpc-gateway/v2/utilities",
    37  		"google.golang.org/protobuf/proto",
    38  		"google.golang.org/grpc",
    39  		"google.golang.org/grpc/codes",
    40  		"google.golang.org/grpc/grpclog",
    41  		"google.golang.org/grpc/metadata",
    42  		"google.golang.org/grpc/status",
    43  	} {
    44  		pkg := descriptor.GoPackage{
    45  			Path: pkgpath,
    46  			Name: path.Base(pkgpath),
    47  		}
    48  		if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
    49  			for i := 0; ; i++ {
    50  				alias := fmt.Sprintf("%s_%d", pkg.Name, i)
    51  				if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
    52  					continue
    53  				}
    54  				pkg.Alias = alias
    55  				break
    56  			}
    57  		}
    58  		imports = append(imports, pkg)
    59  	}
    60  
    61  	return &generator{
    62  		reg:                reg,
    63  		baseImports:        imports,
    64  		useRequestContext:  useRequestContext,
    65  		registerFuncSuffix: registerFuncSuffix,
    66  		allowPatchFeature:  allowPatchFeature,
    67  		standalone:         standalone,
    68  	}
    69  }
    70  
    71  func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error) {
    72  	var files []*descriptor.ResponseFile
    73  	for _, file := range targets {
    74  		if grpclog.V(1) {
    75  			grpclog.Infof("Processing %s", file.GetName())
    76  		}
    77  
    78  		code, err := g.generate(file)
    79  		if errors.Is(err, errNoTargetService) {
    80  			if grpclog.V(1) {
    81  				grpclog.Infof("%s: %v", file.GetName(), err)
    82  			}
    83  			continue
    84  		}
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		formatted, err := format.Source([]byte(code))
    89  		if err != nil {
    90  			grpclog.Errorf("%v: %s", err, code)
    91  			return nil, err
    92  		}
    93  		files = append(files, &descriptor.ResponseFile{
    94  			GoPkg: file.GoPkg,
    95  			CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{
    96  				Name:    proto.String(file.GeneratedFilenamePrefix + ".pb.gw.go"),
    97  				Content: proto.String(string(formatted)),
    98  			},
    99  		})
   100  	}
   101  	return files, nil
   102  }
   103  
   104  func (g *generator) generate(file *descriptor.File) (string, error) {
   105  	pkgSeen := make(map[string]bool)
   106  	var imports []descriptor.GoPackage
   107  	for _, pkg := range g.baseImports {
   108  		pkgSeen[pkg.Path] = true
   109  		imports = append(imports, pkg)
   110  	}
   111  
   112  	if g.standalone {
   113  		imports = append(imports, file.GoPkg)
   114  	}
   115  
   116  	for _, svc := range file.Services {
   117  		for _, m := range svc.Methods {
   118  			imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...)
   119  			pkg := m.RequestType.File.GoPkg
   120  			if len(m.Bindings) == 0 ||
   121  				pkg == file.GoPkg || pkgSeen[pkg.Path] {
   122  				continue
   123  			}
   124  			pkgSeen[pkg.Path] = true
   125  			imports = append(imports, pkg)
   126  		}
   127  	}
   128  	params := param{
   129  		File:               file,
   130  		Imports:            imports,
   131  		UseRequestContext:  g.useRequestContext,
   132  		RegisterFuncSuffix: g.registerFuncSuffix,
   133  		AllowPatchFeature:  g.allowPatchFeature,
   134  	}
   135  	if g.reg != nil {
   136  		params.OmitPackageDoc = g.reg.GetOmitPackageDoc()
   137  	}
   138  	return applyTemplate(params, g.reg)
   139  }
   140  
   141  // addEnumPathParamImports handles adding import of enum path parameter go packages
   142  func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage {
   143  	var imports []descriptor.GoPackage
   144  	for _, b := range m.Bindings {
   145  		for _, p := range b.PathParams {
   146  			e, err := g.reg.LookupEnum("", p.Target.GetTypeName())
   147  			if err != nil {
   148  				continue
   149  			}
   150  			pkg := e.File.GoPkg
   151  			if pkg == file.GoPkg || pkgSeen[pkg.Path] {
   152  				continue
   153  			}
   154  			pkgSeen[pkg.Path] = true
   155  			imports = append(imports, pkg)
   156  		}
   157  	}
   158  	return imports
   159  }
   160  

View as plain text