1 package api
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "net/http"
8 "strings"
9 "testing"
10
11 "github.com/cli/go-gh/v2/pkg/config"
12 "github.com/stretchr/testify/assert"
13 "gopkg.in/h2non/gock.v1"
14 )
15
16 func TestHTTPClient(t *testing.T) {
17 stubConfig(t, testConfig())
18 t.Cleanup(gock.Off)
19
20 gock.New("https://api.github.com").
21 Get("/some/test/path").
22 MatchHeader("Authorization", "token abc123").
23 Reply(200).
24 JSON(`{"message": "success"}`)
25
26 client, err := DefaultHTTPClient()
27 assert.NoError(t, err)
28
29 res, err := client.Get("https://api.github.com/some/test/path")
30 assert.NoError(t, err)
31 assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending()))
32 assert.Equal(t, 200, res.StatusCode)
33 }
34
35 func TestNewHTTPClient(t *testing.T) {
36 reflectHTTP := tripper{
37 roundTrip: func(req *http.Request) (*http.Response, error) {
38 header := req.Header.Clone()
39 body := "{}"
40 return &http.Response{
41 StatusCode: 200,
42 Header: header,
43 Body: io.NopCloser(bytes.NewBufferString(body)),
44 }, nil
45 },
46 }
47
48 tests := []struct {
49 name string
50 enableLog bool
51 log *bytes.Buffer
52 host string
53 headers map[string]string
54 skipHeaders bool
55 wantHeaders http.Header
56 }{
57 {
58 name: "sets default headers",
59 wantHeaders: defaultHeaders(),
60 },
61 {
62 name: "allows overriding default headers",
63 headers: map[string]string{
64 authorization: "token new_token",
65 accept: "application/vnd.github.test-preview",
66 },
67 wantHeaders: func() http.Header {
68 h := defaultHeaders()
69 h.Set(authorization, "token new_token")
70 h.Set(accept, "application/vnd.github.test-preview")
71 return h
72 }(),
73 },
74 {
75 name: "allows setting custom headers",
76 headers: map[string]string{
77 "custom": "testing",
78 },
79 wantHeaders: func() http.Header {
80 h := defaultHeaders()
81 h.Set("custom", "testing")
82 return h
83 }(),
84 },
85 {
86 name: "allows setting logger",
87 enableLog: true,
88 log: &bytes.Buffer{},
89 wantHeaders: defaultHeaders(),
90 },
91 {
92 name: "does not add an authorization header for non-matching host",
93 host: "notauthorized.com",
94 wantHeaders: func() http.Header {
95 h := defaultHeaders()
96 h.Del(authorization)
97 return h
98 }(),
99 },
100 {
101 name: "does not add an authorization header for non-matching host subdomain",
102 host: "test.company",
103 wantHeaders: func() http.Header {
104 h := defaultHeaders()
105 h.Del(authorization)
106 return h
107 }(),
108 },
109 {
110 name: "adds an authorization header for a matching host",
111 host: "test.com",
112 wantHeaders: defaultHeaders(),
113 },
114 {
115 name: "adds an authorization header if hosts match but differ in case",
116 host: "TeSt.CoM",
117 wantHeaders: defaultHeaders(),
118 },
119 {
120 name: "skips default headers",
121 skipHeaders: true,
122 wantHeaders: func() http.Header {
123 h := defaultHeaders()
124 h.Del(accept)
125 h.Del(contentType)
126 h.Del(timeZone)
127 h.Del(userAgent)
128 return h
129 }(),
130 },
131 }
132
133 for _, tt := range tests {
134 t.Run(tt.name, func(t *testing.T) {
135 if tt.host == "" {
136 tt.host = "test.com"
137 }
138 opts := ClientOptions{
139 Host: tt.host,
140 AuthToken: "oauth_token",
141 Headers: tt.headers,
142 SkipDefaultHeaders: tt.skipHeaders,
143 Transport: reflectHTTP,
144 LogIgnoreEnv: true,
145 }
146 if tt.enableLog {
147 opts.Log = tt.log
148 }
149 client, _ := NewHTTPClient(opts)
150 res, err := client.Get("https://test.com")
151 assert.NoError(t, err)
152 assert.Equal(t, tt.wantHeaders, res.Header)
153 if tt.enableLog {
154 assert.NotEmpty(t, tt.log)
155 }
156 })
157 }
158 }
159
160 type tripper struct {
161 roundTrip func(*http.Request) (*http.Response, error)
162 }
163
164 func (tr tripper) RoundTrip(req *http.Request) (*http.Response, error) {
165 return tr.roundTrip(req)
166 }
167
168 func defaultHeaders() http.Header {
169 h := http.Header{}
170 a := "application/vnd.github.merge-info-preview+json"
171 a += ", application/vnd.github.nebula-preview"
172 h.Set(contentType, jsonContentType)
173 h.Set(userAgent, "go-gh")
174 h.Set(authorization, fmt.Sprintf("token %s", "oauth_token"))
175 h.Set(timeZone, currentTimeZone())
176 h.Set(accept, a)
177 return h
178 }
179
180 func stubConfig(t *testing.T, cfgStr string) {
181 t.Helper()
182 old := config.Read
183 config.Read = func(_ *config.Config) (*config.Config, error) {
184 return config.ReadFromString(cfgStr), nil
185 }
186 t.Cleanup(func() {
187 config.Read = old
188 })
189 }
190
191 func printPendingMocks(mocks []gock.Mock) string {
192 paths := []string{}
193 for _, mock := range mocks {
194 paths = append(paths, mock.Request().URLStruct.String())
195 }
196 return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", "))
197 }
198
View as plain text