1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "errors"
7 "fmt"
8 "net/netip"
9 )
10
11
12
13
14 const (
15 defaultAFInet = 2
16 defaultAFInet6 = 3
17 )
18
19 type NetipPrefixScanner interface {
20 ScanNetipPrefix(v netip.Prefix) error
21 }
22
23 type NetipPrefixValuer interface {
24 NetipPrefixValue() (netip.Prefix, error)
25 }
26
27
28
29 type InetCodec struct{}
30
31 func (InetCodec) FormatSupported(format int16) bool {
32 return format == TextFormatCode || format == BinaryFormatCode
33 }
34
35 func (InetCodec) PreferredFormat() int16 {
36 return BinaryFormatCode
37 }
38
39 func (InetCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
40 if _, ok := value.(NetipPrefixValuer); !ok {
41 return nil
42 }
43
44 switch format {
45 case BinaryFormatCode:
46 return encodePlanInetCodecBinary{}
47 case TextFormatCode:
48 return encodePlanInetCodecText{}
49 }
50
51 return nil
52 }
53
54 type encodePlanInetCodecBinary struct{}
55
56 func (encodePlanInetCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
57 prefix, err := value.(NetipPrefixValuer).NetipPrefixValue()
58 if err != nil {
59 return nil, err
60 }
61
62 if !prefix.IsValid() {
63 return nil, nil
64 }
65
66 var family byte
67 if prefix.Addr().Is4() {
68 family = defaultAFInet
69 } else {
70 family = defaultAFInet6
71 }
72
73 buf = append(buf, family)
74
75 ones := prefix.Bits()
76 buf = append(buf, byte(ones))
77
78
79 buf = append(buf, 0)
80
81 if family == defaultAFInet {
82 buf = append(buf, byte(4))
83 b := prefix.Addr().As4()
84 buf = append(buf, b[:]...)
85 } else {
86 buf = append(buf, byte(16))
87 b := prefix.Addr().As16()
88 buf = append(buf, b[:]...)
89 }
90
91 return buf, nil
92 }
93
94 type encodePlanInetCodecText struct{}
95
96 func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
97 prefix, err := value.(NetipPrefixValuer).NetipPrefixValue()
98 if err != nil {
99 return nil, err
100 }
101
102 if !prefix.IsValid() {
103 return nil, nil
104 }
105
106 return append(buf, prefix.String()...), nil
107 }
108
109 func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
110
111 switch format {
112 case BinaryFormatCode:
113 switch target.(type) {
114 case NetipPrefixScanner:
115 return scanPlanBinaryInetToNetipPrefixScanner{}
116 }
117 case TextFormatCode:
118 switch target.(type) {
119 case NetipPrefixScanner:
120 return scanPlanTextAnyToNetipPrefixScanner{}
121 }
122 }
123
124 return nil
125 }
126
127 func (c InetCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
128 return codecDecodeToTextFormat(c, m, oid, format, src)
129 }
130
131 func (c InetCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
132 if src == nil {
133 return nil, nil
134 }
135
136 var prefix netip.Prefix
137 err := codecScan(c, m, oid, format, src, (*netipPrefixWrapper)(&prefix))
138 if err != nil {
139 return nil, err
140 }
141
142 if !prefix.IsValid() {
143 return nil, nil
144 }
145
146 return prefix, nil
147 }
148
149 type scanPlanBinaryInetToNetipPrefixScanner struct{}
150
151 func (scanPlanBinaryInetToNetipPrefixScanner) Scan(src []byte, dst any) error {
152 scanner := (dst).(NetipPrefixScanner)
153
154 if src == nil {
155 return scanner.ScanNetipPrefix(netip.Prefix{})
156 }
157
158 if len(src) != 8 && len(src) != 20 {
159 return fmt.Errorf("Received an invalid size for an inet: %d", len(src))
160 }
161
162
163 bits := src[1]
164
165
166
167 addr, ok := netip.AddrFromSlice(src[4:])
168 if !ok {
169 return errors.New("netip.AddrFromSlice failed")
170 }
171
172 return scanner.ScanNetipPrefix(netip.PrefixFrom(addr, int(bits)))
173 }
174
175 type scanPlanTextAnyToNetipPrefixScanner struct{}
176
177 func (scanPlanTextAnyToNetipPrefixScanner) Scan(src []byte, dst any) error {
178 scanner := (dst).(NetipPrefixScanner)
179
180 if src == nil {
181 return scanner.ScanNetipPrefix(netip.Prefix{})
182 }
183
184 var prefix netip.Prefix
185 if bytes.IndexByte(src, '/') == -1 {
186 addr, err := netip.ParseAddr(string(src))
187 if err != nil {
188 return err
189 }
190 prefix = netip.PrefixFrom(addr, addr.BitLen())
191 } else {
192 var err error
193 prefix, err = netip.ParsePrefix(string(src))
194 if err != nil {
195 return err
196 }
197 }
198
199 return scanner.ScanNetipPrefix(prefix)
200 }
201
View as plain text