//go:build linux package agent import ( "context" "errors" "fmt" "io" "maps" "path/filepath" "slices" "time" "github.com/containerd/containerd" "github.com/containerd/containerd/containers" fsnotify "github.com/fsnotify/fsnotify" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/dynamic/dynamicinformer" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" criruntime "k8s.io/cri-api/pkg/apis/runtime/v1" "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "sigs.k8s.io/controller-runtime/pkg/client" "edge-infra.dev/pkg/k8s/runtime/sap" "edge-infra.dev/pkg/lib/kernel/udev/reader" cc "edge-infra.dev/pkg/sds/devices/agent/common" devicecontainers "edge-infra.dev/pkg/sds/devices/agent/containers" "edge-infra.dev/pkg/sds/devices/agent/events" dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" v1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" plugins "edge-infra.dev/pkg/sds/devices/k8s/device-plugins" "edge-infra.dev/pkg/sds/devices/logger" ) const ( // batchTimeSince is the period (ms) in which the // device agent will wait until it processes the batch batchTimeSince int64 = 200 // batchWaitInterval is wait time (ms) between checking for new events batchWaitInterval = 10 // etc file path etcPath = "/etc" ) var ( // kubeletPath is path to kubelet socket file kubeletPath = filepath.Dir(v1beta1.KubeletSocket) ) // pluginFn is function called to create a new plugin var pluginFn = plugins.NewPlugin // newDeviceClassInformer is function called to create a new dsv1.DeviceClass informer var newDeviceClassInformer = dsv1.WatchFromClient var ( // eventBatch is a slice containing batch of device events to process eventBatch = []events.DeviceEvent{} ) type DeviceAgent interface { Start(ctx context.Context) error } type deviceAgent struct { k8sClient client.Client ctrClient *containerd.Client runtimeClient criruntime.RuntimeServiceClient cfg Config decoder reader.Decoder servers map[string]plugins.Plugin deviceClasses map[string]*dsv1.DeviceClass containers map[string]*containers.Container deviceClassChan <-chan *dsv1.DeviceClass containerDeviceRuleJobQueue map[string]chan func(context.Context) watcher *fsnotify.Watcher resourceManager *sap.ResourceManager informer dynamicinformer.DynamicSharedInformerFactory clientOpts []dsv1.ListOption } // NewDeviceAgent instatiates a new device agent runnable func NewDeviceAgent(ctx context.Context, config *rest.Config, k8sClient client.Client, ctrClient *containerd.Client, rc criruntime.RuntimeServiceClient, cfg Config, decoder reader.Decoder, resourceManager *sap.ResourceManager) (DeviceAgent, error) { log := logger.FromContext(ctx) log = log.WithGroup("device-agent-startup") // opts is the list options for querying device classes opts := []dsv1.ListOption{ dsv1.WithPersistence(true), dsv1.WithPersistencePath(cfg.ClassesPath), } deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient, opts...) if err != nil { return nil, err } containers, err := devicecontainers.FetchAllContainers(ctx, ctrClient, rc) if err != nil { return nil, fmt.Errorf("failed to fetch all containers: %w", err) } log.Info("loaded containers requesting devices", "containers", fetchContainerNames(slices.Collect(maps.Values(containers)))) deviceClassChan := make(chan *dsv1.DeviceClass, 1) informer, err := newDeviceClassInformer(ctx, k8sClient, config, deviceClassChan, opts...) if err != nil { return nil, fmt.Errorf("failed to start Device resources informer: %w", err) } watcher, err := newKubeletSocketWatcher() if err != nil { return nil, fmt.Errorf("failed to start kubelet socket watcher: %w", err) } containerDeviceRuleJobQueue := map[string]chan func(context.Context){} return &deviceAgent{k8sClient, ctrClient, rc, cfg, decoder, map[string]plugins.Plugin{}, deviceClasses, containers, deviceClassChan, containerDeviceRuleJobQueue, watcher, resourceManager, informer, opts}, nil } // Start triggers the device agent to start watching for device change events // including container starts, udev events, and kubelet restarts. // // These device events are processed in batches and propagate changes // to device-plugins and container cgroup updates. func (da *deviceAgent) Start(ctx context.Context) error { log := logger.FromContext(ctx) log = log.WithGroup("device-agent") log.Info("starting device agent", "device classes", slices.Collect(maps.Keys(da.deviceClasses)), "device plugin servers", slices.Collect(maps.Keys(da.servers)), "configuration", da.cfg, ) ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() defer da.watcher.Close() eventService := da.ctrClient.EventService() ctrEventChan, ctrStreamErrChan := eventService.Subscribe(ctx, []string{}...) udevEventChan, udevEventErrChan := reader.StreamUEvents(ctx, da.decoder) da.reconcileDevicePluginServers(ctx) go reportDiskSizeMetrics(ctx) deviceClassPatchQueue := make(chan map[string]*dsv1.DeviceClass, 1) go startDeviceClassPatchingWorker(ctx, da.resourceManager, deviceClassPatchQueue) log.Info("starting DeviceClass informer") go da.informer.Start(ctx.Done()) deviceUpdateQueue := make(chan deviceConfigJob, da.cfg.Workers) go da.startDeviceJobQueue(ctx, deviceUpdateQueue) // update all containers when the device agent starts containers, postHookFns := fetchAllContainersToUpdate(ctx, da.containers) deviceUpdateQueue <- newJob(containers, postHookFns, da.deviceClasses) deviceClassPatchQueue <- da.deviceClasses classEventConstructor := events.ClassEventConstructor(da.servers, da.deviceClasses, da.containers) udevEventConstructor := events.UDevEventConstructor(da.ctrClient, da.deviceClasses, da.containers) containerEventConstructor := events.ContainerEventConstructor(da.containers, da.ctrClient, da.runtimeClient) for { da.checkIfProcessDeviceEventBatch(ctx, da.deviceClasses, deviceUpdateQueue) select { case deviceClass := <-da.deviceClassChan: eventBatch = append(eventBatch, classEventConstructor(ctx, deviceClass)) deviceClassPatchQueue <- da.deviceClasses case udevEvent := <-udevEventChan: event, err := udevEventConstructor(ctx, udevEvent) if err != nil { log.Error("error generating udev event", "error", err) continue } eventBatch = append(eventBatch, event) deviceClassPatchQueue <- da.deviceClasses case ctrEvent := <-ctrEventChan: event, err := containerEventConstructor(ctx, ctrEvent) if err != nil { log.Error("error generating container event", "error", err) continue } if event == nil { continue } eventBatch = append(eventBatch, event) case kubeletRestartEvent := <-da.watcher.Events: if kubeletRestartEvent.Op&fsnotify.Create != fsnotify.Create || kubeletRestartEvent.Name != v1beta1.KubeletSocket { continue } log.Info("kubelet has restarted or device classes have been updated, updating device plugins") eventBatch = append(eventBatch, events.NewKubeletEvent()) case err := <-da.watcher.Errors: log.Error("error fs watcher", "error", err) case err := <-ctrStreamErrChan: log.Error("error container events watcher", "error", err) case err := <-udevEventErrChan: if errors.Is(err, io.EOF) { continue } log.Error("error has occurred monitoring udev events", "error", err) case <-ctx.Done(): log.Info("stopping all device plugin servers") da.stopAllDevicePlugins() return nil case <-time.After(time.Millisecond * batchWaitInterval): continue } } } // checkIfProcessDeviceEventBatch checks if no new device events have passed in the last 100ms and generates a new device configuration job to process func (da deviceAgent) checkIfProcessDeviceEventBatch(ctx context.Context, deviceClasses map[string]*dsv1.DeviceClass, queue chan deviceConfigJob) { if len(eventBatch) >= 1 { // check if batchTimeSince has passed without new uevents ts := time.Since(eventBatch[len(eventBatch)-1].Timestamp()).Milliseconds() if ts >= batchTimeSince { da.reconcileDevicePluginServers(ctx) containers, postHookFns := compactDeviceEvents(eventBatch) eventBatch = make([]events.DeviceEvent, 0) queue <- newJob(containers, postHookFns, deviceClasses) } } } // reconcileDevicePluginServers will instantiate the device plugin server for each class if it does not exist, start the // server and send an update to the device plugin. This ensures that the plugin is up-to-date with latest list of devices func (da deviceAgent) reconcileDevicePluginServers(ctx context.Context) { for name, devClass := range da.deviceClasses { if srv, ok := da.servers[name]; !ok || srv != nil { da.servers[name] = pluginFn(devClass) } if !da.servers[name].IsRunning() { go da.servers[name].Run(ctx) } da.servers[name].Update(devClass) } } // startDeviceJobQueue takes the device event jobs and processes them func (da deviceAgent) startDeviceJobQueue(ctx context.Context, jobQueueChan chan deviceConfigJob) { for { select { case <-ctx.Done(): return case job := <-jobQueueChan: job.Run(ctx) } } } // stopAllDevicePluginServers will go through all the device // plugin servers and attempt to stop them func (da deviceAgent) stopAllDevicePlugins() { for name, srv := range da.servers { srv.Stop() delete(da.servers, name) } } // compactDeviceEvents compacts the device events into a list of containers to update and a list of posthook functions to call func compactDeviceEvents(events []events.DeviceEvent) (map[string]*containers.Container, []func(context.Context)) { containers := map[string]*containers.Container{} postHookFns := []func(context.Context){} for _, updateEvent := range events { // merges all containers to update into one map maps.Insert(containers, maps.All(updateEvent.Containers())) if updateEvent.PosthookFunc() != nil { postHookFns = append(postHookFns, updateEvent.PosthookFunc()) } } return containers, postHookFns } // fetchAllContainersToUpdate will generate a deviceConfigJob to update all containers func fetchAllContainersToUpdate(ctx context.Context, allContainers map[string]*containers.Container) (map[string]*containers.Container, []func(context.Context)) { log := logger.FromContext(ctx) deviceEvents := []events.DeviceEvent{} for _, ctr := range allContainers { event, err := events.NewContainerEvent(ctr) if err != nil { log.Error("error generating container event", "error", err) continue } deviceEvents = append(deviceEvents, event) } return compactDeviceEvents(deviceEvents) } // newKubeletSocketWatcher creates a new file system notify watcher // on the kubelet socket file func newKubeletSocketWatcher() (*fsnotify.Watcher, error) { watcher, err := fsnotify.NewWatcher() if err != nil { return nil, err } if err = watcher.Add(kubeletPath); err != nil { watcher.Close() return nil, err } return watcher, nil } func fetchContainerNames(containers []*containers.Container) []string { names := make([]string, len(containers)) for idx, ctr := range containers { name := ctr.Labels[cc.AnnContainerName] names[idx] = name } return names } // createScheme returns a runtime schema registered // with the corev1 and device-system scheme's. func createScheme() *runtime.Scheme { scheme := runtime.NewScheme() utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(v1.AddToScheme(scheme)) utilruntime.Must(dsv1.AddToScheme(scheme)) return scheme }