//nolint:typecheck package agent import ( "bytes" "context" "errors" "fmt" "io" "log/slog" "maps" "os" "path/filepath" "testing" "time" . "github.com/onsi/ginkgo/v2" //nolint:revive . "github.com/onsi/gomega" //nolint:revive "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "github.com/containerd/containerd" containers "github.com/containerd/containerd/containers" "github.com/containerd/containerd/events" typeurl "github.com/containerd/typeurl/v2" "github.com/golang/mock/gomock" specs "github.com/opencontainers/runtime-spec/specs-go" testfs "gotest.tools/v3/fs" kerrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/dynamic/dynamicinformer" "k8s.io/client-go/rest" criruntime "k8s.io/cri-api/pkg/apis/runtime/v1" "edge-infra.dev/pkg/k8s/runtime/sap" dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" v1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" plugins "edge-infra.dev/pkg/sds/devices/k8s/device-plugins" "edge-infra.dev/pkg/sds/devices/logger" "edge-infra.dev/pkg/lib/kernel/devices" "edge-infra.dev/pkg/lib/kernel/netlink/socket" "edge-infra.dev/pkg/lib/kernel/udev" "edge-infra.dev/pkg/lib/kernel/udev/reader" "edge-infra.dev/pkg/lib/uuid" "edge-infra.dev/pkg/sds/devices/agent/cgroups" cc "edge-infra.dev/pkg/sds/devices/agent/common" "edge-infra.dev/pkg/sds/devices/agent/mocks" "edge-infra.dev/pkg/sds/devices/agent/udevproxy" "edge-infra.dev/pkg/sds/devices/class" ) //go:generate mockgen -destination=./mocks/mock_criclient.go -package=mocks k8s.io/cri-api/pkg/apis/runtime/v1 RuntimeServiceClient //go:generate mockgen -destination=./mocks/mock_containersvc.go -package=mocks github.com/containerd/containerd/containers Store //go:generate mockgen -destination=./mocks/mock_tasksvc.go -package=mocks github.com/containerd/containerd/api/services/tasks/v1 TasksClient //go:generate mockgen -destination=./mocks/mock_eventsvc.go -package=mocks github.com/containerd/containerd EventService //go:generate mockgen -destination=./mocks/mock_cgroups.go -package=mocks edge-infra.dev/pkg/sds/devices/agent/cgroups CgroupRequest //go:generate mockgen -destination=./mocks/mock_deviceplugin.go -package=mocks edge-infra.dev/pkg/sds/devices/k8s/device-plugins Plugin //go:generate mockgen -destination=./mocks/mock_udevproxy.go -package=mocks edge-infra.dev/pkg/sds/devices/agent/udevproxy UdevProxy //go:generate mockgen -destination=./mocks/mock_statuswatcher.go -package=mocks sigs.k8s.io/cli-utils/pkg/kstatus/watcher StatusWatcher //go:generate mockgen -destination=./mocks/mock_informer.go -package=mocks k8s.io/client-go/dynamic/dynamicinformer DynamicSharedInformerFactory // test configuration var ( timeout = time.Second * 10 cfg = Config{} ) // test data var ( usbPodName = "usb-pod" usbContainerName = "usb-container" socketServerAddr = "127.0.0.1:8082" cgroupPath = "kubepods-burstable-pod52da81bb_7d4a_49cf_ae91_424e9d35bb84.slice:cri-containerd:e58ea88f915f3ab51ba0f027084c384ec07491ef2f8b877db04dc94cac75c0f5" networkNamespacePath = "/proc/23342/ns/net" ) type ueventTestCase = []struct { inputEvent udev.UEvent expectedEvent string } type podTestCase map[string]podDefinition type podDefinition struct { name string namespace string annotations map[string]string containerTestCase []containerTestCase } type containerTestCase struct { name string containerID string cgroupPath string networkNamespacePath string state criruntime.ContainerState } var ueventTests = ueventTestCase{ { inputEvent: udev.UEvent{SequenceNumber: "1", EnvVars: map[string]string{"ACTION": "add", "SEQNUM": "1", "SUBSYSTEM": "usb", "DEVPATH": "testdata/sys/devices/usb"}}, expectedEvent: "^.*(libudev)(.*SUBSYSTEM=usb)?(.*ACTION=add)?(.*DEVPATH=testdata/sys/devices/usb)?(.*SEQNUM=1).*$", }, { inputEvent: udev.UEvent{SequenceNumber: "1", EnvVars: map[string]string{"ACTION": "remove", "SEQNUM": "2", "SUBSYSTEM": "usb", "DEVPATH": "testdata/sys/devices/usb"}}, expectedEvent: "^.*(libudev)(.*SUBSYSTEM=usb)?(.*ACTION=remove)?(.*DEVPATH=testdata/sys/devices/usb)?(.*SEQNUM=2).*$", }, } var podTests = podTestCase{ usbPodName: { name: usbPodName, namespace: usbPodName, annotations: map[string]string{ class.FmtClassLabel(usbContainerName): "usb", }, containerTestCase: []containerTestCase{ { name: usbContainerName, containerID: uuid.New().UUID, cgroupPath: cgroupPath, networkNamespacePath: networkNamespacePath, state: criruntime.ContainerState_CONTAINER_RUNNING, }, }, }, } var ctrl *gomock.Controller var mockInformerMethod = func(_ context.Context, _ client.Client, _ *rest.Config, _ chan *v1.DeviceClass, _ ...v1.ListOption) (dynamicinformer.DynamicSharedInformerFactory, error) { informerMock := mocks.NewMockDynamicSharedInformerFactory(ctrl) informerMock.EXPECT().Start(gomock.Any()) return informerMock, nil } func TestDeviceRead(t *testing.T) { testDir := testfs.NewDir(t, etcPath) defer testDir.Remove() kubeletPath = filepath.Join(testDir.Path(), "kubelet.sock") newDeviceClassInformer = mockInformerMethod cfg = Config{ ClassesPath: testDir.Path(), } ctrl = gomock.NewController(t) RegisterFailHandler(Fail) RunSpecs(t, "Device Tests") } var _ = BeforeEach(func() { _, _ = os.Create(kubeletPath) devices.SetTestEnvs() }) var _ = AfterEach(func() { _ = os.Remove(cfg.ClassesPath) _ = os.Remove(kubeletPath) os.Setenv("HOSTNAME", "worker-1") }) var _ = Describe("Device agent test suite", func() { It("No containers exist", func() { ctx := createContext() ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() defaultDeviceClass := defaultClass() defaultDeviceClassSet := defaultClassDeviceSet() k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet) m := generateMocks() eventChan := make(chan *events.Envelope) errChan := make(chan error) m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(eventChan, errChan) m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return([]containers.Container{}, nil) m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1) m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1) m.devicePluginSvc.EXPECT().Stop().MinTimes(1) pluginFn = func(_ *v1.DeviceClass) plugins.Plugin { return m.devicePluginSvc } udevproxy.UDevProxyFn = func() udevproxy.UdevProxy { return m.udevproxySvc } m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).AnyTimes() statusWatcher := mocks.NewMockStatusWatcher(ctrl) statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc)) client, err := containerd.New("", svcOpts) Expect(err).To(BeNil()) resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"}) agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, generateDecoder(), resourceManager) Expect(err).To(BeNil()) err = agent.Start(ctx) Expect(err).To(BeNil()) }) It("Container exists requesting usb device class", func() { ctx := createContext() ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() k8sClient := k8sClientWithObjects(usbClass(), defaultClass()) m := generateMocks() eventChan := make(chan *events.Envelope) errChan := make(chan error) m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(eventChan, errChan) m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return(containerListResponse(usbPodName), nil) m.criClient.EXPECT().PodSandboxStatus(gomock.Any(), gomock.Any()).Return(nil, errors.New("pod sandbox not found")) m.criClient.EXPECT().ContainerStatus(gomock.Any(), gomock.Any()).Return(runningStatusResponse(usbPodName, usbContainerName), nil) m.criClient.EXPECT().ListPodSandbox(gomock.Any(), podSandboxListRequest(usbPodName)).Return(podSandboxListResponse(usbPodName), nil) m.cgroupSvc.EXPECT().Apply(gomock.Any()).MinTimes(1) m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).AnyTimes() m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1) m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1) m.devicePluginSvc.EXPECT().Stop().MinTimes(1) pluginFn = func(_ *v1.DeviceClass) plugins.Plugin { return m.devicePluginSvc } udevproxy.UDevProxyFn = func() udevproxy.UdevProxy { return m.udevproxySvc } cgroupRequestFn = func(_, _, _, _ string, _ map[string]devices.Device, _ bool) cgroups.CgroupRequest { return m.cgroupSvc } statusWatcher := mocks.NewMockStatusWatcher(ctrl) statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"}) svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc)) client, err := containerd.New("", svcOpts) Expect(err).To(BeNil()) agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, generateDecoder(), resourceManager) Expect(err).To(BeNil()) err = agent.Start(ctx) Expect(err).To(BeNil()) }) }) var _ = Describe("Device agent uevent test suite", func() { It("Add and remove an input device for usb class", func() { ctx := createContext() ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() k8sClient := k8sClientWithObjects(defaultClass(), usbClass()) m := generateMocks() dataChan, _ := startSocketServer(ctx, socketServerAddr) time.Sleep(time.Second * 1) decoder, reader, err := createDecoder(socketServerAddr) Expect(err).To(BeNil()) defer reader.Close() m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(make(chan *events.Envelope), make(chan error)) m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return(containerListResponse(usbPodName), nil) m.criClient.EXPECT().PodSandboxStatus(gomock.Any(), gomock.Any()).Return(nil, errors.New("pod sandbox not found")) m.criClient.EXPECT().ContainerStatus(gomock.Any(), gomock.Any()).Return(runningStatusResponse(usbPodName, usbContainerName), nil) m.criClient.EXPECT().ListPodSandbox(gomock.Any(), podSandboxListRequest(usbPodName)).Return(podSandboxListResponse(usbPodName), nil) m.cgroupSvc.EXPECT().Apply(gomock.Any()).MinTimes(1) m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).MinTimes(1) m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1) m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1) m.devicePluginSvc.EXPECT().Stop().MinTimes(1) pluginFn = func(_ *v1.DeviceClass) plugins.Plugin { return m.devicePluginSvc } udevproxy.UDevProxyFn = func() udevproxy.UdevProxy { return m.udevproxySvc } cgroupRequestFn = func(_, _, _, _ string, _ map[string]devices.Device, _ bool) cgroups.CgroupRequest { return m.cgroupSvc } statusWatcher := mocks.NewMockStatusWatcher(ctrl) statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"}) svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc)) client, err := containerd.New("", svcOpts) Expect(err).To(BeNil()) go func() { agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, decoder, resourceManager) Expect(err).To(BeNil()) err = agent.Start(ctx) Expect(err).To(BeNil()) }() // send uevents to data channel for _, ue := range ueventTests { ueventData := toBytes(ue.inputEvent) *dataChan <- ueventData time.Sleep(time.Millisecond * 100) } }) }) var _ = Describe("Parsing sys fs and caching device classes", func() { It("Parse usb device class devices", func() { ctx := createContext() ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() defaultDeviceClass := defaultClass() defaultDeviceClassSet := defaultClassDeviceSet() usbDeviceClass := usbClass() usbDeviceClassSet := usbClassDeviceSet() k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet, usbDeviceClass, usbDeviceClassSet) deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient) Expect(err).To(BeNil()) Expect(deviceClasses).To(HaveLen(2)) usbDeviceIter := deviceClasses[usbDeviceClass.ClassName()].DeviceIter() devices := maps.Collect(usbDeviceIter) Expect(devices).To(HaveLen(8)) }) }) var _ = Describe("Apply device class statuses", func() { It("Format Device Names", func() { expectedName := "ACR-2199" Expect(fmtName(" ACR-2199\n")).To(BeEquivalentTo(expectedName)) Expect(fmtName("-ACR-2199 ")).To(BeEquivalentTo(expectedName)) Expect(fmtName(" \nACR-2199-")).To(BeEquivalentTo(expectedName)) }) It("Apply device class statuses", func() { ctx := createContext() ctx, cancelFn := context.WithTimeout(ctx, timeout) defer cancelFn() defaultDeviceClass := defaultClass() defaultDeviceClassSet := defaultClassDeviceSet() usbDeviceClass := usbClass() usbDeviceClassSet := usbClassDeviceSet() k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet, usbDeviceClass, usbDeviceClassSet) deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient) Expect(err).To(BeNil()) Expect(deviceClasses).To(HaveLen(2)) statusWatcher := mocks.NewMockStatusWatcher(ctrl) statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"}) err = applyDeviceClassStatuses(ctx, resourceManager, deviceClasses) Expect(kerrors.IsNotFound(err)).To(BeTrue()) }) }) type deviceMocks struct { eventSvc *mocks.MockEventService criClient *mocks.MockRuntimeServiceClient containerSvc *mocks.MockStore cgroupSvc *mocks.MockCgroupRequest devicePluginSvc *mocks.MockPlugin udevproxySvc *mocks.MockUdevProxy taskSvc *mocks.MockTasksClient } func generateMocks() deviceMocks { return deviceMocks{ eventSvc: mocks.NewMockEventService(ctrl), criClient: mocks.NewMockRuntimeServiceClient(ctrl), containerSvc: mocks.NewMockStore(ctrl), cgroupSvc: mocks.NewMockCgroupRequest(ctrl), devicePluginSvc: mocks.NewMockPlugin(ctrl), udevproxySvc: mocks.NewMockUdevProxy(ctrl), taskSvc: mocks.NewMockTasksClient(ctrl), } } func k8sClientWithObjects(objs ...client.Object) client.WithWatch { return fake.NewClientBuilder().WithScheme(createScheme()).WithObjects(objs...).Build() } func defaultClass() *dsv1.DeviceClass { className := "default" return generateDeviceClassTestCR(className, 1) } func defaultClassDeviceSet() *dsv1.DeviceSet { className := "default" return generateDeviceSetsTestCR(className, false, map[string]string{"SUBSYSTEM": "misc"}) } func usbClass() *dsv1.DeviceClass { return generateDeviceClassTestCR(usbSubsystem, 1) } func usbClassDeviceSet() *dsv1.DeviceSet { return generateDeviceSetsTestCR(usbSubsystem, false, map[string]string{"SUBSYSTEM": usbSubsystem}) } func generateDeviceClassTestCR(name string, generation int64) *dsv1.DeviceClass { return &dsv1.DeviceClass{ ObjectMeta: metav1.ObjectMeta{ Name: name, Generation: generation, }, Spec: dsv1.DeviceClassSpec{ Devices: []dsv1.DeviceRef{ { Name: name, }, }, Logging: dsv1.Logging{ Level: "info", }, }, } } func generateDeviceSetsTestCR(name string, shouldBlock bool, deviceSetProperties map[string]string) *dsv1.DeviceSet { deviceSet := dsv1.DeviceSetReference{ Name: name, Properties: []dsv1.Rule{}, } if shouldBlock { deviceSet.Blocking = &dsv1.Blocking{} } for key, value := range deviceSetProperties { deviceSet.Properties = append(deviceSet.Properties, dsv1.Rule{ Name: key, RegexValue: value, }) } return &dsv1.DeviceSet{ ObjectMeta: metav1.ObjectMeta{ Name: name, Generation: 1, }, Spec: dsv1.DeviceSpec{ DeviceSets: []dsv1.DeviceSetReference{ deviceSet, }, }, } } func generateDecoder() reader.Decoder { buf := bytes.NewBuffer([]byte{}) return reader.NewReader(buf) } func runningStatusResponse(podName, containerName string) *criruntime.ContainerStatusResponse { podDef := podTests[podName] for _, ctr := range podDef.containerTestCase { if ctr.name == containerName { return &criruntime.ContainerStatusResponse{Status: &criruntime.ContainerStatus{State: ctr.state}} } } return &criruntime.ContainerStatusResponse{Status: &criruntime.ContainerStatus{State: criruntime.ContainerState_CONTAINER_UNKNOWN}} } func containerListResponse(podName string) []containers.Container { containers := []containers.Container{} for _, ctr := range podTests[podName].containerTestCase { containers = append(containers, containerResponse(ctr, podName)) } return containers } func containerResponse(ctr containerTestCase, podName string) containers.Container { spec, _ := typeurl.MarshalAny(&specs.Spec{ Linux: &specs.Linux{ CgroupsPath: ctr.cgroupPath, Namespaces: []specs.LinuxNamespace{ { Type: specs.NetworkNamespace, Path: ctr.networkNamespacePath, }, }, }, }) return containers.Container{ID: ctr.containerID, Spec: spec, Labels: map[string]string{cc.AnnPodName: podName, cc.AnnContainerName: ctr.name}} } func podSandboxListResponse(podName string) *criruntime.ListPodSandboxResponse { response := &criruntime.ListPodSandboxResponse{Items: []*criruntime.PodSandbox{}} testCase := podTests[podName] response.Items = append(response.Items, &criruntime.PodSandbox{ Annotations: testCase.annotations, Metadata: &criruntime.PodSandboxMetadata{ Name: testCase.name, Namespace: testCase.namespace, }, }) return response } func podSandboxListRequest(podName string) *criruntime.ListPodSandboxRequest { return &criruntime.ListPodSandboxRequest{Filter: &criruntime.PodSandboxFilter{LabelSelector: map[string]string{cc.AnnPodName: podName}}} } func toBytes(ue udev.UEvent) []byte { delimeter := []byte("\x00") data := []byte{} data = append(data, []byte("libudev")...) data = append(data, delimeter...) for k, v := range ue.EnvVars { if k == "SEQNUM" { continue } data = append(data, []byte(fmt.Sprintf("%s=%s", k, v))...) data = append(data, delimeter...) } data = append(data, []byte(fmt.Sprintf("SEQNUM=%s", ue.SequenceNumber))...) data = append(data, delimeter...) return data } func startSocketServer(ctx context.Context, destination string) (*chan []byte, chan error) { errChan := make(chan error, 1) socketServer := new(socket.Server) socketServer.Setup(destination) go socketServer.Serve(ctx, errChan) return socketServer.GetDataChan(), errChan } func createDecoder(source string) (reader.Decoder, io.ReadCloser, error) { uEventReader, fd, err := udev.NewUEventReader(source) if err != nil { return nil, nil, err } decoder := reader.NewSocketReader(fd) return decoder, uEventReader, nil } func createContext() context.Context { ctx := context.Background() opts := []logger.Option{ logger.WithLevel(slog.LevelDebug), } log := logger.New(opts...) ctx = logger.IntoContext(ctx, log) return ctx }