...

Source file src/edge-infra.dev/pkg/edge/webhooks/edge-injector/resourcerequest.go

Documentation: edge-infra.dev/pkg/edge/webhooks/edge-injector

     1  package edgeinjector
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"slices"
     7  	"strings"
     8  
     9  	"k8s.io/apimachinery/pkg/runtime"
    10  	"sigs.k8s.io/controller-runtime/pkg/client"
    11  
    12  	appsv1 "k8s.io/api/apps/v1"
    13  	corev1 "k8s.io/api/core/v1"
    14  	virtv1 "kubevirt.io/api/core/v1"
    15  
    16  	dsv1 "edge-infra.dev/pkg/sds/devices/k8s/apis/v1"
    17  
    18  	kresource "k8s.io/apimachinery/pkg/api/resource"
    19  
    20  	"edge-infra.dev/pkg/lib/fog"
    21  	"edge-infra.dev/pkg/sds/devices/class"
    22  	"edge-infra.dev/pkg/sds/ien/resource"
    23  )
    24  
    25  const (
    26  	annVirtLauncherValue       = "virt-launcher"
    27  	computeContainerName       = "compute"
    28  	virtUser             int64 = 107
    29  )
    30  
    31  type ResourceRequestWebhook struct {
    32  	Client client.Client
    33  }
    34  
    35  func (rrw *ResourceRequestWebhook) Default(ctx context.Context, obj runtime.Object) error {
    36  	log := fog.FromContext(ctx).WithName("ResourceRequestWebhook")
    37  	deviceStatuses, err := fetchDeviceStatuses(ctx, rrw.Client)
    38  	if err != nil {
    39  		log.Error(err, "error fetching device statuses")
    40  		return err
    41  	}
    42  
    43  	switch obj := obj.(type) {
    44  	case *appsv1.Deployment:
    45  		template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
    46  		if err != nil {
    47  			log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
    48  			return err
    49  		}
    50  		obj.Spec.Template = template
    51  	case *appsv1.StatefulSet:
    52  		template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
    53  		if err != nil {
    54  			log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
    55  			return err
    56  		}
    57  		obj.Spec.Template = template
    58  	case *appsv1.DaemonSet:
    59  		template, err := injectPodTemplateSpec(obj.Spec.Template, deviceStatuses)
    60  		if err != nil {
    61  			log.Error(err, "failed to inject pod template", "name", obj.Name, "namespace", obj.Namespace)
    62  			return err
    63  		}
    64  		obj.Spec.Template = template
    65  	case *corev1.Pod:
    66  		if err := injectKVMClassToVirtualMachinePod(obj, deviceStatuses); err != nil {
    67  			log.Error(err, "failed to inject pod spec", "name", obj.Name, "namespace", obj.Namespace)
    68  			return err
    69  		}
    70  	default:
    71  		err := fmt.Errorf("webhook expected *appsv1.Deployment/DaemonSet/Statefulset/corev1.Pod request, got: %T", obj)
    72  		log.Error(err, "unexpected type for webhook request")
    73  		return err
    74  	}
    75  	return nil
    76  }
    77  
    78  func fetchDeviceNodeGroups(className string, deviceStatuses []dsv1.DeviceStatuses) []int64 {
    79  	groups := []int64{}
    80  	for _, deviceStatus := range deviceStatuses {
    81  		groups = slices.AppendSeq(groups, slices.Values(deviceStatus.Spec.DeviceGroups[className]))
    82  	}
    83  	slices.Sort(groups)
    84  	return slices.Compact(groups)
    85  }
    86  
    87  func fetchDeviceStatuses(ctx context.Context, c client.Client) ([]dsv1.DeviceStatuses, error) {
    88  	deviceStatuses := dsv1.DeviceStatusesList{}
    89  	if err := c.List(ctx, &deviceStatuses); err != nil {
    90  		return nil, err
    91  	}
    92  	return deviceStatuses.Items, nil
    93  }
    94  
    95  func injectPodTemplateSpec(template corev1.PodTemplateSpec, deviceStatuses []dsv1.DeviceStatuses) (corev1.PodTemplateSpec, error) {
    96  	podClasses, err := fetchContainerRequestAnnotations(template.Spec)
    97  	if err != nil {
    98  		return template, err
    99  	}
   100  
   101  	newAnnotations := injectAnnotations(template.ObjectMeta.Annotations, podClasses)
   102  	template.ObjectMeta.Annotations = newAnnotations
   103  	template.Spec = injectPodSpec(template.Spec, podClasses, deviceStatuses)
   104  	template = injectAudioAndDisplayResources(template)
   105  	if template.Spec.SecurityContext == nil {
   106  		template.Spec.SecurityContext = &corev1.PodSecurityContext{}
   107  	}
   108  
   109  	for _, containerClass := range podClasses {
   110  		classes := strings.Split(containerClass, ",")
   111  		for _, className := range classes {
   112  			if className == "" {
   113  				continue
   114  			}
   115  			groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses)
   116  			template.Spec.SecurityContext.SupplementalGroups = slices.AppendSeq(template.Spec.SecurityContext.SupplementalGroups, slices.Values(groups))
   117  		}
   118  	}
   119  	slices.Sort(template.Spec.SecurityContext.SupplementalGroups)
   120  	template.Spec.SecurityContext.SupplementalGroups = slices.Compact(template.Spec.SecurityContext.SupplementalGroups)
   121  	return template, nil
   122  }
   123  
   124  func injectKVMClassToVirtualMachinePod(pod *corev1.Pod, deviceStatuses []dsv1.DeviceStatuses) error {
   125  	if pod.Labels[virtv1.GroupVersion.Group] != annVirtLauncherValue {
   126  		return nil
   127  	}
   128  
   129  	for idx, ctr := range pod.Spec.Containers {
   130  		if ctr.Name != computeContainerName {
   131  			continue
   132  		}
   133  
   134  		kvmResource := corev1.ResourceName(class.FmtClassLabel("kvm"))
   135  		pod.Spec.Containers[idx].Resources.Requests[kvmResource] = kresource.MustParse("1")
   136  		pod.Spec.Containers[idx].Resources.Limits[kvmResource] = kresource.MustParse("1")
   137  	}
   138  
   139  	podClasses, err := fetchContainerRequestAnnotations(pod.Spec)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	newAnnotations := injectAnnotations(pod.ObjectMeta.Annotations, podClasses)
   145  	pod.ObjectMeta.Annotations = newAnnotations
   146  	pod.Spec.SecurityContext.SupplementalGroups = append(pod.Spec.SecurityContext.SupplementalGroups, virtUser)
   147  	pod.Spec = injectPodSpec(pod.Spec, podClasses, deviceStatuses)
   148  	return nil
   149  }
   150  
   151  func injectPodSpec(podSpec corev1.PodSpec, podClasses map[string]string, deviceStatuses []dsv1.DeviceStatuses) corev1.PodSpec {
   152  	if podSpec.SecurityContext == nil {
   153  		podSpec.SecurityContext = &corev1.PodSecurityContext{}
   154  	}
   155  
   156  	for _, containerClass := range podClasses {
   157  		classes := strings.Split(containerClass, ",")
   158  		for _, className := range classes {
   159  			if className == "" {
   160  				continue
   161  			}
   162  			groups := fetchDeviceNodeGroups(class.FmtClassLabel(className), deviceStatuses)
   163  			podSpec.SecurityContext.SupplementalGroups = slices.AppendSeq(podSpec.SecurityContext.SupplementalGroups, slices.Values(groups))
   164  		}
   165  	}
   166  	slices.Sort(podSpec.SecurityContext.SupplementalGroups)
   167  	podSpec.SecurityContext.SupplementalGroups = slices.Compact(podSpec.SecurityContext.SupplementalGroups)
   168  	return podSpec
   169  }
   170  
   171  func injectAnnotations(annotations map[string]string, podClasses map[string]string) map[string]string {
   172  	if annotations == nil {
   173  		annotations = map[string]string{}
   174  	}
   175  	for k, v := range podClasses {
   176  		annotations[k] = v
   177  	}
   178  	return annotations
   179  }
   180  
   181  func injectAudioAndDisplayResources(template corev1.PodTemplateSpec) corev1.PodTemplateSpec {
   182  	if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.UIRequestResource) {
   183  		template = resource.InjectResourceIntoPod(template, resource.UIRequestResource)
   184  	}
   185  
   186  	if resource.ContainersHasResourceRequest(template.Spec.Containers, resource.AudioRequestResource) {
   187  		template = resource.InjectResourceIntoPod(template, resource.AudioRequestResource)
   188  	}
   189  	return template
   190  }
   191  
   192  func fetchContainerRequestAnnotations(podSpec corev1.PodSpec) (map[string]string, error) {
   193  	containerRequestAnnotations := map[string]string{}
   194  	for _, ctr := range podSpec.Containers {
   195  		classes := []string{}
   196  		for reqName := range ctr.Resources.Requests {
   197  			if !class.IsDeviceClass(reqName.String()) {
   198  				continue
   199  			}
   200  			baseName, err := class.BaseName(reqName.String())
   201  			if err != nil {
   202  				return nil, err
   203  			}
   204  			classes = append(classes, baseName)
   205  		}
   206  		if len(classes) == 0 {
   207  			continue
   208  		}
   209  		classes = append(classes, "default")
   210  		annotationPrefix := fmt.Sprintf("%s/%s", class.DeviceClassPrefix, ctr.Name)
   211  		containerRequestAnnotations[annotationPrefix] = strings.Join(classes, ",")
   212  	}
   213  	return containerRequestAnnotations, nil
   214  }
   215  

View as plain text