1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7 "strconv"
8 "strings"
9
10 "github.com/jackc/pgx/v5/internal/pgio"
11 )
12
13 type TIDScanner interface {
14 ScanTID(v TID) error
15 }
16
17 type TIDValuer interface {
18 TIDValue() (TID, error)
19 }
20
21
22
23
24
25
26
27
28
29
30
31
32 type TID struct {
33 BlockNumber uint32
34 OffsetNumber uint16
35 Valid bool
36 }
37
38 func (b *TID) ScanTID(v TID) error {
39 *b = v
40 return nil
41 }
42
43 func (b TID) TIDValue() (TID, error) {
44 return b, nil
45 }
46
47
48 func (dst *TID) Scan(src any) error {
49 if src == nil {
50 *dst = TID{}
51 return nil
52 }
53
54 switch src := src.(type) {
55 case string:
56 return scanPlanTextAnyToTIDScanner{}.Scan([]byte(src), dst)
57 }
58
59 return fmt.Errorf("cannot scan %T", src)
60 }
61
62
63 func (src TID) Value() (driver.Value, error) {
64 if !src.Valid {
65 return nil, nil
66 }
67
68 buf, err := TIDCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil)
69 if err != nil {
70 return nil, err
71 }
72 return string(buf), err
73 }
74
75 type TIDCodec struct{}
76
77 func (TIDCodec) FormatSupported(format int16) bool {
78 return format == TextFormatCode || format == BinaryFormatCode
79 }
80
81 func (TIDCodec) PreferredFormat() int16 {
82 return BinaryFormatCode
83 }
84
85 func (TIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
86 if _, ok := value.(TIDValuer); !ok {
87 return nil
88 }
89
90 switch format {
91 case BinaryFormatCode:
92 return encodePlanTIDCodecBinary{}
93 case TextFormatCode:
94 return encodePlanTIDCodecText{}
95 }
96
97 return nil
98 }
99
100 type encodePlanTIDCodecBinary struct{}
101
102 func (encodePlanTIDCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
103 tid, err := value.(TIDValuer).TIDValue()
104 if err != nil {
105 return nil, err
106 }
107
108 if !tid.Valid {
109 return nil, nil
110 }
111
112 buf = pgio.AppendUint32(buf, tid.BlockNumber)
113 buf = pgio.AppendUint16(buf, tid.OffsetNumber)
114 return buf, nil
115 }
116
117 type encodePlanTIDCodecText struct{}
118
119 func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
120 tid, err := value.(TIDValuer).TIDValue()
121 if err != nil {
122 return nil, err
123 }
124
125 if !tid.Valid {
126 return nil, nil
127 }
128
129 buf = append(buf, fmt.Sprintf(`(%d,%d)`, tid.BlockNumber, tid.OffsetNumber)...)
130 return buf, nil
131 }
132
133 func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
134
135 switch format {
136 case BinaryFormatCode:
137 switch target.(type) {
138 case TIDScanner:
139 return scanPlanBinaryTIDToTIDScanner{}
140 case TextScanner:
141 return scanPlanBinaryTIDToTextScanner{}
142 }
143 case TextFormatCode:
144 switch target.(type) {
145 case TIDScanner:
146 return scanPlanTextAnyToTIDScanner{}
147 }
148 }
149
150 return nil
151 }
152
153 type scanPlanBinaryTIDToTIDScanner struct{}
154
155 func (scanPlanBinaryTIDToTIDScanner) Scan(src []byte, dst any) error {
156 scanner := (dst).(TIDScanner)
157
158 if src == nil {
159 return scanner.ScanTID(TID{})
160 }
161
162 if len(src) != 6 {
163 return fmt.Errorf("invalid length for tid: %v", len(src))
164 }
165
166 return scanner.ScanTID(TID{
167 BlockNumber: binary.BigEndian.Uint32(src),
168 OffsetNumber: binary.BigEndian.Uint16(src[4:]),
169 Valid: true,
170 })
171 }
172
173 type scanPlanBinaryTIDToTextScanner struct{}
174
175 func (scanPlanBinaryTIDToTextScanner) Scan(src []byte, dst any) error {
176 scanner := (dst).(TextScanner)
177
178 if src == nil {
179 return scanner.ScanText(Text{})
180 }
181
182 if len(src) != 6 {
183 return fmt.Errorf("invalid length for tid: %v", len(src))
184 }
185
186 blockNumber := binary.BigEndian.Uint32(src)
187 offsetNumber := binary.BigEndian.Uint16(src[4:])
188
189 return scanner.ScanText(Text{
190 String: fmt.Sprintf(`(%d,%d)`, blockNumber, offsetNumber),
191 Valid: true,
192 })
193 }
194
195 type scanPlanTextAnyToTIDScanner struct{}
196
197 func (scanPlanTextAnyToTIDScanner) Scan(src []byte, dst any) error {
198 scanner := (dst).(TIDScanner)
199
200 if src == nil {
201 return scanner.ScanTID(TID{})
202 }
203
204 if len(src) < 5 {
205 return fmt.Errorf("invalid length for tid: %v", len(src))
206 }
207
208 block, offset, found := strings.Cut(string(src[1:len(src)-1]), ",")
209 if !found {
210 return fmt.Errorf("invalid format for tid")
211 }
212
213 blockNumber, err := strconv.ParseUint(block, 10, 32)
214 if err != nil {
215 return err
216 }
217
218 offsetNumber, err := strconv.ParseUint(offset, 10, 16)
219 if err != nil {
220 return err
221 }
222
223 return scanner.ScanTID(TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Valid: true})
224 }
225
226 func (c TIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
227 return codecDecodeToTextFormat(c, m, oid, format, src)
228 }
229
230 func (c TIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
231 if src == nil {
232 return nil, nil
233 }
234
235 var tid TID
236 err := codecScan(c, m, oid, format, src, &tid)
237 if err != nil {
238 return nil, err
239 }
240 return tid, nil
241 }
242
View as plain text