...

Source file src/k8s.io/apimachinery/pkg/util/httpstream/httpstream_test.go

Documentation: k8s.io/apimachinery/pkg/util/httpstream

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package httpstream
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"net/http"
    23  	"reflect"
    24  	"testing"
    25  )
    26  
    27  type responseWriter struct {
    28  	header     http.Header
    29  	statusCode *int
    30  }
    31  
    32  func newResponseWriter() *responseWriter {
    33  	return &responseWriter{
    34  		header: make(http.Header),
    35  	}
    36  }
    37  
    38  func (r *responseWriter) Header() http.Header {
    39  	return r.header
    40  }
    41  
    42  func (r *responseWriter) WriteHeader(code int) {
    43  	r.statusCode = &code
    44  }
    45  
    46  func (r *responseWriter) Write([]byte) (int, error) {
    47  	return 0, nil
    48  }
    49  
    50  func TestHandshake(t *testing.T) {
    51  	tests := map[string]struct {
    52  		clientProtocols  []string
    53  		serverProtocols  []string
    54  		expectedProtocol string
    55  		expectError      bool
    56  	}{
    57  		"no common protocol": {
    58  			clientProtocols:  []string{"c"},
    59  			serverProtocols:  []string{"a", "b"},
    60  			expectedProtocol: "",
    61  			expectError:      true,
    62  		},
    63  		"no common protocol with comma separated list": {
    64  			clientProtocols:  []string{"c, d"},
    65  			serverProtocols:  []string{"a", "b"},
    66  			expectedProtocol: "",
    67  			expectError:      true,
    68  		},
    69  		"common protocol": {
    70  			clientProtocols:  []string{"b"},
    71  			serverProtocols:  []string{"a", "b"},
    72  			expectedProtocol: "b",
    73  		},
    74  		"common protocol with comma separated list": {
    75  			clientProtocols:  []string{"b, c"},
    76  			serverProtocols:  []string{"a", "b"},
    77  			expectedProtocol: "b",
    78  		},
    79  	}
    80  
    81  	for name, test := range tests {
    82  		req, err := http.NewRequest("GET", "http://www.example.com/", nil)
    83  		if err != nil {
    84  			t.Fatalf("%s: error creating request: %v", name, err)
    85  		}
    86  
    87  		for _, p := range test.clientProtocols {
    88  			req.Header.Add(HeaderProtocolVersion, p)
    89  		}
    90  
    91  		w := newResponseWriter()
    92  		negotiated, err := Handshake(req, w, test.serverProtocols)
    93  
    94  		// verify negotiated protocol
    95  		if e, a := test.expectedProtocol, negotiated; e != a {
    96  			t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
    97  		}
    98  
    99  		if test.expectError {
   100  			if err == nil {
   101  				t.Errorf("%s: expected error but did not get one", name)
   102  			}
   103  			if w.statusCode == nil {
   104  				t.Errorf("%s: expected w.statusCode to be set", name)
   105  			} else if e, a := http.StatusForbidden, *w.statusCode; e != a {
   106  				t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
   107  			}
   108  			if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
   109  				t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
   110  			}
   111  			continue
   112  		}
   113  		if !test.expectError && err != nil {
   114  			t.Errorf("%s: unexpected error: %v", name, err)
   115  			continue
   116  		}
   117  		if w.statusCode != nil {
   118  			t.Errorf("%s: unexpected non-nil w.statusCode: %d", name, w.statusCode)
   119  		}
   120  
   121  		if len(test.expectedProtocol) == 0 {
   122  			if len(w.Header()[HeaderProtocolVersion]) > 0 {
   123  				t.Errorf("%s: unexpected protocol version response header: %s", name, w.Header()[HeaderProtocolVersion])
   124  			}
   125  			continue
   126  		}
   127  
   128  		// verify response headers
   129  		if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
   130  			t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
   131  		}
   132  	}
   133  }
   134  
   135  func TestIsUpgradeFailureError(t *testing.T) {
   136  	testCases := map[string]struct {
   137  		err      error
   138  		expected bool
   139  	}{
   140  		"nil error should return false": {
   141  			err:      nil,
   142  			expected: false,
   143  		},
   144  		"Non-upgrade error should return false": {
   145  			err:      fmt.Errorf("this is not an upgrade error"),
   146  			expected: false,
   147  		},
   148  		"UpgradeFailure error should return true": {
   149  			err:      &UpgradeFailureError{},
   150  			expected: true,
   151  		},
   152  		"Wrapped Non-UpgradeFailure error should return false": {
   153  			err:      fmt.Errorf("%s: %w", "first error", errors.New("Non-upgrade error")),
   154  			expected: false,
   155  		},
   156  		"Wrapped UpgradeFailure error should return true": {
   157  			err:      fmt.Errorf("%s: %w", "first error", &UpgradeFailureError{}),
   158  			expected: true,
   159  		},
   160  	}
   161  
   162  	for name, test := range testCases {
   163  		t.Run(name, func(t *testing.T) {
   164  			actual := IsUpgradeFailure(test.err)
   165  			if test.expected != actual {
   166  				t.Errorf("expected upgrade failure %t, got %t", test.expected, actual)
   167  			}
   168  		})
   169  	}
   170  }
   171  

View as plain text