1 package awslambda
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/go-kit/kit/endpoint"
10 "github.com/go-kit/kit/transport"
11 "github.com/go-kit/log"
12 )
13
14 type key int
15
16 const (
17 KeyBeforeOne key = iota
18 KeyBeforeTwo key = iota
19 KeyAfterOne key = iota
20 KeyEncMode key = iota
21 )
22
23
24 type apiGatewayProxyRequest struct {
25 Body string `json:"body"`
26 }
27
28
29 type apiGatewayProxyResponse struct {
30 StatusCode int `json:"statusCode"`
31 Body string `json:"body"`
32 }
33
34 func TestDefaultErrorEncoder(t *testing.T) {
35 ctx := context.Background()
36 rootErr := fmt.Errorf("root")
37 b, err := DefaultErrorEncoder(ctx, rootErr)
38 if b != nil {
39 t.Fatalf("DefaultErrorEncoder should return nil as []byte")
40 }
41 if err != rootErr {
42 t.Fatalf("DefaultErrorEncoder expects return back the given error.")
43 }
44 }
45
46 func TestInvokeHappyPath(t *testing.T) {
47 svc := serviceTest01{}
48
49 helloHandler := NewHandler(
50 makeTest01HelloEndpoint(svc),
51 decodeHelloRequestWithTwoBefores,
52 encodeResponse,
53 HandlerErrorHandler(transport.NewLogErrorHandler(log.NewNopLogger())),
54 HandlerBefore(func(
55 ctx context.Context,
56 payload []byte,
57 ) context.Context {
58 ctx = context.WithValue(ctx, KeyBeforeOne, "bef1")
59 return ctx
60 }),
61 HandlerBefore(func(
62 ctx context.Context,
63 payload []byte,
64 ) context.Context {
65 ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2")
66 return ctx
67 }),
68 HandlerAfter(func(
69 ctx context.Context,
70 response interface{},
71 ) context.Context {
72 ctx = context.WithValue(ctx, KeyAfterOne, "af1")
73 return ctx
74 }),
75 HandlerAfter(func(
76 ctx context.Context,
77 response interface{},
78 ) context.Context {
79 if _, ok := ctx.Value(KeyAfterOne).(string); !ok {
80 t.Fatalf("Value was not set properly during multi HandlerAfter")
81 }
82 return ctx
83 }),
84 HandlerFinalizer(func(
85 _ context.Context,
86 resp []byte,
87 _ error,
88 ) {
89 apigwResp := apiGatewayProxyResponse{}
90 err := json.Unmarshal(resp, &apigwResp)
91 if err != nil {
92 t.Fatalf("Should have no error, but got: %+v", err)
93 }
94
95 response := helloResponse{}
96 err = json.Unmarshal([]byte(apigwResp.Body), &response)
97 if err != nil {
98 t.Fatalf("Should have no error, but got: %+v", err)
99 }
100
101 expectedGreeting := "hello john doe bef1 bef2"
102 if response.Greeting != expectedGreeting {
103 t.Fatalf(
104 "Expect: %s, Actual: %s", expectedGreeting, response.Greeting)
105 }
106 }),
107 )
108
109 ctx := context.Background()
110 req, _ := json.Marshal(apiGatewayProxyRequest{
111 Body: `{"name":"john doe"}`,
112 })
113 resp, err := helloHandler.Invoke(ctx, req)
114
115 if err != nil {
116 t.Fatalf("Should have no error, but got: %+v", err)
117 }
118
119 apigwResp := apiGatewayProxyResponse{}
120 err = json.Unmarshal(resp, &apigwResp)
121 if err != nil {
122 t.Fatalf("Should have no error, but got: %+v", err)
123 }
124
125 response := helloResponse{}
126 err = json.Unmarshal([]byte(apigwResp.Body), &response)
127 if err != nil {
128 t.Fatalf("Should have no error, but got: %+v", err)
129 }
130
131 expectedGreeting := "hello john doe bef1 bef2"
132 if response.Greeting != expectedGreeting {
133 t.Fatalf(
134 "Expect: %s, Actual: %s", expectedGreeting, response.Greeting)
135 }
136 }
137
138 func TestInvokeFailDecode(t *testing.T) {
139 svc := serviceTest01{}
140
141 helloHandler := NewHandler(
142 makeTest01HelloEndpoint(svc),
143 decodeHelloRequestWithTwoBefores,
144 encodeResponse,
145 HandlerErrorEncoder(func(
146 ctx context.Context,
147 err error,
148 ) ([]byte, error) {
149 apigwResp := apiGatewayProxyResponse{}
150 apigwResp.Body = `{"error":"yes"}`
151 apigwResp.StatusCode = 500
152 resp, err := json.Marshal(apigwResp)
153 return resp, err
154 }),
155 )
156
157 ctx := context.Background()
158 req, _ := json.Marshal(apiGatewayProxyRequest{
159 Body: `{"name":"john doe"}`,
160 })
161 resp, err := helloHandler.Invoke(ctx, req)
162
163 if err != nil {
164 t.Fatalf("Should have no error, but got: %+v", err)
165 }
166
167 apigwResp := apiGatewayProxyResponse{}
168 json.Unmarshal(resp, &apigwResp)
169 if apigwResp.StatusCode != 500 {
170 t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode)
171 }
172 }
173
174 func TestInvokeFailEndpoint(t *testing.T) {
175 svc := serviceTest01{}
176
177 helloHandler := NewHandler(
178 makeTest01FailEndpoint(svc),
179 decodeHelloRequestWithTwoBefores,
180 encodeResponse,
181 HandlerBefore(func(
182 ctx context.Context,
183 payload []byte,
184 ) context.Context {
185 ctx = context.WithValue(ctx, KeyBeforeOne, "bef1")
186 return ctx
187 }),
188 HandlerBefore(func(
189 ctx context.Context,
190 payload []byte,
191 ) context.Context {
192 ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2")
193 return ctx
194 }),
195 HandlerErrorEncoder(func(
196 ctx context.Context,
197 err error,
198 ) ([]byte, error) {
199 apigwResp := apiGatewayProxyResponse{}
200 apigwResp.Body = `{"error":"yes"}`
201 apigwResp.StatusCode = 500
202 resp, err := json.Marshal(apigwResp)
203 return resp, err
204 }),
205 )
206
207 ctx := context.Background()
208 req, _ := json.Marshal(apiGatewayProxyRequest{
209 Body: `{"name":"john doe"}`,
210 })
211 resp, err := helloHandler.Invoke(ctx, req)
212
213 if err != nil {
214 t.Fatalf("Should have no error, but got: %+v", err)
215 }
216
217 apigwResp := apiGatewayProxyResponse{}
218 json.Unmarshal(resp, &apigwResp)
219 if apigwResp.StatusCode != 500 {
220 t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode)
221 }
222 }
223
224 func TestInvokeFailEncode(t *testing.T) {
225 svc := serviceTest01{}
226
227 helloHandler := NewHandler(
228 makeTest01HelloEndpoint(svc),
229 decodeHelloRequestWithTwoBefores,
230 encodeResponse,
231 HandlerBefore(func(
232 ctx context.Context,
233 payload []byte,
234 ) context.Context {
235 ctx = context.WithValue(ctx, KeyBeforeOne, "bef1")
236 return ctx
237 }),
238 HandlerBefore(func(
239 ctx context.Context,
240 payload []byte,
241 ) context.Context {
242 ctx = context.WithValue(ctx, KeyBeforeTwo, "bef2")
243 return ctx
244 }),
245 HandlerAfter(func(
246 ctx context.Context,
247 response interface{},
248 ) context.Context {
249 ctx = context.WithValue(ctx, KeyEncMode, "fail_encode")
250 return ctx
251 }),
252 HandlerErrorEncoder(func(
253 ctx context.Context,
254 err error,
255 ) ([]byte, error) {
256
257 apigwResp := apiGatewayProxyResponse{}
258 apigwResp.Body = `{"error":"yes"}`
259 apigwResp.StatusCode = 500
260 resp, err := json.Marshal(apigwResp)
261 return resp, err
262 }),
263 )
264
265 ctx := context.Background()
266 req, _ := json.Marshal(apiGatewayProxyRequest{
267 Body: `{"name":"john doe"}`,
268 })
269 resp, err := helloHandler.Invoke(ctx, req)
270
271 if err != nil {
272 t.Fatalf("Should have no error, but got: %+v", err)
273 }
274
275 apigwResp := apiGatewayProxyResponse{}
276 json.Unmarshal(resp, &apigwResp)
277 if apigwResp.StatusCode != 500 {
278 t.Fatalf("Expect status code of 500, instead of %d", apigwResp.StatusCode)
279 }
280 }
281
282 func decodeHelloRequestWithTwoBefores(
283 ctx context.Context, req []byte,
284 ) (interface{}, error) {
285 apigwReq := apiGatewayProxyRequest{}
286 err := json.Unmarshal([]byte(req), &apigwReq)
287 if err != nil {
288 return apigwReq, err
289 }
290
291 request := helloRequest{}
292 err = json.Unmarshal([]byte(apigwReq.Body), &request)
293 if err != nil {
294 return request, err
295 }
296
297 valOne, ok := ctx.Value(KeyBeforeOne).(string)
298 if !ok {
299 return request, fmt.Errorf(
300 "Value was not set properly when multiple HandlerBefores are used")
301 }
302
303 valTwo, ok := ctx.Value(KeyBeforeTwo).(string)
304 if !ok {
305 return request, fmt.Errorf(
306 "Value was not set properly when multiple HandlerBefores are used")
307 }
308
309 request.Name += " " + valOne + " " + valTwo
310 return request, err
311 }
312
313 func encodeResponse(
314 ctx context.Context, response interface{},
315 ) ([]byte, error) {
316 apigwResp := apiGatewayProxyResponse{}
317
318 mode, ok := ctx.Value(KeyEncMode).(string)
319 if ok && mode == "fail_encode" {
320 return nil, fmt.Errorf("fail encoding")
321 }
322
323 respByte, err := json.Marshal(response)
324 if err != nil {
325 return nil, err
326 }
327
328 apigwResp.Body = string(respByte)
329 apigwResp.StatusCode = 200
330
331 resp, err := json.Marshal(apigwResp)
332 return resp, err
333 }
334
335 type helloRequest struct {
336 Name string `json:"name"`
337 }
338
339 type helloResponse struct {
340 Greeting string `json:"greeting"`
341 }
342
343 func makeTest01HelloEndpoint(svc serviceTest01) endpoint.Endpoint {
344 return func(_ context.Context, request interface{}) (interface{}, error) {
345 req := request.(helloRequest)
346 greeting := svc.hello(req.Name)
347 return helloResponse{greeting}, nil
348 }
349 }
350
351 func makeTest01FailEndpoint(_ serviceTest01) endpoint.Endpoint {
352 return func(_ context.Context, request interface{}) (interface{}, error) {
353 return nil, fmt.Errorf("test error endpoint")
354 }
355 }
356
357 type serviceTest01 struct{}
358
359 func (ts *serviceTest01) hello(name string) string {
360 return fmt.Sprintf("hello %s", name)
361 }
362
View as plain text