...
1 package watcherx
2
3 import (
4 "context"
5 "fmt"
6 "net"
7 "net/http"
8 "net/url"
9 "strings"
10 "sync"
11
12 "github.com/gorilla/websocket"
13
14 "github.com/ory/herodot"
15 )
16
17 type (
18 eventChannelSlice struct {
19 sync.Mutex
20 cs []EventChannel
21 }
22 websocketWatcher struct {
23 wsWriteLock sync.Mutex
24 wsReadLock sync.Mutex
25 wsClientChannels eventChannelSlice
26 }
27 )
28
29 const (
30 messageSendNow = "send values now"
31 messageSendNowDone = "done sending %d values"
32 )
33
34 func WatchAndServeWS(ctx context.Context, u *url.URL, writer herodot.Writer) (http.HandlerFunc, error) {
35 c := make(EventChannel)
36 watcher, err := Watch(ctx, u, c)
37 if err != nil {
38 return nil, err
39 }
40 w := &websocketWatcher{
41 wsClientChannels: eventChannelSlice{},
42 }
43 go w.broadcaster(ctx, c)
44 return w.serveWS(ctx, writer, watcher), nil
45 }
46
47 func (ww *websocketWatcher) broadcaster(ctx context.Context, c EventChannel) {
48 for {
49 select {
50 case <-ctx.Done():
51 return
52 case e := <-c:
53 ww.wsClientChannels.Lock()
54 for _, cc := range ww.wsClientChannels.cs {
55 cc <- e
56 }
57 ww.wsClientChannels.Unlock()
58 }
59 }
60 }
61
62 func (ww *websocketWatcher) readWebsocket(ws *websocket.Conn, c chan<- struct{}, watcher Watcher) {
63 for {
64
65 ww.wsReadLock.Lock()
66 _, msg, err := ws.ReadMessage()
67 ww.wsReadLock.Unlock()
68
69 switch errTyped := err.(type) {
70 case nil:
71 if string(msg) == messageSendNow {
72 done, err := watcher.DispatchNow()
73 if err != nil {
74
75 ww.wsWriteLock.Lock()
76 _ = ws.WriteJSON(&ErrorEvent{
77 error: err,
78 source: "",
79 })
80 ww.wsWriteLock.Unlock()
81 }
82
83 go func() {
84 eventsSend := <-done
85
86 ww.wsWriteLock.Lock()
87 defer ww.wsWriteLock.Unlock()
88
89
90 _ = ws.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf(messageSendNowDone, eventsSend)))
91 }()
92 }
93 case *websocket.CloseError:
94 if errTyped.Code == websocket.CloseNormalClosure {
95 close(c)
96 return
97 }
98 case *net.OpError:
99 if errTyped.Op == "read" && strings.Contains(errTyped.Err.Error(), "closed") {
100
101 close(c)
102 return
103 }
104 default:
105
106 return
107 }
108 }
109 }
110
111 func (ww *websocketWatcher) serveWS(ctx context.Context, writer herodot.Writer, watcher Watcher) func(w http.ResponseWriter, r *http.Request) {
112 return func(w http.ResponseWriter, r *http.Request) {
113 ws, err := (&websocket.Upgrader{
114 ReadBufferSize: 256,
115 WriteBufferSize: 1024,
116 }).Upgrade(w, r, nil)
117 if err != nil {
118 writer.WriteError(w, r, err)
119 return
120 }
121
122
123 c := make(EventChannel)
124 ww.wsClientChannels.Lock()
125 ww.wsClientChannels.cs = append(ww.wsClientChannels.cs, c)
126 ww.wsClientChannels.Unlock()
127
128 wsClosed := make(chan struct{})
129 go ww.readWebsocket(ws, wsClosed, watcher)
130
131 defer func() {
132
133
134 ww.wsWriteLock.Lock()
135 _ = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "server context canceled"))
136 ww.wsWriteLock.Unlock()
137
138 _ = ws.Close()
139
140 ww.wsClientChannels.Lock()
141 for i, cc := range ww.wsClientChannels.cs {
142 if c == cc {
143 ww.wsClientChannels.cs[i] = ww.wsClientChannels.cs[len(ww.wsClientChannels.cs)-1]
144 ww.wsClientChannels.cs[len(ww.wsClientChannels.cs)-1] = nil
145 ww.wsClientChannels.cs = ww.wsClientChannels.cs[:len(ww.wsClientChannels.cs)-1]
146 }
147 }
148 ww.wsClientChannels.Unlock()
149 close(c)
150 }()
151
152 for {
153 select {
154 case <-ctx.Done():
155 return
156 case <-wsClosed:
157 return
158 case e, ok := <-c:
159 if !ok {
160 return
161 }
162
163 ww.wsWriteLock.Lock()
164 err := ws.WriteJSON(e)
165 ww.wsWriteLock.Unlock()
166
167 if err != nil {
168 return
169 }
170 }
171 }
172 }
173 }
174
View as plain text