package services import ( "context" "database/sql" "errors" "fmt" "regexp" "strconv" "strings" "slices" "github.com/google/uuid" "github.com/hashicorp/go-multierror" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/mapper" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/constants" linkerd "edge-infra.dev/pkg/edge/linkerd" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/ien/topology" ) const ( // config_key as expected in the db AcRelayKey = "ac_relay" PxeEnabledKey = "pxe_enabled" BootstrapAckKey = "bootstrap_ack" VpnEnabledKey = "vpn_enabled" ThickPosKey = "thick_pos" EgressGatewayEnabledKey = "egress_gateway_enabled" GatewayRateLimitingEnabledKey = "gateway_rate_limiting_enabled" UplinkRateLimitKey = "uplink_rate_limit" DownlinkRateLimitKey = "downlink_rate_limit" ClusterLogLevelKey = "cluster_log_level" NamespaceLogLevelsKey = "namespace_log_levels" AutoUpdateEnabledKey = "auto_update_enabled" MaximumLanOutageHoursKey = "maximum_lan_outage_hours" VncReadWriteAuthRequired = "vnc_read_write_auth_required" VncReadAuthRequired = "vnc_read_auth_required" ) var validClusterLogLevels = []string{ "DEBUG", "INFO", "NOTICE", "WARNING", "ERROR", "CRITICAL", "ALERT", "EMERGENCY", } var ( // default values defaultRateLimit = "100mbit" logLevelDefault = "ERROR" ) var ( ErrInvalidRateLimit = errors.New("invalid format for upload or download rate limits") ErrClusterEdgeIDEmpty = errors.New("cluster edge id cannot be empty") ErrInvalidMaximumLanOutageHours = errors.New("maximum lan outage duration must be at least 24 hours") ) type ErrInvalidLogLevel struct { LogLevel string } func (e *ErrInvalidLogLevel) Error() string { return fmt.Sprintf("%s is invalid log level. Valid log levels: %s", e.LogLevel, strings.Join(validClusterLogLevels, ",")) } //go:generate mockgen -destination=../mocks/mock_cluster_config_service.go -package=mocks edge-infra.dev/pkg/edge/api/services ClusterConfigService type ClusterConfigService interface { UpdateClusterConfig(ctx context.Context, clusterEdgeID string, updatedClusterConfig *model.UpdateClusterConfig) (*model.ClusterConfig, error) GetClusterConfig(ctx context.Context, clusterEdgeID string) (*model.ClusterConfig, error) } type clusterConfigService struct { SQLDB *sql.DB } func NewClusterConfigService(sqlDB *sql.DB) *clusterConfigService { //nolint:revive return &clusterConfigService{ SQLDB: sqlDB, } } func (c *clusterConfigService) UpdateClusterConfig(ctx context.Context, clusterEdgeID string, inputCfg *model.UpdateClusterConfig) (*model.ClusterConfig, error) { clusterCfg, err := c.GetClusterConfig(ctx, clusterEdgeID) if err != nil { return nil, err } // overrides current configuration with new configuration mapUpdateClusterConfig(clusterCfg, inputCfg) // validate updated cluster config if err := validateClusterConfig(clusterCfg); err != nil { return nil, err } // convert values to map[string]string clusterCfgValues, err := mapUpdateClusterConfigAttributes(inputCfg) if err != nil { return nil, err } return clusterCfg, c.upsertClusterConfigKeys(ctx, clusterEdgeID, clusterCfgValues) } func (c *clusterConfigService) GetClusterConfig(ctx context.Context, clusterEdgeID string) (*model.ClusterConfig, error) { if clusterEdgeID == "" { return nil, ErrClusterEdgeIDEmpty } rows, err := c.SQLDB.QueryContext(ctx, sqlquery.GetClusterConfig, clusterEdgeID) if err != nil { return nil, err } defer rows.Close() return scanClusterConfigRows(clusterEdgeID, rows) } // iterates through each key, value and executes an upsert operation on cluster_config table func (c *clusterConfigService) upsertClusterConfigKeys(ctx context.Context, clusterEdgeID string, clusterCfg map[string]string) error { transaction, err := c.SQLDB.BeginTx(ctx, nil) if err != nil { return err } for key, value := range clusterCfg { err = queryClusterConfig(ctx, transaction, sqlquery.UpdateClusterConfig, clusterEdgeID, key, value) if err != nil { return err } } return transaction.Commit() } func queryClusterConfig(ctx context.Context, transaction *sql.Tx, query string, clusterID, key, val string) error { if _, err := transaction.ExecContext(ctx, query, clusterID, key, val, uuid.NewString()); err != nil { if rollbackErr := transaction.Rollback(); rollbackErr != nil { return multierror.Append(err, rollbackErr) } return err } return nil } func mapUpdateClusterConfigAttributes(updateClusterConfig *model.UpdateClusterConfig) (map[string]string, error) { output := map[string]string{} if updateClusterConfig.AcRelay != nil { output[AcRelayKey] = strconv.FormatBool(*updateClusterConfig.AcRelay) } if updateClusterConfig.PxeEnabled != nil { output[PxeEnabledKey] = strconv.FormatBool(*updateClusterConfig.PxeEnabled) } if updateClusterConfig.BootstrapAck != nil { output[BootstrapAckKey] = strconv.FormatBool(*updateClusterConfig.BootstrapAck) } if updateClusterConfig.VpnEnabled != nil { output[VpnEnabledKey] = strconv.FormatBool(*updateClusterConfig.VpnEnabled) } if updateClusterConfig.ThickPos != nil { output[ThickPosKey] = strconv.FormatBool(*updateClusterConfig.ThickPos) } if updateClusterConfig.EgressGatewayEnabled != nil { output[EgressGatewayEnabledKey] = strconv.FormatBool(*updateClusterConfig.EgressGatewayEnabled) } if updateClusterConfig.GatewayRateLimitingEnabled != nil { output[GatewayRateLimitingEnabledKey] = strconv.FormatBool(*updateClusterConfig.GatewayRateLimitingEnabled) } if updateClusterConfig.UplinkRateLimit != nil { output[UplinkRateLimitKey] = *updateClusterConfig.UplinkRateLimit } if updateClusterConfig.DownlinkRateLimit != nil { output[DownlinkRateLimitKey] = *updateClusterConfig.DownlinkRateLimit } if updateClusterConfig.ClusterLogLevel != nil { output[ClusterLogLevelKey] = *updateClusterConfig.ClusterLogLevel } if updateClusterConfig.NamespaceLogLevels != nil { namespaceLogLevelsJSONStr, err := mapper.NLLPToJSON(updateClusterConfig.NamespaceLogLevels) if err != nil { return nil, fmt.Errorf("Error when marshaling to a string in mapUpdatedClusterConfigAttributes: %s", err) } output[NamespaceLogLevelsKey] = namespaceLogLevelsJSONStr } if updateClusterConfig.MaximumLanOutageHours != nil { output[MaximumLanOutageHoursKey] = strconv.FormatInt(int64(*updateClusterConfig.MaximumLanOutageHours), 10) } // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields if updateClusterConfig.LinkerdIdentityIssuerCertDuration != nil { output[constants.LinkerdIdentityIssuerCertDuration] = strconv.FormatInt(int64(*updateClusterConfig.LinkerdIdentityIssuerCertDuration), 10) } // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields if updateClusterConfig.LinkerdIdentityIssuerCertRenewBefore != nil { output[constants.LinkerdIdentityIssuerCertRenewBefore] = strconv.FormatInt(int64(*updateClusterConfig.LinkerdIdentityIssuerCertRenewBefore), 10) } if updateClusterConfig.AutoUpdateEnabled != nil { output[AutoUpdateEnabledKey] = strconv.FormatBool(*updateClusterConfig.AutoUpdateEnabled) } if updateClusterConfig.VncReadWriteAuthRequired != nil { output[VncReadWriteAuthRequired] = strconv.FormatBool(*updateClusterConfig.VncReadWriteAuthRequired) } if updateClusterConfig.VncReadAuthRequired != nil { output[VncReadAuthRequired] = strconv.FormatBool(*updateClusterConfig.VncReadAuthRequired) } return output, nil } func scanClusterConfigRows(clusterEdgeID string, rows *sql.Rows) (*model.ClusterConfig, error) { clusterConfig := defaultClusterConfig(clusterEdgeID) for rows.Next() { var clusterConfigEdgeID, configKey, configValue string if err := rows.Scan(&clusterConfigEdgeID, &clusterConfig.ClusterEdgeID, &configKey, &configValue); err != nil { return nil, err } err := insertValueIntoClusterConfig(clusterConfig, configKey, configValue) if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return clusterConfig, nil } // Update a specific key in the cluster config with the given value func insertValueIntoClusterConfig(clusterConfig *model.ClusterConfig, configKey, configValue string) error { //nolint:gocyclo // Each config option follows a consistent pattern var err error switch configKey { case AcRelayKey: clusterConfig.AcRelay, err = strconv.ParseBool(configValue) if err != nil { return err } case PxeEnabledKey: clusterConfig.PxeEnabled, err = strconv.ParseBool(configValue) if err != nil { return err } case BootstrapAckKey: clusterConfig.BootstrapAck, err = strconv.ParseBool(configValue) if err != nil { return err } case VpnEnabledKey: clusterConfig.VpnEnabled, err = strconv.ParseBool(configValue) if err != nil { return err } case ThickPosKey: clusterConfig.ThickPos, err = strconv.ParseBool(configValue) if err != nil { return err } case EgressGatewayEnabledKey: clusterConfig.EgressGatewayEnabled, err = strconv.ParseBool(configValue) if err != nil { return err } case GatewayRateLimitingEnabledKey: clusterConfig.GatewayRateLimitingEnabled, err = strconv.ParseBool(configValue) if err != nil { return err } case UplinkRateLimitKey: clusterConfig.UplinkRateLimit = configValue case DownlinkRateLimitKey: clusterConfig.DownlinkRateLimit = configValue case ClusterLogLevelKey: clusterConfig.ClusterLogLevel = configValue case NamespaceLogLevelsKey: convertedPayload, err := mapper.JSONtoNLL(configValue) if err != nil { return err } clusterConfig.NamespaceLogLevels = convertedPayload case MaximumLanOutageHoursKey: v, err := strconv.ParseInt(configValue, 10, 64) if err != nil { return err } clusterConfig.MaximumLanOutageHours = int(v) // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields case constants.LinkerdIdentityIssuerCertDuration: v, err := strconv.ParseInt(configValue, 10, 64) if err != nil { return err } clusterConfig.LinkerdIdentityIssuerCertDuration = int(v) // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields case constants.LinkerdIdentityIssuerCertRenewBefore: v, err := strconv.ParseInt(configValue, 10, 64) if err != nil { return err } clusterConfig.LinkerdIdentityIssuerCertRenewBefore = int(v) case AutoUpdateEnabledKey: clusterConfig.AutoUpdateEnabled, err = strconv.ParseBool(configValue) if err != nil { return err } case VncReadWriteAuthRequired: val, err := strconv.ParseBool(configValue) if err != nil { return err } clusterConfig.VncReadWriteAuthRequired = &val case VncReadAuthRequired: val, err := strconv.ParseBool(configValue) if err != nil { return err } clusterConfig.VncReadAuthRequired = &val default: // Level 1 is Warning log := fog.New(fog.WithLevel(1)).WithName("cluster config service") // prevents records with newly added keys from failing API calls msg := fmt.Errorf("unknown cluster config key %s", configKey) log.Error(msg, "ensure the new key is added to the cluster configuration API") return nil } return nil } // runs validation against ClusterConfig func validateClusterConfig(cfg *model.ClusterConfig) error { if cfg == nil { return nil } if err := validateRateLimit(&cfg.UplinkRateLimit); err != nil { return err } if err := validateRateLimit(&cfg.DownlinkRateLimit); err != nil { return err } if err := validateLogLevel(cfg.ClusterLogLevel); err != nil { return err } if err := validateNamespaceLogLevel(cfg.NamespaceLogLevels); err != nil { return err } if err := validateMaximumLanOutageHours(cfg.MaximumLanOutageHours); err != nil { return err } // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields if err := validateLinkerdIdentityIssuerCertTimes(cfg.LinkerdIdentityIssuerCertDuration, cfg.LinkerdIdentityIssuerCertRenewBefore); err != nil { return err } return nil } // checks that the rate limit is in correct format // i.e. 1000kbit or 10mbit func validateRateLimit(rateLimit *string) error { if rateLimit == nil { return nil } rateLimitRegExp, err := regexp.Compile("^[1-9][0-9]*([kK]bit$|[mM]bit$)") if err != nil { return err } if !rateLimitRegExp.MatchString(*rateLimit) { return ErrInvalidRateLimit } return nil } func validateNamespaceLogLevel(namespaceLogLevels []*model.NamespaceLogLevel) error { if namespaceLogLevels == nil { return nil } for _, entry := range namespaceLogLevels { if err := validateLogLevel(entry.Level); err != nil { return err } } return nil } // checks that the entered log level is valid and exists func validateLogLevel(level string) error { levelExists := slices.Contains(validClusterLogLevels, strings.ToUpper(level)) if !levelExists { return &ErrInvalidLogLevel{LogLevel: level} } return nil } // Deprecated: Marked as DEPRECATED. Do not use // checks that the linkerd identity issuer certificate duration is greater than the renewBefore time func validateLinkerdIdentityIssuerCertTimes(duration, renewBefore int) error { if duration <= 0 || renewBefore <= 0 { return topology.ErrInvalidCertificateDurationOrRenewBefore } if duration <= renewBefore { return topology.ErrInvalidCertificateRenewBefore } return nil } func validateMaximumLanOutageHours(maxLanOutageHours int) error { if maxLanOutageHours < 24 { return ErrInvalidMaximumLanOutageHours } return nil } // maps UpdateClusterConfig to ClusterConfig func mapUpdateClusterConfig(clusterCfg *model.ClusterConfig, updateClusterCfg *model.UpdateClusterConfig) { if updateClusterCfg.AcRelay != nil { clusterCfg.AcRelay = *updateClusterCfg.AcRelay } if updateClusterCfg.BootstrapAck != nil { clusterCfg.BootstrapAck = *updateClusterCfg.BootstrapAck } if updateClusterCfg.VpnEnabled != nil { clusterCfg.VpnEnabled = *updateClusterCfg.VpnEnabled } if updateClusterCfg.PxeEnabled != nil { clusterCfg.PxeEnabled = *updateClusterCfg.PxeEnabled } if updateClusterCfg.ThickPos != nil { clusterCfg.ThickPos = *updateClusterCfg.ThickPos } if updateClusterCfg.EgressGatewayEnabled != nil { clusterCfg.EgressGatewayEnabled = *updateClusterCfg.EgressGatewayEnabled } if updateClusterCfg.GatewayRateLimitingEnabled != nil { clusterCfg.GatewayRateLimitingEnabled = *updateClusterCfg.GatewayRateLimitingEnabled } if updateClusterCfg.UplinkRateLimit != nil { clusterCfg.UplinkRateLimit = *updateClusterCfg.UplinkRateLimit } if updateClusterCfg.DownlinkRateLimit != nil { clusterCfg.DownlinkRateLimit = *updateClusterCfg.DownlinkRateLimit } if updateClusterCfg.ClusterLogLevel != nil { clusterCfg.ClusterLogLevel = *updateClusterCfg.ClusterLogLevel } if updateClusterCfg.NamespaceLogLevels != nil { convertedNamespaceLogLevel := mapper.NLLPToNLL(updateClusterCfg.NamespaceLogLevels) clusterCfg.NamespaceLogLevels = convertedNamespaceLogLevel } if updateClusterCfg.MaximumLanOutageHours != nil { clusterCfg.MaximumLanOutageHours = *updateClusterCfg.MaximumLanOutageHours } // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields if updateClusterCfg.LinkerdIdentityIssuerCertDuration != nil { clusterCfg.LinkerdIdentityIssuerCertDuration = *updateClusterCfg.LinkerdIdentityIssuerCertDuration } // TODO: Marked as DEPRECATED. Remove when appropriate //nolint:staticcheck // Allow existing usage of deprecated fields if updateClusterCfg.LinkerdIdentityIssuerCertRenewBefore != nil { clusterCfg.LinkerdIdentityIssuerCertRenewBefore = *updateClusterCfg.LinkerdIdentityIssuerCertRenewBefore } if updateClusterCfg.AutoUpdateEnabled != nil { clusterCfg.AutoUpdateEnabled = *updateClusterCfg.AutoUpdateEnabled } if updateClusterCfg.VncReadWriteAuthRequired != nil { clusterCfg.VncReadWriteAuthRequired = updateClusterCfg.VncReadWriteAuthRequired } if updateClusterCfg.VncReadAuthRequired != nil { clusterCfg.VncReadAuthRequired = updateClusterCfg.VncReadAuthRequired } } // returns default settings for ClusterConfig struct func defaultClusterConfig(clusterEdgeID string) *model.ClusterConfig { return &model.ClusterConfig{ ClusterEdgeID: clusterEdgeID, AcRelay: false, VpnEnabled: false, BootstrapAck: true, PxeEnabled: false, ThickPos: true, EgressGatewayEnabled: false, GatewayRateLimitingEnabled: false, UplinkRateLimit: defaultRateLimit, DownlinkRateLimit: defaultRateLimit, ClusterLogLevel: logLevelDefault, NamespaceLogLevels: make([]*model.NamespaceLogLevel, 0), MaximumLanOutageHours: int(linkerd.DefaultThinPosIdentityIssuerCertificateDurationHours), // TODO: Marked as DEPRECATED. Remove LinkerdIdentityIssuerCertDuration: int(linkerd.DefaultThinPosIdentityIssuerCertificateDurationHours), LinkerdIdentityIssuerCertRenewBefore: int(linkerd.DefaultThinPosIdentityIssuerCertificateRenewBeforeHours), VncReadWriteAuthRequired: nil, VncReadAuthRequired: nil, } }