1
2 package agent
3
4 import (
5 "bytes"
6 "context"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "os"
13 "path/filepath"
14 "testing"
15 "time"
16
17 . "github.com/onsi/ginkgo/v2"
18 . "github.com/onsi/gomega"
19 "sigs.k8s.io/controller-runtime/pkg/client"
20 "sigs.k8s.io/controller-runtime/pkg/client/fake"
21
22 "github.com/containerd/containerd"
23 containers "github.com/containerd/containerd/containers"
24 "github.com/containerd/containerd/events"
25 typeurl "github.com/containerd/typeurl/v2"
26 "github.com/golang/mock/gomock"
27 specs "github.com/opencontainers/runtime-spec/specs-go"
28 testfs "gotest.tools/v3/fs"
29 kerrors "k8s.io/apimachinery/pkg/api/errors"
30 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
31 "k8s.io/client-go/dynamic/dynamicinformer"
32 "k8s.io/client-go/rest"
33 criruntime "k8s.io/cri-api/pkg/apis/runtime/v1"
34
35 "edge-infra.dev/pkg/k8s/runtime/sap"
36 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
37 v1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
38 plugins "edge-infra.dev/pkg/sds/devices/k8s/device-plugins"
39 "edge-infra.dev/pkg/sds/devices/logger"
40
41 "edge-infra.dev/pkg/lib/kernel/devices"
42 "edge-infra.dev/pkg/lib/kernel/netlink/socket"
43 "edge-infra.dev/pkg/lib/kernel/udev"
44 "edge-infra.dev/pkg/lib/kernel/udev/reader"
45 "edge-infra.dev/pkg/lib/uuid"
46 "edge-infra.dev/pkg/sds/devices/agent/cgroups"
47 cc "edge-infra.dev/pkg/sds/devices/agent/common"
48 "edge-infra.dev/pkg/sds/devices/agent/mocks"
49 "edge-infra.dev/pkg/sds/devices/agent/udevproxy"
50 "edge-infra.dev/pkg/sds/devices/class"
51 )
52
53
54
55
56
57
58
59
60
61
62
63
64 var (
65 timeout = time.Second * 10
66 cfg = Config{}
67 )
68
69
70 var (
71 usbPodName = "usb-pod"
72 usbContainerName = "usb-container"
73 socketServerAddr = "127.0.0.1:8082"
74 cgroupPath = "kubepods-burstable-pod52da81bb_7d4a_49cf_ae91_424e9d35bb84.slice:cri-containerd:e58ea88f915f3ab51ba0f027084c384ec07491ef2f8b877db04dc94cac75c0f5"
75 networkNamespacePath = "/proc/23342/ns/net"
76 )
77
78 type ueventTestCase = []struct {
79 inputEvent udev.UEvent
80 expectedEvent string
81 }
82
83 type podTestCase map[string]podDefinition
84
85 type podDefinition struct {
86 name string
87 namespace string
88 annotations map[string]string
89 containerTestCase []containerTestCase
90 }
91
92 type containerTestCase struct {
93 name string
94 containerID string
95 cgroupPath string
96 networkNamespacePath string
97 state criruntime.ContainerState
98 }
99
100 var ueventTests = ueventTestCase{
101 {
102 inputEvent: udev.UEvent{SequenceNumber: "1", EnvVars: map[string]string{"ACTION": "add", "SEQNUM": "1", "SUBSYSTEM": "usb", "DEVPATH": "testdata/sys/devices/usb"}},
103 expectedEvent: "^.*(libudev)(.*SUBSYSTEM=usb)?(.*ACTION=add)?(.*DEVPATH=testdata/sys/devices/usb)?(.*SEQNUM=1).*$",
104 },
105 {
106 inputEvent: udev.UEvent{SequenceNumber: "1", EnvVars: map[string]string{"ACTION": "remove", "SEQNUM": "2", "SUBSYSTEM": "usb", "DEVPATH": "testdata/sys/devices/usb"}},
107 expectedEvent: "^.*(libudev)(.*SUBSYSTEM=usb)?(.*ACTION=remove)?(.*DEVPATH=testdata/sys/devices/usb)?(.*SEQNUM=2).*$",
108 },
109 }
110
111 var podTests = podTestCase{
112 usbPodName: {
113 name: usbPodName,
114 namespace: usbPodName,
115 annotations: map[string]string{
116 class.FmtClassLabel(usbContainerName): "usb",
117 },
118 containerTestCase: []containerTestCase{
119 {
120 name: usbContainerName,
121 containerID: uuid.New().UUID,
122 cgroupPath: cgroupPath,
123 networkNamespacePath: networkNamespacePath,
124 state: criruntime.ContainerState_CONTAINER_RUNNING,
125 },
126 },
127 },
128 }
129
130 var ctrl *gomock.Controller
131
132 var mockInformerMethod = func(_ context.Context, _ client.Client, _ *rest.Config, _ chan *v1.DeviceClass, _ ...v1.ListOption) (dynamicinformer.DynamicSharedInformerFactory, error) {
133 informerMock := mocks.NewMockDynamicSharedInformerFactory(ctrl)
134 informerMock.EXPECT().Start(gomock.Any())
135 return informerMock, nil
136 }
137
138 func TestDeviceRead(t *testing.T) {
139 testDir := testfs.NewDir(t, etcPath)
140 defer testDir.Remove()
141
142 kubeletPath = filepath.Join(testDir.Path(), "kubelet.sock")
143 newDeviceClassInformer = mockInformerMethod
144
145 cfg = Config{
146 ClassesPath: testDir.Path(),
147 }
148
149 ctrl = gomock.NewController(t)
150 RegisterFailHandler(Fail)
151 RunSpecs(t, "Device Tests")
152 }
153
154 var _ = BeforeEach(func() {
155 _, _ = os.Create(kubeletPath)
156 devices.SetTestEnvs()
157 })
158
159 var _ = AfterEach(func() {
160 _ = os.Remove(cfg.ClassesPath)
161 _ = os.Remove(kubeletPath)
162 os.Setenv("HOSTNAME", "worker-1")
163 })
164
165 var _ = Describe("Device agent test suite", func() {
166 It("No containers exist", func() {
167 ctx := createContext()
168 ctx, cancelFn := context.WithTimeout(ctx, timeout)
169 defer cancelFn()
170
171 defaultDeviceClass := defaultClass()
172 defaultDeviceClassSet := defaultClassDeviceSet()
173 k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet)
174
175 m := generateMocks()
176
177 eventChan := make(chan *events.Envelope)
178 errChan := make(chan error)
179
180 m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(eventChan, errChan)
181 m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return([]containers.Container{}, nil)
182 m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1)
183 m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1)
184 m.devicePluginSvc.EXPECT().Stop().MinTimes(1)
185
186 pluginFn = func(_ *v1.DeviceClass) plugins.Plugin {
187 return m.devicePluginSvc
188 }
189
190 udevproxy.UDevProxyFn = func() udevproxy.UdevProxy {
191 return m.udevproxySvc
192 }
193 m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).AnyTimes()
194
195 statusWatcher := mocks.NewMockStatusWatcher(ctrl)
196 statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
197
198 svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc))
199 client, err := containerd.New("", svcOpts)
200 Expect(err).To(BeNil())
201
202 resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"})
203 agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, generateDecoder(), resourceManager)
204 Expect(err).To(BeNil())
205 err = agent.Start(ctx)
206 Expect(err).To(BeNil())
207 })
208
209 It("Container exists requesting usb device class", func() {
210 ctx := createContext()
211 ctx, cancelFn := context.WithTimeout(ctx, timeout)
212 defer cancelFn()
213
214 k8sClient := k8sClientWithObjects(usbClass(), defaultClass())
215
216 m := generateMocks()
217
218 eventChan := make(chan *events.Envelope)
219 errChan := make(chan error)
220
221 m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(eventChan, errChan)
222 m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return(containerListResponse(usbPodName), nil)
223 m.criClient.EXPECT().PodSandboxStatus(gomock.Any(), gomock.Any()).Return(nil, errors.New("pod sandbox not found"))
224 m.criClient.EXPECT().ContainerStatus(gomock.Any(), gomock.Any()).Return(runningStatusResponse(usbPodName, usbContainerName), nil)
225 m.criClient.EXPECT().ListPodSandbox(gomock.Any(), podSandboxListRequest(usbPodName)).Return(podSandboxListResponse(usbPodName), nil)
226
227 m.cgroupSvc.EXPECT().Apply(gomock.Any()).MinTimes(1)
228 m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).AnyTimes()
229
230 m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1)
231 m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1)
232 m.devicePluginSvc.EXPECT().Stop().MinTimes(1)
233
234 pluginFn = func(_ *v1.DeviceClass) plugins.Plugin {
235 return m.devicePluginSvc
236 }
237
238 udevproxy.UDevProxyFn = func() udevproxy.UdevProxy {
239 return m.udevproxySvc
240 }
241
242 cgroupRequestFn = func(_, _, _, _ string, _ map[string]devices.Device, _ bool) cgroups.CgroupRequest {
243 return m.cgroupSvc
244 }
245
246 statusWatcher := mocks.NewMockStatusWatcher(ctrl)
247 statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
248
249 resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"})
250
251 svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc))
252 client, err := containerd.New("", svcOpts)
253 Expect(err).To(BeNil())
254
255 agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, generateDecoder(), resourceManager)
256 Expect(err).To(BeNil())
257
258 err = agent.Start(ctx)
259 Expect(err).To(BeNil())
260 })
261 })
262
263 var _ = Describe("Device agent uevent test suite", func() {
264 It("Add and remove an input device for usb class", func() {
265 ctx := createContext()
266 ctx, cancelFn := context.WithTimeout(ctx, timeout)
267 defer cancelFn()
268
269 k8sClient := k8sClientWithObjects(defaultClass(), usbClass())
270
271 m := generateMocks()
272
273 dataChan, _ := startSocketServer(ctx, socketServerAddr)
274 time.Sleep(time.Second * 1)
275
276 decoder, reader, err := createDecoder(socketServerAddr)
277 Expect(err).To(BeNil())
278 defer reader.Close()
279
280 m.eventSvc.EXPECT().Subscribe(gomock.Any(), []string{}).Return(make(chan *events.Envelope), make(chan error))
281 m.containerSvc.EXPECT().List(gomock.Any(), gomock.Any()).Return(containerListResponse(usbPodName), nil)
282
283 m.criClient.EXPECT().PodSandboxStatus(gomock.Any(), gomock.Any()).Return(nil, errors.New("pod sandbox not found"))
284 m.criClient.EXPECT().ContainerStatus(gomock.Any(), gomock.Any()).Return(runningStatusResponse(usbPodName, usbContainerName), nil)
285 m.criClient.EXPECT().ListPodSandbox(gomock.Any(), podSandboxListRequest(usbPodName)).Return(podSandboxListResponse(usbPodName), nil)
286
287 m.cgroupSvc.EXPECT().Apply(gomock.Any()).MinTimes(1)
288 m.udevproxySvc.EXPECT().Send(gomock.Any(), gomock.Any()).MinTimes(1)
289
290 m.devicePluginSvc.EXPECT().IsRunning().Return(true).MinTimes(1)
291 m.devicePluginSvc.EXPECT().Update(gomock.Any()).MinTimes(1)
292 m.devicePluginSvc.EXPECT().Stop().MinTimes(1)
293
294 pluginFn = func(_ *v1.DeviceClass) plugins.Plugin {
295 return m.devicePluginSvc
296 }
297
298 udevproxy.UDevProxyFn = func() udevproxy.UdevProxy {
299 return m.udevproxySvc
300 }
301
302 cgroupRequestFn = func(_, _, _, _ string, _ map[string]devices.Device, _ bool) cgroups.CgroupRequest {
303 return m.cgroupSvc
304 }
305
306 statusWatcher := mocks.NewMockStatusWatcher(ctrl)
307 statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
308 resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"})
309
310 svcOpts := containerd.WithServices(containerd.WithContainerStore(m.containerSvc), containerd.WithEventService(m.eventSvc), containerd.WithTaskClient(m.taskSvc))
311 client, err := containerd.New("", svcOpts)
312 Expect(err).To(BeNil())
313
314 go func() {
315 agent, err := NewDeviceAgent(ctx, nil, k8sClient, client, m.criClient, cfg, decoder, resourceManager)
316 Expect(err).To(BeNil())
317
318 err = agent.Start(ctx)
319 Expect(err).To(BeNil())
320 }()
321
322
323 for _, ue := range ueventTests {
324 ueventData := toBytes(ue.inputEvent)
325 *dataChan <- ueventData
326 time.Sleep(time.Millisecond * 100)
327 }
328 })
329 })
330
331 var _ = Describe("Parsing sys fs and caching device classes", func() {
332 It("Parse usb device class devices", func() {
333 ctx := createContext()
334 ctx, cancelFn := context.WithTimeout(ctx, timeout)
335 defer cancelFn()
336
337 defaultDeviceClass := defaultClass()
338 defaultDeviceClassSet := defaultClassDeviceSet()
339 usbDeviceClass := usbClass()
340 usbDeviceClassSet := usbClassDeviceSet()
341 k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet, usbDeviceClass, usbDeviceClassSet)
342
343 deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient)
344 Expect(err).To(BeNil())
345 Expect(deviceClasses).To(HaveLen(2))
346
347 usbDeviceIter := deviceClasses[usbDeviceClass.ClassName()].DeviceIter()
348 devices := maps.Collect(usbDeviceIter)
349 Expect(devices).To(HaveLen(8))
350 })
351 })
352
353 var _ = Describe("Apply device class statuses", func() {
354 It("Format Device Names", func() {
355 expectedName := "ACR-2199"
356 Expect(fmtName(" ACR-2199\n")).To(BeEquivalentTo(expectedName))
357 Expect(fmtName("-ACR-2199 ")).To(BeEquivalentTo(expectedName))
358 Expect(fmtName(" \nACR-2199-")).To(BeEquivalentTo(expectedName))
359 })
360 It("Apply device class statuses", func() {
361 ctx := createContext()
362 ctx, cancelFn := context.WithTimeout(ctx, timeout)
363 defer cancelFn()
364
365 defaultDeviceClass := defaultClass()
366 defaultDeviceClassSet := defaultClassDeviceSet()
367 usbDeviceClass := usbClass()
368 usbDeviceClassSet := usbClassDeviceSet()
369 k8sClient := k8sClientWithObjects(defaultDeviceClass, defaultDeviceClassSet, usbDeviceClass, usbDeviceClassSet)
370
371 deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient)
372 Expect(err).To(BeNil())
373 Expect(deviceClasses).To(HaveLen(2))
374
375 statusWatcher := mocks.NewMockStatusWatcher(ctrl)
376 statusWatcher.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
377
378 resourceManager := sap.NewResourceManager(k8sClient, statusWatcher, sap.Owner{Field: "device-agent"})
379 err = applyDeviceClassStatuses(ctx, resourceManager, deviceClasses)
380 Expect(kerrors.IsNotFound(err)).To(BeTrue())
381 })
382 })
383
384 type deviceMocks struct {
385 eventSvc *mocks.MockEventService
386 criClient *mocks.MockRuntimeServiceClient
387 containerSvc *mocks.MockStore
388 cgroupSvc *mocks.MockCgroupRequest
389 devicePluginSvc *mocks.MockPlugin
390 udevproxySvc *mocks.MockUdevProxy
391 taskSvc *mocks.MockTasksClient
392 }
393
394 func generateMocks() deviceMocks {
395 return deviceMocks{
396 eventSvc: mocks.NewMockEventService(ctrl),
397 criClient: mocks.NewMockRuntimeServiceClient(ctrl),
398 containerSvc: mocks.NewMockStore(ctrl),
399 cgroupSvc: mocks.NewMockCgroupRequest(ctrl),
400 devicePluginSvc: mocks.NewMockPlugin(ctrl),
401 udevproxySvc: mocks.NewMockUdevProxy(ctrl),
402 taskSvc: mocks.NewMockTasksClient(ctrl),
403 }
404 }
405
406 func k8sClientWithObjects(objs ...client.Object) client.WithWatch {
407 return fake.NewClientBuilder().WithScheme(createScheme()).WithObjects(objs...).Build()
408 }
409
410 func defaultClass() *dsv1.DeviceClass {
411 className := "default"
412 return generateDeviceClassTestCR(className, 1)
413 }
414
415 func defaultClassDeviceSet() *dsv1.DeviceSet {
416 className := "default"
417 return generateDeviceSetsTestCR(className, false, map[string]string{"SUBSYSTEM": "misc"})
418 }
419
420 func usbClass() *dsv1.DeviceClass {
421 return generateDeviceClassTestCR(usbSubsystem, 1)
422 }
423
424 func usbClassDeviceSet() *dsv1.DeviceSet {
425 return generateDeviceSetsTestCR(usbSubsystem, false, map[string]string{"SUBSYSTEM": usbSubsystem})
426 }
427
428 func generateDeviceClassTestCR(name string, generation int64) *dsv1.DeviceClass {
429 return &dsv1.DeviceClass{
430 ObjectMeta: metav1.ObjectMeta{
431 Name: name,
432 Generation: generation,
433 },
434 Spec: dsv1.DeviceClassSpec{
435 Devices: []dsv1.DeviceRef{
436 {
437 Name: name,
438 },
439 },
440 Logging: dsv1.Logging{
441 Level: "info",
442 },
443 },
444 }
445 }
446
447 func generateDeviceSetsTestCR(name string, shouldBlock bool, deviceSetProperties map[string]string) *dsv1.DeviceSet {
448 deviceSet := dsv1.DeviceSetReference{
449 Name: name,
450 Properties: []dsv1.Rule{},
451 }
452
453 if shouldBlock {
454 deviceSet.Blocking = &dsv1.Blocking{}
455 }
456
457 for key, value := range deviceSetProperties {
458 deviceSet.Properties = append(deviceSet.Properties, dsv1.Rule{
459 Name: key,
460 RegexValue: value,
461 })
462 }
463
464 return &dsv1.DeviceSet{
465 ObjectMeta: metav1.ObjectMeta{
466 Name: name,
467 Generation: 1,
468 },
469 Spec: dsv1.DeviceSpec{
470 DeviceSets: []dsv1.DeviceSetReference{
471 deviceSet,
472 },
473 },
474 }
475 }
476
477 func generateDecoder() reader.Decoder {
478 buf := bytes.NewBuffer([]byte{})
479 return reader.NewReader(buf)
480 }
481
482 func runningStatusResponse(podName, containerName string) *criruntime.ContainerStatusResponse {
483 podDef := podTests[podName]
484 for _, ctr := range podDef.containerTestCase {
485 if ctr.name == containerName {
486 return &criruntime.ContainerStatusResponse{Status: &criruntime.ContainerStatus{State: ctr.state}}
487 }
488 }
489 return &criruntime.ContainerStatusResponse{Status: &criruntime.ContainerStatus{State: criruntime.ContainerState_CONTAINER_UNKNOWN}}
490 }
491
492 func containerListResponse(podName string) []containers.Container {
493 containers := []containers.Container{}
494 for _, ctr := range podTests[podName].containerTestCase {
495 containers = append(containers, containerResponse(ctr, podName))
496 }
497 return containers
498 }
499
500 func containerResponse(ctr containerTestCase, podName string) containers.Container {
501 spec, _ := typeurl.MarshalAny(&specs.Spec{
502 Linux: &specs.Linux{
503 CgroupsPath: ctr.cgroupPath,
504 Namespaces: []specs.LinuxNamespace{
505 {
506 Type: specs.NetworkNamespace,
507 Path: ctr.networkNamespacePath,
508 },
509 },
510 },
511 })
512 return containers.Container{ID: ctr.containerID, Spec: spec, Labels: map[string]string{cc.AnnPodName: podName, cc.AnnContainerName: ctr.name}}
513 }
514
515 func podSandboxListResponse(podName string) *criruntime.ListPodSandboxResponse {
516 response := &criruntime.ListPodSandboxResponse{Items: []*criruntime.PodSandbox{}}
517 testCase := podTests[podName]
518 response.Items = append(response.Items, &criruntime.PodSandbox{
519 Annotations: testCase.annotations,
520 Metadata: &criruntime.PodSandboxMetadata{
521 Name: testCase.name,
522 Namespace: testCase.namespace,
523 },
524 })
525 return response
526 }
527
528 func podSandboxListRequest(podName string) *criruntime.ListPodSandboxRequest {
529 return &criruntime.ListPodSandboxRequest{Filter: &criruntime.PodSandboxFilter{LabelSelector: map[string]string{cc.AnnPodName: podName}}}
530 }
531
532 func toBytes(ue udev.UEvent) []byte {
533 delimeter := []byte("\x00")
534 data := []byte{}
535 data = append(data, []byte("libudev")...)
536 data = append(data, delimeter...)
537 for k, v := range ue.EnvVars {
538 if k == "SEQNUM" {
539 continue
540 }
541 data = append(data, []byte(fmt.Sprintf("%s=%s", k, v))...)
542 data = append(data, delimeter...)
543 }
544 data = append(data, []byte(fmt.Sprintf("SEQNUM=%s", ue.SequenceNumber))...)
545 data = append(data, delimeter...)
546 return data
547 }
548
549 func startSocketServer(ctx context.Context, destination string) (*chan []byte, chan error) {
550 errChan := make(chan error, 1)
551 socketServer := new(socket.Server)
552 socketServer.Setup(destination)
553 go socketServer.Serve(ctx, errChan)
554 return socketServer.GetDataChan(), errChan
555 }
556
557 func createDecoder(source string) (reader.Decoder, io.ReadCloser, error) {
558 uEventReader, fd, err := udev.NewUEventReader(source)
559 if err != nil {
560 return nil, nil, err
561 }
562 decoder := reader.NewSocketReader(fd)
563 return decoder, uEventReader, nil
564 }
565
566 func createContext() context.Context {
567 ctx := context.Background()
568 opts := []logger.Option{
569 logger.WithLevel(slog.LevelDebug),
570 }
571 log := logger.New(opts...)
572 ctx = logger.IntoContext(ctx, log)
573 return ctx
574 }
575
View as plain text