...

Source file src/edge-infra.dev/pkg/sds/interlock/websocket/websocket.go

Documentation: edge-infra.dev/pkg/sds/interlock/websocket

     1  package websocket
     2  
     3  import (
     4  	"net/http"
     5  	"sync"
     6  
     7  	"github.com/gin-gonic/gin"
     8  	"github.com/gorilla/websocket"
     9  
    10  	"edge-infra.dev/pkg/lib/fog"
    11  	"edge-infra.dev/pkg/sds/interlock/internal/errors"
    12  	"edge-infra.dev/pkg/sds/interlock/internal/observability"
    13  )
    14  
    15  var (
    16  	TopicQueryParam = "topic"
    17  )
    18  
    19  // Manager manages the subscribers to the websocket
    20  type Manager struct {
    21  	mutex       sync.Mutex
    22  	subscribers []subscriber
    23  }
    24  
    25  // NewManager creates a new Manager and records the subscriber count as 0
    26  func NewManager() *Manager {
    27  	observability.RecordSubscribers(0)
    28  	return &Manager{
    29  		subscribers: []subscriber{},
    30  	}
    31  }
    32  
    33  // Registers endpoints for the websocket manager
    34  func (m *Manager) RegisterEndpoints(r *gin.Engine) {
    35  	v1 := r.Group("/v1/subscribe")
    36  	v1.GET("", m.subscribe)
    37  }
    38  
    39  // swagger:route GET /v1/subscribe subscribe SubscribeToTopics
    40  // Subscribe to topics to receive event driven updates about state changes
    41  // responses:
    42  //   500: ErrorResponse
    43  //   default: EventResponse
    44  
    45  // The SubscribeToTopics parameters
    46  //
    47  // swagger:parameters SubscribeToTopics
    48  type SubscriptionParameterWrapper struct {
    49  	// The topics to subscribe to. Defaults to all
    50  	//
    51  	// in:query
    52  	// example: ["host"]
    53  	Topics []string `json:"topic"`
    54  }
    55  
    56  // The request was successful
    57  //
    58  // swagger:response EventResponse
    59  type EventResponseWrapper struct {
    60  	// The current host state
    61  	//
    62  	// in:body
    63  	Event Event `json:"event"`
    64  }
    65  
    66  // subscribe creates a new subscriber and upgrades the HTTP connection to
    67  // websocket protocol. The new subscriber listens on the its channel for
    68  // incoming events to send back over the websocket connection
    69  func (m *Manager) subscribe(c *gin.Context) {
    70  	log := fog.FromContext(c.Request.Context())
    71  	log.Info("new subscription request receieved")
    72  
    73  	subscriber := newSubscriber(c.Request.URL.Query())
    74  	m.add(subscriber)
    75  	defer m.remove(subscriber)
    76  
    77  	// upgrade http connection to websocket protocol
    78  	upgrader := websocket.Upgrader{}
    79  	upgrader.CheckOrigin = func(_ *http.Request) bool { return true }
    80  	conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
    81  	if err != nil {
    82  		log.Error(err, "failed to upgrade the HTTP server connection to websocket protocol")
    83  		c.JSON(http.StatusInternalServerError, errors.NewErrorResponse(errors.NewError(http.StatusText(http.StatusInternalServerError))))
    84  		return
    85  	}
    86  	defer conn.Close()
    87  	subscriber.listen(c, conn)
    88  }
    89  
    90  // add the provided subscriber to the Manager subscriber list
    91  func (m *Manager) add(sub subscriber) {
    92  	m.mutex.Lock()
    93  	defer m.mutex.Unlock()
    94  	m.subscribers = append(m.subscribers, sub)
    95  	observability.RecordSubscribers(len(m.subscribers))
    96  }
    97  
    98  // remove the provided subscriber from the Manager subscriber list
    99  func (m *Manager) remove(sub subscriber) {
   100  	m.mutex.Lock()
   101  	defer m.mutex.Unlock()
   102  	for i, s := range m.subscribers {
   103  		if s.channel == sub.channel {
   104  			m.subscribers = append(m.subscribers[:i], m.subscribers[i+1:]...)
   105  			break
   106  		}
   107  	}
   108  	observability.RecordSubscribers(len(m.subscribers))
   109  }
   110  
   111  // Send event to all subscribers that are subscribed to the event topic
   112  func (m *Manager) Send(event Event) {
   113  	m.mutex.Lock()
   114  	defer m.mutex.Unlock()
   115  	for _, s := range m.subscribers {
   116  		if s.isSubscribedTo(event.Topic) {
   117  			s.channel <- event
   118  		}
   119  	}
   120  	observability.RecordEvent(event.Topic)
   121  }
   122  

View as plain text