1
16
17 package connect
18
19 import (
20 "fmt"
21 "net"
22 "os"
23 "strings"
24 "syscall"
25 "time"
26
27 "github.com/ishidawataru/sctp"
28 "github.com/spf13/cobra"
29 )
30
31
32 var CmdConnect = &cobra.Command{
33 Use: "connect [host:port]",
34 Short: "Attempts a TCP, UDP or SCTP connection and returns useful errors",
35 Long: `Tries to open a TCP, UDP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for:
36
37 * UNKNOWN - Generic/unknown (non-network) error (eg, bad arguments)
38 * TIMEOUT - The connection attempt timed out
39 * DNS - An error in DNS resolution
40 * REFUSED - Connection refused
41 * OTHER - Other networking error (eg, "no route to host")`,
42 Args: cobra.ExactArgs(1),
43 Run: main,
44 }
45
46 var (
47 timeout time.Duration
48 protocol string
49 udpData string
50 )
51
52 func init() {
53 CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error")
54 CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp")
55 CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server")
56 }
57
58 func main(cmd *cobra.Command, args []string) {
59 dest := args[0]
60 switch protocol {
61 case "", "tcp":
62 connectTCP(dest, timeout)
63 case "udp":
64 connectUDP(dest, timeout, udpData)
65 case "sctp":
66 connectSCTP(dest, timeout)
67 default:
68 fmt.Fprint(os.Stderr, "Unsupported protocol\n", protocol)
69 os.Exit(1)
70 }
71 }
72
73 func connectTCP(dest string, timeout time.Duration) {
74
75
76 if _, _, err := net.SplitHostPort(dest); err != nil {
77 fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
78 os.Exit(1)
79 }
80 if _, err := net.ResolveTCPAddr("tcp", dest); err != nil {
81 fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
82 os.Exit(1)
83 }
84
85 conn, err := net.DialTimeout("tcp", dest, timeout)
86 if err == nil {
87 conn.Close()
88 os.Exit(0)
89 }
90 if opErr, ok := err.(*net.OpError); ok {
91 if opErr.Timeout() {
92 fmt.Fprintf(os.Stderr, "TIMEOUT\n")
93 os.Exit(1)
94 } else if syscallErr, ok := opErr.Err.(*os.SyscallError); ok {
95 if syscallErr.Err == syscall.ECONNREFUSED {
96 fmt.Fprintf(os.Stderr, "REFUSED\n")
97 os.Exit(1)
98 }
99 }
100 }
101
102 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
103 os.Exit(1)
104 }
105
106 func connectSCTP(dest string, timeout time.Duration) {
107 addr, err := sctp.ResolveSCTPAddr("sctp", dest)
108 if err != nil {
109 fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
110 os.Exit(1)
111 }
112
113 timeoutCh := time.After(timeout)
114 errCh := make(chan error)
115
116 go func() {
117 conn, err := sctp.DialSCTP("sctp", nil, addr)
118 if err == nil {
119 conn.Close()
120 }
121 errCh <- err
122 }()
123
124 select {
125 case err := <-errCh:
126 if err != nil {
127 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
128 os.Exit(1)
129 }
130 case <-timeoutCh:
131 fmt.Fprint(os.Stderr, "TIMEOUT\n")
132 os.Exit(1)
133 }
134 }
135
136 func connectUDP(dest string, timeout time.Duration, data string) {
137 var (
138 readBytes int
139 buf = make([]byte, 1024)
140 )
141
142 if _, err := net.ResolveUDPAddr("udp", dest); err != nil {
143 fmt.Fprintf(os.Stderr, "DNS: %v\n", err)
144 os.Exit(1)
145 }
146
147 conn, err := net.Dial("udp", dest)
148 if err != nil {
149 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
150 os.Exit(1)
151 }
152
153 if timeout > 0 {
154 if err = conn.SetDeadline(time.Now().Add(timeout)); err != nil {
155 fmt.Fprintf(os.Stderr, "OTHER: %v\n", err)
156 os.Exit(1)
157 }
158 }
159
160 if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil {
161 parseUDPErrorAndExit(err)
162 }
163
164 if readBytes, err = conn.Read(buf); err != nil {
165 parseUDPErrorAndExit(err)
166 }
167
168
169 if readBytes == 0 {
170 fmt.Fprintf(os.Stderr, "OTHER: No data received from the server. Cannot guarantee the server received the request.\n")
171 os.Exit(1)
172 }
173 }
174
175 func parseUDPErrorAndExit(err error) {
176 neterr, ok := err.(net.Error)
177 if ok && neterr.Timeout() {
178 fmt.Fprintf(os.Stderr, "TIMEOUT: %v\n", err)
179 } else if strings.Contains(err.Error(), "connection refused") {
180 fmt.Fprintf(os.Stderr, "REFUSED: %v\n", err)
181 } else {
182 fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err)
183 }
184 os.Exit(1)
185 }
186
View as plain text