...

Source file src/github.com/hashicorp/go-retryablehttp/roundtripper_test.go

Documentation: github.com/hashicorp/go-retryablehttp

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package retryablehttp
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"reflect"
    15  	"sync/atomic"
    16  	"testing"
    17  )
    18  
    19  func TestRoundTripper_implements(t *testing.T) {
    20  	// Compile-time proof of interface satisfaction.
    21  	var _ http.RoundTripper = &RoundTripper{}
    22  }
    23  
    24  func TestRoundTripper_init(t *testing.T) {
    25  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    26  		w.WriteHeader(200)
    27  	}))
    28  	defer ts.Close()
    29  
    30  	// Start with a new empty RoundTripper.
    31  	rt := &RoundTripper{}
    32  
    33  	// RoundTrip once.
    34  	req, _ := http.NewRequest("GET", ts.URL, nil)
    35  	if _, err := rt.RoundTrip(req); err != nil {
    36  		t.Fatal(err)
    37  	}
    38  
    39  	// Check that the Client was initialized.
    40  	if rt.Client == nil {
    41  		t.Fatal("expected rt.Client to be initialized")
    42  	}
    43  
    44  	// Save the Client for later comparison.
    45  	initialClient := rt.Client
    46  
    47  	// RoundTrip again.
    48  	req, _ = http.NewRequest("GET", ts.URL, nil)
    49  	if _, err := rt.RoundTrip(req); err != nil {
    50  		t.Fatal(err)
    51  	}
    52  
    53  	// Check that the underlying Client is unchanged.
    54  	if rt.Client != initialClient {
    55  		t.Fatalf("expected %v, got %v", initialClient, rt.Client)
    56  	}
    57  }
    58  
    59  func TestRoundTripper_RoundTrip(t *testing.T) {
    60  	var reqCount int32 = 0
    61  
    62  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    63  		reqNo := atomic.AddInt32(&reqCount, 1)
    64  		if reqNo < 3 {
    65  			w.WriteHeader(404)
    66  		} else {
    67  			w.WriteHeader(200)
    68  			w.Write([]byte("success!"))
    69  		}
    70  	}))
    71  	defer ts.Close()
    72  
    73  	// Make a client with some custom settings to verify they are used.
    74  	retryClient := NewClient()
    75  	retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) {
    76  		return resp.StatusCode == 404, nil
    77  	}
    78  
    79  	// Get the standard client and execute the request.
    80  	client := retryClient.StandardClient()
    81  	resp, err := client.Get(ts.URL)
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	defer resp.Body.Close()
    86  
    87  	// Check the response to ensure the client behaved as expected.
    88  	if resp.StatusCode != 200 {
    89  		t.Fatalf("expected 200, got %d", resp.StatusCode)
    90  	}
    91  	if v, err := ioutil.ReadAll(resp.Body); err != nil {
    92  		t.Fatal(err)
    93  	} else if string(v) != "success!" {
    94  		t.Fatalf("expected %q, got %q", "success!", v)
    95  	}
    96  }
    97  
    98  func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
    99  	// Make a client with some custom settings to verify they are used.
   100  	retryClient := NewClient()
   101  	retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
   102  		if err != nil {
   103  			return true, err
   104  		}
   105  
   106  		return false, nil
   107  	}
   108  
   109  	retryClient.ErrorHandler = PassthroughErrorHandler
   110  
   111  	expectedError := &url.Error{
   112  		Op:  "Get",
   113  		URL: "http://999.999.999.999:999/",
   114  		Err: &net.OpError{
   115  			Op:  "dial",
   116  			Net: "tcp",
   117  			Err: &net.DNSError{
   118  				Name:       "999.999.999.999",
   119  				Err:        "no such host",
   120  				IsNotFound: true,
   121  			},
   122  		},
   123  	}
   124  
   125  	// Get the standard client and execute the request.
   126  	client := retryClient.StandardClient()
   127  	_, err := client.Get("http://999.999.999.999:999/")
   128  
   129  	// assert expectations
   130  	if !reflect.DeepEqual(expectedError, normalizeError(err)) {
   131  		t.Fatalf("expected %q, got %q", expectedError, err)
   132  	}
   133  }
   134  
   135  func normalizeError(err error) error {
   136  	var dnsError *net.DNSError
   137  
   138  	if errors.As(err, &dnsError) {
   139  		// this field is populated with the DNS server on on CI, but not locally
   140  		dnsError.Server = ""
   141  	}
   142  
   143  	return err
   144  }
   145  

View as plain text