1 package jsonrpc
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9
10 httptransport "github.com/go-kit/kit/transport/http"
11 "github.com/go-kit/log"
12 )
13
14 type requestIDKeyType struct{}
15
16 var requestIDKey requestIDKeyType
17
18
19 type Server struct {
20 ecm EndpointCodecMap
21 before []httptransport.RequestFunc
22 beforeCodec []RequestFunc
23 after []httptransport.ServerResponseFunc
24 errorEncoder httptransport.ErrorEncoder
25 finalizer httptransport.ServerFinalizerFunc
26 logger log.Logger
27 }
28
29
30 func NewServer(
31 ecm EndpointCodecMap,
32 options ...ServerOption,
33 ) *Server {
34 s := &Server{
35 ecm: ecm,
36 errorEncoder: DefaultErrorEncoder,
37 logger: log.NewNopLogger(),
38 }
39 for _, option := range options {
40 option(s)
41 }
42 return s
43 }
44
45
46 type ServerOption func(*Server)
47
48
49
50 func ServerBefore(before ...httptransport.RequestFunc) ServerOption {
51 return func(s *Server) { s.before = append(s.before, before...) }
52 }
53
54
55
56
57
58 func ServerBeforeCodec(beforeCodec ...RequestFunc) ServerOption {
59 return func(s *Server) { s.beforeCodec = append(s.beforeCodec, beforeCodec...) }
60 }
61
62
63
64 func ServerAfter(after ...httptransport.ServerResponseFunc) ServerOption {
65 return func(s *Server) { s.after = append(s.after, after...) }
66 }
67
68
69
70
71
72 func ServerErrorEncoder(ee httptransport.ErrorEncoder) ServerOption {
73 return func(s *Server) { s.errorEncoder = ee }
74 }
75
76
77
78
79
80
81 func ServerErrorLogger(logger log.Logger) ServerOption {
82 return func(s *Server) { s.logger = logger }
83 }
84
85
86
87 func ServerFinalizer(f httptransport.ServerFinalizerFunc) ServerOption {
88 return func(s *Server) { s.finalizer = f }
89 }
90
91
92 func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
93 if r.Method != http.MethodPost {
94 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
95 w.WriteHeader(http.StatusMethodNotAllowed)
96 _, _ = io.WriteString(w, "405 must POST\n")
97 return
98 }
99 ctx := r.Context()
100
101 if s.finalizer != nil {
102 iw := &interceptingWriter{w, http.StatusOK}
103 defer func() { s.finalizer(ctx, iw.code, r) }()
104 w = iw
105 }
106
107 for _, f := range s.before {
108 ctx = f(ctx, r)
109 }
110
111
112 var req Request
113 err := json.NewDecoder(r.Body).Decode(&req)
114 if err != nil {
115 rpcerr := parseError("JSON could not be decoded: " + err.Error())
116 s.logger.Log("err", rpcerr)
117 s.errorEncoder(ctx, rpcerr, w)
118 return
119 }
120
121 ctx = context.WithValue(ctx, requestIDKey, req.ID)
122 ctx = context.WithValue(ctx, ContextKeyRequestMethod, req.Method)
123
124 for _, f := range s.beforeCodec {
125 ctx = f(ctx, r, req)
126 }
127
128
129
130 ecm, ok := s.ecm[req.Method]
131 if !ok {
132 err := methodNotFoundError(fmt.Sprintf("Method %s was not found.", req.Method))
133 s.logger.Log("err", err)
134 s.errorEncoder(ctx, err, w)
135 return
136 }
137
138
139 reqParams, err := ecm.Decode(ctx, req.Params)
140 if err != nil {
141 s.logger.Log("err", err)
142 s.errorEncoder(ctx, err, w)
143 return
144 }
145
146
147 response, err := ecm.Endpoint(ctx, reqParams)
148 if err != nil {
149 s.logger.Log("err", err)
150 s.errorEncoder(ctx, err, w)
151 return
152 }
153
154 for _, f := range s.after {
155 ctx = f(ctx, w)
156 }
157
158 res := Response{
159 ID: req.ID,
160 JSONRPC: Version,
161 }
162
163
164 resParams, err := ecm.Encode(ctx, response)
165 if err != nil {
166 s.logger.Log("err", err)
167 s.errorEncoder(ctx, err, w)
168 return
169 }
170
171 res.Result = resParams
172
173 w.Header().Set("Content-Type", ContentType)
174 _ = json.NewEncoder(w).Encode(res)
175 }
176
177
178
179
180
181
182
183 func DefaultErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) {
184 w.Header().Set("Content-Type", ContentType)
185 if headerer, ok := err.(httptransport.Headerer); ok {
186 for k := range headerer.Headers() {
187 w.Header().Set(k, headerer.Headers().Get(k))
188 }
189 }
190
191 e := Error{
192 Code: InternalError,
193 Message: err.Error(),
194 }
195 if sc, ok := err.(ErrorCoder); ok {
196 e.Code = sc.ErrorCode()
197 }
198
199 w.WriteHeader(http.StatusOK)
200
201 var requestID *RequestID
202 if v := ctx.Value(requestIDKey); v != nil {
203 requestID = v.(*RequestID)
204 }
205 _ = json.NewEncoder(w).Encode(Response{
206 ID: requestID,
207 JSONRPC: Version,
208 Error: &e,
209 })
210 }
211
212
213
214
215
216
217 type ErrorCoder interface {
218 ErrorCode() int
219 }
220
221
222
223 type interceptingWriter struct {
224 http.ResponseWriter
225 code int
226 }
227
228
229
230 func (w *interceptingWriter) WriteHeader(code int) {
231 w.code = code
232 w.ResponseWriter.WriteHeader(code)
233 }
234
View as plain text