package edgeinjector import ( "context" "fmt" "slices" "strings" "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/client" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" virtv1 "kubevirt.io/api/core/v1" dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1" kresource "k8s.io/apimachinery/pkg/api/resource" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/devices/class" "edge-infra.dev/pkg/sds/ien/resource" ) const ( annVirtLauncherValue = "virt-launcher" computeContainerName = "compute" virtUser int64 = 107 ) type ResourceRequestWebhook struct { Client client.Client } func (rrw *ResourceRequestWebhook) Default(ctx context.Context, obj runtime.Object) error { log := fog.FromContext(ctx).WithName("ResourceRequestWebhook") deviceStatuses, err := fetchDeviceStatuses(ctx, rrw.Client) if err != nil { log.Error(err, "error fetching device statuses") return err } switch obj := obj.(type) { case *appsv1.Deployment: template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses) if err != nil { log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace) return err } obj.Spec.Template = template case *appsv1.StatefulSet: template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses) if err != nil { log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace) return err } obj.Spec.Template = template case *appsv1.DaemonSet: template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses) if err != nil { log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace) return err } obj.Spec.Template = template case *corev1.Pod: if err := injectKVMClassToVirtualMachinePod(obj, deviceStatuses); err != nil { log.Error(err, "failed to inject pod spec", "name", obj.Name, "namespace", obj.Namespace) return err } default: err := fmt.Errorf("webhook expected *appsv1.Deployment/DaemonSet/Statefulset/corev1.Pod request, got: %T", obj) log.Error(err, "unexpected type for webhook request") return err } return nil } func fetchDeviceNodeGroups(className string, deviceStatuses []dsv1.DeviceStatuses) []int64 { groups := []int64{} for _, deviceStatus := range deviceStatuses { groups = slices.AppendSeq(groups, slices.Values(deviceStatus.Spec.DeviceGroups[className])) } slices.Sort(groups) return slices.Compact(groups) } func fetchDeviceStatuses(ctx context.Context, c client.Client) ([]dsv1.DeviceStatuses, error) { deviceStatuses := dsv1.DeviceStatusesList{} if err := c.List(ctx, &deviceStatuses); err != nil { return nil, err } return deviceStatuses.Items, nil } func injectPodTemplateSpec(template corev1.PodTemplateSpec, deviceStatuses []dsv1.DeviceStatuses) (corev1.PodTemplateSpec, error) { podClasses, err := fetchContainerRequestAnnotations(template.Spec) if err != nil { return template, err } newAnnotations := injectAnnotations(template.ObjectMeta.Annotations, podClasses) template.ObjectMeta.Annotations = newAnnotations template.Spec = injectPodSpec(template.Spec, podClasses, deviceStatuses) template = injectAudioAndDisplayResources(template) if template.Spec.SecurityContext == nil { template.Spec.SecurityContext = &corev1.PodSecurityContext{} } for _, containerClass := range podClasses { classes := strings.Split(containerClass, ",") for _, className := range classes { if className == "" { continue } groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses) template.Spec.SecurityContext.SupplementalGroups = slices.AppendSeq(template.Spec.SecurityContext.SupplementalGroups, slices.Values(groups)) } } slices.Sort(template.Spec.SecurityContext.SupplementalGroups) template.Spec.SecurityContext.SupplementalGroups = slices.Compact(template.Spec.SecurityContext.SupplementalGroups) return template, nil } func injectKVMClassToVirtualMachinePod(pod *corev1.Pod, deviceStatuses []dsv1.DeviceStatuses) error { if pod.Labels[virtv1.GroupVersion.Group] != annVirtLauncherValue { return nil } for idx, ctr := range pod.Spec.Containers { if ctr.Name != computeContainerName { continue } kvmResource := corev1.ResourceName(class.FmtClassLabel("kvm")) pod.Spec.Containers[idx].Resources.Requests[kvmResource] = kresource.MustParse("1") pod.Spec.Containers[idx].Resources.Limits[kvmResource] = kresource.MustParse("1") } podClasses, err := fetchContainerRequestAnnotations(pod.Spec) if err != nil { return err } newAnnotations := injectAnnotations(pod.ObjectMeta.Annotations, podClasses) pod.ObjectMeta.Annotations = newAnnotations pod.Spec.SecurityContext.SupplementalGroups = append(pod.Spec.SecurityContext.SupplementalGroups, virtUser) pod.Spec = injectPodSpec(pod.Spec, podClasses, deviceStatuses) return nil } func injectPodSpec(podSpec corev1.PodSpec, podClasses map[string]string, deviceStatuses []dsv1.DeviceStatuses) corev1.PodSpec { if podSpec.SecurityContext == nil { podSpec.SecurityContext = &corev1.PodSecurityContext{} } for _, containerClass := range podClasses { classes := strings.Split(containerClass, ",") for _, className := range classes { if className == "" { continue } groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses) podSpec.SecurityContext.SupplementalGroups = slices.AppendSeq(podSpec.SecurityContext.SupplementalGroups, slices.Values(groups)) } } slices.Sort(podSpec.SecurityContext.SupplementalGroups) podSpec.SecurityContext.SupplementalGroups = slices.Compact(podSpec.SecurityContext.SupplementalGroups) return podSpec } func injectAnnotations(annotations map[string]string, podClasses map[string]string) map[string]string { if annotations == nil { annotations = map[string]string{} } for k, v := range podClasses { annotations[k] = v } return annotations } func injectAudioAndDisplayResources(template corev1.PodTemplateSpec) corev1.PodTemplateSpec { if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.UIRequestResource) { template = resource.InjectResourceIntoPod(template, resource.UIRequestResource) } if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.AudioRequestResource) { template = resource.InjectResourceIntoPod(template, resource.AudioRequestResource) } return template } func fetchContainerRequestAnnotations(podSpec corev1.PodSpec) (map[string]string, error) { containerRequestAnnotations := map[string]string{} for _, ctr := range podSpec.Containers { classes := []string{} for reqName := range ctr.Resources.Requests { if !class.IsDeviceClass(reqName.String()) { continue } baseName, err := class.BaseName(reqName.String()) if err != nil { return nil, err } classes = append(classes, baseName) } if len(classes) == 0 { continue } classes = append(classes, "default") annotationPrefix := fmt.Sprintf("%s/%s", class.DeviceClassPrefix, ctr.Name) containerRequestAnnotations[annotationPrefix] = strings.Join(classes, ",") } return containerRequestAnnotations, nil }