...
1 package opts
2
3 import (
4 "encoding/csv"
5 "fmt"
6 "strconv"
7 "strings"
8
9 "github.com/docker/docker/api/types/container"
10 "github.com/pkg/errors"
11 )
12
13
14 type GpuOpts struct {
15 values []container.DeviceRequest
16 }
17
18 func parseCount(s string) (int, error) {
19 if s == "all" {
20 return -1, nil
21 }
22 i, err := strconv.Atoi(s)
23 return i, errors.Wrap(err, "count must be an integer")
24 }
25
26
27
28
29 func (o *GpuOpts) Set(value string) error {
30 csvReader := csv.NewReader(strings.NewReader(value))
31 fields, err := csvReader.Read()
32 if err != nil {
33 return err
34 }
35
36 req := container.DeviceRequest{}
37
38 seen := map[string]struct{}{}
39
40 for _, field := range fields {
41 key, val, withValue := strings.Cut(field, "=")
42 if _, ok := seen[key]; ok {
43 return fmt.Errorf("gpu request key '%s' can be specified only once", key)
44 }
45 seen[key] = struct{}{}
46
47 if !withValue {
48 seen["count"] = struct{}{}
49 req.Count, err = parseCount(key)
50 if err != nil {
51 return err
52 }
53 continue
54 }
55
56 switch key {
57 case "driver":
58 req.Driver = val
59 case "count":
60 req.Count, err = parseCount(val)
61 if err != nil {
62 return err
63 }
64 case "device":
65 req.DeviceIDs = strings.Split(val, ",")
66 case "capabilities":
67 req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")}
68 case "options":
69 r := csv.NewReader(strings.NewReader(val))
70 optFields, err := r.Read()
71 if err != nil {
72 return errors.Wrap(err, "failed to read gpu options")
73 }
74 req.Options = ConvertKVStringsToMap(optFields)
75 default:
76 return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
77 }
78 }
79
80 if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
81 req.Count = 1
82 }
83 if req.Options == nil {
84 req.Options = make(map[string]string)
85 }
86 if req.Capabilities == nil {
87 req.Capabilities = [][]string{{"gpu"}}
88 }
89
90 o.values = append(o.values, req)
91 return nil
92 }
93
94
95 func (o *GpuOpts) Type() string {
96 return "gpu-request"
97 }
98
99
100 func (o *GpuOpts) String() string {
101 gpus := []string{}
102 for _, gpu := range o.values {
103 gpus = append(gpus, fmt.Sprintf("%v", gpu))
104 }
105 return strings.Join(gpus, ", ")
106 }
107
108
109 func (o *GpuOpts) Value() []container.DeviceRequest {
110 return o.values
111 }
112
View as plain text