1
16
17 package websocket
18
19 import (
20 "context"
21 "io"
22 "net/http"
23 "net/http/httptest"
24 "net/url"
25 "strings"
26 "testing"
27 "time"
28
29 "github.com/stretchr/testify/assert"
30 "github.com/stretchr/testify/require"
31
32 "k8s.io/apimachinery/pkg/util/httpstream"
33 "k8s.io/apimachinery/pkg/util/httpstream/wsstream"
34 "k8s.io/apimachinery/pkg/util/remotecommand"
35 restclient "k8s.io/client-go/rest"
36 )
37
38 func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
39
40 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
41 conns, err := webSocketServerStreams(req, w)
42 if err != nil {
43 t.Fatalf("error on webSocketServerStreams: %v", err)
44 }
45 defer conns.conn.Close()
46 }))
47 defer websocketServer.Close()
48
49
50 websocketLocation, err := url.Parse(websocketServer.URL)
51 require.NoError(t, err)
52 req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
53 require.NoError(t, err)
54 rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
55 require.NoError(t, err)
56 requestedProtocol := remotecommand.StreamProtocolV5Name
57 req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
58 _, err = rt.RoundTrip(req)
59 require.NoError(t, err)
60
61
62 actualProtocol := wsRt.Connection().Subprotocol()
63 assert.Equal(t, requestedProtocol, actualProtocol)
64
65 }
66
67 func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
68
69 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
70
71 _, err := webSocketServerStreams(req, w)
72 require.Error(t, err)
73 assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready"))
74 }))
75 defer websocketServer.Close()
76
77
78 websocketLocation, err := url.Parse(websocketServer.URL)
79 require.NoError(t, err)
80 req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
81 require.NoError(t, err)
82 rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
83 require.NoError(t, err)
84
85 requestedProtocol := remotecommand.StreamProtocolV1Name
86 req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
87 _, err = rt.RoundTrip(req)
88
89 require.Error(t, err)
90 assert.True(t, strings.Contains(err.Error(), "bad handshake"))
91 assert.True(t, httpstream.IsUpgradeFailure(err))
92 }
93
94 func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
95
96 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
97 conns, err := webSocketServerStreams(req, w)
98 if err != nil {
99 t.Fatalf("error on webSocketServerStreams: %v", err)
100 }
101 defer conns.conn.Close()
102 }))
103 defer websocketServer.Close()
104
105
106 websocketLocation, err := url.Parse(websocketServer.URL)
107 require.NoError(t, err)
108 req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
109 require.NoError(t, err)
110 rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
111 require.NoError(t, err)
112 requestedProtocol := remotecommand.StreamProtocolV5Name
113 conn, err := Negotiate(rt, wsRt, req, requestedProtocol)
114 require.NoError(t, err)
115
116 actualProtocol := conn.Subprotocol()
117 assert.Equal(t, requestedProtocol, actualProtocol)
118 }
119
120
121 type websocketStreams struct {
122 conn io.Closer
123 }
124
125 func webSocketServerStreams(req *http.Request, w http.ResponseWriter) (*websocketStreams, error) {
126 conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
127 remotecommand.StreamProtocolV5Name: {
128 Binary: true,
129 Channels: []wsstream.ChannelType{},
130 },
131 })
132 conn.SetIdleTimeout(4 * time.Hour)
133
134
135 _, _, err := conn.Open(w, req)
136 if err != nil {
137 return nil, err
138 }
139 return &websocketStreams{conn: conn}, nil
140 }
141
View as plain text