1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/binary"
7 "fmt"
8 "math"
9 "strconv"
10 "strings"
11
12 "github.com/jackc/pgx/v5/internal/pgio"
13 )
14
15 type Vec2 struct {
16 X float64
17 Y float64
18 }
19
20 type PointScanner interface {
21 ScanPoint(v Point) error
22 }
23
24 type PointValuer interface {
25 PointValue() (Point, error)
26 }
27
28 type Point struct {
29 P Vec2
30 Valid bool
31 }
32
33 func (p *Point) ScanPoint(v Point) error {
34 *p = v
35 return nil
36 }
37
38 func (p Point) PointValue() (Point, error) {
39 return p, nil
40 }
41
42 func parsePoint(src []byte) (*Point, error) {
43 if src == nil || bytes.Equal(src, []byte("null")) {
44 return &Point{}, nil
45 }
46
47 if len(src) < 5 {
48 return nil, fmt.Errorf("invalid length for point: %v", len(src))
49 }
50 if src[0] == '"' && src[len(src)-1] == '"' {
51 src = src[1 : len(src)-1]
52 }
53 sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
54 if !found {
55 return nil, fmt.Errorf("invalid format for point")
56 }
57
58 x, err := strconv.ParseFloat(sx, 64)
59 if err != nil {
60 return nil, err
61 }
62
63 y, err := strconv.ParseFloat(sy, 64)
64 if err != nil {
65 return nil, err
66 }
67
68 return &Point{P: Vec2{x, y}, Valid: true}, nil
69 }
70
71
72 func (dst *Point) Scan(src any) error {
73 if src == nil {
74 *dst = Point{}
75 return nil
76 }
77
78 switch src := src.(type) {
79 case string:
80 return scanPlanTextAnyToPointScanner{}.Scan([]byte(src), dst)
81 }
82
83 return fmt.Errorf("cannot scan %T", src)
84 }
85
86
87 func (src Point) Value() (driver.Value, error) {
88 if !src.Valid {
89 return nil, nil
90 }
91
92 buf, err := PointCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil)
93 if err != nil {
94 return nil, err
95 }
96 return string(buf), err
97 }
98
99 func (src Point) MarshalJSON() ([]byte, error) {
100 if !src.Valid {
101 return []byte("null"), nil
102 }
103
104 var buff bytes.Buffer
105 buff.WriteByte('"')
106 buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y))
107 buff.WriteByte('"')
108 return buff.Bytes(), nil
109 }
110
111 func (dst *Point) UnmarshalJSON(point []byte) error {
112 p, err := parsePoint(point)
113 if err != nil {
114 return err
115 }
116 *dst = *p
117 return nil
118 }
119
120 type PointCodec struct{}
121
122 func (PointCodec) FormatSupported(format int16) bool {
123 return format == TextFormatCode || format == BinaryFormatCode
124 }
125
126 func (PointCodec) PreferredFormat() int16 {
127 return BinaryFormatCode
128 }
129
130 func (PointCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
131 if _, ok := value.(PointValuer); !ok {
132 return nil
133 }
134
135 switch format {
136 case BinaryFormatCode:
137 return encodePlanPointCodecBinary{}
138 case TextFormatCode:
139 return encodePlanPointCodecText{}
140 }
141
142 return nil
143 }
144
145 type encodePlanPointCodecBinary struct{}
146
147 func (encodePlanPointCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
148 point, err := value.(PointValuer).PointValue()
149 if err != nil {
150 return nil, err
151 }
152
153 if !point.Valid {
154 return nil, nil
155 }
156
157 buf = pgio.AppendUint64(buf, math.Float64bits(point.P.X))
158 buf = pgio.AppendUint64(buf, math.Float64bits(point.P.Y))
159 return buf, nil
160 }
161
162 type encodePlanPointCodecText struct{}
163
164 func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
165 point, err := value.(PointValuer).PointValue()
166 if err != nil {
167 return nil, err
168 }
169
170 if !point.Valid {
171 return nil, nil
172 }
173
174 return append(buf, fmt.Sprintf(`(%s,%s)`,
175 strconv.FormatFloat(point.P.X, 'f', -1, 64),
176 strconv.FormatFloat(point.P.Y, 'f', -1, 64),
177 )...), nil
178 }
179
180 func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
181
182 switch format {
183 case BinaryFormatCode:
184 switch target.(type) {
185 case PointScanner:
186 return scanPlanBinaryPointToPointScanner{}
187 }
188 case TextFormatCode:
189 switch target.(type) {
190 case PointScanner:
191 return scanPlanTextAnyToPointScanner{}
192 }
193 }
194
195 return nil
196 }
197
198 func (c PointCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
199 return codecDecodeToTextFormat(c, m, oid, format, src)
200 }
201
202 func (c PointCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
203 if src == nil {
204 return nil, nil
205 }
206
207 var point Point
208 err := codecScan(c, m, oid, format, src, &point)
209 if err != nil {
210 return nil, err
211 }
212 return point, nil
213 }
214
215 type scanPlanBinaryPointToPointScanner struct{}
216
217 func (scanPlanBinaryPointToPointScanner) Scan(src []byte, dst any) error {
218 scanner := (dst).(PointScanner)
219
220 if src == nil {
221 return scanner.ScanPoint(Point{})
222 }
223
224 if len(src) != 16 {
225 return fmt.Errorf("invalid length for point: %v", len(src))
226 }
227
228 x := binary.BigEndian.Uint64(src)
229 y := binary.BigEndian.Uint64(src[8:])
230
231 return scanner.ScanPoint(Point{
232 P: Vec2{math.Float64frombits(x), math.Float64frombits(y)},
233 Valid: true,
234 })
235 }
236
237 type scanPlanTextAnyToPointScanner struct{}
238
239 func (scanPlanTextAnyToPointScanner) Scan(src []byte, dst any) error {
240 scanner := (dst).(PointScanner)
241
242 if src == nil {
243 return scanner.ScanPoint(Point{})
244 }
245
246 if len(src) < 5 {
247 return fmt.Errorf("invalid length for point: %v", len(src))
248 }
249
250 sx, sy, found := strings.Cut(string(src[1:len(src)-1]), ",")
251 if !found {
252 return fmt.Errorf("invalid format for point")
253 }
254
255 x, err := strconv.ParseFloat(sx, 64)
256 if err != nil {
257 return err
258 }
259
260 y, err := strconv.ParseFloat(sy, 64)
261 if err != nil {
262 return err
263 }
264
265 return scanner.ScanPoint(Point{P: Vec2{x, y}, Valid: true})
266 }
267
View as plain text