...

Source file src/github.com/cli/go-gh/v2/pkg/api/http_client_test.go

Documentation: github.com/cli/go-gh/v2/pkg/api

     1  package api
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/cli/go-gh/v2/pkg/config"
    12  	"github.com/stretchr/testify/assert"
    13  	"gopkg.in/h2non/gock.v1"
    14  )
    15  
    16  func TestHTTPClient(t *testing.T) {
    17  	stubConfig(t, testConfig())
    18  	t.Cleanup(gock.Off)
    19  
    20  	gock.New("https://api.github.com").
    21  		Get("/some/test/path").
    22  		MatchHeader("Authorization", "token abc123").
    23  		Reply(200).
    24  		JSON(`{"message": "success"}`)
    25  
    26  	client, err := DefaultHTTPClient()
    27  	assert.NoError(t, err)
    28  
    29  	res, err := client.Get("https://api.github.com/some/test/path")
    30  	assert.NoError(t, err)
    31  	assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending()))
    32  	assert.Equal(t, 200, res.StatusCode)
    33  }
    34  
    35  func TestNewHTTPClient(t *testing.T) {
    36  	reflectHTTP := tripper{
    37  		roundTrip: func(req *http.Request) (*http.Response, error) {
    38  			header := req.Header.Clone()
    39  			body := "{}"
    40  			return &http.Response{
    41  				StatusCode: 200,
    42  				Header:     header,
    43  				Body:       io.NopCloser(bytes.NewBufferString(body)),
    44  			}, nil
    45  		},
    46  	}
    47  
    48  	tests := []struct {
    49  		name        string
    50  		enableLog   bool
    51  		log         *bytes.Buffer
    52  		host        string
    53  		headers     map[string]string
    54  		skipHeaders bool
    55  		wantHeaders http.Header
    56  	}{
    57  		{
    58  			name:        "sets default headers",
    59  			wantHeaders: defaultHeaders(),
    60  		},
    61  		{
    62  			name: "allows overriding default headers",
    63  			headers: map[string]string{
    64  				authorization: "token new_token",
    65  				accept:        "application/vnd.github.test-preview",
    66  			},
    67  			wantHeaders: func() http.Header {
    68  				h := defaultHeaders()
    69  				h.Set(authorization, "token new_token")
    70  				h.Set(accept, "application/vnd.github.test-preview")
    71  				return h
    72  			}(),
    73  		},
    74  		{
    75  			name: "allows setting custom headers",
    76  			headers: map[string]string{
    77  				"custom": "testing",
    78  			},
    79  			wantHeaders: func() http.Header {
    80  				h := defaultHeaders()
    81  				h.Set("custom", "testing")
    82  				return h
    83  			}(),
    84  		},
    85  		{
    86  			name:        "allows setting logger",
    87  			enableLog:   true,
    88  			log:         &bytes.Buffer{},
    89  			wantHeaders: defaultHeaders(),
    90  		},
    91  		{
    92  			name: "does not add an authorization header for non-matching host",
    93  			host: "notauthorized.com",
    94  			wantHeaders: func() http.Header {
    95  				h := defaultHeaders()
    96  				h.Del(authorization)
    97  				return h
    98  			}(),
    99  		},
   100  		{
   101  			name: "does not add an authorization header for non-matching host subdomain",
   102  			host: "test.company",
   103  			wantHeaders: func() http.Header {
   104  				h := defaultHeaders()
   105  				h.Del(authorization)
   106  				return h
   107  			}(),
   108  		},
   109  		{
   110  			name:        "adds an authorization header for a matching host",
   111  			host:        "test.com",
   112  			wantHeaders: defaultHeaders(),
   113  		},
   114  		{
   115  			name:        "adds an authorization header if hosts match but differ in case",
   116  			host:        "TeSt.CoM",
   117  			wantHeaders: defaultHeaders(),
   118  		},
   119  		{
   120  			name:        "skips default headers",
   121  			skipHeaders: true,
   122  			wantHeaders: func() http.Header {
   123  				h := defaultHeaders()
   124  				h.Del(accept)
   125  				h.Del(contentType)
   126  				h.Del(timeZone)
   127  				h.Del(userAgent)
   128  				return h
   129  			}(),
   130  		},
   131  	}
   132  
   133  	for _, tt := range tests {
   134  		t.Run(tt.name, func(t *testing.T) {
   135  			if tt.host == "" {
   136  				tt.host = "test.com"
   137  			}
   138  			opts := ClientOptions{
   139  				Host:               tt.host,
   140  				AuthToken:          "oauth_token",
   141  				Headers:            tt.headers,
   142  				SkipDefaultHeaders: tt.skipHeaders,
   143  				Transport:          reflectHTTP,
   144  				LogIgnoreEnv:       true,
   145  			}
   146  			if tt.enableLog {
   147  				opts.Log = tt.log
   148  			}
   149  			client, _ := NewHTTPClient(opts)
   150  			res, err := client.Get("https://test.com")
   151  			assert.NoError(t, err)
   152  			assert.Equal(t, tt.wantHeaders, res.Header)
   153  			if tt.enableLog {
   154  				assert.NotEmpty(t, tt.log)
   155  			}
   156  		})
   157  	}
   158  }
   159  
   160  type tripper struct {
   161  	roundTrip func(*http.Request) (*http.Response, error)
   162  }
   163  
   164  func (tr tripper) RoundTrip(req *http.Request) (*http.Response, error) {
   165  	return tr.roundTrip(req)
   166  }
   167  
   168  func defaultHeaders() http.Header {
   169  	h := http.Header{}
   170  	a := "application/vnd.github.merge-info-preview+json"
   171  	a += ", application/vnd.github.nebula-preview"
   172  	h.Set(contentType, jsonContentType)
   173  	h.Set(userAgent, "go-gh")
   174  	h.Set(authorization, fmt.Sprintf("token %s", "oauth_token"))
   175  	h.Set(timeZone, currentTimeZone())
   176  	h.Set(accept, a)
   177  	return h
   178  }
   179  
   180  func stubConfig(t *testing.T, cfgStr string) {
   181  	t.Helper()
   182  	old := config.Read
   183  	config.Read = func(_ *config.Config) (*config.Config, error) {
   184  		return config.ReadFromString(cfgStr), nil
   185  	}
   186  	t.Cleanup(func() {
   187  		config.Read = old
   188  	})
   189  }
   190  
   191  func printPendingMocks(mocks []gock.Mock) string {
   192  	paths := []string{}
   193  	for _, mock := range mocks {
   194  		paths = append(paths, mock.Request().URLStruct.String())
   195  	}
   196  	return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", "))
   197  }
   198  

View as plain text