package services import ( "context" "database/sql" "fmt" "net" "strings" "github.com/hashicorp/go-multierror" netutils "k8s.io/utils/net" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/constants" "edge-infra.dev/pkg/lib/networkvalidator" ) const ( defaultPriority = 100 ) var ( networkServiceValidators = map[string]networkServiceValidator{ constants.ServiceTypeNTP: validateNetworkServiceNTP, constants.ServiceTypeDNS: validateNetworkServiceIPAddress, constants.ServiceTypeVIP: validateNetworkServiceIPAddress, constants.ServiceTypeClusterDNS: validateNetworkServiceClusterDNS, constants.ServiceTypePodNetworkCIDR: validateK8sNetworkCIDR, constants.ServiceTypeServiceNetworkCIDR: validateK8sNetworkCIDR, constants.ServiceTypeEgressTunnelsCIDR: validateK8sNetworkCIDR, } networkRangeValidators = map[string]networkRangeValidator{ constants.ServiceTypePodNetworkCIDR: validatePodNetworkRange, constants.ServiceTypeServiceNetworkCIDR: validateServiceNetworkRange, constants.ServiceTypeEgressTunnelsCIDR: validateEgressTunnelNetworkRange, } ) type networkServiceValidator func(*model.ClusterNetworkServiceInfo, []*model.ClusterNetworkServiceInfo) error type networkRangeValidator func(net.IPMask) error func (s *storeClusterService) GetClusterNetworkServices(ctx context.Context, clusterEdgeID string) ([]*model.ClusterNetworkServiceInfo, error) { rows, err := s.SQLDB.QueryContext(ctx, sqlquery.GetClusterNetworkServices, clusterEdgeID) if err != nil { return nil, sqlerr.Wrap(err) } defer rows.Close() networkServices := []*model.ClusterNetworkServiceInfo{} for rows.Next() { netService := &model.ClusterNetworkServiceInfo{} if err := rows.Scan(&netService.NetworkServiceID, &netService.IP, &netService.Family, &netService.ServiceType, &netService.Priority); err != nil { return nil, sqlerr.Wrap(err) } networkServices = append(networkServices, netService) } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return networkServices, nil } func (s *storeClusterService) GetClusterNetworkServiceByNetworkID(ctx context.Context, clusterEdgeID, networkServiceID string) (*model.ClusterNetworkServiceInfo, error) { networkService := &model.ClusterNetworkServiceInfo{} row := s.SQLDB.QueryRowContext(ctx, sqlquery.GetClusterNetworkServiceByNetworkID, clusterEdgeID, networkServiceID) if err := row.Scan(&networkService.NetworkServiceID, &networkService.IP, &networkService.Family, &networkService.ServiceType, &networkService.Priority); err != nil { return networkService, sqlerr.Wrap(err) } return networkService, nil } func (s *storeClusterService) CreateClusterNetworkService(ctx context.Context, clusterEdgeID string, networkService *model.CreateNetworkServiceInfo) (*model.ClusterNetworkServiceInfo, error) { netService := createNetServiceToNetService(networkService) existingNetworkServices, err := s.GetClusterNetworkServices(ctx, clusterEdgeID) if err != nil { return nil, err } if err := networkServiceValidators[networkService.ServiceType](netService, existingNetworkServices); err != nil { return nil, err } if hasDuplicateIP(netService, existingNetworkServices) { return nil, fmt.Errorf("IP address %s already exists for %s", netService.IP, netService.ServiceType) } priority := validateCreateNetworkServicePriorityField(networkService.Priority) result := s.SQLDB.QueryRowContext(ctx, sqlquery.CreateClusterNetworkServices, clusterEdgeID, networkService.IP, networkService.Family, networkService.ServiceType, priority) if err := castNetworkServiceResult(netService, result); err != nil { return nil, err } return netService, nil } func (s *storeClusterService) CreateClusterNetworkServices(ctx context.Context, clusterEdgeID string, networkServicesInfo []*model.CreateNetworkServiceInfo) ([]*model.ClusterNetworkServiceInfo, error) { var allErrs error networkServices := []*model.ClusterNetworkServiceInfo{} for _, networkService := range networkServicesInfo { netService, err := s.CreateClusterNetworkService(ctx, clusterEdgeID, networkService) if err != nil { allErrs = multierror.Append(err, allErrs) } networkServices = append(networkServices, netService) } return networkServices, allErrs } func (s *storeClusterService) UpdateClusterNetworkService(ctx context.Context, clusterEdgeID string, networkService *model.UpdateNetworkServiceInfo, serviceType string) (*model.ClusterNetworkServiceInfo, error) { // netService := model.ClusterNetworkServiceInfo(*networkService) netService := updateNetServiceToNetService(networkService, serviceType) clusterNetworkServices, err := s.GetClusterNetworkServices(ctx, clusterEdgeID) if err != nil { return nil, err } if err := networkServiceValidators[serviceType](netService, clusterNetworkServices); err != nil { return nil, err } if hasDuplicateIP(netService, clusterNetworkServices) { return nil, fmt.Errorf("IP address %s already exists for %s", netService.IP, netService.ServiceType) } if networkService.Priority != nil { priority := validateUpdateNetworkServicePriorityField(networkService.Priority) networkService.Priority = &priority } result := s.SQLDB.QueryRowContext(ctx, sqlquery.UpdateClusterNetworkServices, netService.IP, netService.Family, netService.Priority, netService.NetworkServiceID, clusterEdgeID) if err := castNetworkServiceResult(netService, result); err != nil { return nil, err } // if we have updated the service network cidr, // we should update the cluster dns ip to ensure its within the same subnet if netService.ServiceType == constants.ServiceTypeServiceNetworkCIDR { if err := s.updateClusterDNSIP(ctx, clusterEdgeID, netService.IP); err != nil { return nil, err } } return netService, nil } func (s *storeClusterService) UpdateClusterNetworkServices(ctx context.Context, clusterEdgeID string, networkServicesInfo []*model.UpdateNetworkServiceInfo, existingServiceTypesByID map[string]string) ([]*model.ClusterNetworkServiceInfo, error) { var allErrs error networkServices := []*model.ClusterNetworkServiceInfo{} for _, networkService := range networkServicesInfo { if networkService == nil { continue } netService, err := s.UpdateClusterNetworkService(ctx, clusterEdgeID, networkService, existingServiceTypesByID[networkService.NetworkServiceID]) if err != nil { allErrs = multierror.Append(err, allErrs) } networkServices = append(networkServices, netService) } return networkServices, allErrs } func (s *storeClusterService) DeleteClusterNetworkService(ctx context.Context, clusterEdgeID, networkServiceID string) (bool, error) { networkService, err := s.GetClusterNetworkServiceByNetworkID(ctx, clusterEdgeID, networkServiceID) if err != nil { return false, err } switch networkService.ServiceType { case constants.ServiceTypePodNetworkCIDR, constants.ServiceTypeServiceNetworkCIDR, constants.ServiceTypeClusterDNS: return true, fmt.Errorf("adhoc deletion of the %s is forbidden", networkService.ServiceType) } _, err = s.SQLDB.ExecContext(ctx, sqlquery.DeleteClusterNetworkService, clusterEdgeID, networkServiceID) if err != nil { return false, err } return true, nil } func (s *storeClusterService) GetClusterK8sNetworkServices(ctx context.Context, clusterEdgeID string) (map[string]string, error) { clusterNetworkServices, err := s.GetClusterNetworkServices(ctx, clusterEdgeID) if err != nil { return nil, err } services := map[string]string{ constants.ServiceTypeClusterDNS: "", constants.ServiceTypePodNetworkCIDR: "", constants.ServiceTypeServiceNetworkCIDR: "", constants.ServiceTypeEgressTunnelsCIDR: "", } for _, clusterNetworkService := range clusterNetworkServices { _, exists := services[clusterNetworkService.ServiceType] if exists { services[clusterNetworkService.ServiceType] = clusterNetworkService.IP } } return services, nil } func (s *storeClusterService) GetK8sClusterNetworkService(ctx context.Context, clusterEdgeID, networkService string) (*model.ClusterNetworkServiceInfo, error) { clusterNetworkServices, err := s.GetClusterNetworkServices(ctx, clusterEdgeID) if err != nil { return nil, err } return getNetworkServiceFromList(clusterNetworkServices, networkService) } func getNetworkServiceFromList(clusterNetworkServices []*model.ClusterNetworkServiceInfo, networkService string) (*model.ClusterNetworkServiceInfo, error) { for _, service := range clusterNetworkServices { if service.ServiceType == networkService { return service, nil } } return nil, fmt.Errorf("could not find network service %s", networkService) } func (s *storeClusterService) updateClusterDNSIP(ctx context.Context, clusterEdgeID, serviceCIDR string) error { clusterDNS, err := s.GetK8sClusterNetworkService(ctx, clusterEdgeID, constants.ServiceTypeClusterDNS) if err != nil { return err } // Set DNS IP to 10th IP in service cidr subnet _, serviceSubnet, _ := net.ParseCIDR(serviceCIDR) dnsIP, err := netutils.GetIndexedIP(serviceSubnet, 10) if err != nil { return err } priority := defaultPriority networkService := &model.ClusterNetworkServiceInfo{ NetworkServiceID: clusterDNS.NetworkServiceID, ServiceType: constants.ServiceTypeClusterDNS, IP: dnsIP.String(), Family: "inet", Priority: &priority, } result := s.SQLDB.QueryRowContext(ctx, sqlquery.UpdateClusterNetworkServices, networkService.IP, networkService.Family, networkService.Priority, networkService.NetworkServiceID, clusterEdgeID) return castNetworkServiceResult(networkService, result) } func validateCreateNetworkServicePriorityField(priority *int) int { if priority == nil { return defaultPriority } if *priority < 1 { return defaultPriority } return *priority } func validateUpdateNetworkServicePriorityField(priority *int) int { if *priority < 1 { return defaultPriority } return *priority } func createNetServiceToNetService(createNetService *model.CreateNetworkServiceInfo) *model.ClusterNetworkServiceInfo { return &model.ClusterNetworkServiceInfo{ ServiceType: createNetService.ServiceType, IP: createNetService.IP, Family: createNetService.Family, Priority: createNetService.Priority, } } func updateNetServiceToNetService(updateNetService *model.UpdateNetworkServiceInfo, serviceType string) *model.ClusterNetworkServiceInfo { return &model.ClusterNetworkServiceInfo{ ServiceType: serviceType, NetworkServiceID: updateNetService.NetworkServiceID, IP: updateNetService.IP, Family: updateNetService.Family, Priority: updateNetService.Priority, } } func castNetworkServiceResult(networkService *model.ClusterNetworkServiceInfo, result *sql.Row) error { if err := result.Err(); err != nil { return sqlerr.Wrap(err) } if err := result.Scan(&networkService.NetworkServiceID, &networkService.IP, &networkService.Family, &networkService.ServiceType, &networkService.Priority); err != nil { return sqlerr.Wrap(err) } return nil } func validateNetworkServiceIPAddress(networkService *model.ClusterNetworkServiceInfo, _ []*model.ClusterNetworkServiceInfo) error { if net.ParseIP(networkService.IP) == nil { return fmt.Errorf("invalid IP address %s for %s", networkService.IP, networkService.ServiceType) } return nil } func validateNetworkServiceNTP(networkService *model.ClusterNetworkServiceInfo, _ []*model.ClusterNetworkServiceInfo) error { if !(networkvalidator.IsValidDomain(networkService.IP) || networkvalidator.ValidateIP(networkService.IP)) { return fmt.Errorf("invalid IP/domain address %s for %s", networkService.IP, networkService.ServiceType) } return nil } func validateNetworkServiceClusterDNS(_ *model.ClusterNetworkServiceInfo, _ []*model.ClusterNetworkServiceInfo) error { return nil } func validateK8sNetworkCIDR(networkService *model.ClusterNetworkServiceInfo, clusterNetworkServices []*model.ClusterNetworkServiceInfo) error { _, network, err := net.ParseCIDR(networkService.IP) if err != nil || network == nil { return fmt.Errorf("invalid CIDR address %s for %s", networkService.IP, networkService.ServiceType) } if err := networkRangeValidators[networkService.ServiceType](network.Mask); err != nil { return err } return validateDisjointSubnets(networkService, clusterNetworkServices) } func hasDuplicateIP(networkService *model.ClusterNetworkServiceInfo, clusterNetworkServices []*model.ClusterNetworkServiceInfo) bool { for _, clusterNetworkService := range clusterNetworkServices { if networkService.ServiceType == clusterNetworkService.ServiceType && networkService.IP == clusterNetworkService.IP { if networkService.NetworkServiceID == clusterNetworkService.NetworkServiceID { // The same network service is being updated, so duplicate IPs are expected here continue } // Network service is being created/updated with a non-unique {ServiceType, IP} pairing return true } } return false } func validatePodNetworkRange(mask net.IPMask) error { prefixLen, _ := mask.Size() if prefixLen < 16 || prefixLen > 21 { return fmt.Errorf("invalid prefix length. Prefix length must be between /16 and /21 for k8s pod network") } return nil } func validateServiceNetworkRange(mask net.IPMask) error { prefixLen, _ := mask.Size() if prefixLen < 16 || prefixLen > 22 { return fmt.Errorf("invalid prefix length. Prefix length must be between /16 and /22 for k8s service network") } return nil } func validateEgressTunnelNetworkRange(mask net.IPMask) error { prefixLen, _ := mask.Size() if prefixLen < 22 || prefixLen > 31 { return fmt.Errorf("invalid prefix length. Prefix length must be between /16 and /22 for egress gateway tunnels") } return nil } func validateDisjointSubnets(networkService *model.ClusterNetworkServiceInfo, clusterNetworkServices []*model.ClusterNetworkServiceInfo) error { subnets := []string{constants.ServiceTypePodNetworkCIDR, constants.ServiceTypeServiceNetworkCIDR} _, subnetToValidate, _ := net.ParseCIDR(networkService.IP) for _, subnet := range subnets { if subnet != networkService.ServiceType { subnetService, err := getNetworkServiceFromList(clusterNetworkServices, subnet) if err != nil { if strings.Contains(err.Error(), "could not find network service") { return nil } return err } _, otherSubnet, _ := net.ParseCIDR(subnetService.IP) if subnetToValidate.Contains(otherSubnet.IP) || otherSubnet.Contains(subnetToValidate.IP) { return fmt.Errorf("invalid subnet - %s must not overlap with %s", networkService.ServiceType, subnet) } } } return nil }