1 package services
2
3 import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8
9 "github.com/google/uuid"
10
11 sqlerr "edge-infra.dev/pkg/edge/api/apierror/sql"
12 "edge-infra.dev/pkg/edge/api/graph/model"
13 sqlquery "edge-infra.dev/pkg/edge/api/sql"
14 "edge-infra.dev/pkg/edge/api/utils"
15 )
16
17 func (v *virtualMachineService) CreateVirtualMachineDiskEntries(ctx context.Context, virtualMachineID string, createDisks []*model.VirtualMachineDiskCreateInput) (createdDisks []*model.VirtualMachineDisk, err error) {
18 transaction, err := v.SQLDB.BeginTx(ctx, nil)
19 if err != nil {
20 return nil, err
21 }
22
23 defer func() {
24 if err != nil {
25 err = errors.Join(err, transaction.Rollback())
26 }
27 }()
28
29 createdDisks, err = v.createVirtualMachineDiskEntries(ctx, transaction, createDisks, virtualMachineID)
30 if err != nil {
31 return nil, err
32 }
33
34 if len(createdDisks) == 0 {
35 return nil, fmt.Errorf("no vm disks to return - error creating disk entries")
36 }
37
38 allDisks, err := v.getVirtualMachineDisks(ctx, transaction, virtualMachineID)
39 if err != nil {
40 return nil, err
41 }
42
43 if utils.HasDuplicateDiskBootOrders(allDisks) {
44 return nil, fmt.Errorf("cannot create vm disks that have duplicate boot orders")
45 }
46
47 if err = transaction.Commit(); err != nil {
48 return nil, err
49 }
50
51 return createdDisks, nil
52 }
53
54 func (v *virtualMachineService) createVirtualMachineDiskEntries(ctx context.Context, transaction *sql.Tx, createDisks []*model.VirtualMachineDiskCreateInput, virtualMachineID string) ([]*model.VirtualMachineDisk, error) {
55 createdDisks := []*model.VirtualMachineDisk{}
56 for _, createDisk := range createDisks {
57 createdDisk := utils.CreateVirtualMachineDiskModel(uuid.NewString(), virtualMachineID, *createDisk.Type, *createDisk.Bus, createDisk.BootOrder, *createDisk.Size, *createDisk.ContainerImageURL)
58
59 if err := utils.ValidateVirtualMachineDisk(&createdDisk); err != nil {
60 return nil, err
61 }
62
63 args := []interface{}{
64 createdDisk.DiskID,
65 virtualMachineID,
66 createdDisk.Type,
67 createdDisk.Bus,
68 createdDisk.BootOrder,
69 createdDisk.Size,
70 createdDisk.ContainerImageURL,
71 }
72
73 if _, err := transaction.ExecContext(ctx, sqlquery.VirtualMachineDiskCreateQuery, args...); err != nil {
74 return nil, err
75 }
76
77 createdDisks = append(createdDisks, &createdDisk)
78 }
79 return createdDisks, nil
80 }
81
82 func (v *virtualMachineService) DeleteVirtualMachineDiskEntry(ctx context.Context, diskID string) (*model.VirtualMachineDisk, error) {
83 virtualMachine, err := v.GetVirtualMachineFromDisk(ctx, diskID)
84 if err != nil {
85 return nil, err
86 }
87 if virtualMachine == nil {
88 return nil, fmt.Errorf("associated virtual machine does not exist so cannot delete disk")
89 }
90
91 disks, err := v.GetVirtualMachineDisks(ctx, virtualMachine.VirtualMachineID)
92 if err != nil {
93 return nil, err
94 }
95 if disks == nil {
96 return nil, fmt.Errorf("cannot find any disks for the linked virtual machine")
97 }
98 if len(disks) == 1 {
99 return nil, fmt.Errorf("unable to delete the last disk of the virtual machine")
100 }
101
102 var deletedDisk *model.VirtualMachineDisk
103 for _, disk := range disks {
104 if disk.DiskID == diskID {
105 deletedDisk = disk
106 break
107 }
108 }
109
110 _, err = v.SQLDB.ExecContext(ctx, sqlquery.VirtualMachineDiskDeleteQuery, diskID)
111 return deletedDisk, err
112 }
113
114 func (v *virtualMachineService) UpdateVirtualMachineDiskEntries(ctx context.Context, updateDisks []*model.VirtualMachineDiskIDInput) (updatedDisks []*model.VirtualMachineDisk, err error) {
115 if len(updateDisks) == 0 {
116 return nil, fmt.Errorf("no vm disks provided to update")
117 }
118
119 transaction, err := v.SQLDB.BeginTx(ctx, nil)
120 if err != nil {
121 return nil, err
122 }
123
124 defer func() {
125 if err != nil {
126 err = errors.Join(err, transaction.Rollback())
127 }
128 }()
129
130 updatedDisks, err = v.updateVirtualMachineDiskEntries(ctx, transaction, updateDisks)
131 if err != nil {
132 return nil, err
133 }
134
135 allDisks, err := v.getVirtualMachineDisks(ctx, transaction, updatedDisks[0].VirtualMachineID)
136 if err != nil {
137 return nil, err
138 }
139
140 if utils.HasDuplicateDiskBootOrders(allDisks) {
141 return nil, fmt.Errorf("cannot update vm disks to cause duplicate boot orders")
142 }
143
144 if err = transaction.Commit(); err != nil {
145 return nil, err
146 }
147
148 return updatedDisks, nil
149 }
150
151 func (v *virtualMachineService) updateVirtualMachineDiskEntries(ctx context.Context, transaction *sql.Tx, updateDisks []*model.VirtualMachineDiskIDInput) ([]*model.VirtualMachineDisk, error) {
152 updatedDisks := []*model.VirtualMachineDisk{}
153 for _, updateDisk := range updateDisks {
154 updatedDisk, err := v.updateVirtualMachineDiskEntry(ctx, transaction, *updateDisk)
155 if err != nil {
156 return nil, err
157 }
158
159 updatedDisks = append(updatedDisks, updatedDisk)
160 }
161 return updatedDisks, nil
162 }
163
164 func (v *virtualMachineService) updateVirtualMachineDiskEntry(ctx context.Context, transaction *sql.Tx, updateDisk model.VirtualMachineDiskIDInput) (*model.VirtualMachineDisk, error) {
165 currentDisk, err := v.GetVirtualMachineDisk(ctx, updateDisk.DiskID)
166 if err != nil {
167 return nil, err
168 }
169
170 if currentDisk == nil {
171 return nil, fmt.Errorf("cannot update a disk that does not exist")
172 }
173
174 updatedDisk := utils.UpdateVirtualMachineDisk(currentDisk, updateDisk.VirtualMachineDiskValues)
175
176 if err = utils.ValidateVirtualMachineDisk(updatedDisk); err != nil {
177 return nil, err
178 }
179
180 args := []interface{}{
181 updatedDisk.Type,
182 updatedDisk.Bus,
183 updatedDisk.BootOrder,
184 updatedDisk.Size,
185 updatedDisk.ContainerImageURL,
186 updatedDisk.DiskID,
187 }
188
189 _, err = transaction.ExecContext(ctx, sqlquery.VirtualMachineDiskUpdateQuery, args...)
190 if err != nil {
191 return nil, err
192 }
193
194 return updatedDisk, nil
195 }
196
197 func (v *virtualMachineService) GetVirtualMachineDisk(ctx context.Context, diskID string) (*model.VirtualMachineDisk, error) {
198 row := v.SQLDB.QueryRowContext(ctx, sqlquery.GetVirtualMachineDiskByIDQuery, diskID)
199 return v.scanVirtualMachineDiskRow(row)
200 }
201
202
203 func (v *virtualMachineService) GetVirtualMachineDisks(ctx context.Context, virtualMachineID string) ([]*model.VirtualMachineDisk, error) {
204 virtualMachine, err := v.GetVirtualMachine(ctx, virtualMachineID)
205 if err != nil {
206 return nil, err
207 }
208 if virtualMachine == nil {
209 return nil, fmt.Errorf("cannot get disks for virtual machine that does not exist")
210 }
211 rows, err := v.SQLDB.QueryContext(ctx, sqlquery.GetVirtualMachineDiskByVirtualMachineIDQuery, virtualMachineID)
212 if err != nil {
213 return nil, err
214 }
215 return v.scanVirtualMachineDiskRows(rows)
216 }
217
218
219 func (v *virtualMachineService) getVirtualMachineDisks(ctx context.Context, transaction *sql.Tx, virtualMachineID string) ([]*model.VirtualMachineDisk, error) {
220 rows, err := transaction.QueryContext(ctx, sqlquery.GetVirtualMachineDiskByVirtualMachineIDQuery, virtualMachineID)
221 if err != nil {
222 return nil, err
223 }
224 return v.scanVirtualMachineDiskRows(rows)
225 }
226
227 func (v *virtualMachineService) GetVirtualMachineFromDisk(ctx context.Context, diskID string) (*model.VirtualMachine, error) {
228 row := v.SQLDB.QueryRowContext(ctx, sqlquery.GetVirtualMachineIDFromDiskQuery, diskID)
229 var virtualMachineID string
230 if err := row.Scan(&virtualMachineID); err != nil {
231 return nil, err
232 }
233
234 virtualMachine, err := v.GetVirtualMachine(ctx, virtualMachineID)
235 if err != nil {
236 return nil, err
237 }
238
239 return virtualMachine, nil
240 }
241
242 func (v *virtualMachineService) scanVirtualMachineDiskRow(row *sql.Row) (*model.VirtualMachineDisk, error) {
243 disk := model.VirtualMachineDisk{}
244 err := row.Scan(&disk.DiskID, &disk.VirtualMachineID, &disk.Type, &disk.Bus, &disk.BootOrder, &disk.Size, &disk.ContainerImageURL)
245 if err != nil {
246 return nil, err
247 }
248 return &disk, nil
249 }
250
251 func (v *virtualMachineService) scanVirtualMachineDiskRows(rows *sql.Rows) ([]*model.VirtualMachineDisk, error) {
252 disks := []*model.VirtualMachineDisk{}
253 for rows.Next() {
254 disk := model.VirtualMachineDisk{}
255 err := rows.Scan(&disk.DiskID, &disk.VirtualMachineID, &disk.Type, &disk.Bus, &disk.BootOrder, &disk.Size, &disk.ContainerImageURL)
256 if err != nil {
257 return nil, err
258 }
259 disks = append(disks, &disk)
260 }
261 if err := rows.Err(); err != nil {
262 return nil, sqlerr.Wrap(err)
263 }
264 return disks, nil
265 }
266
View as plain text