//go:build linux

package sockets

import (
	"errors"
	"io"
	"net"
	"strconv"
	"strings"

	"golang.org/x/sys/unix"
)

var (
	ErrInvalidIP            = errors.New("invalid ip address")
	ErrInvalidPort          = errors.New("invalid port")
	ErrInvalidSocketAddress = errors.New("address must be in the format of <ip>:<port>")
)

type Mode uint32

const (
	KernelEvent Mode = 1
	// Events that are processed by udev - much richer, with more attributes (such as vendor info, serial numbers and more).
	UdevEvent Mode = 2
)

// NewUEventNetlinkReader returns a reader that connects and listens to
// local netlink within the binaries network namespace
func NewUEventNetlinkReader() (io.ReadCloser, int, error) {
	fd, err := newUEventNetlinkSocket()
	if err != nil {
		return nil, 0, err
	}

	sockAddr := unix.SockaddrNetlink{
		Family: unix.AF_NETLINK,
		Groups: uint32(UdevEvent),
	}

	if err = unix.Bind(fd, &sockAddr); err != nil {
		unix.Close(fd)
	}
	return &monitor{fd: fd}, fd, err
}

// NewUEventRemoteReader returns a reader that connects and listens to a remote socket
func NewUEventRemoteReader(address string) (io.ReadCloser, int, error) {
	fd, err := newUEventRemoteSocket(address)
	if err != nil {
		return nil, 0, err
	}
	return &monitor{fd: fd}, fd, nil
}

// newUEventNetlinkSocket instantiates a new netlink socket connection
// and returns the file descriptor
func newUEventNetlinkSocket() (int, error) {
	fd, err := unix.Socket(
		unix.AF_NETLINK,
		unix.SOCK_RAW,
		unix.NETLINK_KOBJECT_UEVENT,
	)
	if err != nil {
		return 0, err
	}

	// avoid leaking the file-descriptor to child processes
	unix.CloseOnExec(fd)
	return fd, err
}

// newUEventRemoteSocket instantiates a new remote socket connection
// and returns the file descriptor
func newUEventRemoteSocket(address string) (int, error) {
	fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP)
	if err != nil {
		return 0, err
	}

	port, ip, err := parseAddress(address)
	if err != nil {
		return 0, err
	}

	addr := unix.SockaddrInet4{
		Port: port,
		Addr: ip,
	}

	if err := unix.Connect(fd, &addr); err != nil {
		unix.Close(fd)
		return 0, err
	}
	// avoid leaking the file-descriptor to child processes
	unix.CloseOnExec(fd)
	return fd, nil
}

// parseAddress will validate an IPv4 address in format 0.0.0.0:8080
// and return the port and IP
func parseAddress(address string) (port int, ip [4]byte, err error) {
	addressSplit := strings.Split(address, ":")
	if len(addressSplit) != 2 {
		return 0, ip, ErrInvalidSocketAddress
	}

	port, err = strconv.Atoi(addressSplit[1])
	if err != nil {
		return 0, ip, ErrInvalidPort
	}

	ipAddr := net.ParseIP(addressSplit[0]).To4()
	if len(ipAddr) != 4 {
		return 0, ip, ErrInvalidIP
	}
	copy(ip[:], ipAddr)
	return port, ip, nil
}