1
2
3
4
5 package socks_test
6
7 import (
8 "context"
9 "io"
10 "math/rand"
11 "net"
12 "os"
13 "testing"
14 "time"
15
16 "golang.org/x/net/internal/socks"
17 "golang.org/x/net/internal/sockstest"
18 )
19
20 func TestDial(t *testing.T) {
21 t.Run("Connect", func(t *testing.T) {
22 ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
23 if err != nil {
24 t.Fatal(err)
25 }
26 defer ss.Close()
27 d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
28 d.AuthMethods = []socks.AuthMethod{
29 socks.AuthMethodNotRequired,
30 socks.AuthMethodUsernamePassword,
31 }
32 d.Authenticate = (&socks.UsernamePassword{
33 Username: "username",
34 Password: "password",
35 }).Authenticate
36 c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
37 if err != nil {
38 t.Fatal(err)
39 }
40 c.(*socks.Conn).BoundAddr()
41 c.Close()
42 })
43 t.Run("ConnectWithConn", func(t *testing.T) {
44 ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
45 if err != nil {
46 t.Fatal(err)
47 }
48 defer ss.Close()
49 c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
50 if err != nil {
51 t.Fatal(err)
52 }
53 defer c.Close()
54 d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
55 d.AuthMethods = []socks.AuthMethod{
56 socks.AuthMethodNotRequired,
57 socks.AuthMethodUsernamePassword,
58 }
59 d.Authenticate = (&socks.UsernamePassword{
60 Username: "username",
61 Password: "password",
62 }).Authenticate
63 a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
64 if err != nil {
65 t.Fatal(err)
66 }
67 if _, ok := a.(*socks.Addr); !ok {
68 t.Fatalf("got %+v; want socks.Addr", a)
69 }
70 })
71 t.Run("Cancel", func(t *testing.T) {
72 ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
73 if err != nil {
74 t.Fatal(err)
75 }
76 defer ss.Close()
77 d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
78 ctx, cancel := context.WithCancel(context.Background())
79 defer cancel()
80 dialErr := make(chan error)
81 go func() {
82 c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
83 if err == nil {
84 c.Close()
85 }
86 dialErr <- err
87 }()
88 time.Sleep(100 * time.Millisecond)
89 cancel()
90 err = <-dialErr
91 if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
92 t.Fatalf("got %v; want context.Canceled or equivalent", err)
93 }
94 })
95 t.Run("Deadline", func(t *testing.T) {
96 ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
97 if err != nil {
98 t.Fatal(err)
99 }
100 defer ss.Close()
101 d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
102 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
103 defer cancel()
104 c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
105 if err == nil {
106 c.Close()
107 }
108 if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
109 t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
110 }
111 })
112 t.Run("WithRogueServer", func(t *testing.T) {
113 ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
114 if err != nil {
115 t.Fatal(err)
116 }
117 defer ss.Close()
118 d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
119 for i := 0; i < 2*len(rogueCmdList); i++ {
120 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
121 defer cancel()
122 c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
123 if err == nil {
124 t.Log(c.(*socks.Conn).BoundAddr())
125 c.Close()
126 t.Error("should fail")
127 }
128 }
129 })
130 }
131
132 func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
133 if _, err := sockstest.ParseCmdRequest(b); err != nil {
134 return err
135 }
136 var bb [1]byte
137 for {
138 if _, err := rw.Read(bb[:]); err != nil {
139 return err
140 }
141 }
142 }
143
144 func rogueCmdFunc(rw io.ReadWriter, b []byte) error {
145 if _, err := sockstest.ParseCmdRequest(b); err != nil {
146 return err
147 }
148 rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))])
149 return nil
150 }
151
152 var rogueCmdList = [][]byte{
153 {0x05},
154 {0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
155 {0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b},
156 {0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3},
157 {0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'},
158 }
159
160 func parseDialError(err error) (perr, nerr error) {
161 if e, ok := err.(*net.OpError); ok {
162 err = e.Err
163 nerr = e
164 }
165 if e, ok := err.(*os.SyscallError); ok {
166 err = e.Err
167 }
168 perr = err
169 return
170 }
171
View as plain text