1
2
3
4 package main
5
6 import (
7 "bytes"
8 "encoding/json"
9 "errors"
10 "flag"
11 "fmt"
12 "net"
13 "os"
14 "path/filepath"
15 "strings"
16
17 securejoin "github.com/cyphar/filepath-securejoin"
18 "github.com/opencontainers/runtime-spec/specs-go"
19 libseccomp "github.com/seccomp/libseccomp-golang"
20 "github.com/sirupsen/logrus"
21 "golang.org/x/sys/unix"
22 )
23
24 var (
25 socketFile string
26 pidFile string
27 )
28
29 func closeStateFds(recvFds []int) {
30 for i := range recvFds {
31 unix.Close(i)
32 }
33 }
34
35
36
37
38
39 func parseStateFds(stateFds []string, recvFds []int) (uintptr, error) {
40
41 idx := -1
42 err := false
43
44 for i, name := range stateFds {
45 if name == specs.SeccompFdName && idx == -1 {
46 idx = i
47 continue
48 }
49
50
51 if name == specs.SeccompFdName && idx != -1 {
52 err = true
53 }
54 }
55
56 if idx == -1 || err {
57 return 0, errors.New("seccomp fd not found or malformed containerProcessState.Fds")
58 }
59
60 if idx >= len(recvFds) || idx < 0 {
61 return 0, errors.New("seccomp fd index out of range")
62 }
63
64 fd := uintptr(recvFds[idx])
65
66 for i := range recvFds {
67 if i == idx {
68 continue
69 }
70
71 unix.Close(recvFds[i])
72 }
73
74 return fd, nil
75 }
76
77 func handleNewMessage(sockfd int) (uintptr, string, error) {
78 const maxNameLen = 4096
79 stateBuf := make([]byte, maxNameLen)
80 oobSpace := unix.CmsgSpace(4)
81 oob := make([]byte, oobSpace)
82
83 n, oobn, _, _, err := unix.Recvmsg(sockfd, stateBuf, oob, 0)
84 if err != nil {
85 return 0, "", err
86 }
87 if n >= maxNameLen || oobn != oobSpace {
88 return 0, "", fmt.Errorf("recvfd: incorrect number of bytes read (n=%d oobn=%d)", n, oobn)
89 }
90
91
92 stateBuf = stateBuf[:n]
93 oob = oob[:oobn]
94
95 scms, err := unix.ParseSocketControlMessage(oob)
96 if err != nil {
97 return 0, "", err
98 }
99 if len(scms) != 1 {
100 return 0, "", fmt.Errorf("recvfd: number of SCMs is not 1: %d", len(scms))
101 }
102 scm := scms[0]
103
104 fds, err := unix.ParseUnixRights(&scm)
105 if err != nil {
106 return 0, "", err
107 }
108
109 containerProcessState := &specs.ContainerProcessState{}
110 err = json.Unmarshal(stateBuf, containerProcessState)
111 if err != nil {
112 closeStateFds(fds)
113 return 0, "", fmt.Errorf("cannot parse OCI state: %w", err)
114 }
115
116 fd, err := parseStateFds(containerProcessState.Fds, fds)
117 if err != nil {
118 closeStateFds(fds)
119 return 0, "", err
120 }
121
122 return fd, containerProcessState.Metadata, nil
123 }
124
125 func readArgString(pid uint32, offset int64) (string, error) {
126 buffer := make([]byte, 4096)
127
128 memfd, err := unix.Open(fmt.Sprintf("/proc/%d/mem", pid), unix.O_RDONLY, 0o777)
129 if err != nil {
130 return "", err
131 }
132 defer unix.Close(memfd)
133
134 _, err = unix.Pread(memfd, buffer, offset)
135 if err != nil {
136 return "", err
137 }
138
139 buffer[len(buffer)-1] = 0
140 s := buffer[:bytes.IndexByte(buffer, 0)]
141 return string(s), nil
142 }
143
144 func runMkdirForContainer(pid uint32, fileName string, mode uint32, metadata string) error {
145
146
147 newFile := fmt.Sprintf("%s-%s", fileName, metadata)
148 root := fmt.Sprintf("/proc/%d/cwd/", pid)
149
150 if strings.HasPrefix(fileName, "/") {
151
152 root = fmt.Sprintf("/proc/%d/root/", pid)
153 }
154
155 path, err := securejoin.SecureJoin(root, newFile)
156 if err != nil {
157 return err
158 }
159
160 return unix.Mkdir(path, mode)
161 }
162
163
164 func notifHandler(fd libseccomp.ScmpFd, metadata string) {
165 defer unix.Close(int(fd))
166 for {
167 req, err := libseccomp.NotifReceive(fd)
168 if err != nil {
169 logrus.Errorf("Error in NotifReceive(): %s", err)
170 continue
171 }
172 syscallName, err := req.Data.Syscall.GetName()
173 if err != nil {
174 logrus.Errorf("Error decoding syscall %v(): %s", req.Data.Syscall, err)
175 continue
176 }
177 logrus.Debugf("Received syscall %q, pid %v, arch %q, args %+v", syscallName, req.Pid, req.Data.Arch, req.Data.Args)
178
179 resp := &libseccomp.ScmpNotifResp{
180 ID: req.ID,
181 Error: 0,
182 Val: 0,
183 Flags: libseccomp.NotifRespFlagContinue,
184 }
185
186
187 if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
188 logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
189 continue
190 }
191
192 switch syscallName {
193 case "mkdir":
194 fileName, err := readArgString(req.Pid, int64(req.Data.Args[0]))
195 if err != nil {
196 logrus.Errorf("Cannot read argument: %s", err)
197 resp.Error = int32(unix.ENOSYS)
198 resp.Val = ^uint64(0)
199 goto sendResponse
200 }
201
202 logrus.Debugf("mkdir: %q", fileName)
203
204
205 if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
206 logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
207 continue
208 }
209
210 err = runMkdirForContainer(req.Pid, fileName, uint32(req.Data.Args[1]), metadata)
211 if err != nil {
212 resp.Error = int32(unix.ENOSYS)
213 resp.Val = ^uint64(0)
214 }
215 resp.Flags = 0
216 case "chmod", "fchmod", "fchmodat":
217 resp.Error = int32(unix.ENOMEDIUM)
218 resp.Val = ^uint64(0)
219 resp.Flags = 0
220 }
221
222 sendResponse:
223 if err = libseccomp.NotifRespond(fd, resp); err != nil {
224 logrus.Errorf("Error in notification response: %s", err)
225 continue
226 }
227 }
228 }
229
230 func main() {
231 flag.StringVar(&socketFile, "socketfile", "/run/seccomp-agent.socket", "Socket file")
232 flag.StringVar(&pidFile, "pid-file", "", "Pid file")
233 logrus.SetLevel(logrus.DebugLevel)
234
235
236 flag.Parse()
237 if flag.NArg() > 0 {
238 flag.PrintDefaults()
239 logrus.Fatal("Invalid command")
240 }
241
242 if err := os.Remove(socketFile); err != nil && !errors.Is(err, os.ErrNotExist) {
243 logrus.Fatalf("Cannot cleanup socket file: %v", err)
244 }
245
246 if pidFile != "" {
247 pid := fmt.Sprintf("%d", os.Getpid())
248 if err := os.WriteFile(pidFile, []byte(pid), 0o644); err != nil {
249 logrus.Fatalf("Cannot write pid file: %v", err)
250 }
251 }
252
253 logrus.Info("Waiting for seccomp file descriptors")
254 l, err := net.Listen("unix", socketFile)
255 if err != nil {
256 logrus.Fatalf("Cannot listen: %s", err)
257 }
258 defer l.Close()
259
260 for {
261 conn, err := l.Accept()
262 if err != nil {
263 logrus.Errorf("Cannot accept connection: %s", err)
264 continue
265 }
266 socket, err := conn.(*net.UnixConn).File()
267 conn.Close()
268 if err != nil {
269 logrus.Errorf("Cannot get socket: %v", err)
270 continue
271 }
272 newFd, metadata, err := handleNewMessage(int(socket.Fd()))
273 socket.Close()
274 if err != nil {
275 logrus.Errorf("Error receiving seccomp file descriptor: %v", err)
276 continue
277 }
278
279
280
281
282 metadata = filepath.Base(metadata)
283 if strings.Contains(metadata, "/") {
284
285 metadata = "agent-generated-suffix"
286 }
287
288 logrus.Infof("Received new seccomp fd: %v", newFd)
289 go notifHandler(libseccomp.ScmpFd(newFd), metadata)
290 }
291 }
292
View as plain text