1
2
3 package agent
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "path/filepath"
12 "slices"
13 "time"
14
15 "github.com/containerd/containerd"
16 "github.com/containerd/containerd/containers"
17 fsnotify "github.com/fsnotify/fsnotify"
18 "k8s.io/apimachinery/pkg/runtime"
19 utilruntime "k8s.io/apimachinery/pkg/util/runtime"
20 "k8s.io/client-go/dynamic/dynamicinformer"
21 clientgoscheme "k8s.io/client-go/kubernetes/scheme"
22 "k8s.io/client-go/rest"
23 criruntime "k8s.io/cri-api/pkg/apis/runtime/v1"
24 "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
25 "sigs.k8s.io/controller-runtime/pkg/client"
26
27 "edge-infra.dev/pkg/k8s/runtime/sap"
28 "edge-infra.dev/pkg/lib/kernel/udev/reader"
29 cc "edge-infra.dev/pkg/sds/devices/agent/common"
30 devicecontainers "edge-infra.dev/pkg/sds/devices/agent/containers"
31 "edge-infra.dev/pkg/sds/devices/agent/events"
32 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
33 v1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
34 plugins "edge-infra.dev/pkg/sds/devices/k8s/device-plugins"
35 "edge-infra.dev/pkg/sds/devices/logger"
36 )
37
38 const (
39
40
41 batchTimeSince int64 = 200
42
43 batchWaitInterval = 10
44
45 etcPath = "/etc"
46 )
47
48 var (
49
50 kubeletPath = filepath.Dir(v1beta1.KubeletSocket)
51 )
52
53
54 var pluginFn = plugins.NewPlugin
55
56
57 var newDeviceClassInformer = dsv1.WatchFromClient
58
59 var (
60
61 eventBatch = []events.DeviceEvent{}
62 )
63
64 type DeviceAgent interface {
65 Start(ctx context.Context) error
66 }
67
68 type deviceAgent struct {
69 k8sClient client.Client
70 ctrClient *containerd.Client
71 runtimeClient criruntime.RuntimeServiceClient
72 cfg Config
73 decoder reader.Decoder
74 servers map[string]plugins.Plugin
75 deviceClasses map[string]*dsv1.DeviceClass
76 containers map[string]*containers.Container
77 deviceClassChan <-chan *dsv1.DeviceClass
78 containerDeviceRuleJobQueue map[string]chan func(context.Context)
79 watcher *fsnotify.Watcher
80 resourceManager *sap.ResourceManager
81 informer dynamicinformer.DynamicSharedInformerFactory
82 clientOpts []dsv1.ListOption
83 }
84
85
86 func NewDeviceAgent(ctx context.Context, config *rest.Config, k8sClient client.Client, ctrClient *containerd.Client, rc criruntime.RuntimeServiceClient, cfg Config, decoder reader.Decoder, resourceManager *sap.ResourceManager) (DeviceAgent, error) {
87 log := logger.FromContext(ctx)
88 log = log.WithGroup("device-agent-startup")
89
90
91 opts := []dsv1.ListOption{
92 dsv1.WithPersistence(true),
93 dsv1.WithPersistencePath(cfg.ClassesPath),
94 }
95
96 deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient, opts...)
97 if err != nil {
98 return nil, err
99 }
100
101 containers, err := devicecontainers.FetchAllContainers(ctx, ctrClient, rc)
102 if err != nil {
103 return nil, fmt.Errorf("failed to fetch all containers: %w", err)
104 }
105 log.Info("loaded containers requesting devices", "containers", fetchContainerNames(slices.Collect(maps.Values(containers))))
106
107 deviceClassChan := make(chan *dsv1.DeviceClass, 1)
108 informer, err := newDeviceClassInformer(ctx, k8sClient, config, deviceClassChan, opts...)
109 if err != nil {
110 return nil, fmt.Errorf("failed to start Device resources informer: %w", err)
111 }
112
113 watcher, err := newKubeletSocketWatcher()
114 if err != nil {
115 return nil, fmt.Errorf("failed to start kubelet socket watcher: %w", err)
116 }
117
118 containerDeviceRuleJobQueue := map[string]chan func(context.Context){}
119 return &deviceAgent{k8sClient, ctrClient, rc, cfg, decoder, map[string]plugins.Plugin{}, deviceClasses, containers, deviceClassChan, containerDeviceRuleJobQueue, watcher, resourceManager, informer, opts}, nil
120 }
121
122
123
124
125
126
127 func (da *deviceAgent) Start(ctx context.Context) error {
128 log := logger.FromContext(ctx)
129 log = log.WithGroup("device-agent")
130 log.Info("starting device agent",
131 "device classes", slices.Collect(maps.Keys(da.deviceClasses)),
132 "device plugin servers", slices.Collect(maps.Keys(da.servers)),
133 "configuration", da.cfg,
134 )
135
136 ctx, cancelFn := context.WithCancel(ctx)
137 defer cancelFn()
138 defer da.watcher.Close()
139
140 eventService := da.ctrClient.EventService()
141 ctrEventChan, ctrStreamErrChan := eventService.Subscribe(ctx, []string{}...)
142 udevEventChan, udevEventErrChan := reader.StreamUEvents(ctx, da.decoder)
143
144 da.reconcileDevicePluginServers(ctx)
145
146 go reportDiskSizeMetrics(ctx)
147
148 deviceClassPatchQueue := make(chan map[string]*dsv1.DeviceClass, 1)
149 go startDeviceClassPatchingWorker(ctx, da.resourceManager, deviceClassPatchQueue)
150
151 log.Info("starting DeviceClass informer")
152 go da.informer.Start(ctx.Done())
153
154 deviceUpdateQueue := make(chan deviceConfigJob, da.cfg.Workers)
155 go da.startDeviceJobQueue(ctx, deviceUpdateQueue)
156
157
158 containers, postHookFns := fetchAllContainersToUpdate(ctx, da.containers)
159 deviceUpdateQueue <- newJob(containers, postHookFns, da.deviceClasses)
160 deviceClassPatchQueue <- da.deviceClasses
161
162 classEventConstructor := events.ClassEventConstructor(da.servers, da.deviceClasses, da.containers)
163 udevEventConstructor := events.UDevEventConstructor(da.ctrClient, da.deviceClasses, da.containers)
164 containerEventConstructor := events.ContainerEventConstructor(da.containers, da.ctrClient, da.runtimeClient)
165
166 for {
167 da.checkIfProcessDeviceEventBatch(ctx, da.deviceClasses, deviceUpdateQueue)
168
169 select {
170 case deviceClass := <-da.deviceClassChan:
171 eventBatch = append(eventBatch, classEventConstructor(ctx, deviceClass))
172 deviceClassPatchQueue <- da.deviceClasses
173 case udevEvent := <-udevEventChan:
174 event, err := udevEventConstructor(ctx, udevEvent)
175 if err != nil {
176 log.Error("error generating udev event", "error", err)
177 continue
178 }
179 eventBatch = append(eventBatch, event)
180 deviceClassPatchQueue <- da.deviceClasses
181 case ctrEvent := <-ctrEventChan:
182 event, err := containerEventConstructor(ctx, ctrEvent)
183 if err != nil {
184 log.Error("error generating container event", "error", err)
185 continue
186 }
187 if event == nil {
188 continue
189 }
190 eventBatch = append(eventBatch, event)
191 case kubeletRestartEvent := <-da.watcher.Events:
192 if kubeletRestartEvent.Op&fsnotify.Create != fsnotify.Create || kubeletRestartEvent.Name != v1beta1.KubeletSocket {
193 continue
194 }
195 log.Info("kubelet has restarted or device classes have been updated, updating device plugins")
196 eventBatch = append(eventBatch, events.NewKubeletEvent())
197 case err := <-da.watcher.Errors:
198 log.Error("error fs watcher", "error", err)
199 case err := <-ctrStreamErrChan:
200 log.Error("error container events watcher", "error", err)
201 case err := <-udevEventErrChan:
202 if errors.Is(err, io.EOF) {
203 continue
204 }
205 log.Error("error has occurred monitoring udev events", "error", err)
206 case <-ctx.Done():
207 log.Info("stopping all device plugin servers")
208 da.stopAllDevicePlugins()
209 return nil
210 case <-time.After(time.Millisecond * batchWaitInterval):
211 continue
212 }
213 }
214 }
215
216
217 func (da deviceAgent) checkIfProcessDeviceEventBatch(ctx context.Context, deviceClasses map[string]*dsv1.DeviceClass, queue chan deviceConfigJob) {
218 if len(eventBatch) >= 1 {
219
220 ts := time.Since(eventBatch[len(eventBatch)-1].Timestamp()).Milliseconds()
221 if ts >= batchTimeSince {
222 da.reconcileDevicePluginServers(ctx)
223 containers, postHookFns := compactDeviceEvents(eventBatch)
224 eventBatch = make([]events.DeviceEvent, 0)
225 queue <- newJob(containers, postHookFns, deviceClasses)
226 }
227 }
228 }
229
230
231
232 func (da deviceAgent) reconcileDevicePluginServers(ctx context.Context) {
233 for name, devClass := range da.deviceClasses {
234 if srv, ok := da.servers[name]; !ok || srv != nil {
235 da.servers[name] = pluginFn(devClass)
236 }
237 if !da.servers[name].IsRunning() {
238 go da.servers[name].Run(ctx)
239 }
240 da.servers[name].Update(devClass)
241 }
242 }
243
244
245 func (da deviceAgent) startDeviceJobQueue(ctx context.Context, jobQueueChan chan deviceConfigJob) {
246 for {
247 select {
248 case <-ctx.Done():
249 return
250 case job := <-jobQueueChan:
251 job.Run(ctx)
252 }
253 }
254 }
255
256
257
258 func (da deviceAgent) stopAllDevicePlugins() {
259 for name, srv := range da.servers {
260 srv.Stop()
261 delete(da.servers, name)
262 }
263 }
264
265
266 func compactDeviceEvents(events []events.DeviceEvent) (map[string]*containers.Container, []func(context.Context)) {
267 containers := map[string]*containers.Container{}
268 postHookFns := []func(context.Context){}
269 for _, updateEvent := range events {
270
271 maps.Insert(containers, maps.All(updateEvent.Containers()))
272 if updateEvent.PosthookFunc() != nil {
273 postHookFns = append(postHookFns, updateEvent.PosthookFunc())
274 }
275 }
276 return containers, postHookFns
277 }
278
279
280 func fetchAllContainersToUpdate(ctx context.Context, allContainers map[string]*containers.Container) (map[string]*containers.Container, []func(context.Context)) {
281 log := logger.FromContext(ctx)
282 deviceEvents := []events.DeviceEvent{}
283 for _, ctr := range allContainers {
284 event, err := events.NewContainerEvent(ctr)
285 if err != nil {
286 log.Error("error generating container event", "error", err)
287 continue
288 }
289 deviceEvents = append(deviceEvents, event)
290 }
291 return compactDeviceEvents(deviceEvents)
292 }
293
294
295
296 func newKubeletSocketWatcher() (*fsnotify.Watcher, error) {
297 watcher, err := fsnotify.NewWatcher()
298 if err != nil {
299 return nil, err
300 }
301 if err = watcher.Add(kubeletPath); err != nil {
302 watcher.Close()
303 return nil, err
304 }
305 return watcher, nil
306 }
307
308 func fetchContainerNames(containers []*containers.Container) []string {
309 names := make([]string, len(containers))
310 for idx, ctr := range containers {
311 name := ctr.Labels[cc.AnnContainerName]
312 names[idx] = name
313 }
314 return names
315 }
316
317
318
319 func createScheme() *runtime.Scheme {
320 scheme := runtime.NewScheme()
321 utilruntime.Must(clientgoscheme.AddToScheme(scheme))
322 utilruntime.Must(v1.AddToScheme(scheme))
323 utilruntime.Must(dsv1.AddToScheme(scheme))
324 return scheme
325 }
326
View as plain text