1 package services
2
3 import (
4 "context"
5 "database/sql"
6 "strconv"
7
8 "github.com/hashicorp/go-multierror"
9
10 sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
11 "edge-infra.dev/pkg/edge/api/graph/model"
12 sqlquery "edge-infra.dev/pkg/edge/api/sql"
13 )
14
15 type BannerConfigService interface {
16 UpdateBannerConfig(ctx context.Context, bannerEdgeID string, bannerConfig *model.UpdateBannerConfig) (*model.BannerConfig, error)
17 GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error)
18 }
19
20 type bannerConfigService struct {
21 sqlDB *sql.DB
22 }
23
24 const (
25 VNCReadWriteAuthRequired = "vnc_read_write_auth_required"
26 VNCReadWriteAuthRequiredOverride = "vnc_read_write_auth_required_override"
27 VNCReadAuthRequired = "vnc_read_auth_required"
28 VNCReadAuthRequiredOverride = "vnc_read_auth_required_override"
29 )
30
31 func (s *bannerConfigService) UpdateBannerConfig(ctx context.Context, bannerEdgeID string, inputCfg *model.UpdateBannerConfig) (*model.BannerConfig, error) {
32 bannerCfg, err := s.GetBannerConfig(ctx, bannerEdgeID)
33 if err != nil {
34 return nil, err
35 }
36
37
38 mapUpdateBannerConfig(bannerCfg, inputCfg)
39
40
41 bannerCfgValues := mapUpdateBannerConfigAttributes(inputCfg)
42
43 return bannerCfg, s.upsertBannerConfigKeys(ctx, bannerEdgeID, bannerCfgValues)
44 }
45
46 func (s *bannerConfigService) upsertBannerConfigKeys(ctx context.Context, bannerEdgeID string, bannerCfg map[string]string) error {
47 transaction, err := s.sqlDB.BeginTx(ctx, nil)
48 if err != nil {
49 return err
50 }
51
52 for key, value := range bannerCfg {
53 err = queryBannerConfig(ctx, transaction, sqlquery.UpdateBannerConfig, bannerEdgeID, key, value)
54 if err != nil {
55 return err
56 }
57 }
58
59 return transaction.Commit()
60 }
61
62 func queryBannerConfig(ctx context.Context, transaction *sql.Tx, query string, bannerID, key, val string) error {
63 if _, err := transaction.ExecContext(ctx, query, bannerID, key, val); err != nil {
64 if rollbackErr := transaction.Rollback(); rollbackErr != nil {
65 return multierror.Append(err, rollbackErr)
66 }
67 return err
68 }
69 return nil
70 }
71
72 func mapUpdateBannerConfigAttributes(updateBannerConfig *model.UpdateBannerConfig) map[string]string {
73 output := map[string]string{}
74
75 if updateBannerConfig.VncReadWriteAuthRequired != nil {
76 output[VNCReadWriteAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequired)
77 }
78
79 if updateBannerConfig.VncReadWriteAuthRequiredOverride != nil {
80 output[VNCReadWriteAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadWriteAuthRequiredOverride)
81 }
82 if updateBannerConfig.VncReadAuthRequired != nil {
83 output[VNCReadAuthRequired] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequired)
84 }
85 if updateBannerConfig.VncReadAuthRequiredOverride != nil {
86 output[VNCReadAuthRequiredOverride] = strconv.FormatBool(*updateBannerConfig.VncReadAuthRequiredOverride)
87 }
88
89 return output
90 }
91
92 func mapUpdateBannerConfig(bannerCfg *model.BannerConfig, updateBannerCfg *model.UpdateBannerConfig) {
93 if updateBannerCfg.VncReadWriteAuthRequired != nil {
94 bannerCfg.VncReadWriteAuthRequired = *updateBannerCfg.VncReadWriteAuthRequired
95 }
96 if updateBannerCfg.VncReadWriteAuthRequiredOverride != nil {
97 bannerCfg.VncReadWriteAuthRequiredOverride = *updateBannerCfg.VncReadWriteAuthRequiredOverride
98 }
99 if updateBannerCfg.VncReadAuthRequired != nil {
100 bannerCfg.VncReadAuthRequired = *updateBannerCfg.VncReadAuthRequired
101 }
102 if updateBannerCfg.VncReadAuthRequiredOverride != nil {
103 bannerCfg.VncReadAuthRequiredOverride = *updateBannerCfg.VncReadAuthRequiredOverride
104 }
105 }
106
107 func (s *bannerConfigService) GetBannerConfig(ctx context.Context, bannerEdgeID string) (*model.BannerConfig, error) {
108 rows, err := s.sqlDB.QueryContext(ctx, sqlquery.GetBannerConfig, bannerEdgeID)
109 if err != nil {
110 return nil, err
111 }
112 defer rows.Close()
113 return scanBannerConfigRows(rows)
114 }
115
116 func scanBannerConfigRows(rows *sql.Rows) (*model.BannerConfig, error) {
117 cfg := &model.BannerConfig{}
118 for rows.Next() {
119 var configKey, configValue string
120 if err := rows.Scan(&configKey, &configValue); err != nil {
121 return nil, err
122 }
123 err := insertValueIntoBannerConfig(cfg, configKey, configValue)
124 if err != nil {
125 return nil, err
126 }
127 }
128 if err := rows.Err(); err != nil {
129 return nil, sqlerr.Wrap(err)
130 }
131 return cfg, nil
132 }
133
134 func insertValueIntoBannerConfig(cfg *model.BannerConfig, configKey, configValue string) error {
135 var err error
136 switch configKey {
137 case VNCReadWriteAuthRequired:
138 cfg.VncReadWriteAuthRequired, err = strconv.ParseBool(configValue)
139 return err
140 case VNCReadWriteAuthRequiredOverride:
141 cfg.VncReadWriteAuthRequiredOverride, err = strconv.ParseBool(configValue)
142 return err
143 case VNCReadAuthRequired:
144 cfg.VncReadAuthRequired, err = strconv.ParseBool(configValue)
145 return err
146 case VNCReadAuthRequiredOverride:
147 cfg.VncReadAuthRequiredOverride, err = strconv.ParseBool(configValue)
148 return err
149 }
150 return nil
151 }
152
153 func NewBannerConfigService(sqlDB *sql.DB) BannerConfigService {
154 return &bannerConfigService{
155 sqlDB: sqlDB,
156 }
157 }
158
View as plain text