1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding/binary"
6 "fmt"
7
8 "github.com/jackc/pgx/v5/internal/pgio"
9 )
10
11 type BitsScanner interface {
12 ScanBits(v Bits) error
13 }
14
15 type BitsValuer interface {
16 BitsValue() (Bits, error)
17 }
18
19
20 type Bits struct {
21 Bytes []byte
22 Len int32
23 Valid bool
24 }
25
26 func (b *Bits) ScanBits(v Bits) error {
27 *b = v
28 return nil
29 }
30
31 func (b Bits) BitsValue() (Bits, error) {
32 return b, nil
33 }
34
35
36 func (dst *Bits) Scan(src any) error {
37 if src == nil {
38 *dst = Bits{}
39 return nil
40 }
41
42 switch src := src.(type) {
43 case string:
44 return scanPlanTextAnyToBitsScanner{}.Scan([]byte(src), dst)
45 }
46
47 return fmt.Errorf("cannot scan %T", src)
48 }
49
50
51 func (src Bits) Value() (driver.Value, error) {
52 if !src.Valid {
53 return nil, nil
54 }
55
56 buf, err := BitsCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil)
57 if err != nil {
58 return nil, err
59 }
60 return string(buf), err
61 }
62
63 type BitsCodec struct{}
64
65 func (BitsCodec) FormatSupported(format int16) bool {
66 return format == TextFormatCode || format == BinaryFormatCode
67 }
68
69 func (BitsCodec) PreferredFormat() int16 {
70 return BinaryFormatCode
71 }
72
73 func (BitsCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
74 if _, ok := value.(BitsValuer); !ok {
75 return nil
76 }
77
78 switch format {
79 case BinaryFormatCode:
80 return encodePlanBitsCodecBinary{}
81 case TextFormatCode:
82 return encodePlanBitsCodecText{}
83 }
84
85 return nil
86 }
87
88 type encodePlanBitsCodecBinary struct{}
89
90 func (encodePlanBitsCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
91 bits, err := value.(BitsValuer).BitsValue()
92 if err != nil {
93 return nil, err
94 }
95
96 if !bits.Valid {
97 return nil, nil
98 }
99
100 buf = pgio.AppendInt32(buf, bits.Len)
101 return append(buf, bits.Bytes...), nil
102 }
103
104 type encodePlanBitsCodecText struct{}
105
106 func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
107 bits, err := value.(BitsValuer).BitsValue()
108 if err != nil {
109 return nil, err
110 }
111
112 if !bits.Valid {
113 return nil, nil
114 }
115
116 for i := int32(0); i < bits.Len; i++ {
117 byteIdx := i / 8
118 bitMask := byte(128 >> byte(i%8))
119 char := byte('0')
120 if bits.Bytes[byteIdx]&bitMask > 0 {
121 char = '1'
122 }
123 buf = append(buf, char)
124 }
125
126 return buf, nil
127 }
128
129 func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
130
131 switch format {
132 case BinaryFormatCode:
133 switch target.(type) {
134 case BitsScanner:
135 return scanPlanBinaryBitsToBitsScanner{}
136 }
137 case TextFormatCode:
138 switch target.(type) {
139 case BitsScanner:
140 return scanPlanTextAnyToBitsScanner{}
141 }
142 }
143
144 return nil
145 }
146
147 func (c BitsCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
148 return codecDecodeToTextFormat(c, m, oid, format, src)
149 }
150
151 func (c BitsCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
152 if src == nil {
153 return nil, nil
154 }
155
156 var box Bits
157 err := codecScan(c, m, oid, format, src, &box)
158 if err != nil {
159 return nil, err
160 }
161 return box, nil
162 }
163
164 type scanPlanBinaryBitsToBitsScanner struct{}
165
166 func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
167 scanner := (dst).(BitsScanner)
168
169 if src == nil {
170 return scanner.ScanBits(Bits{})
171 }
172
173 if len(src) < 4 {
174 return fmt.Errorf("invalid length for bit/varbit: %v", len(src))
175 }
176
177 bitLen := int32(binary.BigEndian.Uint32(src))
178 rp := 4
179 buf := make([]byte, len(src[rp:]))
180 copy(buf, src[rp:])
181
182 return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
183 }
184
185 type scanPlanTextAnyToBitsScanner struct{}
186
187 func (scanPlanTextAnyToBitsScanner) Scan(src []byte, dst any) error {
188 scanner := (dst).(BitsScanner)
189
190 if src == nil {
191 return scanner.ScanBits(Bits{})
192 }
193
194 bitLen := len(src)
195 byteLen := bitLen / 8
196 if bitLen%8 > 0 {
197 byteLen++
198 }
199 buf := make([]byte, byteLen)
200
201 for i, b := range src {
202 if b == '1' {
203 byteIdx := i / 8
204 bitIdx := uint(i % 8)
205 buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx)
206 }
207 }
208
209 return scanner.ScanBits(Bits{Bytes: buf, Len: int32(bitLen), Valid: true})
210 }
211
View as plain text