...

Source file src/github.com/opencontainers/runc/contrib/cmd/seccompagent/seccompagent.go

Documentation: github.com/opencontainers/runc/contrib/cmd/seccompagent

     1  //go:build linux && seccomp
     2  // +build linux,seccomp
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"net"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  
    17  	securejoin "github.com/cyphar/filepath-securejoin"
    18  	"github.com/opencontainers/runtime-spec/specs-go"
    19  	libseccomp "github.com/seccomp/libseccomp-golang"
    20  	"github.com/sirupsen/logrus"
    21  	"golang.org/x/sys/unix"
    22  )
    23  
    24  var (
    25  	socketFile string
    26  	pidFile    string
    27  )
    28  
    29  func closeStateFds(recvFds []int) {
    30  	for i := range recvFds {
    31  		unix.Close(i)
    32  	}
    33  }
    34  
    35  // parseStateFds returns the seccomp-fd and closes the rest of the fds in recvFds.
    36  // In case of error, no fd is closed.
    37  // StateFds is assumed to be formatted as specs.ContainerProcessState.Fds and
    38  // recvFds the corresponding list of received fds in the same SCM_RIGHT message.
    39  func parseStateFds(stateFds []string, recvFds []int) (uintptr, error) {
    40  	// Let's find the index in stateFds of the seccomp-fd.
    41  	idx := -1
    42  	err := false
    43  
    44  	for i, name := range stateFds {
    45  		if name == specs.SeccompFdName && idx == -1 {
    46  			idx = i
    47  			continue
    48  		}
    49  
    50  		// We found the seccompFdName twice. Error out!
    51  		if name == specs.SeccompFdName && idx != -1 {
    52  			err = true
    53  		}
    54  	}
    55  
    56  	if idx == -1 || err {
    57  		return 0, errors.New("seccomp fd not found or malformed containerProcessState.Fds")
    58  	}
    59  
    60  	if idx >= len(recvFds) || idx < 0 {
    61  		return 0, errors.New("seccomp fd index out of range")
    62  	}
    63  
    64  	fd := uintptr(recvFds[idx])
    65  
    66  	for i := range recvFds {
    67  		if i == idx {
    68  			continue
    69  		}
    70  
    71  		unix.Close(recvFds[i])
    72  	}
    73  
    74  	return fd, nil
    75  }
    76  
    77  func handleNewMessage(sockfd int) (uintptr, string, error) {
    78  	const maxNameLen = 4096
    79  	stateBuf := make([]byte, maxNameLen)
    80  	oobSpace := unix.CmsgSpace(4)
    81  	oob := make([]byte, oobSpace)
    82  
    83  	n, oobn, _, _, err := unix.Recvmsg(sockfd, stateBuf, oob, 0)
    84  	if err != nil {
    85  		return 0, "", err
    86  	}
    87  	if n >= maxNameLen || oobn != oobSpace {
    88  		return 0, "", fmt.Errorf("recvfd: incorrect number of bytes read (n=%d oobn=%d)", n, oobn)
    89  	}
    90  
    91  	// Truncate.
    92  	stateBuf = stateBuf[:n]
    93  	oob = oob[:oobn]
    94  
    95  	scms, err := unix.ParseSocketControlMessage(oob)
    96  	if err != nil {
    97  		return 0, "", err
    98  	}
    99  	if len(scms) != 1 {
   100  		return 0, "", fmt.Errorf("recvfd: number of SCMs is not 1: %d", len(scms))
   101  	}
   102  	scm := scms[0]
   103  
   104  	fds, err := unix.ParseUnixRights(&scm)
   105  	if err != nil {
   106  		return 0, "", err
   107  	}
   108  
   109  	containerProcessState := &specs.ContainerProcessState{}
   110  	err = json.Unmarshal(stateBuf, containerProcessState)
   111  	if err != nil {
   112  		closeStateFds(fds)
   113  		return 0, "", fmt.Errorf("cannot parse OCI state: %w", err)
   114  	}
   115  
   116  	fd, err := parseStateFds(containerProcessState.Fds, fds)
   117  	if err != nil {
   118  		closeStateFds(fds)
   119  		return 0, "", err
   120  	}
   121  
   122  	return fd, containerProcessState.Metadata, nil
   123  }
   124  
   125  func readArgString(pid uint32, offset int64) (string, error) {
   126  	buffer := make([]byte, 4096) // PATH_MAX
   127  
   128  	memfd, err := unix.Open(fmt.Sprintf("/proc/%d/mem", pid), unix.O_RDONLY, 0o777)
   129  	if err != nil {
   130  		return "", err
   131  	}
   132  	defer unix.Close(memfd)
   133  
   134  	_, err = unix.Pread(memfd, buffer, offset)
   135  	if err != nil {
   136  		return "", err
   137  	}
   138  
   139  	buffer[len(buffer)-1] = 0
   140  	s := buffer[:bytes.IndexByte(buffer, 0)]
   141  	return string(s), nil
   142  }
   143  
   144  func runMkdirForContainer(pid uint32, fileName string, mode uint32, metadata string) error {
   145  	// We validated before that metadata is not a string that can make
   146  	// newFile a file in a different location other than root.
   147  	newFile := fmt.Sprintf("%s-%s", fileName, metadata)
   148  	root := fmt.Sprintf("/proc/%d/cwd/", pid)
   149  
   150  	if strings.HasPrefix(fileName, "/") {
   151  		// If it starts with /, use the rootfs as base
   152  		root = fmt.Sprintf("/proc/%d/root/", pid)
   153  	}
   154  
   155  	path, err := securejoin.SecureJoin(root, newFile)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	return unix.Mkdir(path, mode)
   161  }
   162  
   163  // notifHandler handles seccomp notifications and responses
   164  func notifHandler(fd libseccomp.ScmpFd, metadata string) {
   165  	defer unix.Close(int(fd))
   166  	for {
   167  		req, err := libseccomp.NotifReceive(fd)
   168  		if err != nil {
   169  			logrus.Errorf("Error in NotifReceive(): %s", err)
   170  			continue
   171  		}
   172  		syscallName, err := req.Data.Syscall.GetName()
   173  		if err != nil {
   174  			logrus.Errorf("Error decoding syscall %v(): %s", req.Data.Syscall, err)
   175  			continue
   176  		}
   177  		logrus.Debugf("Received syscall %q, pid %v, arch %q, args %+v", syscallName, req.Pid, req.Data.Arch, req.Data.Args)
   178  
   179  		resp := &libseccomp.ScmpNotifResp{
   180  			ID:    req.ID,
   181  			Error: 0,
   182  			Val:   0,
   183  			Flags: libseccomp.NotifRespFlagContinue,
   184  		}
   185  
   186  		// TOCTOU check
   187  		if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
   188  			logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
   189  			continue
   190  		}
   191  
   192  		switch syscallName {
   193  		case "mkdir":
   194  			fileName, err := readArgString(req.Pid, int64(req.Data.Args[0]))
   195  			if err != nil {
   196  				logrus.Errorf("Cannot read argument: %s", err)
   197  				resp.Error = int32(unix.ENOSYS)
   198  				resp.Val = ^uint64(0) // -1
   199  				goto sendResponse
   200  			}
   201  
   202  			logrus.Debugf("mkdir: %q", fileName)
   203  
   204  			// TOCTOU check
   205  			if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
   206  				logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
   207  				continue
   208  			}
   209  
   210  			err = runMkdirForContainer(req.Pid, fileName, uint32(req.Data.Args[1]), metadata)
   211  			if err != nil {
   212  				resp.Error = int32(unix.ENOSYS)
   213  				resp.Val = ^uint64(0) // -1
   214  			}
   215  			resp.Flags = 0
   216  		case "chmod", "fchmod", "fchmodat":
   217  			resp.Error = int32(unix.ENOMEDIUM)
   218  			resp.Val = ^uint64(0) // -1
   219  			resp.Flags = 0
   220  		}
   221  
   222  	sendResponse:
   223  		if err = libseccomp.NotifRespond(fd, resp); err != nil {
   224  			logrus.Errorf("Error in notification response: %s", err)
   225  			continue
   226  		}
   227  	}
   228  }
   229  
   230  func main() {
   231  	flag.StringVar(&socketFile, "socketfile", "/run/seccomp-agent.socket", "Socket file")
   232  	flag.StringVar(&pidFile, "pid-file", "", "Pid file")
   233  	logrus.SetLevel(logrus.DebugLevel)
   234  
   235  	// Parse arguments
   236  	flag.Parse()
   237  	if flag.NArg() > 0 {
   238  		flag.PrintDefaults()
   239  		logrus.Fatal("Invalid command")
   240  	}
   241  
   242  	if err := os.Remove(socketFile); err != nil && !errors.Is(err, os.ErrNotExist) {
   243  		logrus.Fatalf("Cannot cleanup socket file: %v", err)
   244  	}
   245  
   246  	if pidFile != "" {
   247  		pid := fmt.Sprintf("%d", os.Getpid())
   248  		if err := os.WriteFile(pidFile, []byte(pid), 0o644); err != nil {
   249  			logrus.Fatalf("Cannot write pid file: %v", err)
   250  		}
   251  	}
   252  
   253  	logrus.Info("Waiting for seccomp file descriptors")
   254  	l, err := net.Listen("unix", socketFile)
   255  	if err != nil {
   256  		logrus.Fatalf("Cannot listen: %s", err)
   257  	}
   258  	defer l.Close()
   259  
   260  	for {
   261  		conn, err := l.Accept()
   262  		if err != nil {
   263  			logrus.Errorf("Cannot accept connection: %s", err)
   264  			continue
   265  		}
   266  		socket, err := conn.(*net.UnixConn).File()
   267  		conn.Close()
   268  		if err != nil {
   269  			logrus.Errorf("Cannot get socket: %v", err)
   270  			continue
   271  		}
   272  		newFd, metadata, err := handleNewMessage(int(socket.Fd()))
   273  		socket.Close()
   274  		if err != nil {
   275  			logrus.Errorf("Error receiving seccomp file descriptor: %v", err)
   276  			continue
   277  		}
   278  
   279  		// Make sure we don't allow strings like "/../p", as that means
   280  		// a file in a different location than expected. We just want
   281  		// safe things to use as a suffix for a file name.
   282  		metadata = filepath.Base(metadata)
   283  		if strings.Contains(metadata, "/") {
   284  			// Fallback to a safe string.
   285  			metadata = "agent-generated-suffix"
   286  		}
   287  
   288  		logrus.Infof("Received new seccomp fd: %v", newFd)
   289  		go notifHandler(libseccomp.ScmpFd(newFd), metadata)
   290  	}
   291  }
   292  

View as plain text