...
1 package pgproto3
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "encoding/hex"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "math"
11
12 "github.com/jackc/pgx/v5/internal/pgio"
13 )
14
15 type Bind struct {
16 DestinationPortal string
17 PreparedStatement string
18 ParameterFormatCodes []int16
19 Parameters [][]byte
20 ResultFormatCodes []int16
21 }
22
23
24 func (*Bind) Frontend() {}
25
26
27
28 func (dst *Bind) Decode(src []byte) error {
29 *dst = Bind{}
30
31 idx := bytes.IndexByte(src, 0)
32 if idx < 0 {
33 return &invalidMessageFormatErr{messageType: "Bind"}
34 }
35 dst.DestinationPortal = string(src[:idx])
36 rp := idx + 1
37
38 idx = bytes.IndexByte(src[rp:], 0)
39 if idx < 0 {
40 return &invalidMessageFormatErr{messageType: "Bind"}
41 }
42 dst.PreparedStatement = string(src[rp : rp+idx])
43 rp += idx + 1
44
45 if len(src[rp:]) < 2 {
46 return &invalidMessageFormatErr{messageType: "Bind"}
47 }
48 parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
49 rp += 2
50
51 if parameterFormatCodeCount > 0 {
52 dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
53
54 if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
55 return &invalidMessageFormatErr{messageType: "Bind"}
56 }
57 for i := 0; i < parameterFormatCodeCount; i++ {
58 dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
59 rp += 2
60 }
61 }
62
63 if len(src[rp:]) < 2 {
64 return &invalidMessageFormatErr{messageType: "Bind"}
65 }
66 parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
67 rp += 2
68
69 if parameterCount > 0 {
70 dst.Parameters = make([][]byte, parameterCount)
71
72 for i := 0; i < parameterCount; i++ {
73 if len(src[rp:]) < 4 {
74 return &invalidMessageFormatErr{messageType: "Bind"}
75 }
76
77 msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
78 rp += 4
79
80
81 if msgSize == -1 {
82 continue
83 }
84
85 if len(src[rp:]) < msgSize {
86 return &invalidMessageFormatErr{messageType: "Bind"}
87 }
88
89 dst.Parameters[i] = src[rp : rp+msgSize]
90 rp += msgSize
91 }
92 }
93
94 if len(src[rp:]) < 2 {
95 return &invalidMessageFormatErr{messageType: "Bind"}
96 }
97 resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
98 rp += 2
99
100 dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
101 if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
102 return &invalidMessageFormatErr{messageType: "Bind"}
103 }
104 for i := 0; i < resultFormatCodeCount; i++ {
105 dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
106 rp += 2
107 }
108
109 return nil
110 }
111
112
113 func (src *Bind) Encode(dst []byte) ([]byte, error) {
114 dst, sp := beginMessage(dst, 'B')
115
116 dst = append(dst, src.DestinationPortal...)
117 dst = append(dst, 0)
118 dst = append(dst, src.PreparedStatement...)
119 dst = append(dst, 0)
120
121 if len(src.ParameterFormatCodes) > math.MaxUint16 {
122 return nil, errors.New("too many parameter format codes")
123 }
124 dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
125 for _, fc := range src.ParameterFormatCodes {
126 dst = pgio.AppendInt16(dst, fc)
127 }
128
129 if len(src.Parameters) > math.MaxUint16 {
130 return nil, errors.New("too many parameters")
131 }
132 dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
133 for _, p := range src.Parameters {
134 if p == nil {
135 dst = pgio.AppendInt32(dst, -1)
136 continue
137 }
138
139 dst = pgio.AppendInt32(dst, int32(len(p)))
140 dst = append(dst, p...)
141 }
142
143 if len(src.ResultFormatCodes) > math.MaxUint16 {
144 return nil, errors.New("too many result format codes")
145 }
146 dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
147 for _, fc := range src.ResultFormatCodes {
148 dst = pgio.AppendInt16(dst, fc)
149 }
150
151 return finishMessage(dst, sp)
152 }
153
154
155 func (src Bind) MarshalJSON() ([]byte, error) {
156 formattedParameters := make([]map[string]string, len(src.Parameters))
157 for i, p := range src.Parameters {
158 if p == nil {
159 continue
160 }
161
162 textFormat := true
163 if len(src.ParameterFormatCodes) == 1 {
164 textFormat = src.ParameterFormatCodes[0] == 0
165 } else if len(src.ParameterFormatCodes) > 1 {
166 textFormat = src.ParameterFormatCodes[i] == 0
167 }
168
169 if textFormat {
170 formattedParameters[i] = map[string]string{"text": string(p)}
171 } else {
172 formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
173 }
174 }
175
176 return json.Marshal(struct {
177 Type string
178 DestinationPortal string
179 PreparedStatement string
180 ParameterFormatCodes []int16
181 Parameters []map[string]string
182 ResultFormatCodes []int16
183 }{
184 Type: "Bind",
185 DestinationPortal: src.DestinationPortal,
186 PreparedStatement: src.PreparedStatement,
187 ParameterFormatCodes: src.ParameterFormatCodes,
188 Parameters: formattedParameters,
189 ResultFormatCodes: src.ResultFormatCodes,
190 })
191 }
192
193
194 func (dst *Bind) UnmarshalJSON(data []byte) error {
195
196 if string(data) == "null" {
197 return nil
198 }
199
200 var msg struct {
201 DestinationPortal string
202 PreparedStatement string
203 ParameterFormatCodes []int16
204 Parameters []map[string]string
205 ResultFormatCodes []int16
206 }
207 err := json.Unmarshal(data, &msg)
208 if err != nil {
209 return err
210 }
211 dst.DestinationPortal = msg.DestinationPortal
212 dst.PreparedStatement = msg.PreparedStatement
213 dst.ParameterFormatCodes = msg.ParameterFormatCodes
214 dst.Parameters = make([][]byte, len(msg.Parameters))
215 dst.ResultFormatCodes = msg.ResultFormatCodes
216 for n, parameter := range msg.Parameters {
217 dst.Parameters[n], err = getValueFromJSON(parameter)
218 if err != nil {
219 return fmt.Errorf("cannot get param %d: %w", n, err)
220 }
221 }
222 return nil
223 }
224
View as plain text