...
1 package client
2
3 import (
4 "encoding/json"
5 "fmt"
6 "io"
7 "net/http/httptest"
8 "reflect"
9 "strings"
10
11 "github.com/gorilla/websocket"
12 )
13
14 const (
15 connectionInitMsg = "connection_init"
16 startMsg = "start"
17 connectionAckMsg = "connection_ack"
18 connectionKaMsg = "ka"
19 dataMsg = "data"
20 errorMsg = "error"
21 )
22
23 type operationMessage struct {
24 Payload json.RawMessage `json:"payload,omitempty"`
25 ID string `json:"id,omitempty"`
26 Type string `json:"type"`
27 }
28
29 type Subscription struct {
30 Close func() error
31 Next func(response interface{}) error
32 }
33
34 func errorSubscription(err error) *Subscription {
35 return &Subscription{
36 Close: func() error { return nil },
37 Next: func(response interface{}) error {
38 return err
39 },
40 }
41 }
42
43 func (p *Client) Websocket(query string, options ...Option) *Subscription {
44 return p.WebsocketWithPayload(query, nil, options...)
45 }
46
47
48 func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
49 sock := p.Websocket(query, options...)
50 defer sock.Close()
51 if reflect.ValueOf(resp).Kind() == reflect.Ptr {
52 return sock.Next(resp)
53 }
54
55 return sock.Next(&resp)
56 }
57
58 func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
59 r, err := p.newRequest(query, options...)
60 if err != nil {
61 return errorSubscription(fmt.Errorf("request: %w", err))
62 }
63
64 requestBody, err := io.ReadAll(r.Body)
65 if err != nil {
66 return errorSubscription(fmt.Errorf("parse body: %w", err))
67 }
68
69 srv := httptest.NewServer(p.h)
70 host := strings.ReplaceAll(srv.URL, "http://", "ws://")
71 c, resp, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header)
72 if err != nil {
73 return errorSubscription(fmt.Errorf("dial: %w", err))
74 }
75 defer resp.Body.Close()
76
77 initMessage := operationMessage{Type: connectionInitMsg}
78 if initPayload != nil {
79 initMessage.Payload, err = json.Marshal(initPayload)
80 if err != nil {
81 return errorSubscription(fmt.Errorf("parse payload: %w", err))
82 }
83 }
84
85 if err = c.WriteJSON(initMessage); err != nil {
86 return errorSubscription(fmt.Errorf("init: %w", err))
87 }
88
89 var ack operationMessage
90 if err = c.ReadJSON(&ack); err != nil {
91 return errorSubscription(fmt.Errorf("ack: %w", err))
92 }
93
94 if ack.Type != connectionAckMsg {
95 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
96 }
97
98 var ka operationMessage
99 if err = c.ReadJSON(&ka); err != nil {
100 return errorSubscription(fmt.Errorf("ack: %w", err))
101 }
102
103 if ka.Type != connectionKaMsg {
104 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
105 }
106
107 if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
108 return errorSubscription(fmt.Errorf("start: %w", err))
109 }
110
111 return &Subscription{
112 Close: func() error {
113 srv.Close()
114 return c.Close()
115 },
116 Next: func(response interface{}) error {
117 for {
118 var op operationMessage
119 err := c.ReadJSON(&op)
120 if err != nil {
121 return err
122 }
123
124 switch op.Type {
125 case dataMsg:
126 break
127 case connectionKaMsg:
128 continue
129 case errorMsg:
130 return fmt.Errorf(string(op.Payload))
131 default:
132 return fmt.Errorf("expected data message, got %#v", op)
133 }
134
135 var respDataRaw Response
136 err = json.Unmarshal(op.Payload, &respDataRaw)
137 if err != nil {
138 return fmt.Errorf("decode: %w", err)
139 }
140
141
142 unpackErr := unpack(respDataRaw.Data, response, p.dc)
143
144 if respDataRaw.Errors != nil {
145 return RawJsonError{respDataRaw.Errors}
146 }
147 return unpackErr
148 }
149 },
150 }
151 }
152
View as plain text