1 package websocket
2
3 import (
4 "net/http"
5 "net/http/httptest"
6 "net/url"
7 "os"
8 "strings"
9 "testing"
10 "time"
11
12 "github.com/gin-gonic/gin"
13 "github.com/gorilla/websocket"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 )
17
18 func init() {
19 gin.SetMode(gin.TestMode)
20 }
21
22 func TestMain(m *testing.M) {
23 os.Exit(m.Run())
24 }
25
26 func TestNewSubscriber(t *testing.T) {
27 testCases := map[string]struct {
28 filters []string
29 }{
30 "EmptyFilters": {
31 filters: []string{},
32 },
33 "OneFilter": {
34 filters: []string{"host"},
35 },
36 "MultipleFilters": {
37 filters: []string{"host", "cluster", "other-topic"},
38 },
39 }
40
41 for name, tc := range testCases {
42 t.Run(name, func(t *testing.T) {
43 vals := url.Values(map[string][]string{
44 "topic": tc.filters,
45 })
46
47 sub := newSubscriber(vals)
48
49 assert.Equal(t, tc.filters, sub.filters)
50 })
51 }
52 }
53
54 func TestIsSubscribedTo(t *testing.T) {
55 testCases := map[string]struct {
56 filters []string
57 target string
58 expectSubscribed bool
59 }{
60 "NilFilters": {
61 filters: nil,
62 target: "host",
63 expectSubscribed: true,
64 },
65 "EmptyFilters": {
66 filters: []string{},
67 target: "host",
68 expectSubscribed: true,
69 },
70 "IsSubscribed": {
71 filters: []string{"host", "other-topic"},
72 target: "host",
73 expectSubscribed: true,
74 },
75 "IsNotSubscribed": {
76 filters: []string{"cluster", "other-topic"},
77 target: "host",
78 expectSubscribed: false,
79 },
80 }
81
82 for name, tc := range testCases {
83 t.Run(name, func(t *testing.T) {
84 sub := subscriber{
85 filters: tc.filters,
86 }
87
88 isSubscribed := sub.isSubscribedTo(tc.target)
89
90 assert.Equal(t, tc.expectSubscribed, isSubscribed)
91 })
92 }
93 }
94
95 func TestListen(t *testing.T) {
96 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
97 require.NoError(t, socketHandler(w, r))
98 }))
99 defer s.Close()
100
101 wsURL := strings.Replace(s.URL, "http", "ws", 1)
102
103 c, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
104 if err != nil {
105 t.Error(err)
106 }
107 defer c.Close()
108
109 vals := url.Values{}
110 sub := newSubscriber(vals)
111
112 w := httptest.NewRecorder()
113 ctx, _ := gin.CreateTestContext(w)
114 ctx.Request = &http.Request{
115 Header: make(http.Header),
116 }
117
118
119 go sub.listen(ctx, c)
120
121
122 input := &Event{
123 Topic: "test-topic",
124 Data: "test-data",
125 }
126 sub.channel <- *input
127
128
129 err = c.SetReadDeadline(time.Now().Add(time.Second * 1))
130 require.NoError(t, err)
131
132
133
134 output := &Event{}
135 err = c.ReadJSON(output)
136 require.NoError(t, err)
137
138 assert.Equal(t, input, output)
139
140
141
142 err = c.Close()
143 require.NoError(t, err)
144
145 sub.channel <- *input
146
147
148
149 var closed bool
150 select {
151 case <-sub.channel:
152 closed = true
153 case <-time.After(1 * time.Second):
154 closed = false
155 }
156 assert.True(t, closed)
157 }
158
159 func socketHandler(w http.ResponseWriter, r *http.Request) error {
160
161 upgrader := websocket.Upgrader{}
162 conn, err := upgrader.Upgrade(w, r, nil)
163 if err != nil {
164 return err
165 }
166 defer conn.Close()
167
168 event := &Event{}
169
170 if err := conn.ReadJSON(event); err != nil {
171 return err
172 }
173
174 if err = conn.WriteJSON(event); err != nil {
175 return err
176 }
177 return nil
178 }
179
View as plain text