...
1 package agent
2
3 import (
4 "context"
5 "os"
6 "time"
7
8 "github.com/containerd/containerd/containers"
9
10 "edge-infra.dev/pkg/lib/uuid"
11 cc "edge-infra.dev/pkg/sds/devices/agent/common"
12 "edge-infra.dev/pkg/sds/devices/agent/metrics"
13 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
14 "edge-infra.dev/pkg/sds/devices/logger"
15 )
16
17 type deviceConfigJob interface {
18 Run(ctx context.Context)
19 }
20
21 type job struct {
22 allDeviceClasses map[string]*dsv1.DeviceClass
23 containers map[string]*containers.Container
24 postHookFns []func(context.Context)
25 }
26
27 func newJob(containers map[string]*containers.Container, postHookFns []func(context.Context), allDeviceClasses map[string]*dsv1.DeviceClass) deviceConfigJob {
28 return &job{
29 allDeviceClasses: allDeviceClasses,
30 containers: containers,
31 postHookFns: postHookFns,
32 }
33 }
34
35 func (j job) Run(ctx context.Context) {
36 j.applyDeviceConfigurationUpdate(ctx)
37 }
38
39
40 func (j job) applyDeviceConfigurationUpdate(ctx context.Context) {
41 requestID := uuid.New().UUID
42 log := logger.FromContext(ctx).With("requestId", requestID)
43 ctx = logger.IntoContext(ctx, log)
44 changedContainerNames := []string{}
45
46 configUpdate := func(ctx context.Context) {
47 log.Debug("device request started")
48
49 for _, ctr := range j.containers {
50 ctrName := ctr.Labels[cc.AnnContainerName]
51 changedContainerNames = append(changedContainerNames, ctrName)
52 ApplyCgroupsToContainer(ctx, requestID, ctr, j.allDeviceClasses)
53 }
54
55 for _, postHookFn := range j.postHookFns {
56 postHookFn(ctx)
57 }
58 log.Debug("device request finished", "containers", changedContainerNames)
59 }
60
61 duration := timeConfigUpdate(ctx, configUpdate)
62 if len(changedContainerNames) != 0 {
63 log.Info("applied device update to containers", "updated containers", changedContainerNames, "time", duration)
64 }
65 }
66
67
68 func timeConfigUpdate(ctx context.Context, fn func(ctx context.Context)) string {
69 startTime := time.Now()
70 fn(ctx)
71 endTime := time.Now()
72 diff := endTime.Sub(startTime)
73 metrics.RecordDuration(nodeName(), diff.Seconds())
74 return diff.String()
75 }
76
77 func nodeName() string {
78 return os.Getenv("HOSTNAME")
79 }
80
View as plain text