package containers import ( "context" "errors" "fmt" "strconv" "strings" "time" "github.com/containerd/containerd" "github.com/containerd/containerd/api/services/tasks/v1" "github.com/containerd/containerd/containers" "k8s.io/apimachinery/pkg/util/wait" criruntime "k8s.io/cri-api/pkg/apis/runtime/v1" "edge-infra.dev/pkg/sds/devices/class" cc "edge-infra.dev/pkg/sds/devices/agent/common" "edge-infra.dev/pkg/sds/devices/logger" ) // runtime states for containers that are not running var notRunningStates = map[criruntime.ContainerState]string{ criruntime.ContainerState_CONTAINER_EXITED: "", } // exponential wait 10ms, 20ms, 40ms, 80ms, 160ms, 320ms, 640ms, 1280ms var defaultBackoff = wait.Backoff{ Steps: 9, Duration: 10 * time.Millisecond, Factor: 2, Jitter: 0.1, Cap: time.Millisecond * 2560, } var ( // ErrPodSandboxDoesNotExist is thrown if the pod sandbox expected // for a container does not exist ErrPodSandboxDoesNotExist = errors.New("pod sandbox does not exist") // namespaces to ignore from cgroup application ignoreNamespaces = map[string]string{"kube-system": "", "device-system": ""} // containers to ignore from cgroup application ignoreContainers = map[string]string{"etcd": "", "linkerd-init": "", "linkerd-proxy": ""} ) const ( annContainerName = "io.kubernetes.container.name" annPodName = "io.kubernetes.pod.name" annPodNamespace = "io.kubernetes.pod.namespace" ) // FetchAllContainers devices will return a list of containers that request a device class func FetchAllContainers(ctx context.Context, ctrClient *containerd.Client, runtimeClient criruntime.RuntimeServiceClient) (map[string]*containers.Container, error) { log := logger.FromContext(ctx) containerService := ctrClient.ContainerService() containersStore, err := containerService.List(ctx) if err != nil { return nil, err } containers := map[string]*containers.Container{} for _, ctr := range containersStore { running, err := containerIsRunning(ctx, runtimeClient, ctr.ID) if err != nil { log.Error("error checking container state", "ctrID", ctr.ID, "error", err) continue } if !running { continue } requestCtr, err := addContainerDeviceRequests(ctx, runtimeClient, &ctr) if err != nil { log.Error("error fetching device requests", "ctrID", ctr.ID, "error", err) continue } if requestCtr == nil { continue } containers[requestCtr.ID] = requestCtr } return containers, nil } // FetchContainer will attempt to fetch a container given the container ID and populate it // with the device class request labels from the pod sandbox func FetchContainer(ctx context.Context, ctrClient *containerd.Client, runtimeClient criruntime.RuntimeServiceClient, ctrID string) (*containers.Container, error) { if running, err := containerIsRunning(ctx, runtimeClient, ctrID); err != nil || !running { return nil, err } ctr, err := getContainer(ctx, ctrClient, ctrID) if err != nil { return nil, err } requestCtr, err := addContainerDeviceRequests(ctx, runtimeClient, &ctr) if err != nil { return nil, err } if requestCtr == nil { return nil, nil } return requestCtr, nil } // FetchContainerParentProcessID will fetch the parent process id for a given container func FetchContainerParentProcessID(ctx context.Context, ctrClient *containerd.Client, ctrID string) (string, error) { taskResult, err := ctrClient.TaskService().ListPids(ctx, &tasks.ListPidsRequest{ ContainerID: ctrID, }) if err != nil { return "", fmt.Errorf("could not find container process ids: %w", err) } if len(taskResult.Processes) == 0 { return "", fmt.Errorf("could not find container process ids") } return strconv.Itoa(int(taskResult.Processes[0].Pid)), nil } // FetchContainerRootPath fetches the containers root path func FetchContainerRootPath(ctx context.Context, ctrClient *containerd.Client, ctr *containers.Container) (string, error) { pid, err := FetchContainerParentProcessID(ctx, ctrClient, ctr.ID) if err != nil { return "", fmt.Errorf("could not find container %s process id: %w", ctr.ID, err) } return fmt.Sprintf("/proc/%s/root", pid), nil } // WithContainerLogger returns a device logger instantiatd with container information in the log func WithContainerLogger(ctx context.Context, ctr *containers.Container) context.Context { logLevel := ctr.Labels[cc.AnnDeviceLogLevel] ctrName := ctr.Labels[cc.AnnContainerName] podName := ctr.Labels[annPodName] podNamespace := ctr.Labels[annPodNamespace] opts := []logger.Option{ logger.WithLevel(logger.ToLevel(logLevel)), } return logger.IntoContext(ctx, logger.New(opts...).WithGroup(ctrName).With("container", ctrName, "containerId", ctr.ID, "pod", podName, "namespace", podNamespace)) } // containerIsRunning will check that the container is running. If container is not found, the client will retry exponentially. func containerIsRunning(ctx context.Context, runtimeClient criruntime.RuntimeServiceClient, ctrID string) (bool, error) { var state criruntime.ContainerState var lastErr error // ignore container ids that are pod sandbox podSandbox, err := runtimeClient.PodSandboxStatus(ctx, &criruntime.PodSandboxStatusRequest{PodSandboxId: ctrID}) if err == nil && podSandbox != nil { return false, nil } if err := wait.ExponentialBackoffWithContext(ctx, defaultBackoff, func(ctx context.Context) (done bool, err error) { status, err := runtimeClient.ContainerStatus(ctx, &criruntime.ContainerStatusRequest{ContainerId: ctrID}) if err != nil { lastErr = err return false, nil } state = status.Status.State return true, nil }); err != nil { return false, fmt.Errorf("error fetching container runtime state: %w, %w", err, lastErr) } _, ok := notRunningStates[state] if ok { return false, nil } return true, nil } // getContainer does a get request for a container and attempts exponential backoff if the container is not found. func getContainer(ctx context.Context, ctrClient *containerd.Client, ctrID string) (containers.Container, error) { var container containers.Container var lastError error containerService := ctrClient.ContainerService() if err := wait.ExponentialBackoffWithContext(ctx, defaultBackoff, func(_ context.Context) (done bool, err error) { container, err = containerService.Get(ctx, ctrID) if err != nil { lastError = err return false, nil } return true, nil }); err != nil { return container, fmt.Errorf("could not fetch container %w: %s: %w", err, ctrID, lastError) } return container, nil } // addContainerDeviceRequests will add the containers device request by doing a lookup on the pod sandbox from container cri func addContainerDeviceRequests(ctx context.Context, runtimeClient criruntime.RuntimeServiceClient, ctr *containers.Container) (*containers.Container, error) { ctrName := ctr.Labels[cc.AnnContainerName] podName := ctr.Labels[cc.AnnPodName] podNamespace := ctr.Labels[cc.AnnPodNamespace] if _, ok := ignoreContainers[ctrName]; ok { return nil, nil } if _, ok := ignoreNamespaces[podNamespace]; ok { return nil, nil } if ctrName == "" || podName == "" { return nil, fmt.Errorf("missing pod or container name, container: %s, pod: %s", ctrName, podName) } podSandboxList, err := runtimeClient.ListPodSandbox(ctx, &criruntime.ListPodSandboxRequest{ Filter: &criruntime.PodSandboxFilter{ LabelSelector: map[string]string{ cc.AnnPodName: podName, }, }}, ) if err != nil { return nil, err } if podSandboxList == nil || len(podSandboxList.Items) == 0 { return nil, ErrPodSandboxDoesNotExist } pod := podSandboxList.Items[0] requestsDevice := false for key, value := range pod.Annotations { if !class.IsDeviceClass(key) { continue } classes := strings.Split(value, ",") parsedCtrName, err := parseContainerName(key) if err != nil { return nil, err } if parsedCtrName != ctrName { continue } for _, className := range classes { if className == "" { continue } requestsDevice = true ctr.Labels[class.FmtClassLabel(className)] = "requested" } } // ignore containers that do no request devices if !requestsDevice { return nil, nil } ctr.Labels[cc.AnnDeviceLogLevel] = "info" logLevel, ok := pod.Annotations[cc.AnnDeviceLogLevel] if ok { ctr.Labels[cc.AnnDeviceLogLevel] = logLevel } return ctr, nil } // parseContainerName will return the container name from the device class annotation func parseContainerName(className string) (string, error) { return class.BaseName(className) }