1
2
3 package events
4
5 import (
6 "context"
7 "errors"
8 "regexp"
9 "strings"
10 "time"
11
12 "github.com/containerd/containerd"
13 "github.com/containerd/containerd/containers"
14 typeurl "github.com/containerd/typeurl/v2"
15 "github.com/opencontainers/runtime-spec/specs-go"
16
17 "edge-infra.dev/pkg/lib/kernel/udev"
18 cc "edge-infra.dev/pkg/sds/devices/agent/common"
19 devctrs "edge-infra.dev/pkg/sds/devices/agent/containers"
20 "edge-infra.dev/pkg/sds/devices/agent/udevproxy"
21 "edge-infra.dev/pkg/sds/devices/class"
22 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
23 "edge-infra.dev/pkg/sds/devices/logger"
24 )
25
26 var udevProxyFn = udevproxy.ReplayUEventsToContainers
27
28 const (
29 defaultExecutablePath = "/usr/local/bin/device-system-executable"
30 executablePathEnvVar = "DEVICE_SYSTEM_EXECUTABLE_PATH"
31 )
32
33 var executablePathRegex, _ = regexp.Compile("^DEVICE_SYSTEM_EXECUTABLE_PATH=.*$")
34
35
36
37
38 func UDevEventConstructor(ctrClient *containerd.Client, deviceClasses map[string]*dsv1.DeviceClass, allContainers map[string]*containers.Container) func(ctx context.Context, udevEvent *udev.UEvent) (DeviceEvent, error) {
39 return func(ctx context.Context, udevEvent *udev.UEvent) (DeviceEvent, error) {
40 if udevEvent == nil {
41 return nil, errors.New("error, attempted to parse nil uevent")
42 }
43 switch udevEvent.Action {
44 case udev.AddAction:
45 addDeviceFromUEvent(ctx, udevEvent, deviceClasses)
46 return newUDevEvent(ctx, ctrClient, udevEvent, allContainers, deviceClasses)
47 case udev.RemoveAction:
48 event, err := newUDevEvent(ctx, ctrClient, udevEvent, allContainers, deviceClasses)
49 if err != nil {
50 return nil, err
51 }
52 removeDeviceFromUEvent(udevEvent, deviceClasses)
53 udevproxy.ReplayRemoveUEvents(ctx, []*udev.UEvent{udevEvent}, allContainers)
54 return event, nil
55 default:
56 return newUDevEvent(ctx, ctrClient, udevEvent, allContainers, deviceClasses)
57 }
58 }
59 }
60
61
62 func newUDevEvent(ctx context.Context, ctrClient *containerd.Client, uevent *udev.UEvent, allContainers map[string]*containers.Container, allDeviceClasses map[string]*dsv1.DeviceClass) (DeviceEvent, error) {
63 if uevent == nil {
64 return nil, errors.New("uevent cannot be nil")
65 }
66
67 event := &udevEvent{
68 event: &event{
69 containers: map[string]*containers.Container{},
70 postHookFn: func(context.Context) {},
71 timestamp: time.Now(),
72 },
73 }
74 log := logger.FromContext(ctx)
75
76 containerExecFns := map[string]func(){}
77 for _, ctr := range allContainers {
78 ctrCtx := devctrs.WithContainerLogger(ctx, ctr)
79 if !ueventMatchesContainer(ctrCtx, uevent, ctr, allDeviceClasses) {
80 continue
81 }
82
83 executablePath := fetchExecutablePath(ctrCtx, ctr)
84 rootPath, err := devctrs.FetchContainerRootPath(ctrCtx, ctrClient, ctr)
85 if err != nil {
86 log.Debug("could not find container root path", "error", err)
87 } else {
88 containerExecFns[ctr.ID] = func() {
89 devctrs.NewExecFn(ctrCtx, ctr.Labels[cc.AnnContainerName], rootPath, executablePath, uevent.EnvVars)
90 }
91 }
92
93 ctr.Labels[class.DefaultClass] = requested
94 event.event.containers[ctr.ID] = ctr
95 }
96
97 event.event.postHookFn = func(ctx context.Context) {
98 udevProxyFn(ctx, []*udev.UEvent{uevent}, event.event.containers)
99 for _, fn := range containerExecFns {
100 fn()
101 }
102 }
103 return event, nil
104 }
105
106
107
108 func fetchExecutablePath(ctx context.Context, ctr *containers.Container) string {
109 log := logger.FromContext(ctx)
110
111 spec := &specs.Spec{}
112 if err := typeurl.UnmarshalTo(ctr.Spec, spec); err != nil {
113 log.Debug("could not parse container spec", "container", ctr.Labels[cc.AnnContainerName], "error", err)
114 return defaultExecutablePath
115 }
116
117 for _, envVar := range spec.Process.Env {
118 if executablePathRegex.Match([]byte(envVar)) {
119 splitPath := strings.Split(envVar, "=")
120 if len(splitPath) != 2 {
121 return defaultExecutablePath
122 }
123 return splitPath[1]
124 }
125 }
126 return defaultExecutablePath
127 }
128
129
130 func addDeviceFromUEvent(ctx context.Context, udevEvent *udev.UEvent, deviceClasses map[string]*dsv1.DeviceClass) {
131 log := logger.FromContext(ctx)
132 devicePath := prefixSysPath(udevEvent.SysPath)
133 for _, devClass := range deviceClasses {
134 if _, err := devClass.AddDeviceIfMatched(devicePath); err != nil {
135 log.Error("error adding device to class", "class", devClass.ClassName(), "path", udevEvent.SysPath, "error", err)
136 continue
137 }
138 }
139 }
140
141
142 func removeDeviceFromUEvent(udevEvent *udev.UEvent, deviceClasses map[string]*dsv1.DeviceClass) {
143 devicePath := prefixSysPath(udevEvent.SysPath)
144 for _, devClass := range deviceClasses {
145 devClass.RemoveDevice(devicePath)
146 }
147 }
148
View as plain text