...
1
2
3
4
5
6
7 package ctxhttp
8
9 import (
10 "context"
11 "io"
12 "net/http"
13 "net/http/httptest"
14 "testing"
15 "time"
16 )
17
18 func TestGo17Context(t *testing.T) {
19 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20 io.WriteString(w, "ok")
21 }))
22 defer ts.Close()
23 ctx := context.Background()
24 resp, err := Get(ctx, http.DefaultClient, ts.URL)
25 if resp == nil || err != nil {
26 t.Fatalf("error received from client: %v %v", err, resp)
27 }
28 resp.Body.Close()
29 }
30
31 const (
32 requestDuration = 100 * time.Millisecond
33 requestBody = "ok"
34 )
35
36 func okHandler(w http.ResponseWriter, r *http.Request) {
37 time.Sleep(requestDuration)
38 io.WriteString(w, requestBody)
39 }
40
41 func TestNoTimeout(t *testing.T) {
42 ts := httptest.NewServer(http.HandlerFunc(okHandler))
43 defer ts.Close()
44
45 ctx := context.Background()
46 res, err := Get(ctx, nil, ts.URL)
47 if err != nil {
48 t.Fatal(err)
49 }
50 defer res.Body.Close()
51 slurp, err := io.ReadAll(res.Body)
52 if err != nil {
53 t.Fatal(err)
54 }
55 if string(slurp) != requestBody {
56 t.Errorf("body = %q; want %q", slurp, requestBody)
57 }
58 }
59
60 func TestCancelBeforeHeaders(t *testing.T) {
61 ctx, cancel := context.WithCancel(context.Background())
62
63 blockServer := make(chan struct{})
64 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65 cancel()
66 <-blockServer
67 io.WriteString(w, requestBody)
68 }))
69 defer ts.Close()
70 defer close(blockServer)
71
72 res, err := Get(ctx, nil, ts.URL)
73 if err == nil {
74 res.Body.Close()
75 t.Fatal("Get returned unexpected nil error")
76 }
77 if err != context.Canceled {
78 t.Errorf("err = %v; want %v", err, context.Canceled)
79 }
80 }
81
82 func TestCancelAfterHangingRequest(t *testing.T) {
83 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
84 w.WriteHeader(http.StatusOK)
85 w.(http.Flusher).Flush()
86 <-w.(http.CloseNotifier).CloseNotify()
87 }))
88 defer ts.Close()
89
90 ctx, cancel := context.WithCancel(context.Background())
91 resp, err := Get(ctx, nil, ts.URL)
92 if err != nil {
93 t.Fatalf("unexpected error in Get: %v", err)
94 }
95
96
97
98
99 cancel()
100
101 done := make(chan struct{})
102
103 go func() {
104 b, err := io.ReadAll(resp.Body)
105 if len(b) != 0 || err == nil {
106 t.Errorf(`Read got (%q, %v); want ("", error)`, b, err)
107 }
108 close(done)
109 }()
110
111 select {
112 case <-time.After(1 * time.Second):
113 t.Errorf("Test timed out")
114 case <-done:
115 }
116 }
117
View as plain text