1 package jsonrpc
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "io/ioutil"
8 "net/http"
9 "net/url"
10 "sync/atomic"
11
12 "github.com/go-kit/kit/endpoint"
13 httptransport "github.com/go-kit/kit/transport/http"
14 )
15
16
17 type Client struct {
18 client httptransport.HTTPClient
19
20
21 tgt *url.URL
22
23
24 method string
25
26 enc EncodeRequestFunc
27 dec DecodeResponseFunc
28 before []httptransport.RequestFunc
29 after []httptransport.ClientResponseFunc
30 finalizer httptransport.ClientFinalizerFunc
31 requestID RequestIDGenerator
32 bufferedStream bool
33 }
34
35 type clientRequest struct {
36 JSONRPC string `json:"jsonrpc"`
37 Method string `json:"method"`
38 Params json.RawMessage `json:"params"`
39 ID interface{} `json:"id"`
40 }
41
42
43 func NewClient(
44 tgt *url.URL,
45 method string,
46 options ...ClientOption,
47 ) *Client {
48 c := &Client{
49 client: http.DefaultClient,
50 method: method,
51 tgt: tgt,
52 enc: DefaultRequestEncoder,
53 dec: DefaultResponseDecoder,
54 before: []httptransport.RequestFunc{},
55 after: []httptransport.ClientResponseFunc{},
56 requestID: NewAutoIncrementID(0),
57 bufferedStream: false,
58 }
59 for _, option := range options {
60 option(c)
61 }
62 return c
63 }
64
65
66 func DefaultRequestEncoder(_ context.Context, req interface{}) (json.RawMessage, error) {
67 return json.Marshal(req)
68 }
69
70
71
72 func DefaultResponseDecoder(_ context.Context, res Response) (interface{}, error) {
73 if res.Error != nil {
74 return nil, *res.Error
75 }
76 var result interface{}
77 err := json.Unmarshal(res.Result, &result)
78 if err != nil {
79 return nil, err
80 }
81 return result, nil
82 }
83
84
85 type ClientOption func(*Client)
86
87
88
89 func SetClient(client httptransport.HTTPClient) ClientOption {
90 return func(c *Client) { c.client = client }
91 }
92
93
94
95 func ClientBefore(before ...httptransport.RequestFunc) ClientOption {
96 return func(c *Client) { c.before = append(c.before, before...) }
97 }
98
99
100
101
102 func ClientAfter(after ...httptransport.ClientResponseFunc) ClientOption {
103 return func(c *Client) { c.after = append(c.after, after...) }
104 }
105
106
107
108 func ClientFinalizer(f httptransport.ClientFinalizerFunc) ClientOption {
109 return func(c *Client) { c.finalizer = f }
110 }
111
112
113
114 func ClientRequestEncoder(enc EncodeRequestFunc) ClientOption {
115 return func(c *Client) { c.enc = enc }
116 }
117
118
119
120 func ClientResponseDecoder(dec DecodeResponseFunc) ClientOption {
121 return func(c *Client) { c.dec = dec }
122 }
123
124
125 type RequestIDGenerator interface {
126 Generate() interface{}
127 }
128
129
130
131
132 func ClientRequestIDGenerator(g RequestIDGenerator) ClientOption {
133 return func(c *Client) { c.requestID = g }
134 }
135
136
137
138 func BufferedStream(buffered bool) ClientOption {
139 return func(c *Client) { c.bufferedStream = buffered }
140 }
141
142
143 func (c Client) Endpoint() endpoint.Endpoint {
144 return func(ctx context.Context, request interface{}) (interface{}, error) {
145 ctx, cancel := context.WithCancel(ctx)
146 defer cancel()
147
148 var (
149 resp *http.Response
150 err error
151 )
152 if c.finalizer != nil {
153 defer func() {
154 if resp != nil {
155 ctx = context.WithValue(ctx, httptransport.ContextKeyResponseHeaders, resp.Header)
156 ctx = context.WithValue(ctx, httptransport.ContextKeyResponseSize, resp.ContentLength)
157 }
158 c.finalizer(ctx, err)
159 }()
160 }
161
162 ctx = context.WithValue(ctx, ContextKeyRequestMethod, c.method)
163
164 var params json.RawMessage
165 if params, err = c.enc(ctx, request); err != nil {
166 return nil, err
167 }
168 rpcReq := clientRequest{
169 JSONRPC: Version,
170 Method: c.method,
171 Params: params,
172 ID: c.requestID.Generate(),
173 }
174
175 req, err := http.NewRequest("POST", c.tgt.String(), nil)
176 if err != nil {
177 return nil, err
178 }
179
180 req.Header.Set("Content-Type", "application/json; charset=utf-8")
181 var b bytes.Buffer
182 req.Body = ioutil.NopCloser(&b)
183 err = json.NewEncoder(&b).Encode(rpcReq)
184 if err != nil {
185 return nil, err
186 }
187
188 for _, f := range c.before {
189 ctx = f(ctx, req)
190 }
191
192 resp, err = c.client.Do(req.WithContext(ctx))
193 if err != nil {
194 return nil, err
195 }
196
197 if !c.bufferedStream {
198 defer resp.Body.Close()
199 }
200
201 for _, f := range c.after {
202 ctx = f(ctx, resp)
203 }
204
205
206 var rpcRes Response
207 err = json.NewDecoder(resp.Body).Decode(&rpcRes)
208 if err != nil {
209 return nil, err
210 }
211
212 response, err := c.dec(ctx, rpcRes)
213 if err != nil {
214 return nil, err
215 }
216
217 return response, nil
218 }
219 }
220
221
222
223
224
225
226
227 type ClientFinalizerFunc func(ctx context.Context, err error)
228
229
230
231 type autoIncrementID struct {
232 v *uint64
233 }
234
235
236
237 func NewAutoIncrementID(init uint64) RequestIDGenerator {
238
239 v := init - 1
240 return &autoIncrementID{v: &v}
241 }
242
243
244 func (i *autoIncrementID) Generate() interface{} {
245 id := atomic.AddUint64(i.v, 1)
246 return id
247 }
248
View as plain text