1 package services
2
3 import (
4 "context"
5 "database/sql"
6 "errors"
7
8 "github.com/google/uuid"
9
10 sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
11 "edge-infra.dev/pkg/edge/api/graph/model"
12 sqlquery "edge-infra.dev/pkg/edge/api/sql"
13 "edge-infra.dev/pkg/edge/api/utils"
14 )
15
16 func (t *terminalService) CreateTerminalDiskEntry(ctx context.Context, terminalID string, newTerminalDisk *model.TerminalDiskCreateInput) (*model.TerminalDisk, error) {
17 if err := t.validateTerminalDiskDevicePath(ctx, terminalID, newTerminalDisk.DevicePath); err != nil {
18 return nil, err
19 }
20
21 transaction, err := t.SQLDB.BeginTx(ctx, nil)
22 if err != nil {
23 return nil, err
24 }
25
26 newDiskList := []*model.TerminalDiskCreateInput{newTerminalDisk}
27 diskList, err := t.createTerminalDiskEntries(ctx, transaction, newDiskList, terminalID)
28 if err != nil {
29 if rollbackErr := transaction.Rollback(); rollbackErr != nil {
30 return nil, rollbackErr
31 }
32 return nil, err
33 }
34
35 if err = transaction.Commit(); err != nil {
36 return nil, err
37 }
38
39 if len(diskList) == 0 {
40 return nil, errors.New("no disk to return - potential error creating disk entry")
41 }
42
43 return diskList[0], nil
44 }
45
46 func (t *terminalService) createTerminalDiskEntries(ctx context.Context, transaction *sql.Tx, newDisks []*model.TerminalDiskCreateInput, terminalID string) ([]*model.TerminalDisk, error) {
47 terminalDisks := []*model.TerminalDisk{}
48 for _, newDisk := range newDisks {
49 terminalDisk := utils.CreateTerminalDiskModel(uuid.NewString(), terminalID, newDisk.IncludeDisk, newDisk.ExpectEmpty, newDisk.DevicePath, newDisk.UsePart)
50
51 args := []interface{}{
52 terminalDisk.TerminalDiskID,
53 terminalID,
54 terminalDisk.IncludeDisk,
55 terminalDisk.ExpectEmpty,
56 terminalDisk.DevicePath,
57 terminalDisk.UsePart,
58 }
59
60 if _, err := transaction.ExecContext(ctx, sqlquery.TerminalDiskCreateQuery, args...); err != nil {
61 return nil, err
62 }
63
64 terminalDisks = append(terminalDisks, &terminalDisk)
65 }
66 return terminalDisks, nil
67 }
68
69 func (t *terminalService) DeleteTerminalDiskEntry(ctx context.Context, terminalDiskID string) (*model.Terminal, error) {
70 terminal, err := t.GetTerminalFromDisk(ctx, terminalDiskID)
71 if err != nil {
72 return nil, err
73 }
74
75 if len(terminal.Disks) == 1 {
76 return nil, errors.New("unable to delete the last disk of the terminal")
77 }
78
79 for i, disk := range terminal.Disks {
80 if disk.TerminalDiskID == terminalDiskID {
81 terminal.Disks[i] = terminal.Disks[len(terminal.Disks)-1]
82 terminal.Disks = terminal.Disks[:len(terminal.Disks)-1]
83 break
84 }
85 }
86
87 _, err = t.SQLDB.ExecContext(ctx, sqlquery.TerminalDiskDeleteQuery, terminalDiskID)
88 return terminal, err
89 }
90
91 func (t *terminalService) UpdateTerminalDiskEntry(ctx context.Context, terminalDiskID string, diskInput model.TerminalDiskUpdateInput) (*model.TerminalDisk, error) {
92 transaction, err := t.SQLDB.BeginTx(ctx, nil)
93 if err != nil {
94 return nil, err
95 }
96
97 updateDiskInput := []*model.TerminalDiskIDInput{
98 {
99 TerminalDiskID: terminalDiskID,
100 TerminalDiskValues: &model.TerminalDiskUpdateInput{
101 DevicePath: diskInput.DevicePath,
102 IncludeDisk: diskInput.IncludeDisk,
103 ExpectEmpty: diskInput.ExpectEmpty,
104 UsePart: diskInput.UsePart,
105 },
106 },
107 }
108
109 updatedDisks, err := t.updateTerminalDiskEntries(ctx, transaction, updateDiskInput)
110 if err != nil {
111 if rollbackErr := transaction.Rollback(); rollbackErr != nil {
112 return nil, rollbackErr
113 }
114 return nil, err
115 }
116
117 if err = transaction.Commit(); err != nil {
118 return nil, err
119 }
120
121 if len(updatedDisks) == 0 {
122 return nil, errors.New("no disk to return - potential error updating disk entry")
123 }
124
125 return updatedDisks[0], nil
126 }
127
128 func (t *terminalService) updateTerminalDiskEntries(ctx context.Context, transaction *sql.Tx, updateDisks []*model.TerminalDiskIDInput) ([]*model.TerminalDisk, error) {
129 updatedTerminalDisks := []*model.TerminalDisk{}
130 for _, updateDisk := range updateDisks {
131 disk, err := t.getTerminalDisk(ctx, updateDisk.TerminalDiskID)
132 if err != nil {
133 return nil, err
134 }
135
136 updateDiskInput := updateDisk.TerminalDiskValues
137 if updateDiskInput.DevicePath != nil {
138 if err = t.validateTerminalDiskDevicePath(ctx, disk.TerminalID, *updateDiskInput.DevicePath); err != nil {
139 return nil, err
140 }
141 }
142
143 updatedTerminalDisk, err := utils.UpdateTerminalDisk(disk, updateDiskInput)
144 if err != nil {
145 return nil, err
146 }
147
148 args := []interface{}{
149 updatedTerminalDisk.IncludeDisk,
150 updatedTerminalDisk.ExpectEmpty,
151 updatedTerminalDisk.DevicePath,
152 updatedTerminalDisk.UsePart,
153 updatedTerminalDisk.TerminalDiskID,
154 }
155 _, err = transaction.ExecContext(ctx, sqlquery.TerminalDiskUpdateQuery, args...)
156 if err != nil {
157 return nil, err
158 }
159
160 updatedTerminalDisks = append(updatedTerminalDisks, updatedTerminalDisk)
161 }
162
163 return updatedTerminalDisks, nil
164 }
165
166 func (t *terminalService) GetTerminalFromDisk(ctx context.Context, terminalDiskID string) (*model.Terminal, error) {
167 row := t.SQLDB.QueryRowContext(ctx, sqlquery.GetTerminalIDFromDiskQuery, terminalDiskID)
168 var terminalID string
169 if err := row.Scan(&terminalID); err != nil {
170 return nil, err
171 }
172
173 getLabel := true
174
175 terminal, err := t.GetTerminal(ctx, terminalID, &getLabel)
176 if err != nil {
177 return nil, err
178 }
179
180 return terminal, nil
181 }
182
183 func (t *terminalService) getTerminalDisk(ctx context.Context, terminalDiskID string) (*model.TerminalDisk, error) {
184 row := t.SQLDB.QueryRowContext(ctx, sqlquery.GetTerminalDiskByIDQuery, terminalDiskID)
185
186 terminalDisk := model.TerminalDisk{}
187 if err := row.Scan(&terminalDisk.TerminalDiskID, &terminalDisk.TerminalID, &terminalDisk.IncludeDisk, &terminalDisk.ExpectEmpty, &terminalDisk.DevicePath, &terminalDisk.UsePart); err != nil {
188 return nil, err
189 }
190
191 return &terminalDisk, nil
192 }
193
194 func (t *terminalService) getTerminalDisks(ctx context.Context, terminalID *string) ([]*model.TerminalDisk, error) {
195 rows, err := t.SQLDB.QueryContext(ctx, sqlquery.GetTerminalDiskByTerminalIDQuery, terminalID)
196 if err != nil {
197 return nil, err
198 }
199
200 disks, err := t.scanTerminalDiskRows(rows)
201 if err != nil {
202 return nil, err
203 }
204
205 return disks, nil
206 }
207
208 func (t *terminalService) validateTerminalDiskDevicePath(ctx context.Context, terminalID string, devicePath string) error {
209 disks, err := t.getTerminalDisks(ctx, &terminalID)
210 if err != nil {
211 return err
212 }
213
214 for _, disk := range disks {
215 if disk.DevicePath == devicePath {
216 return ErrDuplicateTerminalDiskDevicePaths
217 }
218 }
219
220 return nil
221 }
222
223 func (t *terminalService) scanTerminalDiskRows(rows *sql.Rows) ([]*model.TerminalDisk, error) {
224 terminalDisks := []*model.TerminalDisk{}
225 for rows.Next() {
226 terminalDisk := model.TerminalDisk{}
227 err := rows.Scan(&terminalDisk.TerminalDiskID, &terminalDisk.TerminalID, &terminalDisk.IncludeDisk, &terminalDisk.ExpectEmpty, &terminalDisk.DevicePath, &terminalDisk.UsePart)
228 if err != nil {
229 return nil, err
230 }
231 terminalDisks = append(terminalDisks, &terminalDisk)
232 }
233 if err := rows.Err(); err != nil {
234 return nil, sqlerr.Wrap(err)
235 }
236 return terminalDisks, nil
237 }
238
View as plain text