...

Source file src/edge-infra.dev/pkg/sds/devices/k8s/device-plugins/plugins.go

Documentation: edge-infra.dev/pkg/sds/devices/k8s/device-plugins

     1  package plugins
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"fmt"
     7  	"log/slog"
     8  	"net"
     9  	"os"
    10  	"path/filepath"
    11  	"slices"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	fsnotify "github.com/fsnotify/fsnotify"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/connectivity"
    19  	"google.golang.org/grpc/credentials/insecure"
    20  	"k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
    21  	pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
    22  
    23  	"edge-infra.dev/pkg/lib/filesystem"
    24  	"edge-infra.dev/pkg/lib/kernel/devices"
    25  	cc "edge-infra.dev/pkg/sds/devices/agent/common"
    26  	"edge-infra.dev/pkg/sds/devices/class"
    27  	dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
    28  	"edge-infra.dev/pkg/sds/devices/logger"
    29  )
    30  
    31  var (
    32  	// kubelet plugin dir
    33  	kubeletDevicePluginPath = "/var/lib/kubelet/device-plugins/"
    34  	// serverWaitTimeout is amount seconds to wait for server to respond
    35  	serverWaitTimeout = time.Second * 1
    36  )
    37  
    38  const (
    39  	// rootUserID is user id for root
    40  	rootUserID = 0
    41  	// removable is constant for removable devices
    42  	removable = "removable"
    43  )
    44  
    45  type Plugin interface {
    46  	v1beta1.DevicePluginServer
    47  	Run(ctx context.Context)
    48  	Stop()
    49  	Remove()
    50  	Update(deviceClass *dsv1.DeviceClass)
    51  	IsRunning() bool
    52  }
    53  
    54  const maxDeviceRequestsPerClass = 100
    55  const waitInterval = time.Second * 5
    56  
    57  type plugin struct {
    58  	name           string
    59  	ids            []string
    60  	socketFilePath string
    61  	grpcServer     *grpc.Server
    62  	serving        bool
    63  	log            *slog.Logger
    64  	mutex          sync.Mutex
    65  	devicesClassCh chan *dsv1.DeviceClass
    66  	deviceClass    *dsv1.DeviceClass
    67  }
    68  
    69  // NewPlugin returns a new device plugin from the DeviceClass
    70  func NewPlugin(deviceClass *dsv1.DeviceClass) Plugin {
    71  	name := class.FmtClassLabel(deviceClass.ObjectMeta.GetName())
    72  	socketPath := generateFilePathName(name)
    73  	opts := []logger.Option{
    74  		logger.WithLevel(deviceClass.Spec.LogLevel()),
    75  	}
    76  	return &plugin{
    77  		ids:            []string{},
    78  		devicesClassCh: make(chan *dsv1.DeviceClass, 1),
    79  		mutex:          sync.Mutex{},
    80  		name:           name,
    81  		socketFilePath: socketPath,
    82  		serving:        false,
    83  		log:            logger.New(opts...).With("class", name, "path", socketPath),
    84  		deviceClass:    deviceClass,
    85  	}
    86  }
    87  
    88  // Name returns the class name for the resource
    89  func (p *plugin) Name() string {
    90  	return p.name
    91  }
    92  
    93  // File returns the file path for device plugin socket
    94  func (p *plugin) File() string {
    95  	return p.socketFilePath
    96  }
    97  
    98  func (p *plugin) setRunning() {
    99  	p.mutex.Lock()
   100  	p.serving = true
   101  	p.mutex.Unlock()
   102  }
   103  
   104  func (p *plugin) Stop() {
   105  	p.mutex.Lock()
   106  	p.serving = false
   107  	p.mutex.Unlock()
   108  	if p.grpcServer != nil {
   109  		p.grpcServer.Stop()
   110  		return
   111  	}
   112  	p.Remove()
   113  }
   114  
   115  // Remove will delete the socket file from file system
   116  func (p *plugin) Remove() {
   117  	_ = os.Remove(p.socketFilePath)
   118  }
   119  
   120  func (p *plugin) IsRunning() bool {
   121  	p.mutex.Lock()
   122  	isRunning := p.serving
   123  	p.mutex.Unlock()
   124  	return isRunning
   125  }
   126  
   127  // Run runs the device plugin every 5s until the context is cancelled
   128  func (p *plugin) Run(ctx context.Context) {
   129  	p.runRegistration(ctx)
   130  }
   131  
   132  // AddDevices will send the devices to channel to be sent to device plugin socket
   133  func (p *plugin) Update(deviceClass *dsv1.DeviceClass) {
   134  	if deviceClass == nil {
   135  		return
   136  	}
   137  	p.devicesClassCh <- deviceClass
   138  }
   139  
   140  // runRegistration will attempt to run the device plugin registration once
   141  func (p *plugin) runRegistration(ctx context.Context) {
   142  	p.setRunning()
   143  	for p.IsRunning() {
   144  		// wait before checking registration again
   145  		time.Sleep(waitInterval)
   146  
   147  		if err := waitForGRPCServer(ctx, p.socketFilePath); err == nil {
   148  			return // server is already running
   149  		}
   150  		_ = os.Remove(p.socketFilePath)
   151  
   152  		socket, err := net.Listen("unix", p.socketFilePath)
   153  		if err != nil {
   154  			p.log.Error("error listening to unix socket", "error", err)
   155  			continue
   156  		}
   157  
   158  		// register new grpc server and device server
   159  		p.grpcServer = grpc.NewServer(grpc.ConnectionTimeout(serverWaitTimeout))
   160  		v1beta1.RegisterDevicePluginServer(p.grpcServer, p)
   161  
   162  		go func() {
   163  			if err := p.grpcServer.Serve(socket); err != nil {
   164  				p.log.Error("error serving grpc server", "error", err.Error())
   165  				p.grpcServer.Stop()
   166  				p.Remove()
   167  			}
   168  		}()
   169  
   170  		if err := waitForGRPCServer(ctx, p.socketFilePath); err != nil {
   171  			p.log.Error("error listening to gRPC server socket", "error", err)
   172  			continue
   173  		}
   174  
   175  		if err := p.registerWithKubelet(ctx); err != nil {
   176  			p.log.Error("error registering with kubelet", "error", err)
   177  			continue
   178  		}
   179  
   180  		if err := p.watchSocketFile(); err != nil {
   181  			p.log.Error("error watching socket file", "error", err)
   182  			continue
   183  		}
   184  		p.log.Info("stopping grpc server")
   185  		p.Stop()
   186  
   187  		// attempt to start the device plugin again
   188  		p.setRunning()
   189  	}
   190  }
   191  
   192  // Allocate which return list of devices. This gets called during pod creation. If a pod requests a device,
   193  // the device plugin will automatically mount /dev dir and /run/udev/control socket to the container.
   194  //
   195  // Given a slice of devices, if a DeviceClass should block, the Allocate request will return an UnexpectedAdmissionError
   196  // to block the start of the container until the requested device is present.
   197  func (p *plugin) Allocate(_ context.Context, req *v1beta1.AllocateRequest) (*pluginapi.AllocateResponse, error) {
   198  	res := &v1beta1.AllocateResponse{
   199  		ContainerResponses: make([]*v1beta1.ContainerAllocateResponse, 0, len(req.ContainerRequests)),
   200  	}
   201  
   202  	if willBlock, blockingSetName := p.deviceClass.WillBlock(); willBlock {
   203  		return nil, fmt.Errorf("blocking container start, waiting for devices to become available for class %s: %s", p.name, blockingSetName)
   204  	}
   205  
   206  	for _, req := range req.ContainerRequests {
   207  		resp := &v1beta1.ContainerAllocateResponse{
   208  			Annotations: map[string]string{
   209  				p.name: "1",
   210  			},
   211  			Mounts: []*pluginapi.Mount{
   212  				{
   213  					ContainerPath: "/dev",
   214  					HostPath:      "/dev",
   215  				},
   216  				{
   217  					ContainerPath: "/run/udev",
   218  					HostPath:      "/run/udev",
   219  				},
   220  			},
   221  			Envs: map[string]string{},
   222  		}
   223  
   224  		for _, id := range req.DevicesIDs {
   225  			if !slices.Contains(p.ids, id) {
   226  				continue
   227  			}
   228  			for _, dev := range p.deviceClass.DeviceIter() {
   229  				node, err := dev.Node()
   230  				if err != nil || node == nil {
   231  					p.log.Debug("device has no node", "path", dev.Path())
   232  					continue
   233  				}
   234  
   235  				if _, err := os.Stat(node.Path()); err != nil {
   236  					p.log.Debug("could not stat device node", "path", node.Path())
   237  					continue
   238  				}
   239  
   240  				if err := p.overrideDeviceNodeGroupOwner(node); err != nil {
   241  					p.log.Error("error setting device node permissions", "error", err)
   242  				}
   243  
   244  				// only add non-removable devices to the runtime spec
   245  				if !p.rootDeviceIsRemovable(dev.Path()) {
   246  					resp.Devices = append(resp.Devices, &pluginapi.DeviceSpec{
   247  						ContainerPath: node.Path(),
   248  						HostPath:      node.Path(),
   249  						Permissions:   "rwm",
   250  					})
   251  				}
   252  			}
   253  		}
   254  		p.log.Info("allocate devices to container", "devices", len(resp.Devices))
   255  		res.ContainerResponses = append(res.ContainerResponses, resp)
   256  	}
   257  	return res, nil
   258  }
   259  
   260  // ListAndWatch will check every 5s for add/removed devices from the device channel and send an updated
   261  // response to the device plugins list and watcher server.
   262  func (p *plugin) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DevicePlugin_ListAndWatchServer) error {
   263  	if err := p.sendDeviceUpdate(stream); err != nil {
   264  		return err
   265  	}
   266  
   267  	// change blocking behaviour to here
   268  	for devClass := range p.devicesClassCh {
   269  		p.deviceClass = devClass
   270  		if err := p.sendDeviceUpdate(stream); err != nil {
   271  			return err
   272  		}
   273  	}
   274  	return nil
   275  }
   276  
   277  // sendDevice update will send an updated response to the device plugins list
   278  // and watcher server with current set of devices matched to the device plugin.
   279  func (p *plugin) sendDeviceUpdate(stream v1beta1.DevicePlugin_ListAndWatchServer) error {
   280  	res := &v1beta1.ListAndWatchResponse{
   281  		Devices: []*pluginapi.Device{},
   282  	}
   283  
   284  	for i := 1; i <= maxDeviceRequestsPerClass; i++ {
   285  		h := sha256.New()
   286  		id := fmt.Sprintf("%s-%d", p.name, i)
   287  		h.Write([]byte(id))
   288  		id = fmt.Sprintf("%x", h.Sum(nil))
   289  		res.Devices = append(res.Devices, &pluginapi.Device{
   290  			ID:     id,
   291  			Health: pluginapi.Healthy,
   292  		})
   293  		p.ids = append(p.ids, id)
   294  	}
   295  	if err := stream.Send(res); err != nil {
   296  		p.Stop()
   297  		return err
   298  	}
   299  	return nil
   300  }
   301  
   302  func (p *plugin) PreStartContainer(context.Context, *v1beta1.PreStartContainerRequest) (*v1beta1.PreStartContainerResponse, error) {
   303  	return &pluginapi.PreStartContainerResponse{}, nil
   304  }
   305  
   306  func (p *plugin) GetPreferredAllocation(context.Context, *v1beta1.PreferredAllocationRequest) (*v1beta1.PreferredAllocationResponse, error) {
   307  	return &pluginapi.PreferredAllocationResponse{}, nil
   308  }
   309  
   310  func (p *plugin) GetDevicePluginOptions(_ context.Context, _ *v1beta1.Empty) (*v1beta1.DevicePluginOptions, error) {
   311  	return &pluginapi.DevicePluginOptions{
   312  		PreStartRequired:                false,
   313  		GetPreferredAllocationAvailable: true,
   314  	}, nil
   315  }
   316  
   317  // registerWithKubelet will register the device plugin with kubelet service
   318  func (p *plugin) registerWithKubelet(ctx context.Context) error {
   319  	c, err := grpc.NewClient(filepath.Join("unix://", v1beta1.KubeletSocket), grpc.WithTransportCredentials(insecure.NewCredentials()))
   320  	if err != nil {
   321  		return err
   322  	}
   323  	defer c.Close()
   324  
   325  	client := v1beta1.NewRegistrationClient(c)
   326  	request := &v1beta1.RegisterRequest{
   327  		Version:      v1beta1.Version,
   328  		Endpoint:     filepath.Base(p.socketFilePath),
   329  		ResourceName: p.name,
   330  	}
   331  
   332  	if _, err := client.Register(ctx, request); err != nil {
   333  		return fmt.Errorf("failed to register plugin with kubelet service: %v", err)
   334  	}
   335  	p.log.Info("registered resource with kubelet", "resource", p.name)
   336  	return nil
   337  }
   338  
   339  // waitForGRPCServer will wait to connect to the device plugin socket and wait until its in a ready state
   340  // before registering it with the kubelet
   341  func waitForGRPCServer(ctx context.Context, socket string) error {
   342  	conn, err := grpc.NewClient(filepath.Join("unix://", socket), grpc.WithTransportCredentials(insecure.NewCredentials()))
   343  	if err != nil {
   344  		return err
   345  	}
   346  	defer conn.Close()
   347  
   348  	ctx, cancel := context.WithTimeout(ctx, serverWaitTimeout)
   349  	defer cancel()
   350  	for {
   351  		state := conn.GetState()
   352  		if state == connectivity.Idle {
   353  			conn.Connect()
   354  		}
   355  		if state == connectivity.Ready {
   356  			return nil
   357  		}
   358  		if !conn.WaitForStateChange(ctx, state) {
   359  			return ctx.Err()
   360  		}
   361  	}
   362  }
   363  
   364  // watchSocketFile will wait until the socket file
   365  // is removed (i.e. kubelet restart) and will close
   366  // the device plugin server.
   367  func (p *plugin) watchSocketFile() error {
   368  	watcher, err := fsnotify.NewWatcher()
   369  	if err != nil {
   370  		return err
   371  	}
   372  	defer watcher.Close()
   373  
   374  	if err = watcher.Add(filepath.Dir(p.socketFilePath)); err != nil {
   375  		return err
   376  	}
   377  
   378  	for {
   379  		select {
   380  		case event := <-watcher.Events:
   381  			if (event.Op == fsnotify.Remove || event.Op == fsnotify.Rename) && event.Name == p.socketFilePath {
   382  				return nil
   383  			}
   384  		case err := <-watcher.Errors:
   385  			return err
   386  		}
   387  	}
   388  }
   389  
   390  // overrideDeviceNodeGroupOwner will change the group owner of device node to deviceg (1015) group
   391  // and set the device node permissions to read/write
   392  func (p *plugin) overrideDeviceNodeGroupOwner(node devices.Node) error {
   393  	groupID, err := node.GroupID()
   394  	if err != nil {
   395  		return fmt.Errorf("error could not determin node group owner id: %s : %w", node.Path(), err)
   396  	}
   397  
   398  	// ignore device nodes not owned by root user
   399  	if groupID != rootUserID {
   400  		return nil
   401  	}
   402  
   403  	if groupID != cc.DeviceGroupID {
   404  		p.log.Debug("changing device node group owner from root to deviceg (1015)", "node", node.Path())
   405  		if err := os.Chown(node.Path(), -1, cc.DeviceGroupID); err != nil {
   406  			return fmt.Errorf("error changing device node group owner to deviceg (1015): %s: %w", node.Path(), err)
   407  		}
   408  	}
   409  
   410  	fileMode, err := node.FileMode()
   411  	if err != nil {
   412  		return err
   413  	}
   414  
   415  	newMode := fileMode | filesystem.GroupReadWritePerm // mask group read write permissions
   416  	if fileMode == newMode {
   417  		return nil
   418  	}
   419  
   420  	p.log.Debug("changing device node permissions to read/write by group owner", "node", node.Path())
   421  	if err := os.Chmod(node.Path(), newMode); err != nil {
   422  		return fmt.Errorf("error changing device node permissions to read/write for group owner: %s: %w", node.Path(), err)
   423  	}
   424  	return nil
   425  }
   426  
   427  // rootDeviceIsRemovable will attempt to find the root parent device
   428  // for a given path and check if the device is removable.
   429  func (p *plugin) rootDeviceIsRemovable(path string) bool {
   430  	pathSplit := strings.Split(path, "/")
   431  	for idx := 0; idx <= len(pathSplit)-1; idx++ {
   432  		searchPath := strings.Join(pathSplit[:idx], "/")
   433  		device := p.deviceClass.DeviceGet(searchPath)
   434  		if device == nil {
   435  			continue
   436  		}
   437  
   438  		canRemove, exists, _ := device.Attribute(removable)
   439  		if !exists {
   440  			continue
   441  		}
   442  
   443  		if canRemove == removable {
   444  			return true
   445  		}
   446  	}
   447  	return false
   448  }
   449  
   450  // generateFilePathName generates a socker name for the resource under
   451  // device-system-<className>
   452  func generateFilePathName(resourceName string) string {
   453  	classFmt := resourceFileName(resourceName)
   454  	return filepath.Join(kubeletDevicePluginPath, fmt.Sprintf("ds-%s.sock", classFmt))
   455  }
   456  
   457  // resourceFileName will remove the device class prefix and remove any forward slashes or periods.
   458  func resourceFileName(resourceName string) string {
   459  	classFmt := strings.ReplaceAll(resourceName, class.DeviceClassPrefix, "")
   460  	classFmt = strings.ReplaceAll(classFmt, "/", "")
   461  	classFmt = strings.ReplaceAll(classFmt, ".", "-")
   462  	return classFmt
   463  }
   464  

View as plain text