1 package edgeinjector
2
3 import (
4 "context"
5 "fmt"
6 "slices"
7 "strings"
8
9 "k8s.io/apimachinery/pkg/runtime"
10 "sigs.k8s.io/controller-runtime/pkg/client"
11
12 appsv1 "k8s.io/api/apps/v1"
13 corev1 "k8s.io/api/core/v1"
14 virtv1 "kubevirt.io/api/core/v1"
15
16 dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
17
18 kresource "k8s.io/apimachinery/pkg/api/resource"
19
20 "edge-infra.dev/pkg/lib/fog"
21 "edge-infra.dev/pkg/sds/devices/class"
22 "edge-infra.dev/pkg/sds/ien/resource"
23 )
24
25 const (
26 annVirtLauncherValue = "virt-launcher"
27 computeContainerName = "compute"
28 virtUser int64 = 107
29 )
30
31 type ResourceRequestWebhook struct {
32 Client client.Client
33 }
34
35 func (rrw *ResourceRequestWebhook) Default(ctx context.Context, obj runtime.Object) error {
36 log := fog.FromContext(ctx).WithName("ResourceRequestWebhook")
37 deviceStatuses, err := fetchDeviceStatuses(ctx, rrw.Client)
38 if err != nil {
39 log.Error(err, "error fetching device statuses")
40 return err
41 }
42
43 switch obj := obj.(type) {
44 case *appsv1.Deployment:
45 template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
46 if err != nil {
47 log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
48 return err
49 }
50 obj.Spec.Template = template
51 case *appsv1.StatefulSet:
52 template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
53 if err != nil {
54 log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
55 return err
56 }
57 obj.Spec.Template = template
58 case *appsv1.DaemonSet:
59 template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
60 if err != nil {
61 log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
62 return err
63 }
64 obj.Spec.Template = template
65 case *corev1.Pod:
66 if err := injectKVMClassToVirtualMachinePod(obj, deviceStatuses); err != nil {
67 log.Error(err, "failed to inject pod spec", "name", obj.Name, "namespace", obj.Namespace)
68 return err
69 }
70 default:
71 err := fmt.Errorf("webhook expected *appsv1.Deployment/DaemonSet/Statefulset/corev1.Pod request, got: %T", obj)
72 log.Error(err, "unexpected type for webhook request")
73 return err
74 }
75 return nil
76 }
77
78 func fetchDeviceNodeGroups(className string, deviceStatuses []dsv1.DeviceStatuses) []int64 {
79 groups := []int64{}
80 for _, deviceStatus := range deviceStatuses {
81 groups = slices.AppendSeq(groups, slices.Values(deviceStatus.Spec.DeviceGroups[className]))
82 }
83 slices.Sort(groups)
84 return slices.Compact(groups)
85 }
86
87 func fetchDeviceStatuses(ctx context.Context, c client.Client) ([]dsv1.DeviceStatuses, error) {
88 deviceStatuses := dsv1.DeviceStatusesList{}
89 if err := c.List(ctx, &deviceStatuses); err != nil {
90 return nil, err
91 }
92 return deviceStatuses.Items, nil
93 }
94
95 func injectPodTemplateSpec(template corev1.PodTemplateSpec, deviceStatuses []dsv1.DeviceStatuses) (corev1.PodTemplateSpec, error) {
96 podClasses, err := fetchContainerRequestAnnotations(template.Spec)
97 if err != nil {
98 return template, err
99 }
100
101 newAnnotations := injectAnnotations(template.ObjectMeta.Annotations, podClasses)
102 template.ObjectMeta.Annotations = newAnnotations
103 template.Spec = injectPodSpec(template.Spec, podClasses, deviceStatuses)
104 template = injectAudioAndDisplayResources(template)
105 if template.Spec.SecurityContext == nil {
106 template.Spec.SecurityContext = &corev1.PodSecurityContext{}
107 }
108
109 for _, containerClass := range podClasses {
110 classes := strings.Split(containerClass, ",")
111 for _, className := range classes {
112 if className == "" {
113 continue
114 }
115 groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses)
116 template.Spec.SecurityContext.SupplementalGroups = slices.AppendSeq(template.Spec.SecurityContext.SupplementalGroups, slices.Values(groups))
117 }
118 }
119 slices.Sort(template.Spec.SecurityContext.SupplementalGroups)
120 template.Spec.SecurityContext.SupplementalGroups = slices.Compact(template.Spec.SecurityContext.SupplementalGroups)
121 return template, nil
122 }
123
124 func injectKVMClassToVirtualMachinePod(pod *corev1.Pod, deviceStatuses []dsv1.DeviceStatuses) error {
125 if pod.Labels[virtv1.GroupVersion.Group] != annVirtLauncherValue {
126 return nil
127 }
128
129 for idx, ctr := range pod.Spec.Containers {
130 if ctr.Name != computeContainerName {
131 continue
132 }
133
134 kvmResource := corev1.ResourceName(class.FmtClassLabel("kvm"))
135 pod.Spec.Containers[idx].Resources.Requests[kvmResource] = kresource.MustParse("1")
136 pod.Spec.Containers[idx].Resources.Limits[kvmResource] = kresource.MustParse("1")
137 }
138
139 podClasses, err := fetchContainerRequestAnnotations(pod.Spec)
140 if err != nil {
141 return err
142 }
143
144 newAnnotations := injectAnnotations(pod.ObjectMeta.Annotations, podClasses)
145 pod.ObjectMeta.Annotations = newAnnotations
146 pod.Spec.SecurityContext.SupplementalGroups = append(pod.Spec.SecurityContext.SupplementalGroups, virtUser)
147 pod.Spec = injectPodSpec(pod.Spec, podClasses, deviceStatuses)
148 return nil
149 }
150
151 func injectPodSpec(podSpec corev1.PodSpec, podClasses map[string]string, deviceStatuses []dsv1.DeviceStatuses) corev1.PodSpec {
152 if podSpec.SecurityContext == nil {
153 podSpec.SecurityContext = &corev1.PodSecurityContext{}
154 }
155
156 for _, containerClass := range podClasses {
157 classes := strings.Split(containerClass, ",")
158 for _, className := range classes {
159 if className == "" {
160 continue
161 }
162 groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses)
163 podSpec.SecurityContext.SupplementalGroups = slices.AppendSeq(podSpec.SecurityContext.SupplementalGroups, slices.Values(groups))
164 }
165 }
166 slices.Sort(podSpec.SecurityContext.SupplementalGroups)
167 podSpec.SecurityContext.SupplementalGroups = slices.Compact(podSpec.SecurityContext.SupplementalGroups)
168 return podSpec
169 }
170
171 func injectAnnotations(annotations map[string]string, podClasses map[string]string) map[string]string {
172 if annotations == nil {
173 annotations = map[string]string{}
174 }
175 for k, v := range podClasses {
176 annotations[k] = v
177 }
178 return annotations
179 }
180
181 func injectAudioAndDisplayResources(template corev1.PodTemplateSpec) corev1.PodTemplateSpec {
182 if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.UIRequestResource) {
183 template = resource.InjectResourceIntoPod(template, resource.UIRequestResource)
184 }
185
186 if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.AudioRequestResource) {
187 template = resource.InjectResourceIntoPod(template, resource.AudioRequestResource)
188 }
189 return template
190 }
191
192 func fetchContainerRequestAnnotations(podSpec corev1.PodSpec) (map[string]string, error) {
193 containerRequestAnnotations := map[string]string{}
194 for _, ctr := range podSpec.Containers {
195 classes := []string{}
196 for reqName := range ctr.Resources.Requests {
197 if !class.IsDeviceClass(reqName.String()) {
198 continue
199 }
200 baseName, err := class.BaseName(reqName.String())
201 if err != nil {
202 return nil, err
203 }
204 classes = append(classes, baseName)
205 }
206 if len(classes) == 0 {
207 continue
208 }
209 classes = append(classes, "default")
210 annotationPrefix := fmt.Sprintf("%s/%s", class.DeviceClassPrefix, ctr.Name)
211 containerRequestAnnotations[annotationPrefix] = strings.Join(classes, ",")
212 }
213 return containerRequestAnnotations, nil
214 }
215
View as plain text