1 package plugins
2
3 import (
4 "context"
5 "crypto/sha256"
6 "fmt"
7 "log/slog"
8 "net"
9 "os"
10 "path/filepath"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 fsnotify "github.com/fsnotify/fsnotify"
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/connectivity"
19 "google.golang.org/grpc/credentials/insecure"
20 "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
21 pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
22
23 "edge-infra.dev/pkg/lib/filesystem"
24 "edge-infra.dev/pkg/lib/kernel/devices"
25 cc "edge-infra.dev/pkg/sds/devices/agent/common"
26 "edge-infra.dev/pkg/sds/devices/class"
27 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
28 "edge-infra.dev/pkg/sds/devices/logger"
29 )
30
31 var (
32
33 kubeletDevicePluginPath = "/var/lib/kubelet/device-plugins/"
34
35 serverWaitTimeout = time.Second * 1
36 )
37
38 const (
39
40 rootUserID = 0
41
42 removable = "removable"
43 )
44
45 type Plugin interface {
46 v1beta1.DevicePluginServer
47 Run(ctx context.Context)
48 Stop()
49 Remove()
50 Update(deviceClass *dsv1.DeviceClass)
51 IsRunning() bool
52 }
53
54 const maxDeviceRequestsPerClass = 100
55 const waitInterval = time.Second * 5
56
57 type plugin struct {
58 name string
59 ids []string
60 socketFilePath string
61 grpcServer *grpc.Server
62 serving bool
63 log *slog.Logger
64 mutex sync.Mutex
65 devicesClassCh chan *dsv1.DeviceClass
66 deviceClass *dsv1.DeviceClass
67 }
68
69
70 func NewPlugin(deviceClass *dsv1.DeviceClass) Plugin {
71 name := class.FmtClassLabel(deviceClass.ObjectMeta.GetName())
72 socketPath := generateFilePathName(name)
73 opts := []logger.Option{
74 logger.WithLevel(deviceClass.Spec.LogLevel()),
75 }
76 return &plugin{
77 ids: []string{},
78 devicesClassCh: make(chan *dsv1.DeviceClass, 1),
79 mutex: sync.Mutex{},
80 name: name,
81 socketFilePath: socketPath,
82 serving: false,
83 log: logger.New(opts...).With("class", name, "path", socketPath),
84 deviceClass: deviceClass,
85 }
86 }
87
88
89 func (p *plugin) Name() string {
90 return p.name
91 }
92
93
94 func (p *plugin) File() string {
95 return p.socketFilePath
96 }
97
98 func (p *plugin) setRunning() {
99 p.mutex.Lock()
100 p.serving = true
101 p.mutex.Unlock()
102 }
103
104 func (p *plugin) Stop() {
105 p.mutex.Lock()
106 p.serving = false
107 p.mutex.Unlock()
108 if p.grpcServer != nil {
109 p.grpcServer.Stop()
110 return
111 }
112 p.Remove()
113 }
114
115
116 func (p *plugin) Remove() {
117 _ = os.Remove(p.socketFilePath)
118 }
119
120 func (p *plugin) IsRunning() bool {
121 p.mutex.Lock()
122 isRunning := p.serving
123 p.mutex.Unlock()
124 return isRunning
125 }
126
127
128 func (p *plugin) Run(ctx context.Context) {
129 p.runRegistration(ctx)
130 }
131
132
133 func (p *plugin) Update(deviceClass *dsv1.DeviceClass) {
134 if deviceClass == nil {
135 return
136 }
137 p.devicesClassCh <- deviceClass
138 }
139
140
141 func (p *plugin) runRegistration(ctx context.Context) {
142 p.setRunning()
143 for p.IsRunning() {
144
145 time.Sleep(waitInterval)
146
147 if err := waitForGRPCServer(ctx, p.socketFilePath); err == nil {
148 return
149 }
150 _ = os.Remove(p.socketFilePath)
151
152 socket, err := net.Listen("unix", p.socketFilePath)
153 if err != nil {
154 p.log.Error("error listening to unix socket", "error", err)
155 continue
156 }
157
158
159 p.grpcServer = grpc.NewServer(grpc.ConnectionTimeout(serverWaitTimeout))
160 v1beta1.RegisterDevicePluginServer(p.grpcServer, p)
161
162 go func() {
163 if err := p.grpcServer.Serve(socket); err != nil {
164 p.log.Error("error serving grpc server", "error", err.Error())
165 p.grpcServer.Stop()
166 p.Remove()
167 }
168 }()
169
170 if err := waitForGRPCServer(ctx, p.socketFilePath); err != nil {
171 p.log.Error("error listening to gRPC server socket", "error", err)
172 continue
173 }
174
175 if err := p.registerWithKubelet(ctx); err != nil {
176 p.log.Error("error registering with kubelet", "error", err)
177 continue
178 }
179
180 if err := p.watchSocketFile(); err != nil {
181 p.log.Error("error watching socket file", "error", err)
182 continue
183 }
184 p.log.Info("stopping grpc server")
185 p.Stop()
186
187
188 p.setRunning()
189 }
190 }
191
192
193
194
195
196
197 func (p *plugin) Allocate(_ context.Context, req *v1beta1.AllocateRequest) (*pluginapi.AllocateResponse, error) {
198 res := &v1beta1.AllocateResponse{
199 ContainerResponses: make([]*v1beta1.ContainerAllocateResponse, 0, len(req.ContainerRequests)),
200 }
201
202 if willBlock, blockingSetName := p.deviceClass.WillBlock(); willBlock {
203 return nil, fmt.Errorf("blocking container start, waiting for devices to become available for class %s: %s", p.name, blockingSetName)
204 }
205
206 for _, req := range req.ContainerRequests {
207 resp := &v1beta1.ContainerAllocateResponse{
208 Annotations: map[string]string{
209 p.name: "1",
210 },
211 Mounts: []*pluginapi.Mount{
212 {
213 ContainerPath: "/dev",
214 HostPath: "/dev",
215 },
216 {
217 ContainerPath: "/run/udev",
218 HostPath: "/run/udev",
219 },
220 },
221 Envs: map[string]string{},
222 }
223
224 for _, id := range req.DevicesIDs {
225 if !slices.Contains(p.ids, id) {
226 continue
227 }
228 for _, dev := range p.deviceClass.DeviceIter() {
229 node, err := dev.Node()
230 if err != nil || node == nil {
231 p.log.Debug("device has no node", "path", dev.Path())
232 continue
233 }
234
235 if _, err := os.Stat(node.Path()); err != nil {
236 p.log.Debug("could not stat device node", "path", node.Path())
237 continue
238 }
239
240 if err := p.overrideDeviceNodeGroupOwner(node); err != nil {
241 p.log.Error("error setting device node permissions", "error", err)
242 }
243
244
245 if !p.rootDeviceIsRemovable(dev.Path()) {
246 resp.Devices = append(resp.Devices, &pluginapi.DeviceSpec{
247 ContainerPath: node.Path(),
248 HostPath: node.Path(),
249 Permissions: "rwm",
250 })
251 }
252 }
253 }
254 p.log.Info("allocate devices to container", "devices", len(resp.Devices))
255 res.ContainerResponses = append(res.ContainerResponses, resp)
256 }
257 return res, nil
258 }
259
260
261
262 func (p *plugin) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DevicePlugin_ListAndWatchServer) error {
263 if err := p.sendDeviceUpdate(stream); err != nil {
264 return err
265 }
266
267
268 for devClass := range p.devicesClassCh {
269 p.deviceClass = devClass
270 if err := p.sendDeviceUpdate(stream); err != nil {
271 return err
272 }
273 }
274 return nil
275 }
276
277
278
279 func (p *plugin) sendDeviceUpdate(stream v1beta1.DevicePlugin_ListAndWatchServer) error {
280 res := &v1beta1.ListAndWatchResponse{
281 Devices: []*pluginapi.Device{},
282 }
283
284 for i := 1; i <= maxDeviceRequestsPerClass; i++ {
285 h := sha256.New()
286 id := fmt.Sprintf("%s-%d", p.name, i)
287 h.Write([]byte(id))
288 id = fmt.Sprintf("%x", h.Sum(nil))
289 res.Devices = append(res.Devices, &pluginapi.Device{
290 ID: id,
291 Health: pluginapi.Healthy,
292 })
293 p.ids = append(p.ids, id)
294 }
295 if err := stream.Send(res); err != nil {
296 p.Stop()
297 return err
298 }
299 return nil
300 }
301
302 func (p *plugin) PreStartContainer(context.Context, *v1beta1.PreStartContainerRequest) (*v1beta1.PreStartContainerResponse, error) {
303 return &pluginapi.PreStartContainerResponse{}, nil
304 }
305
306 func (p *plugin) GetPreferredAllocation(context.Context, *v1beta1.PreferredAllocationRequest) (*v1beta1.PreferredAllocationResponse, error) {
307 return &pluginapi.PreferredAllocationResponse{}, nil
308 }
309
310 func (p *plugin) GetDevicePluginOptions(_ context.Context, _ *v1beta1.Empty) (*v1beta1.DevicePluginOptions, error) {
311 return &pluginapi.DevicePluginOptions{
312 PreStartRequired: false,
313 GetPreferredAllocationAvailable: true,
314 }, nil
315 }
316
317
318 func (p *plugin) registerWithKubelet(ctx context.Context) error {
319 c, err := grpc.NewClient(filepath.Join("unix://", v1beta1.KubeletSocket), grpc.WithTransportCredentials(insecure.NewCredentials()))
320 if err != nil {
321 return err
322 }
323 defer c.Close()
324
325 client := v1beta1.NewRegistrationClient(c)
326 request := &v1beta1.RegisterRequest{
327 Version: v1beta1.Version,
328 Endpoint: filepath.Base(p.socketFilePath),
329 ResourceName: p.name,
330 }
331
332 if _, err := client.Register(ctx, request); err != nil {
333 return fmt.Errorf("failed to register plugin with kubelet service: %v", err)
334 }
335 p.log.Info("registered resource with kubelet", "resource", p.name)
336 return nil
337 }
338
339
340
341 func waitForGRPCServer(ctx context.Context, socket string) error {
342 conn, err := grpc.NewClient(filepath.Join("unix://", socket), grpc.WithTransportCredentials(insecure.NewCredentials()))
343 if err != nil {
344 return err
345 }
346 defer conn.Close()
347
348 ctx, cancel := context.WithTimeout(ctx, serverWaitTimeout)
349 defer cancel()
350 for {
351 state := conn.GetState()
352 if state == connectivity.Idle {
353 conn.Connect()
354 }
355 if state == connectivity.Ready {
356 return nil
357 }
358 if !conn.WaitForStateChange(ctx, state) {
359 return ctx.Err()
360 }
361 }
362 }
363
364
365
366
367 func (p *plugin) watchSocketFile() error {
368 watcher, err := fsnotify.NewWatcher()
369 if err != nil {
370 return err
371 }
372 defer watcher.Close()
373
374 if err = watcher.Add(filepath.Dir(p.socketFilePath)); err != nil {
375 return err
376 }
377
378 for {
379 select {
380 case event := <-watcher.Events:
381 if (event.Op == fsnotify.Remove || event.Op == fsnotify.Rename) && event.Name == p.socketFilePath {
382 return nil
383 }
384 case err := <-watcher.Errors:
385 return err
386 }
387 }
388 }
389
390
391
392 func (p *plugin) overrideDeviceNodeGroupOwner(node devices.Node) error {
393 groupID, err := node.GroupID()
394 if err != nil {
395 return fmt.Errorf("error could not determin node group owner id: %s : %w", node.Path(), err)
396 }
397
398
399 if groupID != rootUserID {
400 return nil
401 }
402
403 if groupID != cc.DeviceGroupID {
404 p.log.Debug("changing device node group owner from root to deviceg (1015)", "node", node.Path())
405 if err := os.Chown(node.Path(), -1, cc.DeviceGroupID); err != nil {
406 return fmt.Errorf("error changing device node group owner to deviceg (1015): %s: %w", node.Path(), err)
407 }
408 }
409
410 fileMode, err := node.FileMode()
411 if err != nil {
412 return err
413 }
414
415 newMode := fileMode | filesystem.GroupReadWritePerm
416 if fileMode == newMode {
417 return nil
418 }
419
420 p.log.Debug("changing device node permissions to read/write by group owner", "node", node.Path())
421 if err := os.Chmod(node.Path(), newMode); err != nil {
422 return fmt.Errorf("error changing device node permissions to read/write for group owner: %s: %w", node.Path(), err)
423 }
424 return nil
425 }
426
427
428
429 func (p *plugin) rootDeviceIsRemovable(path string) bool {
430 pathSplit := strings.Split(path, "/")
431 for idx := 0; idx <= len(pathSplit)-1; idx++ {
432 searchPath := strings.Join(pathSplit[:idx], "/")
433 device := p.deviceClass.DeviceGet(searchPath)
434 if device == nil {
435 continue
436 }
437
438 canRemove, exists, _ := device.Attribute(removable)
439 if !exists {
440 continue
441 }
442
443 if canRemove == removable {
444 return true
445 }
446 }
447 return false
448 }
449
450
451
452 func generateFilePathName(resourceName string) string {
453 classFmt := resourceFileName(resourceName)
454 return filepath.Join(kubeletDevicePluginPath, fmt.Sprintf("ds-%s.sock", classFmt))
455 }
456
457
458 func resourceFileName(resourceName string) string {
459 classFmt := strings.ReplaceAll(resourceName, class.DeviceClassPrefix, "")
460 classFmt = strings.ReplaceAll(classFmt, "/", "")
461 classFmt = strings.ReplaceAll(classFmt, ".", "-")
462 return classFmt
463 }
464
View as plain text