...
1 package websocket
2
3 import (
4 "net/http"
5 "sync"
6
7 "github.com/gin-gonic/gin"
8 "github.com/gorilla/websocket"
9
10 "edge-infra.dev/pkg/lib/fog"
11 "edge-infra.dev/pkg/sds/interlock/internal/errors"
12 "edge-infra.dev/pkg/sds/interlock/internal/observability"
13 )
14
15 var (
16 TopicQueryParam = "topic"
17 )
18
19
20 type Manager struct {
21 mutex sync.Mutex
22 subscribers []subscriber
23 }
24
25
26 func NewManager() *Manager {
27 observability.RecordSubscribers(0)
28 return &Manager{
29 subscribers: []subscriber{},
30 }
31 }
32
33
34 func (m *Manager) RegisterEndpoints(r *gin.Engine) {
35 v1 := r.Group("/v1/subscribe")
36 v1.GET("", m.subscribe)
37 }
38
39
40
41
42
43
44
45
46
47
48 type SubscriptionParameterWrapper struct {
49
50
51
52
53 Topics []string `json:"topic"`
54 }
55
56
57
58
59 type EventResponseWrapper struct {
60
61
62
63 Event Event `json:"event"`
64 }
65
66
67
68
69 func (m *Manager) subscribe(c *gin.Context) {
70 log := fog.FromContext(c.Request.Context())
71 log.Info("new subscription request receieved")
72
73 subscriber := newSubscriber(c.Request.URL.Query())
74 m.add(subscriber)
75 defer m.remove(subscriber)
76
77
78 upgrader := websocket.Upgrader{}
79 upgrader.CheckOrigin = func(_ *http.Request) bool { return true }
80 conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
81 if err != nil {
82 log.Error(err, "failed to upgrade the HTTP server connection to websocket protocol")
83 c.JSON(http.StatusInternalServerError, errors.NewErrorResponse(errors.NewError(http.StatusText(http.StatusInternalServerError))))
84 return
85 }
86 defer conn.Close()
87 subscriber.listen(c, conn)
88 }
89
90
91 func (m *Manager) add(sub subscriber) {
92 m.mutex.Lock()
93 defer m.mutex.Unlock()
94 m.subscribers = append(m.subscribers, sub)
95 observability.RecordSubscribers(len(m.subscribers))
96 }
97
98
99 func (m *Manager) remove(sub subscriber) {
100 m.mutex.Lock()
101 defer m.mutex.Unlock()
102 for i, s := range m.subscribers {
103 if s.channel == sub.channel {
104 m.subscribers = append(m.subscribers[:i], m.subscribers[i+1:]...)
105 break
106 }
107 }
108 observability.RecordSubscribers(len(m.subscribers))
109 }
110
111
112 func (m *Manager) Send(event Event) {
113 m.mutex.Lock()
114 defer m.mutex.Unlock()
115 for _, s := range m.subscribers {
116 if s.isSubscribedTo(event.Topic) {
117 s.channel <- event
118 }
119 }
120 observability.RecordEvent(event.Topic)
121 }
122
View as plain text