1
2
3
4
19
20 package util
21
22 import (
23 "context"
24 "fmt"
25 "net"
26 "net/url"
27 "os"
28 "path/filepath"
29
30 "golang.org/x/sys/unix"
31 "k8s.io/klog/v2"
32 )
33
34 const (
35
36 unixProtocol = "unix"
37 )
38
39
40 func CreateListener(endpoint string) (net.Listener, error) {
41 protocol, addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol)
42 if err != nil {
43 return nil, err
44 }
45 if protocol != unixProtocol {
46 return nil, fmt.Errorf("only support unix socket endpoint")
47 }
48
49
50 err = unix.Unlink(addr)
51 if err != nil && !os.IsNotExist(err) {
52 return nil, fmt.Errorf("failed to unlink socket file %q: %v", addr, err)
53 }
54
55 if err := os.MkdirAll(filepath.Dir(addr), 0750); err != nil {
56 return nil, fmt.Errorf("error creating socket directory %q: %v", filepath.Dir(addr), err)
57 }
58
59
60 file, err := os.CreateTemp(filepath.Dir(addr), "")
61 if err != nil {
62 return nil, fmt.Errorf("failed to create temporary file: %v", err)
63 }
64
65 if err := os.Remove(file.Name()); err != nil {
66 return nil, fmt.Errorf("failed to remove temporary file: %v", err)
67 }
68
69 l, err := net.Listen(protocol, file.Name())
70 if err != nil {
71 return nil, err
72 }
73
74 if err = os.Rename(file.Name(), addr); err != nil {
75 return nil, fmt.Errorf("failed to move temporary file to addr %q: %v", addr, err)
76 }
77
78 return l, nil
79 }
80
81
82 func GetAddressAndDialer(endpoint string) (string, func(ctx context.Context, addr string) (net.Conn, error), error) {
83 protocol, addr, err := parseEndpointWithFallbackProtocol(endpoint, unixProtocol)
84 if err != nil {
85 return "", nil, err
86 }
87 if protocol != unixProtocol {
88 return "", nil, fmt.Errorf("only support unix socket endpoint")
89 }
90
91 return addr, dial, nil
92 }
93
94 func dial(ctx context.Context, addr string) (net.Conn, error) {
95 return (&net.Dialer{}).DialContext(ctx, unixProtocol, addr)
96 }
97
98 func parseEndpointWithFallbackProtocol(endpoint string, fallbackProtocol string) (protocol string, addr string, err error) {
99 if protocol, addr, err = parseEndpoint(endpoint); err != nil && protocol == "" {
100 fallbackEndpoint := fallbackProtocol + "://" + endpoint
101 protocol, addr, err = parseEndpoint(fallbackEndpoint)
102 if err == nil {
103 klog.InfoS("Using this endpoint is deprecated, please consider using full URL format", "endpoint", endpoint, "URL", fallbackEndpoint)
104 }
105 }
106 return
107 }
108
109 func parseEndpoint(endpoint string) (string, string, error) {
110 u, err := url.Parse(endpoint)
111 if err != nil {
112 return "", "", err
113 }
114
115 switch u.Scheme {
116 case "tcp":
117 return "tcp", u.Host, nil
118
119 case "unix":
120 return "unix", u.Path, nil
121
122 case "":
123 return "", "", fmt.Errorf("using %q as endpoint is deprecated, please consider using full url format", endpoint)
124
125 default:
126 return u.Scheme, "", fmt.Errorf("protocol %q not supported", u.Scheme)
127 }
128 }
129
130
131 func LocalEndpoint(path, file string) (string, error) {
132 u := url.URL{
133 Scheme: unixProtocol,
134 Path: path,
135 }
136 return filepath.Join(u.String(), file+".sock"), nil
137 }
138
139
140 func NormalizePath(path string) string {
141 return path
142 }
143
View as plain text