1
16
17 package app
18
19 import (
20 "context"
21 "encoding/json"
22 "errors"
23 "fmt"
24 "os"
25 "path/filepath"
26 "sync"
27
28 "google.golang.org/grpc"
29 "google.golang.org/grpc/codes"
30 "google.golang.org/grpc/status"
31
32 resourceapi "k8s.io/api/resource/v1alpha2"
33 "k8s.io/apimachinery/pkg/runtime"
34 "k8s.io/dynamic-resource-allocation/kubeletplugin"
35 "k8s.io/klog/v2"
36 drapbv1alpha2 "k8s.io/kubelet/pkg/apis/dra/v1alpha2"
37 drapbv1alpha3 "k8s.io/kubelet/pkg/apis/dra/v1alpha3"
38 )
39
40 type ExamplePlugin struct {
41 stopCh <-chan struct{}
42 logger klog.Logger
43 d kubeletplugin.DRAPlugin
44 fileOps FileOperations
45
46 cdiDir string
47 driverName string
48 nodeName string
49
50 mutex sync.Mutex
51 prepared map[ClaimID]bool
52 gRPCCalls []GRPCCall
53
54 block bool
55 }
56
57 type GRPCCall struct {
58
59 FullMethod string
60
61
62 Request interface{}
63
64
65 Response interface{}
66
67
68 Err error
69 }
70
71
72
73
74 type ClaimID struct {
75 Name string
76 UID string
77 }
78
79 var _ drapbv1alpha2.NodeServer = &ExamplePlugin{}
80 var _ drapbv1alpha3.NodeServer = &ExamplePlugin{}
81
82
83 func (ex *ExamplePlugin) getJSONFilePath(claimUID string) string {
84 return filepath.Join(ex.cdiDir, fmt.Sprintf("%s-%s.json", ex.driverName, claimUID))
85 }
86
87
88
89 type FileOperations struct {
90
91 Create func(name string, content []byte) error
92
93
94
95 Remove func(name string) error
96
97
98
99
100 NumResourceInstances int
101 }
102
103
104 func StartPlugin(ctx context.Context, cdiDir, driverName string, nodeName string, fileOps FileOperations, opts ...kubeletplugin.Option) (*ExamplePlugin, error) {
105 logger := klog.FromContext(ctx)
106 if fileOps.Create == nil {
107 fileOps.Create = func(name string, content []byte) error {
108 return os.WriteFile(name, content, os.FileMode(0644))
109 }
110 }
111 if fileOps.Remove == nil {
112 fileOps.Remove = func(name string) error {
113 if err := os.Remove(name); err != nil && !os.IsNotExist(err) {
114 return err
115 }
116 return nil
117 }
118 }
119 ex := &ExamplePlugin{
120 stopCh: ctx.Done(),
121 logger: logger,
122 fileOps: fileOps,
123 cdiDir: cdiDir,
124 driverName: driverName,
125 nodeName: nodeName,
126 prepared: make(map[ClaimID]bool),
127 }
128
129 opts = append(opts,
130 kubeletplugin.Logger(logger),
131 kubeletplugin.DriverName(driverName),
132 kubeletplugin.GRPCInterceptor(ex.recordGRPCCall),
133 kubeletplugin.GRPCStreamInterceptor(ex.recordGRPCStream),
134 )
135 d, err := kubeletplugin.Start(ex, opts...)
136 if err != nil {
137 return nil, fmt.Errorf("start kubelet plugin: %w", err)
138 }
139 ex.d = d
140
141 return ex, nil
142 }
143
144
145 func (ex *ExamplePlugin) Stop() {
146 ex.d.Stop()
147 }
148
149 func (ex *ExamplePlugin) IsRegistered() bool {
150 status := ex.d.RegistrationStatus()
151 if status == nil {
152 return false
153 }
154 return status.PluginRegistered
155 }
156
157
158
159 func (ex *ExamplePlugin) Block() {
160 ex.block = true
161 }
162
163
164
165
166
167 func (ex *ExamplePlugin) NodePrepareResource(ctx context.Context, req *drapbv1alpha2.NodePrepareResourceRequest) (*drapbv1alpha2.NodePrepareResourceResponse, error) {
168 logger := klog.FromContext(ctx)
169
170
171
172 if ex.block {
173 <-ctx.Done()
174 return nil, ctx.Err()
175 }
176
177
178 var p parameters
179 switch len(req.StructuredResourceHandle) {
180 case 0:
181
182 if err := json.Unmarshal([]byte(req.ResourceHandle), &p); err != nil {
183 return nil, fmt.Errorf("unmarshal resource handle: %w", err)
184 }
185 case 1:
186
187 handle := req.StructuredResourceHandle[0]
188 if handle == nil {
189 return nil, errors.New("unexpected nil StructuredResourceHandle")
190 }
191 p.NodeName = handle.NodeName
192 if err := extractParameters(handle.VendorClassParameters, &p.EnvVars, "admin"); err != nil {
193 return nil, err
194 }
195 if err := extractParameters(handle.VendorClaimParameters, &p.EnvVars, "user"); err != nil {
196 return nil, err
197 }
198 for _, result := range handle.Results {
199 if err := extractParameters(result.VendorRequestParameters, &p.EnvVars, "user"); err != nil {
200 return nil, err
201 }
202 }
203 default:
204
205 return nil, fmt.Errorf("invalid length of NodePrepareResourceRequest.StructuredResourceHandle: %d", len(req.StructuredResourceHandle))
206 }
207
208
209 if p.NodeName != "" && ex.nodeName != "" && p.NodeName != ex.nodeName {
210 return nil, fmt.Errorf("claim was allocated for %q, cannot be prepared on %q", p.NodeName, ex.nodeName)
211 }
212
213
214 envs := []string{}
215 for key, val := range p.EnvVars {
216 envs = append(envs, key+"="+val)
217 }
218
219 deviceName := "claim-" + req.ClaimUid
220 vendor := ex.driverName
221 class := "test"
222 spec := &spec{
223 Version: "0.3.0",
224 Kind: vendor + "/" + class,
225
226
227 Devices: []device{
228 {
229 Name: deviceName,
230 ContainerEdits: containerEdits{
231 Env: envs,
232 },
233 },
234 },
235 }
236 filePath := ex.getJSONFilePath(req.ClaimUid)
237 buffer, err := json.Marshal(spec)
238 if err != nil {
239 return nil, fmt.Errorf("marshal spec: %w", err)
240 }
241 if err := ex.fileOps.Create(filePath, buffer); err != nil {
242 return nil, fmt.Errorf("failed to write CDI file %v", err)
243 }
244
245 dev := vendor + "/" + class + "=" + deviceName
246 resp := &drapbv1alpha2.NodePrepareResourceResponse{CdiDevices: []string{dev}}
247
248 ex.mutex.Lock()
249 defer ex.mutex.Unlock()
250 ex.prepared[ClaimID{Name: req.ClaimName, UID: req.ClaimUid}] = true
251
252 logger.V(3).Info("CDI file created", "path", filePath, "device", dev)
253 return resp, nil
254 }
255
256 func extractParameters(parameters runtime.RawExtension, env *map[string]string, kind string) error {
257 if len(parameters.Raw) == 0 {
258 return nil
259 }
260 var data map[string]string
261 if err := json.Unmarshal(parameters.Raw, &data); err != nil {
262 return fmt.Errorf("decoding %s parameters: %v", kind, err)
263 }
264 if len(data) > 0 && *env == nil {
265 *env = make(map[string]string)
266 }
267 for key, value := range data {
268 (*env)[kind+"_"+key] = value
269 }
270 return nil
271 }
272
273 func (ex *ExamplePlugin) NodePrepareResources(ctx context.Context, req *drapbv1alpha3.NodePrepareResourcesRequest) (*drapbv1alpha3.NodePrepareResourcesResponse, error) {
274 resp := &drapbv1alpha3.NodePrepareResourcesResponse{
275 Claims: make(map[string]*drapbv1alpha3.NodePrepareResourceResponse),
276 }
277 for _, claimReq := range req.Claims {
278 claimResp, err := ex.NodePrepareResource(ctx, &drapbv1alpha2.NodePrepareResourceRequest{
279 Namespace: claimReq.Namespace,
280 ClaimName: claimReq.Name,
281 ClaimUid: claimReq.Uid,
282 ResourceHandle: claimReq.ResourceHandle,
283 StructuredResourceHandle: claimReq.StructuredResourceHandle,
284 })
285 if err != nil {
286 resp.Claims[claimReq.Uid] = &drapbv1alpha3.NodePrepareResourceResponse{
287 Error: err.Error(),
288 }
289 } else {
290 resp.Claims[claimReq.Uid] = &drapbv1alpha3.NodePrepareResourceResponse{
291 CDIDevices: claimResp.CdiDevices,
292 }
293 }
294 }
295 return resp, nil
296 }
297
298
299
300
301 func (ex *ExamplePlugin) NodeUnprepareResource(ctx context.Context, req *drapbv1alpha2.NodeUnprepareResourceRequest) (*drapbv1alpha2.NodeUnprepareResourceResponse, error) {
302 logger := klog.FromContext(ctx)
303
304
305
306 if ex.block {
307 <-ctx.Done()
308 return nil, ctx.Err()
309 }
310
311 filePath := ex.getJSONFilePath(req.ClaimUid)
312 if err := ex.fileOps.Remove(filePath); err != nil {
313 return nil, fmt.Errorf("error removing CDI file: %w", err)
314 }
315 logger.V(3).Info("CDI file removed", "path", filePath)
316
317 ex.mutex.Lock()
318 defer ex.mutex.Unlock()
319 delete(ex.prepared, ClaimID{Name: req.ClaimName, UID: req.ClaimUid})
320
321 return &drapbv1alpha2.NodeUnprepareResourceResponse{}, nil
322 }
323
324 func (ex *ExamplePlugin) NodeUnprepareResources(ctx context.Context, req *drapbv1alpha3.NodeUnprepareResourcesRequest) (*drapbv1alpha3.NodeUnprepareResourcesResponse, error) {
325 resp := &drapbv1alpha3.NodeUnprepareResourcesResponse{
326 Claims: make(map[string]*drapbv1alpha3.NodeUnprepareResourceResponse),
327 }
328 for _, claimReq := range req.Claims {
329 _, err := ex.NodeUnprepareResource(ctx, &drapbv1alpha2.NodeUnprepareResourceRequest{
330 Namespace: claimReq.Namespace,
331 ClaimName: claimReq.Name,
332 ClaimUid: claimReq.Uid,
333 ResourceHandle: claimReq.ResourceHandle,
334 })
335 if err != nil {
336 resp.Claims[claimReq.Uid] = &drapbv1alpha3.NodeUnprepareResourceResponse{
337 Error: err.Error(),
338 }
339 } else {
340 resp.Claims[claimReq.Uid] = &drapbv1alpha3.NodeUnprepareResourceResponse{}
341 }
342 }
343 return resp, nil
344 }
345
346 func (ex *ExamplePlugin) NodeListAndWatchResources(req *drapbv1alpha3.NodeListAndWatchResourcesRequest, stream drapbv1alpha3.Node_NodeListAndWatchResourcesServer) error {
347 if ex.fileOps.NumResourceInstances < 0 {
348 ex.logger.Info("Sending no NodeResourcesResponse")
349 return status.New(codes.Unimplemented, "node resource support disabled").Err()
350 }
351
352 instances := make([]resourceapi.NamedResourcesInstance, ex.fileOps.NumResourceInstances)
353 for i := 0; i < ex.fileOps.NumResourceInstances; i++ {
354 instances[i].Name = fmt.Sprintf("instance-%d", i)
355 }
356 resp := &drapbv1alpha3.NodeListAndWatchResourcesResponse{
357 Resources: []*resourceapi.ResourceModel{
358 {
359 NamedResources: &resourceapi.NamedResourcesResources{
360 Instances: instances,
361 },
362 },
363 },
364 }
365
366 ex.logger.Info("Sending NodeListAndWatchResourcesResponse", "response", resp)
367 if err := stream.Send(resp); err != nil {
368 return err
369 }
370
371
372
373 <-ex.stopCh
374 ex.logger.Info("Done sending NodeListAndWatchResourcesResponse, closing stream")
375
376 return nil
377 }
378
379 func (ex *ExamplePlugin) GetPreparedResources() []ClaimID {
380 ex.mutex.Lock()
381 defer ex.mutex.Unlock()
382 var prepared []ClaimID
383 for claimID := range ex.prepared {
384 prepared = append(prepared, claimID)
385 }
386 return prepared
387 }
388
389 func (ex *ExamplePlugin) recordGRPCCall(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
390 call := GRPCCall{
391 FullMethod: info.FullMethod,
392 Request: req,
393 }
394 ex.mutex.Lock()
395 ex.gRPCCalls = append(ex.gRPCCalls, call)
396 index := len(ex.gRPCCalls) - 1
397 ex.mutex.Unlock()
398
399
400 call.Response, call.Err = handler(ctx, req)
401
402 ex.mutex.Lock()
403 ex.gRPCCalls[index] = call
404 ex.mutex.Unlock()
405
406 return call.Response, call.Err
407 }
408
409 func (ex *ExamplePlugin) recordGRPCStream(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
410 call := GRPCCall{
411 FullMethod: info.FullMethod,
412 }
413 ex.mutex.Lock()
414 ex.gRPCCalls = append(ex.gRPCCalls, call)
415 index := len(ex.gRPCCalls) - 1
416 ex.mutex.Unlock()
417
418
419 call.Err = handler(srv, stream)
420
421 ex.mutex.Lock()
422 ex.gRPCCalls[index] = call
423 ex.mutex.Unlock()
424
425 return call.Err
426 }
427
428 func (ex *ExamplePlugin) GetGRPCCalls() []GRPCCall {
429 ex.mutex.Lock()
430 defer ex.mutex.Unlock()
431
432
433
434
435 calls := make([]GRPCCall, 0, len(ex.gRPCCalls))
436 calls = append(calls, ex.gRPCCalls...)
437 return calls
438 }
439
View as plain text