...

Source file src/github.com/google/go-containerregistry/pkg/v1/remote/transport/retry_test.go

Documentation: github.com/google/go-containerregistry/pkg/v1/remote/transport

     1  // Copyright 2018 Google LLC All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/google/go-containerregistry/internal/retry"
    28  )
    29  
    30  type mockTransport struct {
    31  	errs  []error
    32  	resps []*http.Response
    33  	count int
    34  }
    35  
    36  func (t *mockTransport) RoundTrip(_ *http.Request) (out *http.Response, err error) {
    37  	defer func() { t.count++ }()
    38  	if t.count < len(t.resps) {
    39  		out = t.resps[t.count]
    40  	}
    41  	if t.count < len(t.errs) {
    42  		err = t.errs[t.count]
    43  	}
    44  	return
    45  }
    46  
    47  type perm struct{}
    48  
    49  func (e perm) Error() string {
    50  	return "permanent error"
    51  }
    52  
    53  type temp struct{}
    54  
    55  func (e temp) Error() string {
    56  	return "temporary error"
    57  }
    58  
    59  func (e temp) Temporary() bool {
    60  	return true
    61  }
    62  
    63  func resp(code int) *http.Response {
    64  	return &http.Response{
    65  		StatusCode: code,
    66  		Body:       io.NopCloser(strings.NewReader("hi")),
    67  	}
    68  }
    69  
    70  func TestRetryTransport(t *testing.T) {
    71  	for _, test := range []struct {
    72  		errs  []error
    73  		resps []*http.Response
    74  		ctx   context.Context
    75  		count int
    76  	}{{
    77  		// Don't retry retry.Never.
    78  		errs:  []error{temp{}},
    79  		ctx:   retry.Never(context.Background()),
    80  		count: 1,
    81  	}, {
    82  		// Don't retry permanent.
    83  		errs:  []error{perm{}},
    84  		count: 1,
    85  	}, {
    86  		// Do retry temp.
    87  		errs:  []error{temp{}, perm{}},
    88  		count: 2,
    89  	}, {
    90  		// Stop at some max.
    91  		errs:  []error{temp{}, temp{}, temp{}, temp{}, temp{}},
    92  		count: 3,
    93  	}, {
    94  		// Retry http errors.
    95  		errs: []error{nil, nil, temp{}, temp{}, temp{}},
    96  		resps: []*http.Response{
    97  			resp(http.StatusRequestTimeout),
    98  			resp(http.StatusInternalServerError),
    99  			nil,
   100  		},
   101  		count: 3,
   102  	}} {
   103  		mt := mockTransport{
   104  			errs:  test.errs,
   105  			resps: test.resps,
   106  		}
   107  
   108  		tr := NewRetry(&mt,
   109  			WithRetryBackoff(retry.Backoff{Steps: 3}),
   110  			WithRetryPredicate(retry.IsTemporary),
   111  			WithRetryStatusCodes(http.StatusRequestTimeout, http.StatusInternalServerError),
   112  		)
   113  
   114  		ctx := context.Background()
   115  		if test.ctx != nil {
   116  			ctx = test.ctx
   117  		}
   118  		req, err := http.NewRequestWithContext(ctx, "GET", "example.com", nil)
   119  		if err != nil {
   120  			t.Fatal(err)
   121  		}
   122  		tr.RoundTrip(req)
   123  		if mt.count != test.count {
   124  			t.Errorf("wrong count, wanted %d, got %d", test.count, mt.count)
   125  		}
   126  	}
   127  }
   128  
   129  func TestRetryDefaults(t *testing.T) {
   130  	tr := NewRetry(http.DefaultTransport)
   131  	rt, ok := tr.(*retryTransport)
   132  	if !ok {
   133  		t.Fatal("could not cast to retryTransport")
   134  	}
   135  
   136  	if rt.backoff != defaultBackoff {
   137  		t.Fatalf("default backoff wrong: %v", rt.backoff)
   138  	}
   139  
   140  	if rt.predicate == nil {
   141  		t.Fatal("default predicate not set")
   142  	}
   143  }
   144  
   145  func TestTimeoutContext(t *testing.T) {
   146  	tr := NewRetry(http.DefaultTransport)
   147  
   148  	slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   149  		// hanging request
   150  		time.Sleep(time.Second * 1)
   151  	}))
   152  	defer func() { go func() { slowServer.Close() }() }()
   153  
   154  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*20))
   155  	defer cancel()
   156  	req, err := http.NewRequest("GET", slowServer.URL, nil)
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	req = req.WithContext(ctx)
   161  
   162  	result := make(chan error)
   163  
   164  	go func() {
   165  		_, err := tr.RoundTrip(req)
   166  		result <- err
   167  	}()
   168  
   169  	select {
   170  	case err := <-result:
   171  		if !errors.Is(err, context.DeadlineExceeded) {
   172  			t.Fatalf("got: %v, want: %v", err, context.DeadlineExceeded)
   173  		}
   174  	case <-time.After(time.Millisecond * 100):
   175  		t.Fatalf("deadline was not recognized by transport")
   176  	}
   177  }
   178  

View as plain text