package plugins import ( "context" "crypto/sha256" "fmt" "log/slog" "net" "os" "path/filepath" "slices" "strings" "sync" "time" fsnotify "github.com/fsnotify/fsnotify" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "edge-infra.dev/pkg/lib/filesystem" "edge-infra.dev/pkg/lib/kernel/devices" cc "edge-infra.dev/pkg/sds/devices/agent/common" "edge-infra.dev/pkg/sds/devices/class" dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" "edge-infra.dev/pkg/sds/devices/logger" ) var ( // kubelet plugin dir kubeletDevicePluginPath = "/var/lib/kubelet/device-plugins/" // serverWaitTimeout is amount seconds to wait for server to respond serverWaitTimeout = time.Second * 1 ) const ( // rootUserID is user id for root rootUserID = 0 // removable is constant for removable devices removable = "removable" ) type Plugin interface { v1beta1.DevicePluginServer Run(ctx context.Context) Stop() Remove() Update(deviceClass *dsv1.DeviceClass) IsRunning() bool } const maxDeviceRequestsPerClass = 100 const waitInterval = time.Second * 5 type plugin struct { name string ids []string socketFilePath string grpcServer *grpc.Server serving bool log *slog.Logger mutex sync.Mutex devicesClassCh chan *dsv1.DeviceClass deviceClass *dsv1.DeviceClass } // NewPlugin returns a new device plugin from the DeviceClass func NewPlugin(deviceClass *dsv1.DeviceClass) Plugin { name := class.FmtClassLabel(deviceClass.ObjectMeta.GetName()) socketPath := generateFilePathName(name) opts := []logger.Option{ logger.WithLevel(deviceClass.Spec.LogLevel()), } return &plugin{ ids: []string{}, devicesClassCh: make(chan *dsv1.DeviceClass, 1), mutex: sync.Mutex{}, name: name, socketFilePath: socketPath, serving: false, log: logger.New(opts...).With("class", name, "path", socketPath), deviceClass: deviceClass, } } // Name returns the class name for the resource func (p *plugin) Name() string { return p.name } // File returns the file path for device plugin socket func (p *plugin) File() string { return p.socketFilePath } func (p *plugin) setRunning() { p.mutex.Lock() p.serving = true p.mutex.Unlock() } func (p *plugin) Stop() { p.mutex.Lock() p.serving = false p.mutex.Unlock() if p.grpcServer != nil { p.grpcServer.Stop() return } p.Remove() } // Remove will delete the socket file from file system func (p *plugin) Remove() { _ = os.Remove(p.socketFilePath) } func (p *plugin) IsRunning() bool { p.mutex.Lock() isRunning := p.serving p.mutex.Unlock() return isRunning } // Run runs the device plugin every 5s until the context is cancelled func (p *plugin) Run(ctx context.Context) { p.runRegistration(ctx) } // AddDevices will send the devices to channel to be sent to device plugin socket func (p *plugin) Update(deviceClass *dsv1.DeviceClass) { if deviceClass == nil { return } p.devicesClassCh <- deviceClass } // runRegistration will attempt to run the device plugin registration once func (p *plugin) runRegistration(ctx context.Context) { p.setRunning() for p.IsRunning() { // wait before checking registration again time.Sleep(waitInterval) if err := waitForGRPCServer(ctx, p.socketFilePath); err == nil { return // server is already running } _ = os.Remove(p.socketFilePath) socket, err := net.Listen("unix", p.socketFilePath) if err != nil { p.log.Error("error listening to unix socket", "error", err) continue } // register new grpc server and device server p.grpcServer = grpc.NewServer(grpc.ConnectionTimeout(serverWaitTimeout)) v1beta1.RegisterDevicePluginServer(p.grpcServer, p) go func() { if err := p.grpcServer.Serve(socket); err != nil { p.log.Error("error serving grpc server", "error", err.Error()) p.grpcServer.Stop() p.Remove() } }() if err := waitForGRPCServer(ctx, p.socketFilePath); err != nil { p.log.Error("error listening to gRPC server socket", "error", err) continue } if err := p.registerWithKubelet(ctx); err != nil { p.log.Error("error registering with kubelet", "error", err) continue } if err := p.watchSocketFile(); err != nil { p.log.Error("error watching socket file", "error", err) continue } p.log.Info("stopping grpc server") p.Stop() // attempt to start the device plugin again p.setRunning() } } // Allocate which return list of devices. This gets called during pod creation. If a pod requests a device, // the device plugin will automatically mount /dev dir and /run/udev/control socket to the container. // // Given a slice of devices, if a DeviceClass should block, the Allocate request will return an UnexpectedAdmissionError // to block the start of the container until the requested device is present. func (p *plugin) Allocate(_ context.Context, req *v1beta1.AllocateRequest) (*pluginapi.AllocateResponse, error) { res := &v1beta1.AllocateResponse{ ContainerResponses: make([]*v1beta1.ContainerAllocateResponse, 0, len(req.ContainerRequests)), } if willBlock, blockingSetName := p.deviceClass.WillBlock(); willBlock { return nil, fmt.Errorf("blocking container start, waiting for devices to become available for class %s: %s", p.name, blockingSetName) } for _, req := range req.ContainerRequests { resp := &v1beta1.ContainerAllocateResponse{ Annotations: map[string]string{ p.name: "1", }, Mounts: []*pluginapi.Mount{ { ContainerPath: "/dev", HostPath: "/dev", }, { ContainerPath: "/run/udev", HostPath: "/run/udev", }, }, Envs: map[string]string{}, } for _, id := range req.DevicesIDs { if !slices.Contains(p.ids, id) { continue } for _, dev := range p.deviceClass.DeviceIter() { node, err := dev.Node() if err != nil || node == nil { p.log.Debug("device has no node", "path", dev.Path()) continue } if _, err := os.Stat(node.Path()); err != nil { p.log.Debug("could not stat device node", "path", node.Path()) continue } if err := p.overrideDeviceNodeGroupOwner(node); err != nil { p.log.Error("error setting device node permissions", "error", err) } // only add non-removable devices to the runtime spec if !p.rootDeviceIsRemovable(dev.Path()) { resp.Devices = append(resp.Devices, &pluginapi.DeviceSpec{ ContainerPath: node.Path(), HostPath: node.Path(), Permissions: "rwm", }) } } } p.log.Info("allocate devices to container", "devices", len(resp.Devices)) res.ContainerResponses = append(res.ContainerResponses, resp) } return res, nil } // ListAndWatch will check every 5s for add/removed devices from the device channel and send an updated // response to the device plugins list and watcher server. func (p *plugin) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DevicePlugin_ListAndWatchServer) error { if err := p.sendDeviceUpdate(stream); err != nil { return err } // change blocking behaviour to here for devClass := range p.devicesClassCh { p.deviceClass = devClass if err := p.sendDeviceUpdate(stream); err != nil { return err } } return nil } // sendDevice update will send an updated response to the device plugins list // and watcher server with current set of devices matched to the device plugin. func (p *plugin) sendDeviceUpdate(stream v1beta1.DevicePlugin_ListAndWatchServer) error { res := &v1beta1.ListAndWatchResponse{ Devices: []*pluginapi.Device{}, } for i := 1; i <= maxDeviceRequestsPerClass; i++ { h := sha256.New() id := fmt.Sprintf("%s-%d", p.name, i) h.Write([]byte(id)) id = fmt.Sprintf("%x", h.Sum(nil)) res.Devices = append(res.Devices, &pluginapi.Device{ ID: id, Health: pluginapi.Healthy, }) p.ids = append(p.ids, id) } if err := stream.Send(res); err != nil { p.Stop() return err } return nil } func (p *plugin) PreStartContainer(context.Context, *v1beta1.PreStartContainerRequest) (*v1beta1.PreStartContainerResponse, error) { return &pluginapi.PreStartContainerResponse{}, nil } func (p *plugin) GetPreferredAllocation(context.Context, *v1beta1.PreferredAllocationRequest) (*v1beta1.PreferredAllocationResponse, error) { return &pluginapi.PreferredAllocationResponse{}, nil } func (p *plugin) GetDevicePluginOptions(_ context.Context, _ *v1beta1.Empty) (*v1beta1.DevicePluginOptions, error) { return &pluginapi.DevicePluginOptions{ PreStartRequired: false, GetPreferredAllocationAvailable: true, }, nil } // registerWithKubelet will register the device plugin with kubelet service func (p *plugin) registerWithKubelet(ctx context.Context) error { c, err := grpc.NewClient(filepath.Join("unix://", v1beta1.KubeletSocket), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err } defer c.Close() client := v1beta1.NewRegistrationClient(c) request := &v1beta1.RegisterRequest{ Version: v1beta1.Version, Endpoint: filepath.Base(p.socketFilePath), ResourceName: p.name, } if _, err := client.Register(ctx, request); err != nil { return fmt.Errorf("failed to register plugin with kubelet service: %v", err) } p.log.Info("registered resource with kubelet", "resource", p.name) return nil } // waitForGRPCServer will wait to connect to the device plugin socket and wait until its in a ready state // before registering it with the kubelet func waitForGRPCServer(ctx context.Context, socket string) error { conn, err := grpc.NewClient(filepath.Join("unix://", socket), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return err } defer conn.Close() ctx, cancel := context.WithTimeout(ctx, serverWaitTimeout) defer cancel() for { state := conn.GetState() if state == connectivity.Idle { conn.Connect() } if state == connectivity.Ready { return nil } if !conn.WaitForStateChange(ctx, state) { return ctx.Err() } } } // watchSocketFile will wait until the socket file // is removed (i.e. kubelet restart) and will close // the device plugin server. func (p *plugin) watchSocketFile() error { watcher, err := fsnotify.NewWatcher() if err != nil { return err } defer watcher.Close() if err = watcher.Add(filepath.Dir(p.socketFilePath)); err != nil { return err } for { select { case event := <-watcher.Events: if (event.Op == fsnotify.Remove || event.Op == fsnotify.Rename) && event.Name == p.socketFilePath { return nil } case err := <-watcher.Errors: return err } } } // overrideDeviceNodeGroupOwner will change the group owner of device node to deviceg (1015) group // and set the device node permissions to read/write func (p *plugin) overrideDeviceNodeGroupOwner(node devices.Node) error { groupID, err := node.GroupID() if err != nil { return fmt.Errorf("error could not determin node group owner id: %s : %w", node.Path(), err) } // ignore device nodes not owned by root user if groupID != rootUserID { return nil } if groupID != cc.DeviceGroupID { p.log.Debug("changing device node group owner from root to deviceg (1015)", "node", node.Path()) if err := os.Chown(node.Path(), -1, cc.DeviceGroupID); err != nil { return fmt.Errorf("error changing device node group owner to deviceg (1015): %s: %w", node.Path(), err) } } fileMode, err := node.FileMode() if err != nil { return err } newMode := fileMode | filesystem.GroupReadWritePerm // mask group read write permissions if fileMode == newMode { return nil } p.log.Debug("changing device node permissions to read/write by group owner", "node", node.Path()) if err := os.Chmod(node.Path(), newMode); err != nil { return fmt.Errorf("error changing device node permissions to read/write for group owner: %s: %w", node.Path(), err) } return nil } // rootDeviceIsRemovable will attempt to find the root parent device // for a given path and check if the device is removable. func (p *plugin) rootDeviceIsRemovable(path string) bool { pathSplit := strings.Split(path, "/") for idx := 0; idx <= len(pathSplit)-1; idx++ { searchPath := strings.Join(pathSplit[:idx], "/") device := p.deviceClass.DeviceGet(searchPath) if device == nil { continue } canRemove, exists, _ := device.Attribute(removable) if !exists { continue } if canRemove == removable { return true } } return false } // generateFilePathName generates a socker name for the resource under // device-system- func generateFilePathName(resourceName string) string { classFmt := resourceFileName(resourceName) return filepath.Join(kubeletDevicePluginPath, fmt.Sprintf("ds-%s.sock", classFmt)) } // resourceFileName will remove the device class prefix and remove any forward slashes or periods. func resourceFileName(resourceName string) string { classFmt := strings.ReplaceAll(resourceName, class.DeviceClassPrefix, "") classFmt = strings.ReplaceAll(classFmt, "/", "") classFmt = strings.ReplaceAll(classFmt, ".", "-") return classFmt }