1 package grpc
2
3 import (
4 "context"
5
6 "google.golang.org/grpc"
7 "google.golang.org/grpc/metadata"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/transport"
11 "github.com/go-kit/log"
12 )
13
14
15
16
17 type Handler interface {
18 ServeGRPC(ctx context.Context, request interface{}) (context.Context, interface{}, error)
19 }
20
21
22 type Server struct {
23 e endpoint.Endpoint
24 dec DecodeRequestFunc
25 enc EncodeResponseFunc
26 before []ServerRequestFunc
27 after []ServerResponseFunc
28 finalizer []ServerFinalizerFunc
29 errorHandler transport.ErrorHandler
30 }
31
32
33
34
35
36
37 func NewServer(
38 e endpoint.Endpoint,
39 dec DecodeRequestFunc,
40 enc EncodeResponseFunc,
41 options ...ServerOption,
42 ) *Server {
43 s := &Server{
44 e: e,
45 dec: dec,
46 enc: enc,
47 errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()),
48 }
49 for _, option := range options {
50 option(s)
51 }
52 return s
53 }
54
55
56 type ServerOption func(*Server)
57
58
59
60 func ServerBefore(before ...ServerRequestFunc) ServerOption {
61 return func(s *Server) { s.before = append(s.before, before...) }
62 }
63
64
65
66 func ServerAfter(after ...ServerResponseFunc) ServerOption {
67 return func(s *Server) { s.after = append(s.after, after...) }
68 }
69
70
71
72
73 func ServerErrorLogger(logger log.Logger) ServerOption {
74 return func(s *Server) { s.errorHandler = transport.NewLogErrorHandler(logger) }
75 }
76
77
78
79 func ServerErrorHandler(errorHandler transport.ErrorHandler) ServerOption {
80 return func(s *Server) { s.errorHandler = errorHandler }
81 }
82
83
84
85 func ServerFinalizer(f ...ServerFinalizerFunc) ServerOption {
86 return func(s *Server) { s.finalizer = append(s.finalizer, f...) }
87 }
88
89
90 func (s Server) ServeGRPC(ctx context.Context, req interface{}) (retctx context.Context, resp interface{}, err error) {
91
92 md, ok := metadata.FromIncomingContext(ctx)
93 if !ok {
94 md = metadata.MD{}
95 }
96
97 if len(s.finalizer) > 0 {
98 defer func() {
99 for _, f := range s.finalizer {
100 f(ctx, err)
101 }
102 }()
103 }
104
105 for _, f := range s.before {
106 ctx = f(ctx, md)
107 }
108
109 var (
110 request interface{}
111 response interface{}
112 grpcResp interface{}
113 )
114
115 request, err = s.dec(ctx, req)
116 if err != nil {
117 s.errorHandler.Handle(ctx, err)
118 return ctx, nil, err
119 }
120
121 response, err = s.e(ctx, request)
122 if err != nil {
123 s.errorHandler.Handle(ctx, err)
124 return ctx, nil, err
125 }
126
127 var mdHeader, mdTrailer metadata.MD
128 for _, f := range s.after {
129 ctx = f(ctx, &mdHeader, &mdTrailer)
130 }
131
132 grpcResp, err = s.enc(ctx, response)
133 if err != nil {
134 s.errorHandler.Handle(ctx, err)
135 return ctx, nil, err
136 }
137
138 if len(mdHeader) > 0 {
139 if err = grpc.SendHeader(ctx, mdHeader); err != nil {
140 s.errorHandler.Handle(ctx, err)
141 return ctx, nil, err
142 }
143 }
144
145 if len(mdTrailer) > 0 {
146 if err = grpc.SetTrailer(ctx, mdTrailer); err != nil {
147 s.errorHandler.Handle(ctx, err)
148 return ctx, nil, err
149 }
150 }
151
152 return ctx, grpcResp, nil
153 }
154
155
156
157 type ServerFinalizerFunc func(ctx context.Context, err error)
158
159
160
161
162
163 func Interceptor(
164 ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
165 ) (resp interface{}, err error) {
166 ctx = context.WithValue(ctx, ContextKeyRequestMethod, info.FullMethod)
167 return handler(ctx, req)
168 }
169
View as plain text