1
16
17 package httpstream
18
19 import (
20 "errors"
21 "fmt"
22 "net/http"
23 "reflect"
24 "testing"
25 )
26
27 type responseWriter struct {
28 header http.Header
29 statusCode *int
30 }
31
32 func newResponseWriter() *responseWriter {
33 return &responseWriter{
34 header: make(http.Header),
35 }
36 }
37
38 func (r *responseWriter) Header() http.Header {
39 return r.header
40 }
41
42 func (r *responseWriter) WriteHeader(code int) {
43 r.statusCode = &code
44 }
45
46 func (r *responseWriter) Write([]byte) (int, error) {
47 return 0, nil
48 }
49
50 func TestHandshake(t *testing.T) {
51 tests := map[string]struct {
52 clientProtocols []string
53 serverProtocols []string
54 expectedProtocol string
55 expectError bool
56 }{
57 "no common protocol": {
58 clientProtocols: []string{"c"},
59 serverProtocols: []string{"a", "b"},
60 expectedProtocol: "",
61 expectError: true,
62 },
63 "no common protocol with comma separated list": {
64 clientProtocols: []string{"c, d"},
65 serverProtocols: []string{"a", "b"},
66 expectedProtocol: "",
67 expectError: true,
68 },
69 "common protocol": {
70 clientProtocols: []string{"b"},
71 serverProtocols: []string{"a", "b"},
72 expectedProtocol: "b",
73 },
74 "common protocol with comma separated list": {
75 clientProtocols: []string{"b, c"},
76 serverProtocols: []string{"a", "b"},
77 expectedProtocol: "b",
78 },
79 }
80
81 for name, test := range tests {
82 req, err := http.NewRequest("GET", "http://www.example.com/", nil)
83 if err != nil {
84 t.Fatalf("%s: error creating request: %v", name, err)
85 }
86
87 for _, p := range test.clientProtocols {
88 req.Header.Add(HeaderProtocolVersion, p)
89 }
90
91 w := newResponseWriter()
92 negotiated, err := Handshake(req, w, test.serverProtocols)
93
94
95 if e, a := test.expectedProtocol, negotiated; e != a {
96 t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
97 }
98
99 if test.expectError {
100 if err == nil {
101 t.Errorf("%s: expected error but did not get one", name)
102 }
103 if w.statusCode == nil {
104 t.Errorf("%s: expected w.statusCode to be set", name)
105 } else if e, a := http.StatusForbidden, *w.statusCode; e != a {
106 t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
107 }
108 if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
109 t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
110 }
111 continue
112 }
113 if !test.expectError && err != nil {
114 t.Errorf("%s: unexpected error: %v", name, err)
115 continue
116 }
117 if w.statusCode != nil {
118 t.Errorf("%s: unexpected non-nil w.statusCode: %d", name, w.statusCode)
119 }
120
121 if len(test.expectedProtocol) == 0 {
122 if len(w.Header()[HeaderProtocolVersion]) > 0 {
123 t.Errorf("%s: unexpected protocol version response header: %s", name, w.Header()[HeaderProtocolVersion])
124 }
125 continue
126 }
127
128
129 if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
130 t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
131 }
132 }
133 }
134
135 func TestIsUpgradeFailureError(t *testing.T) {
136 testCases := map[string]struct {
137 err error
138 expected bool
139 }{
140 "nil error should return false": {
141 err: nil,
142 expected: false,
143 },
144 "Non-upgrade error should return false": {
145 err: fmt.Errorf("this is not an upgrade error"),
146 expected: false,
147 },
148 "UpgradeFailure error should return true": {
149 err: &UpgradeFailureError{},
150 expected: true,
151 },
152 "Wrapped Non-UpgradeFailure error should return false": {
153 err: fmt.Errorf("%s: %w", "first error", errors.New("Non-upgrade error")),
154 expected: false,
155 },
156 "Wrapped UpgradeFailure error should return true": {
157 err: fmt.Errorf("%s: %w", "first error", &UpgradeFailureError{}),
158 expected: true,
159 },
160 }
161
162 for name, test := range testCases {
163 t.Run(name, func(t *testing.T) {
164 actual := IsUpgradeFailure(test.err)
165 if test.expected != actual {
166 t.Errorf("expected upgrade failure %t, got %t", test.expected, actual)
167 }
168 })
169 }
170 }
171
View as plain text