...

Source file src/edge-infra.dev/pkg/edge/api/services/banner_config_service.go

Documentation: edge-infra.dev/pkg/edge/api/services

     1  package services
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"strconv"
     7  
     8  	"github.com/hashicorp/go-multierror"
     9  
    10  	sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
    11  	"edge-infra.dev/pkg/edge/api/graph/model"
    12  	sqlquery "edge-infra.dev/pkg/edge/api/sql"
    13  )
    14  
    15  type BannerConfigService interface {
    16  	UpdateBannerConfig(ctx context.Context, bannerEdgeID string, bannerConfig *model.UpdateBannerConfig) (*model.BannerConfig, error)
    17  	GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error)
    18  }
    19  
    20  type bannerConfigService struct {
    21  	sqlDB *sql.DB
    22  }
    23  
    24  const (
    25  	VNCReadWriteAuthRequired         = "vnc_read_write_auth_required"
    26  	VNCReadWriteAuthRequiredOverride = "vnc_read_write_auth_required_override"
    27  	VNCReadAuthRequired              = "vnc_read_auth_required"
    28  	VNCReadAuthRequiredOverride      = "vnc_read_auth_required_override"
    29  )
    30  
    31  func (s *bannerConfigService) UpdateBannerConfig(ctx context.Context, bannerEdgeID string, inputCfg *model.UpdateBannerConfig) (*model.BannerConfig, error) {
    32  	bannerCfg, err := s.GetBannerConfig(ctx, bannerEdgeID)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	// overrides current configuration with new configuration
    38  	mapUpdateBannerConfig(bannerCfg, inputCfg)
    39  
    40  	// convert values to map[string]string
    41  	bannerCfgValues := mapUpdateBannerConfigAttributes(inputCfg)
    42  
    43  	return bannerCfg, s.upsertBannerConfigKeys(ctx, bannerEdgeID, bannerCfgValues)
    44  }
    45  
    46  func (s *bannerConfigService) upsertBannerConfigKeys(ctx context.Context, bannerEdgeID string, bannerCfg map[string]string) error {
    47  	transaction, err := s.sqlDB.BeginTx(ctx, nil)
    48  	if err != nil {
    49  		return err
    50  	}
    51  
    52  	for key, value := range bannerCfg {
    53  		err = queryBannerConfig(ctx, transaction, sqlquery.UpdateBannerConfig, bannerEdgeID, key, value)
    54  		if err != nil {
    55  			return err
    56  		}
    57  	}
    58  
    59  	return transaction.Commit()
    60  }
    61  
    62  func queryBannerConfig(ctx context.Context, transaction *sql.Tx, query string, bannerID, key, val string) error {
    63  	if _, err := transaction.ExecContext(ctx, query, bannerID, key, val); err != nil {
    64  		if rollbackErr := transaction.Rollback(); rollbackErr != nil {
    65  			return multierror.Append(err, rollbackErr)
    66  		}
    67  		return err
    68  	}
    69  	return nil
    70  }
    71  
    72  func mapUpdateBannerConfigAttributes(updateBannerConfig *model.UpdateBannerConfig) map[string]string {
    73  	output := map[string]string{}
    74  
    75  	if updateBannerConfig.VncReadWriteAuthRequired != nil {
    76  		output[VNCReadWriteAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequired)
    77  	}
    78  
    79  	if updateBannerConfig.VncReadWriteAuthRequiredOverride != nil {
    80  		output[VNCReadWriteAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequiredOverride)
    81  	}
    82  	if updateBannerConfig.VncReadAuthRequired != nil {
    83  		output[VNCReadAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequired)
    84  	}
    85  	if updateBannerConfig.VncReadAuthRequiredOverride != nil {
    86  		output[VNCReadAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequiredOverride)
    87  	}
    88  
    89  	return output
    90  }
    91  
    92  func mapUpdateBannerConfig(bannerCfg *model.BannerConfig, updateBannerCfg *model.UpdateBannerConfig) {
    93  	if updateBannerCfg.VncReadWriteAuthRequired != nil {
    94  		bannerCfg.VncReadWriteAuthRequired = *updateBannerCfg.VncReadWriteAuthRequired
    95  	}
    96  	if updateBannerCfg.VncReadWriteAuthRequiredOverride != nil {
    97  		bannerCfg.VncReadWriteAuthRequiredOverride = *updateBannerCfg.VncReadWriteAuthRequiredOverride
    98  	}
    99  	if updateBannerCfg.VncReadAuthRequired != nil {
   100  		bannerCfg.VncReadAuthRequired = *updateBannerCfg.VncReadAuthRequired
   101  	}
   102  	if updateBannerCfg.VncReadAuthRequiredOverride != nil {
   103  		bannerCfg.VncReadAuthRequiredOverride = *updateBannerCfg.VncReadAuthRequiredOverride
   104  	}
   105  }
   106  
   107  func (s *bannerConfigService) GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error) {
   108  	rows, err := s.sqlDB.QueryContext(ctx, sqlquery.GetBannerConfig, bannerEdgeID)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	defer rows.Close()
   113  	return scanBannerConfigRows(rows)
   114  }
   115  
   116  func scanBannerConfigRows(rows *sql.Rows) (*model.BannerConfig, error) {
   117  	cfg := &model.BannerConfig{}
   118  	for rows.Next() {
   119  		var configKey, configValue string
   120  		if err := rows.Scan(&configKey, &configValue); err != nil {
   121  			return nil, err
   122  		}
   123  		err := insertValueIntoBannerConfig(cfg, configKey, configValue)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  	}
   128  	if err := rows.Err(); err != nil {
   129  		return nil, sqlerr.Wrap(err)
   130  	}
   131  	return cfg, nil
   132  }
   133  
   134  func insertValueIntoBannerConfig(cfg *model.BannerConfig, configKey, configValue string) error {
   135  	var err error
   136  	switch configKey {
   137  	case VNCReadWriteAuthRequired:
   138  		cfg.VncReadWriteAuthRequired, err = strconv.ParseBool(configValue)
   139  		return err
   140  	case VNCReadWriteAuthRequiredOverride:
   141  		cfg.VncReadWriteAuthRequiredOverride, err = strconv.ParseBool(configValue)
   142  		return err
   143  	case VNCReadAuthRequired:
   144  		cfg.VncReadAuthRequired, err = strconv.ParseBool(configValue)
   145  		return err
   146  	case VNCReadAuthRequiredOverride:
   147  		cfg.VncReadAuthRequiredOverride, err = strconv.ParseBool(configValue)
   148  		return err
   149  	}
   150  	return nil
   151  }
   152  
   153  func NewBannerConfigService(sqlDB *sql.DB) BannerConfigService {
   154  	return &bannerConfigService{
   155  		sqlDB: sqlDB,
   156  	}
   157  }
   158  

View as plain text