...
1 package client
2
3 import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "fmt"
8 "net/http/httptest"
9 "net/textproto"
10 "strings"
11 )
12
13 type SSE struct {
14 Close func() error
15 Next func(response interface{}) error
16 }
17
18 type SSEResponse struct {
19 Data interface{} `json:"data"`
20 Label string `json:"label"`
21 Path []interface{} `json:"path"`
22 HasNext bool `json:"hasNext"`
23 Errors json.RawMessage `json:"errors"`
24 Extensions map[string]interface{} `json:"extensions"`
25 }
26
27 func errorSSE(err error) *SSE {
28 return &SSE{
29 Close: func() error { return nil },
30 Next: func(response interface{}) error {
31 return err
32 },
33 }
34 }
35
36 func (p *Client) SSE(ctx context.Context, query string, options ...Option) *SSE {
37 r, err := p.newRequest(query, options...)
38 if err != nil {
39 return errorSSE(fmt.Errorf("request: %w", err))
40 }
41 r = r.WithContext(ctx)
42
43 r.Header.Set("Accept", "text/event-stream")
44 r.Header.Set("Cache-Control", "no-cache")
45 r.Header.Set("Connection", "keep-alive")
46
47 srv := httptest.NewServer(p.h)
48 w := httptest.NewRecorder()
49 p.h.ServeHTTP(w, r)
50
51 reader := textproto.NewReader(bufio.NewReader(w.Body))
52 line, err := reader.ReadLine()
53 if err != nil {
54 return errorSSE(fmt.Errorf("response: %w", err))
55 }
56 if line != ":" {
57 return errorSSE(fmt.Errorf("expected :, got %s", line))
58 }
59
60 return &SSE{
61 Close: func() error {
62 srv.Close()
63 return nil
64 },
65 Next: func(response interface{}) error {
66 for {
67 line, err := reader.ReadLine()
68 if err != nil {
69 return err
70 }
71 kv := strings.SplitN(line, ": ", 2)
72
73 switch kv[0] {
74 case "":
75 continue
76 case "event":
77 switch kv[1] {
78 case "next":
79 continue
80 case "complete":
81 return nil
82 default:
83 return fmt.Errorf("expected event type: %#v", kv[1])
84 }
85 case "data":
86 var respDataRaw SSEResponse
87 if err = json.Unmarshal([]byte(kv[1]), &respDataRaw); err != nil {
88 return fmt.Errorf("decode: %w", err)
89 }
90
91
92 unpackErr := unpack(respDataRaw, response, p.dc)
93
94 if respDataRaw.Errors != nil {
95 return RawJsonError{respDataRaw.Errors}
96 }
97
98 return unpackErr
99 default:
100 return fmt.Errorf("unexpected sse field %s", kv[0])
101 }
102 }
103 },
104 }
105 }
106
View as plain text