1 package pgtype
2
3 import (
4 "database/sql/driver"
5 "encoding"
6 "fmt"
7 "net"
8 "strings"
9 )
10
11
12
13
14 const (
15 defaultAFInet = 2
16 defaultAFInet6 = 3
17 )
18
19
20 type Inet struct {
21 IPNet *net.IPNet
22 Status Status
23 }
24
25 func (dst *Inet) Set(src interface{}) error {
26 if src == nil {
27 *dst = Inet{Status: Null}
28 return nil
29 }
30
31 if value, ok := src.(interface{ Get() interface{} }); ok {
32 value2 := value.Get()
33 if value2 != value {
34 return dst.Set(value2)
35 }
36 }
37
38 switch value := src.(type) {
39 case net.IPNet:
40 *dst = Inet{IPNet: &value, Status: Present}
41 case net.IP:
42 if len(value) == 0 {
43 *dst = Inet{Status: Null}
44 } else {
45 bitCount := len(value) * 8
46 mask := net.CIDRMask(bitCount, bitCount)
47 *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present}
48 }
49 case string:
50 ip, ipnet, err := net.ParseCIDR(value)
51 if err != nil {
52 ip := net.ParseIP(value)
53 if ip == nil {
54 return fmt.Errorf("unable to parse inet address: %s", value)
55 }
56
57 if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil {
58 ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)}
59 } else {
60 ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
61 }
62 } else {
63 ipnet.IP = ip
64 if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil {
65 ipnet.IP = ipv4
66 if len(ipnet.Mask) == 16 {
67 ipnet.Mask = ipnet.Mask[12:]
68 }
69 }
70 }
71
72 *dst = Inet{IPNet: ipnet, Status: Present}
73 case *net.IPNet:
74 if value == nil {
75 *dst = Inet{Status: Null}
76 } else {
77 return dst.Set(*value)
78 }
79 case *net.IP:
80 if value == nil {
81 *dst = Inet{Status: Null}
82 } else {
83 return dst.Set(*value)
84 }
85 case *string:
86 if value == nil {
87 *dst = Inet{Status: Null}
88 } else {
89 return dst.Set(*value)
90 }
91 default:
92 if tv, ok := src.(encoding.TextMarshaler); ok {
93 text, err := tv.MarshalText()
94 if err != nil {
95 return fmt.Errorf("cannot marshal %v: %w", value, err)
96 }
97 return dst.Set(string(text))
98 }
99 if sv, ok := src.(fmt.Stringer); ok {
100 return dst.Set(sv.String())
101 }
102 if originalSrc, ok := underlyingPtrType(src); ok {
103 return dst.Set(originalSrc)
104 }
105 return fmt.Errorf("cannot convert %v to Inet", value)
106 }
107
108 return nil
109 }
110
111
112
113
114
115
116
117
118
119 func maybeGetIPv4(input string, ip net.IP) net.IP {
120
121
122
123 if strings.Contains(input, ":") {
124 return nil
125 }
126
127 return ip.To4()
128 }
129
130 func (dst Inet) Get() interface{} {
131 switch dst.Status {
132 case Present:
133 return dst.IPNet
134 case Null:
135 return nil
136 default:
137 return dst.Status
138 }
139 }
140
141 func (src *Inet) AssignTo(dst interface{}) error {
142 switch src.Status {
143 case Present:
144 switch v := dst.(type) {
145 case *net.IPNet:
146 *v = net.IPNet{
147 IP: make(net.IP, len(src.IPNet.IP)),
148 Mask: make(net.IPMask, len(src.IPNet.Mask)),
149 }
150 copy(v.IP, src.IPNet.IP)
151 copy(v.Mask, src.IPNet.Mask)
152 return nil
153 case *net.IP:
154 if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount {
155 return fmt.Errorf("cannot assign %v to %T", src, dst)
156 }
157 *v = make(net.IP, len(src.IPNet.IP))
158 copy(*v, src.IPNet.IP)
159 return nil
160 default:
161 if tv, ok := dst.(encoding.TextUnmarshaler); ok {
162 if err := tv.UnmarshalText([]byte(src.IPNet.String())); err != nil {
163 return fmt.Errorf("cannot unmarshal %v to %T: %w", src, dst, err)
164 }
165 return nil
166 }
167 if nextDst, retry := GetAssignToDstType(dst); retry {
168 return src.AssignTo(nextDst)
169 }
170 return fmt.Errorf("unable to assign to %T", dst)
171 }
172 case Null:
173 return NullAssignTo(dst)
174 }
175
176 return fmt.Errorf("cannot decode %#v into %T", src, dst)
177 }
178
179 func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
180 if src == nil {
181 *dst = Inet{Status: Null}
182 return nil
183 }
184
185 var ipnet *net.IPNet
186 var err error
187
188 if ip := net.ParseIP(string(src)); ip != nil {
189 if ipv4 := ip.To4(); ipv4 != nil {
190 ip = ipv4
191 }
192 bitCount := len(ip) * 8
193 mask := net.CIDRMask(bitCount, bitCount)
194 ipnet = &net.IPNet{Mask: mask, IP: ip}
195 } else {
196 ip, ipnet, err = net.ParseCIDR(string(src))
197 if err != nil {
198 return err
199 }
200 if ipv4 := ip.To4(); ipv4 != nil {
201 ip = ipv4
202 }
203 ones, _ := ipnet.Mask.Size()
204 *ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)}
205 }
206
207 *dst = Inet{IPNet: ipnet, Status: Present}
208 return nil
209 }
210
211 func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
212 if src == nil {
213 *dst = Inet{Status: Null}
214 return nil
215 }
216
217 if len(src) != 8 && len(src) != 20 {
218 return fmt.Errorf("Received an invalid size for an inet: %d", len(src))
219 }
220
221
222 bits := src[1]
223
224 addressLength := src[3]
225
226 var ipnet net.IPNet
227 ipnet.IP = make(net.IP, int(addressLength))
228 copy(ipnet.IP, src[4:])
229 if ipv4 := ipnet.IP.To4(); ipv4 != nil {
230 ipnet.IP = ipv4
231 }
232 ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8)
233
234 *dst = Inet{IPNet: &ipnet, Status: Present}
235
236 return nil
237 }
238
239 func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
240 switch src.Status {
241 case Null:
242 return nil, nil
243 case Undefined:
244 return nil, errUndefined
245 }
246
247 return append(buf, src.IPNet.String()...), nil
248 }
249
250
251 func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
252 switch src.Status {
253 case Null:
254 return nil, nil
255 case Undefined:
256 return nil, errUndefined
257 }
258
259 var family byte
260 switch len(src.IPNet.IP) {
261 case net.IPv4len:
262 family = defaultAFInet
263 case net.IPv6len:
264 family = defaultAFInet6
265 default:
266 return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP))
267 }
268
269 buf = append(buf, family)
270
271 ones, _ := src.IPNet.Mask.Size()
272 buf = append(buf, byte(ones))
273
274
275 buf = append(buf, 0)
276
277 buf = append(buf, byte(len(src.IPNet.IP)))
278
279 return append(buf, src.IPNet.IP...), nil
280 }
281
282
283 func (dst *Inet) Scan(src interface{}) error {
284 if src == nil {
285 *dst = Inet{Status: Null}
286 return nil
287 }
288
289 switch src := src.(type) {
290 case string:
291 return dst.DecodeText(nil, []byte(src))
292 case []byte:
293 srcCopy := make([]byte, len(src))
294 copy(srcCopy, src)
295 return dst.DecodeText(nil, srcCopy)
296 }
297
298 return fmt.Errorf("cannot scan %T", src)
299 }
300
301
302 func (src Inet) Value() (driver.Value, error) {
303 return EncodeValueText(src)
304 }
305
View as plain text