1 package proxy
2
3 import (
4 "encoding/binary"
5 "net"
6 "strings"
7 "sync"
8 "syscall"
9 "time"
10 )
11
12 const (
13
14 UDPConnTrackTimeout = 90 * time.Second
15
16 UDPBufSize = 65507
17 )
18
19
20
21 type connTrackKey struct {
22 IPHigh uint64
23 IPLow uint64
24 Port int
25 }
26
27 func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
28 if len(addr.IP) == net.IPv4len {
29 return &connTrackKey{
30 IPHigh: 0,
31 IPLow: uint64(binary.BigEndian.Uint32(addr.IP)),
32 Port: addr.Port,
33 }
34 }
35 return &connTrackKey{
36 IPHigh: binary.BigEndian.Uint64(addr.IP[:8]),
37 IPLow: binary.BigEndian.Uint64(addr.IP[8:]),
38 Port: addr.Port,
39 }
40 }
41
42 type connTrackMap map[connTrackKey]*net.UDPConn
43
44
45
46
47 type UDPProxy struct {
48 Logger logger
49 listener *net.UDPConn
50 frontendAddr *net.UDPAddr
51 backendAddr *net.UDPAddr
52 connTrackTable connTrackMap
53 connTrackLock sync.Mutex
54 }
55
56
57 func NewUDPProxy(frontendAddr, backendAddr *net.UDPAddr, ops ...func(*UDPProxy)) (*UDPProxy, error) {
58 listener, err := net.ListenUDP("udp", frontendAddr)
59 if err != nil {
60 return nil, err
61 }
62
63 proxy := &UDPProxy{
64 listener: listener,
65 frontendAddr: listener.LocalAddr().(*net.UDPAddr),
66 backendAddr: backendAddr,
67 connTrackTable: make(connTrackMap),
68 Logger: &noopLogger{},
69 }
70
71 for _, op := range ops {
72 op(proxy)
73 }
74
75 return proxy, nil
76 }
77
78 func (proxy *UDPProxy) replyLoop(proxyConn *net.UDPConn, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
79 defer func() {
80 proxy.connTrackLock.Lock()
81 delete(proxy.connTrackTable, *clientKey)
82 proxy.connTrackLock.Unlock()
83 _ = proxyConn.Close()
84 }()
85
86 readBuf := make([]byte, UDPBufSize)
87 for {
88 _ = proxyConn.SetReadDeadline(time.Now().Add(UDPConnTrackTimeout))
89 again:
90 read, err := proxyConn.Read(readBuf)
91 if err != nil {
92 if err, ok := err.(*net.OpError); ok && err.Err == syscall.ECONNREFUSED {
93
94
95
96
97
98 goto again
99 }
100 return
101 }
102 for i := 0; i != read; {
103 written, err := proxy.listener.WriteToUDP(readBuf[i:read], clientAddr)
104 if err != nil {
105 return
106 }
107 i += written
108 }
109 }
110 }
111
112
113 func (proxy *UDPProxy) Run() {
114 readBuf := make([]byte, UDPBufSize)
115 for {
116 read, from, err := proxy.listener.ReadFromUDP(readBuf)
117 if err != nil {
118
119
120
121 if !isClosedError(err) {
122 proxy.Logger.Printf("Stopping proxy on udp/%v for udp/%v (%s)", proxy.frontendAddr, proxy.backendAddr, err)
123 }
124 break
125 }
126
127 fromKey := newConnTrackKey(from)
128 proxy.connTrackLock.Lock()
129 proxyConn, hit := proxy.connTrackTable[*fromKey]
130 if !hit {
131 proxyConn, err = net.DialUDP("udp", nil, proxy.backendAddr)
132 if err != nil {
133 proxy.Logger.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
134 proxy.connTrackLock.Unlock()
135 continue
136 }
137 proxy.connTrackTable[*fromKey] = proxyConn
138 go proxy.replyLoop(proxyConn, from, fromKey)
139 }
140 proxy.connTrackLock.Unlock()
141 for i := 0; i != read; {
142 written, err := proxyConn.Write(readBuf[i:read])
143 if err != nil {
144 proxy.Logger.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
145 break
146 }
147 i += written
148 }
149 }
150 }
151
152
153 func (proxy *UDPProxy) Close() {
154 _ = proxy.listener.Close()
155 proxy.connTrackLock.Lock()
156 defer proxy.connTrackLock.Unlock()
157 for _, conn := range proxy.connTrackTable {
158 _ = conn.Close()
159 }
160 }
161
162
163 func (proxy *UDPProxy) FrontendAddr() net.Addr { return proxy.frontendAddr }
164
165
166 func (proxy *UDPProxy) BackendAddr() net.Addr { return proxy.backendAddr }
167
168 func isClosedError(err error) bool {
169
175 return strings.HasSuffix(err.Error(), "use of closed network connection")
176 }
177
View as plain text