...
1import socket
2import sys
3from threading import Thread
4
5
6def server(dst_host: str, dst_port: int, src_host: str, src_port: int) -> None:
7 listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
8 listen_sock.bind((src_host, src_port))
9 listen_sock.listen(5)
10
11 while True:
12 src_sock = listen_sock.accept()[0]
13 dst_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
14 dst_sock.connect((dst_host, dst_port))
15 Thread(target=forward, args=(src_sock, dst_sock), daemon=True).start()
16 Thread(target=forward, args=(dst_sock, src_sock), daemon=True).start()
17
18
19def forward(dst: socket.socket, src: socket.socket) -> None:
20 while True:
21 data = src.recv(64 * 1024)
22 if data:
23 dst.sendall(data)
24 else:
25 try:
26 # Close destination first as origin is likely already closed
27 dst.shutdown(socket.SHUT_WR)
28 src.shutdown(socket.SHUT_RD)
29 except OSError:
30 pass
31 return
32
33
34if __name__ == "__main__":
35 dst_host, dst_port_str, src_host, src_port_str = sys.argv[1:]
36 dst_port = int(dst_port_str)
37 src_port = int(src_port_str)
38 server(dst_host, dst_port, src_host, src_port)
View as plain text