1
2
3
4
5 package jsonrpc2
6
7 import (
8 "bufio"
9 "context"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net"
14 "strconv"
15 "strings"
16 )
17
18
19
20
21
22
23
24 type Stream interface {
25
26 Read(context.Context) (Message, int64, error)
27
28 Write(context.Context, Message) (int64, error)
29
30
31 Close() error
32 }
33
34
35
36
37 type Framer func(conn net.Conn) Stream
38
39
40
41
42 func NewRawStream(conn net.Conn) Stream {
43 return &rawStream{
44 conn: conn,
45 in: json.NewDecoder(conn),
46 }
47 }
48
49 type rawStream struct {
50 conn net.Conn
51 in *json.Decoder
52 }
53
54 func (s *rawStream) Read(ctx context.Context) (Message, int64, error) {
55 select {
56 case <-ctx.Done():
57 return nil, 0, ctx.Err()
58 default:
59 }
60 var raw json.RawMessage
61 if err := s.in.Decode(&raw); err != nil {
62 return nil, 0, err
63 }
64 msg, err := DecodeMessage(raw)
65 return msg, int64(len(raw)), err
66 }
67
68 func (s *rawStream) Write(ctx context.Context, msg Message) (int64, error) {
69 select {
70 case <-ctx.Done():
71 return 0, ctx.Err()
72 default:
73 }
74 data, err := json.Marshal(msg)
75 if err != nil {
76 return 0, fmt.Errorf("marshaling message: %v", err)
77 }
78 n, err := s.conn.Write(data)
79 return int64(n), err
80 }
81
82 func (s *rawStream) Close() error {
83 return s.conn.Close()
84 }
85
86
87
88
89 func NewHeaderStream(conn net.Conn) Stream {
90 return &headerStream{
91 conn: conn,
92 in: bufio.NewReader(conn),
93 }
94 }
95
96 type headerStream struct {
97 conn net.Conn
98 in *bufio.Reader
99 }
100
101 func (s *headerStream) Read(ctx context.Context) (Message, int64, error) {
102 select {
103 case <-ctx.Done():
104 return nil, 0, ctx.Err()
105 default:
106 }
107 var total, length int64
108
109 for {
110 line, err := s.in.ReadString('\n')
111 total += int64(len(line))
112 if err != nil {
113 return nil, total, fmt.Errorf("failed reading header line: %w", err)
114 }
115 line = strings.TrimSpace(line)
116
117 if line == "" {
118 break
119 }
120 colon := strings.IndexRune(line, ':')
121 if colon < 0 {
122 return nil, total, fmt.Errorf("invalid header line %q", line)
123 }
124 name, value := line[:colon], strings.TrimSpace(line[colon+1:])
125 switch name {
126 case "Content-Length":
127 if length, err = strconv.ParseInt(value, 10, 32); err != nil {
128 return nil, total, fmt.Errorf("failed parsing Content-Length: %v", value)
129 }
130 if length <= 0 {
131 return nil, total, fmt.Errorf("invalid Content-Length: %v", length)
132 }
133 default:
134
135 }
136 }
137 if length == 0 {
138 return nil, total, fmt.Errorf("missing Content-Length header")
139 }
140 data := make([]byte, length)
141 if _, err := io.ReadFull(s.in, data); err != nil {
142 return nil, total, err
143 }
144 total += length
145 msg, err := DecodeMessage(data)
146 return msg, total, err
147 }
148
149 func (s *headerStream) Write(ctx context.Context, msg Message) (int64, error) {
150 select {
151 case <-ctx.Done():
152 return 0, ctx.Err()
153 default:
154 }
155 data, err := json.Marshal(msg)
156 if err != nil {
157 return 0, fmt.Errorf("marshaling message: %v", err)
158 }
159 n, err := fmt.Fprintf(s.conn, "Content-Length: %v\r\n\r\n", len(data))
160 total := int64(n)
161 if err == nil {
162 n, err = s.conn.Write(data)
163 total += int64(n)
164 }
165 return total, err
166 }
167
168 func (s *headerStream) Close() error {
169 return s.conn.Close()
170 }
171
View as plain text