package services import ( "context" "database/sql" "strconv" "github.com/hashicorp/go-multierror" sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" ) type BannerConfigService interface { UpdateBannerConfig(ctx context.Context, bannerEdgeID string, bannerConfig *model.UpdateBannerConfig) (*model.BannerConfig, error) GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error) } type bannerConfigService struct { sqlDB *sql.DB } const ( VNCReadWriteAuthRequired = "vnc_read_write_auth_required" VNCReadWriteAuthRequiredOverride = "vnc_read_write_auth_required_override" VNCReadAuthRequired = "vnc_read_auth_required" VNCReadAuthRequiredOverride = "vnc_read_auth_required_override" ) func (s *bannerConfigService) UpdateBannerConfig(ctx context.Context, bannerEdgeID string, inputCfg *model.UpdateBannerConfig) (*model.BannerConfig, error) { bannerCfg, err := s.GetBannerConfig(ctx, bannerEdgeID) if err != nil { return nil, err } // overrides current configuration with new configuration mapUpdateBannerConfig(bannerCfg, inputCfg) // convert values to map[string]string bannerCfgValues := mapUpdateBannerConfigAttributes(inputCfg) return bannerCfg, s.upsertBannerConfigKeys(ctx, bannerEdgeID, bannerCfgValues) } func (s *bannerConfigService) upsertBannerConfigKeys(ctx context.Context, bannerEdgeID string, bannerCfg map[string]string) error { transaction, err := s.sqlDB.BeginTx(ctx, nil) if err != nil { return err } for key, value := range bannerCfg { err = queryBannerConfig(ctx, transaction, sqlquery.UpdateBannerConfig, bannerEdgeID, key, value) if err != nil { return err } } return transaction.Commit() } func queryBannerConfig(ctx context.Context, transaction *sql.Tx, query string, bannerID, key, val string) error { if _, err := transaction.ExecContext(ctx, query, bannerID, key, val); err != nil { if rollbackErr := transaction.Rollback(); rollbackErr != nil { return multierror.Append(err, rollbackErr) } return err } return nil } func mapUpdateBannerConfigAttributes(updateBannerConfig *model.UpdateBannerConfig) map[string]string { output := map[string]string{} if updateBannerConfig.VncReadWriteAuthRequired != nil { output[VNCReadWriteAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequired) } if updateBannerConfig.VncReadWriteAuthRequiredOverride != nil { output[VNCReadWriteAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequiredOverride) } if updateBannerConfig.VncReadAuthRequired != nil { output[VNCReadAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequired) } if updateBannerConfig.VncReadAuthRequiredOverride != nil { output[VNCReadAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequiredOverride) } return output } func mapUpdateBannerConfig(bannerCfg *model.BannerConfig, updateBannerCfg *model.UpdateBannerConfig) { if updateBannerCfg.VncReadWriteAuthRequired != nil { bannerCfg.VncReadWriteAuthRequired = *updateBannerCfg.VncReadWriteAuthRequired } if updateBannerCfg.VncReadWriteAuthRequiredOverride != nil { bannerCfg.VncReadWriteAuthRequiredOverride = *updateBannerCfg.VncReadWriteAuthRequiredOverride } if updateBannerCfg.VncReadAuthRequired != nil { bannerCfg.VncReadAuthRequired = *updateBannerCfg.VncReadAuthRequired } if updateBannerCfg.VncReadAuthRequiredOverride != nil { bannerCfg.VncReadAuthRequiredOverride = *updateBannerCfg.VncReadAuthRequiredOverride } } func (s *bannerConfigService) GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error) { rows, err := s.sqlDB.QueryContext(ctx, sqlquery.GetBannerConfig, bannerEdgeID) if err != nil { return nil, err } defer rows.Close() return scanBannerConfigRows(rows) } func scanBannerConfigRows(rows *sql.Rows) (*model.BannerConfig, error) { cfg := &model.BannerConfig{} for rows.Next() { var configKey, configValue string if err := rows.Scan(&configKey, &configValue); err != nil { return nil, err } err := insertValueIntoBannerConfig(cfg, configKey, configValue) if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, sqlerr.Wrap(err) } return cfg, nil } func insertValueIntoBannerConfig(cfg *model.BannerConfig, configKey, configValue string) error { var err error switch configKey { case VNCReadWriteAuthRequired: cfg.VncReadWriteAuthRequired, err = strconv.ParseBool(configValue) return err case VNCReadWriteAuthRequiredOverride: cfg.VncReadWriteAuthRequiredOverride, err = strconv.ParseBool(configValue) return err case VNCReadAuthRequired: cfg.VncReadAuthRequired, err = strconv.ParseBool(configValue) return err case VNCReadAuthRequiredOverride: cfg.VncReadAuthRequiredOverride, err = strconv.ParseBool(configValue) return err } return nil } func NewBannerConfigService(sqlDB *sql.DB) BannerConfigService { return &bannerConfigService{ sqlDB: sqlDB, } }