...

Source file src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn_test.go

Documentation: k8s.io/apimachinery/pkg/util/httpstream/wsstream

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wsstream
    18  
    19  import (
    20  	"encoding/base64"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"reflect"
    25  	"sync"
    26  	"testing"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  	"golang.org/x/net/websocket"
    31  )
    32  
    33  func newServer(handler http.Handler) (*httptest.Server, string) {
    34  	server := httptest.NewServer(handler)
    35  	serverAddr := server.Listener.Addr().String()
    36  	return server, serverAddr
    37  }
    38  
    39  func TestRawConn(t *testing.T) {
    40  	channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
    41  	conn := NewConn(NewDefaultChannelProtocols(channels))
    42  
    43  	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    44  		conn.Open(w, req)
    45  	}))
    46  	defer s.Close()
    47  
    48  	client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
    49  	if err != nil {
    50  		t.Fatal(err)
    51  	}
    52  	defer client.Close()
    53  
    54  	<-conn.ready
    55  	wg := sync.WaitGroup{}
    56  
    57  	// verify we can read a client write
    58  	wg.Add(1)
    59  	go func() {
    60  		defer wg.Done()
    61  		data, err := io.ReadAll(conn.channels[0])
    62  		if err != nil {
    63  			t.Error(err)
    64  			return
    65  		}
    66  		if !reflect.DeepEqual(data, []byte("client")) {
    67  			t.Errorf("unexpected server read: %v", data)
    68  		}
    69  	}()
    70  
    71  	if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
    72  		t.Fatalf("%d: %v", n, err)
    73  	}
    74  
    75  	// verify we can read a server write
    76  	wg.Add(1)
    77  	go func() {
    78  		defer wg.Done()
    79  		if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
    80  			t.Errorf("%d: %v", n, err)
    81  		}
    82  	}()
    83  
    84  	data := make([]byte, 1024)
    85  	if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
    86  		t.Fatalf("%d: %v", n, err)
    87  	}
    88  	if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
    89  		t.Errorf("unexpected client read: %v", data[:7])
    90  	}
    91  
    92  	// verify that an ignore channel is empty in both directions.
    93  	if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
    94  		t.Errorf("writes should be ignored")
    95  	}
    96  	data = make([]byte, 1024)
    97  	if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
    98  		t.Errorf("reads should be ignored")
    99  	}
   100  
   101  	// verify that a write to a Read channel doesn't block
   102  	if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
   103  		t.Errorf("writes should be ignored")
   104  	}
   105  
   106  	// verify that a read from a Write channel doesn't block
   107  	data = make([]byte, 1024)
   108  	if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
   109  		t.Errorf("reads should be ignored")
   110  	}
   111  
   112  	// verify that a client write to a Write channel doesn't block (is dropped)
   113  	if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
   114  		t.Fatalf("%d: %v", n, err)
   115  	}
   116  
   117  	client.Close()
   118  	wg.Wait()
   119  }
   120  
   121  func TestBase64Conn(t *testing.T) {
   122  	conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
   123  	s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   124  		conn.Open(w, req)
   125  	}))
   126  	defer s.Close()
   127  
   128  	config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	config.Protocol = []string{"base64.channel.k8s.io"}
   133  	client, err := websocket.DialConfig(config)
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	defer client.Close()
   138  
   139  	<-conn.ready
   140  	wg := sync.WaitGroup{}
   141  	wg.Add(1)
   142  	go func() {
   143  		defer wg.Done()
   144  		data, err := io.ReadAll(conn.channels[0])
   145  		if err != nil {
   146  			t.Error(err)
   147  			return
   148  		}
   149  		if !reflect.DeepEqual(data, []byte("client")) {
   150  			t.Errorf("unexpected server read: %s", string(data))
   151  		}
   152  	}()
   153  
   154  	clientData := base64.StdEncoding.EncodeToString([]byte("client"))
   155  	if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
   156  		t.Fatalf("%d: %v", n, err)
   157  	}
   158  
   159  	wg.Add(1)
   160  	go func() {
   161  		defer wg.Done()
   162  		if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
   163  			t.Errorf("%d: %v", n, err)
   164  		}
   165  	}()
   166  
   167  	data := make([]byte, 1024)
   168  	if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
   169  		t.Fatalf("%d: %v", n, err)
   170  	}
   171  	expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
   172  
   173  	if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
   174  		t.Errorf("unexpected client read: %v", data[:9])
   175  	}
   176  
   177  	client.Close()
   178  	wg.Wait()
   179  }
   180  
   181  type versionTest struct {
   182  	supported map[string]bool // protocol -> binary
   183  	requested []string
   184  	error     bool
   185  	expected  string
   186  }
   187  
   188  func versionTests() []versionTest {
   189  	const (
   190  		binary = true
   191  		base64 = false
   192  	)
   193  	return []versionTest{
   194  		{
   195  			supported: nil,
   196  			requested: []string{"raw"},
   197  			error:     true,
   198  		},
   199  		{
   200  			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
   201  			requested: nil,
   202  			expected:  "",
   203  		},
   204  		{
   205  			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
   206  			requested: []string{"v1.raw"},
   207  			error:     true,
   208  		},
   209  		{
   210  			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
   211  			requested: []string{"v1.raw", "v1.base64"},
   212  			error:     true,
   213  		}, {
   214  			supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
   215  			requested: []string{"v1.raw", "raw"},
   216  			expected:  "raw",
   217  		},
   218  		{
   219  			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
   220  			requested: []string{"v1.raw"},
   221  			expected:  "v1.raw",
   222  		},
   223  		{
   224  			supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
   225  			requested: []string{"v2.base64"},
   226  			expected:  "v2.base64",
   227  		},
   228  	}
   229  }
   230  
   231  func TestVersionedConn(t *testing.T) {
   232  	for i, test := range versionTests() {
   233  		func() {
   234  			supportedProtocols := map[string]ChannelProtocolConfig{}
   235  			for p, binary := range test.supported {
   236  				supportedProtocols[p] = ChannelProtocolConfig{
   237  					Binary:   binary,
   238  					Channels: []ChannelType{ReadWriteChannel},
   239  				}
   240  			}
   241  			conn := NewConn(supportedProtocols)
   242  			// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
   243  			// we use a channel.
   244  			selectedProtocol := make(chan string)
   245  			s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   246  				p, _, _ := conn.Open(w, req)
   247  				selectedProtocol <- p
   248  			}))
   249  			defer s.Close()
   250  
   251  			config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
   252  			if err != nil {
   253  				t.Fatal(err)
   254  			}
   255  			config.Protocol = test.requested
   256  			client, err := websocket.DialConfig(config)
   257  			if err != nil {
   258  				if !test.error {
   259  					t.Fatalf("test %d: didn't expect error: %v", i, err)
   260  				} else {
   261  					return
   262  				}
   263  			}
   264  			defer client.Close()
   265  			if test.error && err == nil {
   266  				t.Fatalf("test %d: expected an error", i)
   267  			}
   268  
   269  			<-conn.ready
   270  			if got, expected := <-selectedProtocol, test.expected; got != expected {
   271  				t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
   272  			}
   273  		}()
   274  	}
   275  }
   276  
   277  func TestIsWebSocketRequestWithStreamCloseProtocol(t *testing.T) {
   278  	tests := map[string]struct {
   279  		headers  map[string]string
   280  		expected bool
   281  	}{
   282  		"No headers returns false": {
   283  			headers:  map[string]string{},
   284  			expected: false,
   285  		},
   286  		"Only connection upgrade header is false": {
   287  			headers: map[string]string{
   288  				"Connection": "upgrade",
   289  			},
   290  			expected: false,
   291  		},
   292  		"Only websocket upgrade header is false": {
   293  			headers: map[string]string{
   294  				"Upgrade": "websocket",
   295  			},
   296  			expected: false,
   297  		},
   298  		"Only websocket and connection upgrade headers is false": {
   299  			headers: map[string]string{
   300  				"Connection": "upgrade",
   301  				"Upgrade":    "websocket",
   302  			},
   303  			expected: false,
   304  		},
   305  		"Missing connection/upgrade header is false": {
   306  			headers: map[string]string{
   307  				"Upgrade":               "websocket",
   308  				WebSocketProtocolHeader: "v5.channel.k8s.io",
   309  			},
   310  			expected: false,
   311  		},
   312  		"Websocket connection upgrade headers with v5 protocol is true": {
   313  			headers: map[string]string{
   314  				"Connection":            "upgrade",
   315  				"Upgrade":               "websocket",
   316  				WebSocketProtocolHeader: "v5.channel.k8s.io",
   317  			},
   318  			expected: true,
   319  		},
   320  		"Websocket connection upgrade headers with wrong case v5 protocol is false": {
   321  			headers: map[string]string{
   322  				"Connection":            "upgrade",
   323  				"Upgrade":               "websocket",
   324  				WebSocketProtocolHeader: "v5.CHANNEL.k8s.io", // header value is case-sensitive
   325  			},
   326  			expected: false,
   327  		},
   328  		"Websocket connection upgrade headers with v4 protocol is false": {
   329  			headers: map[string]string{
   330  				"Connection":            "upgrade",
   331  				"Upgrade":               "websocket",
   332  				WebSocketProtocolHeader: "v4.channel.k8s.io",
   333  			},
   334  			expected: false,
   335  		},
   336  		"Websocket connection upgrade headers with multiple protocols but missing v5 is false": {
   337  			headers: map[string]string{
   338  				"Connection":            "upgrade",
   339  				"Upgrade":               "websocket",
   340  				WebSocketProtocolHeader: "v4.channel.k8s.io,v3.channel.k8s.io,v2.channel.k8s.io",
   341  			},
   342  			expected: false,
   343  		},
   344  		"Websocket connection upgrade headers with multiple protocols including v5 and spaces is true": {
   345  			headers: map[string]string{
   346  				"Connection":            "upgrade",
   347  				"Upgrade":               "websocket",
   348  				WebSocketProtocolHeader: "v5.channel.k8s.io,  v4.channel.k8s.io",
   349  			},
   350  			expected: true,
   351  		},
   352  		"Websocket connection upgrade headers with multiple protocols out of order including v5 and spaces is true": {
   353  			headers: map[string]string{
   354  				"Connection":            "upgrade",
   355  				"Upgrade":               "websocket",
   356  				WebSocketProtocolHeader: "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
   357  			},
   358  			expected: true,
   359  		},
   360  
   361  		"Websocket connection upgrade headers key is case-insensitive": {
   362  			headers: map[string]string{
   363  				"Connection":             "upgrade",
   364  				"Upgrade":                "websocket",
   365  				"sec-websocket-protocol": "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
   366  			},
   367  			expected: true,
   368  		},
   369  	}
   370  
   371  	for name, test := range tests {
   372  		req, err := http.NewRequest("GET", "http://www.example.com/", nil)
   373  		require.NoError(t, err)
   374  		for key, value := range test.headers {
   375  			req.Header.Add(key, value)
   376  		}
   377  		actual := IsWebSocketRequestWithStreamCloseProtocol(req)
   378  		assert.Equal(t, test.expected, actual, "%s: expected (%t), got (%t)", name, test.expected, actual)
   379  	}
   380  }
   381  
   382  func TestProtocolSupportsStreamClose(t *testing.T) {
   383  	tests := map[string]struct {
   384  		protocol string
   385  		expected bool
   386  	}{
   387  		"empty protocol returns false": {
   388  			protocol: "",
   389  			expected: false,
   390  		},
   391  		"not binary protocol returns false": {
   392  			protocol: "base64.channel.k8s.io",
   393  			expected: false,
   394  		},
   395  		"V1 protocol returns false": {
   396  			protocol: "channel.k8s.io",
   397  			expected: false,
   398  		},
   399  		"V4 protocol returns false": {
   400  			protocol: "v4.channel.k8s.io",
   401  			expected: false,
   402  		},
   403  		"V5 protocol returns true": {
   404  			protocol: "v5.channel.k8s.io",
   405  			expected: true,
   406  		},
   407  		"V5 protocol wrong case returns false": {
   408  			protocol: "V5.channel.K8S.io",
   409  			expected: false,
   410  		},
   411  	}
   412  
   413  	for name, test := range tests {
   414  		actual := protocolSupportsStreamClose(test.protocol)
   415  		assert.Equal(t, test.expected, actual,
   416  			"%s: expected (%t), got (%t)", name, test.expected, actual)
   417  	}
   418  }
   419  

View as plain text