package websocket import ( "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func init() { gin.SetMode(gin.TestMode) } func TestMain(m *testing.M) { os.Exit(m.Run()) } func TestNewSubscriber(t *testing.T) { testCases := map[string]struct { filters []string }{ "EmptyFilters": { filters: []string{}, }, "OneFilter": { filters: []string{"host"}, }, "MultipleFilters": { filters: []string{"host", "cluster", "other-topic"}, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { vals := url.Values(map[string][]string{ "topic": tc.filters, }) sub := newSubscriber(vals) assert.Equal(t, tc.filters, sub.filters) }) } } func TestIsSubscribedTo(t *testing.T) { testCases := map[string]struct { filters []string target string expectSubscribed bool }{ "NilFilters": { filters: nil, target: "host", expectSubscribed: true, }, "EmptyFilters": { filters: []string{}, target: "host", expectSubscribed: true, }, "IsSubscribed": { filters: []string{"host", "other-topic"}, target: "host", expectSubscribed: true, }, "IsNotSubscribed": { filters: []string{"cluster", "other-topic"}, target: "host", expectSubscribed: false, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { sub := subscriber{ filters: tc.filters, } isSubscribed := sub.isSubscribedTo(tc.target) assert.Equal(t, tc.expectSubscribed, isSubscribed) }) } } func TestListen(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.NoError(t, socketHandler(w, r)) })) defer s.Close() wsURL := strings.Replace(s.URL, "http", "ws", 1) c, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Error(err) } defer c.Close() vals := url.Values{} sub := newSubscriber(vals) w := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(w) ctx.Request = &http.Request{ Header: make(http.Header), } // set subscriber to listen for incoming events go sub.listen(ctx, c) // send event to subscriber channel input := &Event{ Topic: "test-topic", Data: "test-data", } sub.channel <- *input // expect message back from socketHandler within 1 second err = c.SetReadDeadline(time.Now().Add(time.Second * 1)) require.NoError(t, err) // confirm that the socketHandler sent back the same message that we sent to // the subscriber output := &Event{} err = c.ReadJSON(output) require.NoError(t, err) assert.Equal(t, input, output) // confirm that the subscriber channel is closed (and therefore listen // method exited) when the websocket connection is closed err = c.Close() require.NoError(t, err) sub.channel <- *input // confirm that the subscriber channel is closed, implying that the event // loop is broken var closed bool select { case <-sub.channel: closed = true case <-time.After(1 * time.Second): closed = false } assert.True(t, closed) } func socketHandler(w http.ResponseWriter, r *http.Request) error { // Upgrade our raw HTTP connection to a websocket based one upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return err } defer conn.Close() event := &Event{} // receive event from listen if err := conn.ReadJSON(event); err != nil { return err } // send event back over connection to assert correct event was received if err = conn.WriteJSON(event); err != nil { return err } return nil }