...
1
2
3
4 package hcsv2
5
6 import (
7 "context"
8 "fmt"
9 "os"
10 "os/exec"
11 "strings"
12
13 oci "github.com/opencontainers/runtime-spec/specs-go"
14 "github.com/pkg/errors"
15
16 "github.com/Microsoft/hcsshim/cmd/gcstools/generichook"
17 "github.com/Microsoft/hcsshim/internal/guest/storage/pci"
18 "github.com/Microsoft/hcsshim/internal/guestpath"
19 "github.com/Microsoft/hcsshim/internal/hooks"
20 "github.com/Microsoft/hcsshim/pkg/annotations"
21 )
22
23 const nvidiaDebugFilePath = "/nvidia-container.log"
24
25 const nvidiaToolBinary = "nvidia-container-cli"
26
27
28
29 func addNvidiaDeviceHook(ctx context.Context, spec *oci.Spec) error {
30 genericHookBinary := "generichook"
31 genericHookPath, err := exec.LookPath(genericHookBinary)
32 if err != nil {
33 return errors.Wrapf(err, "failed to find %s for container device support", genericHookBinary)
34 }
35
36 debugOption := fmt.Sprintf("--debug=%s", nvidiaDebugFilePath)
37
38
39
40
41 args := []string{
42 genericHookPath,
43 nvidiaToolBinary,
44 debugOption,
45 "--load-kmods",
46 "--no-pivot",
47 "configure",
48 "--ldconfig=@/sbin/ldconfig",
49 }
50 if capabilities, ok := spec.Annotations[annotations.ContainerGPUCapabilities]; ok {
51 caps := strings.Split(capabilities, ",")
52 for _, c := range caps {
53 args = append(args, fmt.Sprintf("--%s", c))
54 }
55 }
56
57 for _, d := range spec.Windows.Devices {
58 switch d.IDType {
59 case "gpu":
60 busLocation, err := pci.FindDeviceBusLocationFromVMBusGUID(ctx, d.ID)
61 if err != nil {
62 return errors.Wrapf(err, "failed to find nvidia gpu bus location")
63 }
64 args = append(args, fmt.Sprintf("--device=%s", busLocation))
65 }
66 }
67
68
69 args = append(args, "--no-cgroups", "--pid={{pid}}", spec.Root.Path)
70
71 hookLogDebugFileEnvOpt := fmt.Sprintf("%s=%s", generichook.LogDebugFileEnvKey, nvidiaDebugFilePath)
72 hookEnv := append(updateEnvWithNvidiaVariables(), hookLogDebugFileEnvOpt)
73 nvidiaHook := hooks.NewOCIHook(genericHookPath, args, hookEnv)
74 return hooks.AddOCIHook(spec, hooks.CreateRuntime, nvidiaHook)
75 }
76
77
78
79
80 func getNvidiaDriversUsrLibPath() string {
81 return fmt.Sprintf("%s/content/usr/lib", guestpath.LCOWNvidiaMountPath)
82 }
83
84
85
86
87 func getNvidiaDriverUsrBinPath() string {
88 return fmt.Sprintf("%s/content/usr/bin", guestpath.LCOWNvidiaMountPath)
89 }
90
91
92 func updateEnvWithNvidiaVariables() []string {
93 env := updatePathEnv(getNvidiaDriverUsrBinPath())
94
95
96 env = append(env, "NVC_INSECURE_MODE=1")
97 return env
98 }
99
100
101 func updatePathEnv(dirs ...string) []string {
102 pathPrefix := "PATH="
103 additionalDirs := strings.Join(dirs, ":")
104 env := os.Environ()
105 for i, v := range env {
106 if strings.HasPrefix(v, pathPrefix) {
107 newPath := fmt.Sprintf("%s:%s", v, additionalDirs)
108 env[i] = newPath
109 return env
110 }
111 }
112 return append(env, fmt.Sprintf("PATH=%s", additionalDirs))
113 }
114
View as plain text