...
1
16
17 package httpstream
18
19 import (
20 "errors"
21 "fmt"
22 "io"
23 "net/http"
24 "strings"
25 "time"
26 )
27
28 const (
29 HeaderConnection = "Connection"
30 HeaderUpgrade = "Upgrade"
31 HeaderProtocolVersion = "X-Stream-Protocol-Version"
32 HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
33 )
34
35
36
37
38 type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error
39
40
41
42 func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil }
43
44
45 type Dialer interface {
46
47
48
49 Dial(protocols ...string) (Connection, string, error)
50 }
51
52
53
54
55
56 type UpgradeRoundTripper interface {
57 http.RoundTripper
58
59 NewConnection(resp *http.Response) (Connection, error)
60 }
61
62
63
64 type ResponseUpgrader interface {
65
66
67
68 UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
69 }
70
71
72 type Connection interface {
73
74 CreateStream(headers http.Header) (Stream, error)
75
76 Close() error
77
78 CloseChan() <-chan bool
79
80
81 SetIdleTimeout(timeout time.Duration)
82
83 RemoveStreams(streams ...Stream)
84 }
85
86
87
88 type Stream interface {
89 io.ReadWriteCloser
90
91
92 Reset() error
93
94 Headers() http.Header
95
96 Identifier() uint32
97 }
98
99
100
101 type UpgradeFailureError struct {
102 Cause error
103 }
104
105 func (u *UpgradeFailureError) Error() string {
106 return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause)
107 }
108
109
110
111 func IsUpgradeFailure(err error) bool {
112 if err == nil {
113 return false
114 }
115 var upgradeErr *UpgradeFailureError
116 return errors.As(err, &upgradeErr)
117 }
118
119
120 func IsUpgradeRequest(req *http.Request) bool {
121 for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] {
122 if strings.Contains(strings.ToLower(h), strings.ToLower(HeaderUpgrade)) {
123 return true
124 }
125 }
126 return false
127 }
128
129 func negotiateProtocol(clientProtocols, serverProtocols []string) string {
130 for i := range clientProtocols {
131 for j := range serverProtocols {
132 if clientProtocols[i] == serverProtocols[j] {
133 return clientProtocols[i]
134 }
135 }
136 }
137 return ""
138 }
139
140 func commaSeparatedHeaderValues(header []string) []string {
141 var parsedClientProtocols []string
142 for i := range header {
143 for _, clientProtocol := range strings.Split(header[i], ",") {
144 if proto := strings.Trim(clientProtocol, " "); len(proto) > 0 {
145 parsedClientProtocols = append(parsedClientProtocols, proto)
146 }
147 }
148 }
149 return parsedClientProtocols
150 }
151
152
153
154
155
156
157
158 func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
159 clientProtocols := commaSeparatedHeaderValues(req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)])
160 if len(clientProtocols) == 0 {
161 return "", fmt.Errorf("unable to upgrade: %s is required", HeaderProtocolVersion)
162 }
163
164 if len(serverProtocols) == 0 {
165 panic(fmt.Errorf("unable to upgrade: serverProtocols is required"))
166 }
167
168 negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
169 if len(negotiatedProtocol) == 0 {
170 for i := range serverProtocols {
171 w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
172 }
173 err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
174 http.Error(w, err.Error(), http.StatusForbidden)
175 return "", err
176 }
177
178 w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
179 return negotiatedProtocol, nil
180 }
181
View as plain text