...

Source file src/github.com/jackc/pgtype/inet.go

Documentation: github.com/jackc/pgtype

     1  package pgtype
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"encoding"
     6  	"fmt"
     7  	"net"
     8  	"strings"
     9  )
    10  
    11  // Network address family is dependent on server socket.h value for AF_INET.
    12  // In practice, all platforms appear to have the same value. See
    13  // src/include/utils/inet.h for more information.
    14  const (
    15  	defaultAFInet  = 2
    16  	defaultAFInet6 = 3
    17  )
    18  
    19  // Inet represents both inet and cidr PostgreSQL types.
    20  type Inet struct {
    21  	IPNet  *net.IPNet
    22  	Status Status
    23  }
    24  
    25  func (dst *Inet) Set(src interface{}) error {
    26  	if src == nil {
    27  		*dst = Inet{Status: Null}
    28  		return nil
    29  	}
    30  
    31  	if value, ok := src.(interface{ Get() interface{} }); ok {
    32  		value2 := value.Get()
    33  		if value2 != value {
    34  			return dst.Set(value2)
    35  		}
    36  	}
    37  
    38  	switch value := src.(type) {
    39  	case net.IPNet:
    40  		*dst = Inet{IPNet: &value, Status: Present}
    41  	case net.IP:
    42  		if len(value) == 0 {
    43  			*dst = Inet{Status: Null}
    44  		} else {
    45  			bitCount := len(value) * 8
    46  			mask := net.CIDRMask(bitCount, bitCount)
    47  			*dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present}
    48  		}
    49  	case string:
    50  		ip, ipnet, err := net.ParseCIDR(value)
    51  		if err != nil {
    52  			ip := net.ParseIP(value)
    53  			if ip == nil {
    54  				return fmt.Errorf("unable to parse inet address: %s", value)
    55  			}
    56  
    57  			if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil {
    58  				ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)}
    59  			} else {
    60  				ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
    61  			}
    62  		} else {
    63  			ipnet.IP = ip
    64  			if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil {
    65  				ipnet.IP = ipv4
    66  				if len(ipnet.Mask) == 16 {
    67  					ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed.
    68  				}
    69  			}
    70  		}
    71  
    72  		*dst = Inet{IPNet: ipnet, Status: Present}
    73  	case *net.IPNet:
    74  		if value == nil {
    75  			*dst = Inet{Status: Null}
    76  		} else {
    77  			return dst.Set(*value)
    78  		}
    79  	case *net.IP:
    80  		if value == nil {
    81  			*dst = Inet{Status: Null}
    82  		} else {
    83  			return dst.Set(*value)
    84  		}
    85  	case *string:
    86  		if value == nil {
    87  			*dst = Inet{Status: Null}
    88  		} else {
    89  			return dst.Set(*value)
    90  		}
    91  	default:
    92  		if tv, ok := src.(encoding.TextMarshaler); ok {
    93  			text, err := tv.MarshalText()
    94  			if err != nil {
    95  				return fmt.Errorf("cannot marshal %v: %w", value, err)
    96  			}
    97  			return dst.Set(string(text))
    98  		}
    99  		if sv, ok := src.(fmt.Stringer); ok {
   100  			return dst.Set(sv.String())
   101  		}
   102  		if originalSrc, ok := underlyingPtrType(src); ok {
   103  			return dst.Set(originalSrc)
   104  		}
   105  		return fmt.Errorf("cannot convert %v to Inet", value)
   106  	}
   107  
   108  	return nil
   109  }
   110  
   111  // Convert the net.IP to IPv4, if appropriate.
   112  //
   113  // When parsing a string to a net.IP using net.ParseIP() and the like, we get a
   114  // 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function
   115  // calls To4() to convert them to a 4 byte slice. This is useful as it allows
   116  // users of the net.IP check for IPv4 addresses based on the length and makes
   117  // it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6
   118  // addresses.
   119  func maybeGetIPv4(input string, ip net.IP) net.IP {
   120  	// Do not do this if the provided input looks like IPv6. This is because
   121  	// To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave
   122  	// different in some cases.
   123  	if strings.Contains(input, ":") {
   124  		return nil
   125  	}
   126  
   127  	return ip.To4()
   128  }
   129  
   130  func (dst Inet) Get() interface{} {
   131  	switch dst.Status {
   132  	case Present:
   133  		return dst.IPNet
   134  	case Null:
   135  		return nil
   136  	default:
   137  		return dst.Status
   138  	}
   139  }
   140  
   141  func (src *Inet) AssignTo(dst interface{}) error {
   142  	switch src.Status {
   143  	case Present:
   144  		switch v := dst.(type) {
   145  		case *net.IPNet:
   146  			*v = net.IPNet{
   147  				IP:   make(net.IP, len(src.IPNet.IP)),
   148  				Mask: make(net.IPMask, len(src.IPNet.Mask)),
   149  			}
   150  			copy(v.IP, src.IPNet.IP)
   151  			copy(v.Mask, src.IPNet.Mask)
   152  			return nil
   153  		case *net.IP:
   154  			if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount {
   155  				return fmt.Errorf("cannot assign %v to %T", src, dst)
   156  			}
   157  			*v = make(net.IP, len(src.IPNet.IP))
   158  			copy(*v, src.IPNet.IP)
   159  			return nil
   160  		default:
   161  			if tv, ok := dst.(encoding.TextUnmarshaler); ok {
   162  				if err := tv.UnmarshalText([]byte(src.IPNet.String())); err != nil {
   163  					return fmt.Errorf("cannot unmarshal %v to %T: %w", src, dst, err)
   164  				}
   165  				return nil
   166  			}
   167  			if nextDst, retry := GetAssignToDstType(dst); retry {
   168  				return src.AssignTo(nextDst)
   169  			}
   170  			return fmt.Errorf("unable to assign to %T", dst)
   171  		}
   172  	case Null:
   173  		return NullAssignTo(dst)
   174  	}
   175  
   176  	return fmt.Errorf("cannot decode %#v into %T", src, dst)
   177  }
   178  
   179  func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
   180  	if src == nil {
   181  		*dst = Inet{Status: Null}
   182  		return nil
   183  	}
   184  
   185  	var ipnet *net.IPNet
   186  	var err error
   187  
   188  	if ip := net.ParseIP(string(src)); ip != nil {
   189  		if ipv4 := ip.To4(); ipv4 != nil {
   190  			ip = ipv4
   191  		}
   192  		bitCount := len(ip) * 8
   193  		mask := net.CIDRMask(bitCount, bitCount)
   194  		ipnet = &net.IPNet{Mask: mask, IP: ip}
   195  	} else {
   196  		ip, ipnet, err = net.ParseCIDR(string(src))
   197  		if err != nil {
   198  			return err
   199  		}
   200  		if ipv4 := ip.To4(); ipv4 != nil {
   201  			ip = ipv4
   202  		}
   203  		ones, _ := ipnet.Mask.Size()
   204  		*ipnet = net.IPNet{IP: ip, Mask: net.CIDRMask(ones, len(ip)*8)}
   205  	}
   206  
   207  	*dst = Inet{IPNet: ipnet, Status: Present}
   208  	return nil
   209  }
   210  
   211  func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
   212  	if src == nil {
   213  		*dst = Inet{Status: Null}
   214  		return nil
   215  	}
   216  
   217  	if len(src) != 8 && len(src) != 20 {
   218  		return fmt.Errorf("Received an invalid size for an inet: %d", len(src))
   219  	}
   220  
   221  	// ignore family
   222  	bits := src[1]
   223  	// ignore is_cidr
   224  	addressLength := src[3]
   225  
   226  	var ipnet net.IPNet
   227  	ipnet.IP = make(net.IP, int(addressLength))
   228  	copy(ipnet.IP, src[4:])
   229  	if ipv4 := ipnet.IP.To4(); ipv4 != nil {
   230  		ipnet.IP = ipv4
   231  	}
   232  	ipnet.Mask = net.CIDRMask(int(bits), len(ipnet.IP)*8)
   233  
   234  	*dst = Inet{IPNet: &ipnet, Status: Present}
   235  
   236  	return nil
   237  }
   238  
   239  func (src Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
   240  	switch src.Status {
   241  	case Null:
   242  		return nil, nil
   243  	case Undefined:
   244  		return nil, errUndefined
   245  	}
   246  
   247  	return append(buf, src.IPNet.String()...), nil
   248  }
   249  
   250  // EncodeBinary encodes src into w.
   251  func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
   252  	switch src.Status {
   253  	case Null:
   254  		return nil, nil
   255  	case Undefined:
   256  		return nil, errUndefined
   257  	}
   258  
   259  	var family byte
   260  	switch len(src.IPNet.IP) {
   261  	case net.IPv4len:
   262  		family = defaultAFInet
   263  	case net.IPv6len:
   264  		family = defaultAFInet6
   265  	default:
   266  		return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP))
   267  	}
   268  
   269  	buf = append(buf, family)
   270  
   271  	ones, _ := src.IPNet.Mask.Size()
   272  	buf = append(buf, byte(ones))
   273  
   274  	// is_cidr is ignored on server
   275  	buf = append(buf, 0)
   276  
   277  	buf = append(buf, byte(len(src.IPNet.IP)))
   278  
   279  	return append(buf, src.IPNet.IP...), nil
   280  }
   281  
   282  // Scan implements the database/sql Scanner interface.
   283  func (dst *Inet) Scan(src interface{}) error {
   284  	if src == nil {
   285  		*dst = Inet{Status: Null}
   286  		return nil
   287  	}
   288  
   289  	switch src := src.(type) {
   290  	case string:
   291  		return dst.DecodeText(nil, []byte(src))
   292  	case []byte:
   293  		srcCopy := make([]byte, len(src))
   294  		copy(srcCopy, src)
   295  		return dst.DecodeText(nil, srcCopy)
   296  	}
   297  
   298  	return fmt.Errorf("cannot scan %T", src)
   299  }
   300  
   301  // Value implements the database/sql/driver Valuer interface.
   302  func (src Inet) Value() (driver.Value, error) {
   303  	return EncodeValueText(src)
   304  }
   305  

View as plain text