...

Source file src/edge-infra.dev/pkg/sds/devices/agent/agent.go

Documentation: edge-infra.dev/pkg/sds/devices/agent

     1  //go:build linux
     2  
     3  package agent
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"maps"
    11  	"path/filepath"
    12  	"slices"
    13  	"time"
    14  
    15  	"github.com/containerd/containerd"
    16  	"github.com/containerd/containerd/containers"
    17  	fsnotify "github.com/fsnotify/fsnotify"
    18  	"k8s.io/apimachinery/pkg/runtime"
    19  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    20  	"k8s.io/client-go/dynamic/dynamicinformer"
    21  	clientgoscheme "k8s.io/client-go/kubernetes/scheme"
    22  	"k8s.io/client-go/rest"
    23  	criruntime "k8s.io/cri-api/pkg/apis/runtime/v1"
    24  	"k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
    25  	"sigs.k8s.io/controller-runtime/pkg/client"
    26  
    27  	"edge-infra.dev/pkg/k8s/runtime/sap"
    28  	"edge-infra.dev/pkg/lib/kernel/udev/reader"
    29  	cc "edge-infra.dev/pkg/sds/devices/agent/common"
    30  	devicecontainers "edge-infra.dev/pkg/sds/devices/agent/containers"
    31  	"edge-infra.dev/pkg/sds/devices/agent/events"
    32  	dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
    33  	v1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
    34  	plugins "edge-infra.dev/pkg/sds/devices/k8s/device-plugins"
    35  	"edge-infra.dev/pkg/sds/devices/logger"
    36  )
    37  
    38  const (
    39  	// batchTimeSince is the period (ms) in which the
    40  	// device agent will wait until it processes the batch
    41  	batchTimeSince int64 = 200
    42  	// batchWaitInterval is wait time (ms) between checking for new events
    43  	batchWaitInterval = 10
    44  	// etc file path
    45  	etcPath = "/etc"
    46  )
    47  
    48  var (
    49  	// kubeletPath is path to kubelet socket file
    50  	kubeletPath = filepath.Dir(v1beta1.KubeletSocket)
    51  )
    52  
    53  // pluginFn is function called to create a new plugin
    54  var pluginFn = plugins.NewPlugin
    55  
    56  // newDeviceClassInformer is function called to create a new dsv1.DeviceClass informer
    57  var newDeviceClassInformer = dsv1.WatchFromClient
    58  
    59  var (
    60  	// eventBatch is a slice containing batch of device events to process
    61  	eventBatch = []events.DeviceEvent{}
    62  )
    63  
    64  type DeviceAgent interface {
    65  	Start(ctx context.Context) error
    66  }
    67  
    68  type deviceAgent struct {
    69  	k8sClient                   client.Client
    70  	ctrClient                   *containerd.Client
    71  	runtimeClient               criruntime.RuntimeServiceClient
    72  	cfg                         Config
    73  	decoder                     reader.Decoder
    74  	servers                     map[string]plugins.Plugin
    75  	deviceClasses               map[string]*dsv1.DeviceClass
    76  	containers                  map[string]*containers.Container
    77  	deviceClassChan             <-chan *dsv1.DeviceClass
    78  	containerDeviceRuleJobQueue map[string]chan func(context.Context)
    79  	watcher                     *fsnotify.Watcher
    80  	resourceManager             *sap.ResourceManager
    81  	informer                    dynamicinformer.DynamicSharedInformerFactory
    82  	clientOpts                  []dsv1.ListOption
    83  }
    84  
    85  // NewDeviceAgent instatiates a new device agent runnable
    86  func NewDeviceAgent(ctx context.Context, config *rest.Config, k8sClient client.Client, ctrClient *containerd.Client, rc criruntime.RuntimeServiceClient, cfg Config, decoder reader.Decoder, resourceManager *sap.ResourceManager) (DeviceAgent, error) {
    87  	log := logger.FromContext(ctx)
    88  	log = log.WithGroup("device-agent-startup")
    89  
    90  	// opts is the list options for querying device classes
    91  	opts := []dsv1.ListOption{
    92  		dsv1.WithPersistence(true),
    93  		dsv1.WithPersistencePath(cfg.ClassesPath),
    94  	}
    95  
    96  	deviceClasses, err := dsv1.ListFromClient(ctx, k8sClient, opts...)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	containers, err := devicecontainers.FetchAllContainers(ctx, ctrClient, rc)
   102  	if err != nil {
   103  		return nil, fmt.Errorf("failed to fetch all containers: %w", err)
   104  	}
   105  	log.Info("loaded containers requesting devices", "containers", fetchContainerNames(slices.Collect(maps.Values(containers))))
   106  
   107  	deviceClassChan := make(chan *dsv1.DeviceClass, 1)
   108  	informer, err := newDeviceClassInformer(ctx, k8sClient, config, deviceClassChan, opts...)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("failed to start Device resources informer: %w", err)
   111  	}
   112  
   113  	watcher, err := newKubeletSocketWatcher()
   114  	if err != nil {
   115  		return nil, fmt.Errorf("failed to start kubelet socket watcher: %w", err)
   116  	}
   117  
   118  	containerDeviceRuleJobQueue := map[string]chan func(context.Context){}
   119  	return &deviceAgent{k8sClient, ctrClient, rc, cfg, decoder, map[string]plugins.Plugin{}, deviceClasses, containers, deviceClassChan, containerDeviceRuleJobQueue, watcher, resourceManager, informer, opts}, nil
   120  }
   121  
   122  // Start triggers the device agent to start watching for device change events
   123  // including container starts, udev events, and kubelet restarts.
   124  //
   125  // These device events are processed in batches and propagate changes
   126  // to device-plugins and container cgroup updates.
   127  func (da *deviceAgent) Start(ctx context.Context) error {
   128  	log := logger.FromContext(ctx)
   129  	log = log.WithGroup("device-agent")
   130  	log.Info("starting device agent",
   131  		"device classes", slices.Collect(maps.Keys(da.deviceClasses)),
   132  		"device plugin servers", slices.Collect(maps.Keys(da.servers)),
   133  		"configuration", da.cfg,
   134  	)
   135  
   136  	ctx, cancelFn := context.WithCancel(ctx)
   137  	defer cancelFn()
   138  	defer da.watcher.Close()
   139  
   140  	eventService := da.ctrClient.EventService()
   141  	ctrEventChan, ctrStreamErrChan := eventService.Subscribe(ctx, []string{}...)
   142  	udevEventChan, udevEventErrChan := reader.StreamUEvents(ctx, da.decoder)
   143  
   144  	da.reconcileDevicePluginServers(ctx)
   145  
   146  	go reportDiskSizeMetrics(ctx)
   147  
   148  	deviceClassPatchQueue := make(chan map[string]*dsv1.DeviceClass, 1)
   149  	go startDeviceClassPatchingWorker(ctx, da.resourceManager, deviceClassPatchQueue)
   150  
   151  	log.Info("starting DeviceClass informer")
   152  	go da.informer.Start(ctx.Done())
   153  
   154  	deviceUpdateQueue := make(chan deviceConfigJob, da.cfg.Workers)
   155  	go da.startDeviceJobQueue(ctx, deviceUpdateQueue)
   156  
   157  	// update all containers when the device agent starts
   158  	containers, postHookFns := fetchAllContainersToUpdate(ctx, da.containers)
   159  	deviceUpdateQueue <- newJob(containers, postHookFns, da.deviceClasses)
   160  	deviceClassPatchQueue <- da.deviceClasses
   161  
   162  	classEventConstructor := events.ClassEventConstructor(da.servers, da.deviceClasses, da.containers)
   163  	udevEventConstructor := events.UDevEventConstructor(da.ctrClient, da.deviceClasses, da.containers)
   164  	containerEventConstructor := events.ContainerEventConstructor(da.containers, da.ctrClient, da.runtimeClient)
   165  
   166  	for {
   167  		da.checkIfProcessDeviceEventBatch(ctx, da.deviceClasses, deviceUpdateQueue)
   168  
   169  		select {
   170  		case deviceClass := <-da.deviceClassChan:
   171  			eventBatch = append(eventBatch, classEventConstructor(ctx, deviceClass))
   172  			deviceClassPatchQueue <- da.deviceClasses
   173  		case udevEvent := <-udevEventChan:
   174  			event, err := udevEventConstructor(ctx, udevEvent)
   175  			if err != nil {
   176  				log.Error("error generating udev event", "error", err)
   177  				continue
   178  			}
   179  			eventBatch = append(eventBatch, event)
   180  			deviceClassPatchQueue <- da.deviceClasses
   181  		case ctrEvent := <-ctrEventChan:
   182  			event, err := containerEventConstructor(ctx, ctrEvent)
   183  			if err != nil {
   184  				log.Error("error generating container event", "error", err)
   185  				continue
   186  			}
   187  			if event == nil {
   188  				continue
   189  			}
   190  			eventBatch = append(eventBatch, event)
   191  		case kubeletRestartEvent := <-da.watcher.Events:
   192  			if kubeletRestartEvent.Op&fsnotify.Create != fsnotify.Create || kubeletRestartEvent.Name != v1beta1.KubeletSocket {
   193  				continue
   194  			}
   195  			log.Info("kubelet has restarted or device classes have been updated, updating device plugins")
   196  			eventBatch = append(eventBatch, events.NewKubeletEvent())
   197  		case err := <-da.watcher.Errors:
   198  			log.Error("error fs watcher", "error", err)
   199  		case err := <-ctrStreamErrChan:
   200  			log.Error("error container events watcher", "error", err)
   201  		case err := <-udevEventErrChan:
   202  			if errors.Is(err, io.EOF) {
   203  				continue
   204  			}
   205  			log.Error("error has occurred monitoring udev events", "error", err)
   206  		case <-ctx.Done():
   207  			log.Info("stopping all device plugin servers")
   208  			da.stopAllDevicePlugins()
   209  			return nil
   210  		case <-time.After(time.Millisecond * batchWaitInterval):
   211  			continue
   212  		}
   213  	}
   214  }
   215  
   216  // checkIfProcessDeviceEventBatch checks if no new device events have passed in the last 100ms and generates a new device configuration job to process
   217  func (da deviceAgent) checkIfProcessDeviceEventBatch(ctx context.Context, deviceClasses map[string]*dsv1.DeviceClass, queue chan deviceConfigJob) {
   218  	if len(eventBatch) >= 1 {
   219  		// check if batchTimeSince has passed without new uevents
   220  		ts := time.Since(eventBatch[len(eventBatch)-1].Timestamp()).Milliseconds()
   221  		if ts >= batchTimeSince {
   222  			da.reconcileDevicePluginServers(ctx)
   223  			containers, postHookFns := compactDeviceEvents(eventBatch)
   224  			eventBatch = make([]events.DeviceEvent, 0)
   225  			queue <- newJob(containers, postHookFns, deviceClasses)
   226  		}
   227  	}
   228  }
   229  
   230  // reconcileDevicePluginServers will instantiate the device plugin server for each class if it does not exist, start the
   231  // server and send an update to the device plugin. This ensures that the plugin is up-to-date with latest list of devices
   232  func (da deviceAgent) reconcileDevicePluginServers(ctx context.Context) {
   233  	for name, devClass := range da.deviceClasses {
   234  		if srv, ok := da.servers[name]; !ok || srv != nil {
   235  			da.servers[name] = pluginFn(devClass)
   236  		}
   237  		if !da.servers[name].IsRunning() {
   238  			go da.servers[name].Run(ctx)
   239  		}
   240  		da.servers[name].Update(devClass)
   241  	}
   242  }
   243  
   244  // startDeviceJobQueue takes the device event jobs and processes them
   245  func (da deviceAgent) startDeviceJobQueue(ctx context.Context, jobQueueChan chan deviceConfigJob) {
   246  	for {
   247  		select {
   248  		case <-ctx.Done():
   249  			return
   250  		case job := <-jobQueueChan:
   251  			job.Run(ctx)
   252  		}
   253  	}
   254  }
   255  
   256  // stopAllDevicePluginServers will go through all the device
   257  // plugin servers and attempt to stop them
   258  func (da deviceAgent) stopAllDevicePlugins() {
   259  	for name, srv := range da.servers {
   260  		srv.Stop()
   261  		delete(da.servers, name)
   262  	}
   263  }
   264  
   265  // compactDeviceEvents compacts the device events into a list of containers to update and a list of posthook functions to call
   266  func compactDeviceEvents(events []events.DeviceEvent) (map[string]*containers.Container, []func(context.Context)) {
   267  	containers := map[string]*containers.Container{}
   268  	postHookFns := []func(context.Context){}
   269  	for _, updateEvent := range events {
   270  		// merges all containers to update into one map
   271  		maps.Insert(containers, maps.All(updateEvent.Containers()))
   272  		if updateEvent.PosthookFunc() != nil {
   273  			postHookFns = append(postHookFns, updateEvent.PosthookFunc())
   274  		}
   275  	}
   276  	return containers, postHookFns
   277  }
   278  
   279  // fetchAllContainersToUpdate will generate a deviceConfigJob to update all containers
   280  func fetchAllContainersToUpdate(ctx context.Context, allContainers map[string]*containers.Container) (map[string]*containers.Container, []func(context.Context)) {
   281  	log := logger.FromContext(ctx)
   282  	deviceEvents := []events.DeviceEvent{}
   283  	for _, ctr := range allContainers {
   284  		event, err := events.NewContainerEvent(ctr)
   285  		if err != nil {
   286  			log.Error("error generating container event", "error", err)
   287  			continue
   288  		}
   289  		deviceEvents = append(deviceEvents, event)
   290  	}
   291  	return compactDeviceEvents(deviceEvents)
   292  }
   293  
   294  // newKubeletSocketWatcher creates a new file system notify watcher
   295  // on the kubelet socket file
   296  func newKubeletSocketWatcher() (*fsnotify.Watcher, error) {
   297  	watcher, err := fsnotify.NewWatcher()
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  	if err = watcher.Add(kubeletPath); err != nil {
   302  		watcher.Close()
   303  		return nil, err
   304  	}
   305  	return watcher, nil
   306  }
   307  
   308  func fetchContainerNames(containers []*containers.Container) []string {
   309  	names := make([]string, len(containers))
   310  	for idx, ctr := range containers {
   311  		name := ctr.Labels[cc.AnnContainerName]
   312  		names[idx] = name
   313  	}
   314  	return names
   315  }
   316  
   317  // createScheme returns a runtime schema registered
   318  // with the corev1 and device-system scheme's.
   319  func createScheme() *runtime.Scheme {
   320  	scheme := runtime.NewScheme()
   321  	utilruntime.Must(clientgoscheme.AddToScheme(scheme))
   322  	utilruntime.Must(v1.AddToScheme(scheme))
   323  	utilruntime.Must(dsv1.AddToScheme(scheme))
   324  	return scheme
   325  }
   326  

View as plain text