...
1
2
3
4 package main
5
6 import (
7 "bytes"
8 "encoding/json"
9 "fmt"
10 "io"
11 "os"
12 "os/exec"
13 "text/template"
14
15 "github.com/Microsoft/hcsshim/cmd/gcstools/generichook"
16 specs "github.com/opencontainers/runtime-spec/specs-go"
17 "github.com/sirupsen/logrus"
18 )
19
20 func runGenericHook() error {
21 state, err := loadHookState(os.Stdin)
22 if err != nil {
23 return err
24 }
25
26 var (
27 tctx = newTemplateContext(state)
28 args = []string(os.Args[1:])
29 env = os.Environ()
30 )
31
32 parsedArgs, err := render(args, tctx)
33 if err != nil {
34 return err
35 }
36 parsedEnv, err := render(env, tctx)
37 if err != nil {
38 return err
39 }
40
41 hookCmd := exec.Command(parsedArgs[0], parsedArgs[1:]...)
42 hookCmd.Env = parsedEnv
43
44 out, err := hookCmd.CombinedOutput()
45 if err != nil {
46 return fmt.Errorf("failed to run nvidia cli tool with: %v, %v", string(out), err)
47 }
48
49 return nil
50 }
51
52 func logDebugFile(debugFilePath string) {
53 contents, err := os.ReadFile(debugFilePath)
54 if err != nil {
55 logrus.Errorf("failed to read debug file at %s: %v", debugFilePath, err)
56 return
57 }
58 numBytesInContents := len(contents)
59
60
61 maxLogSize := 8000
62 startBytes := 0
63 i := 0
64 for startBytes < numBytesInContents {
65 bytesLeft := len(contents[startBytes:])
66 chunkSize := maxLogSize
67 if bytesLeft < maxLogSize {
68 chunkSize = bytesLeft
69 }
70 stopBytes := startBytes + chunkSize
71 output := string(contents[startBytes:stopBytes])
72 logrus.WithField("output", output).Infof("%s debug part %d", debugFilePath, i)
73 i += 1
74 startBytes += chunkSize
75 }
76 }
77
78 func genericHookMain() {
79 if err := runGenericHook(); err != nil {
80 logrus.Errorf("error in generic hook: %s", err)
81 debugFileToRead := os.Getenv(generichook.LogDebugFileEnvKey)
82 if debugFileToRead != "" {
83 logDebugFile(debugFileToRead)
84 }
85 os.Exit(1)
86 }
87 os.Exit(0)
88 }
89
90
91
92 func loadHookState(r io.Reader) (*specs.State, error) {
93 var s *specs.State
94 if err := json.NewDecoder(r).Decode(&s); err != nil {
95 return nil, err
96 }
97 return s, nil
98 }
99
100 func newTemplateContext(state *specs.State) *templateContext {
101 t := &templateContext{
102 state: state,
103 }
104 t.funcs = template.FuncMap{
105 "id": t.id,
106 "pid": t.pid,
107 "annotation": t.annotation,
108 }
109 return t
110 }
111
112 type templateContext struct {
113 state *specs.State
114 funcs template.FuncMap
115 }
116
117 func (t *templateContext) id() string {
118 return t.state.ID
119 }
120
121 func (t *templateContext) pid() int {
122 return t.state.Pid
123 }
124
125 func (t *templateContext) annotation(k string) string {
126 return t.state.Annotations[k]
127 }
128
129 func render(templateList []string, tctx *templateContext) ([]string, error) {
130 buf := bytes.NewBuffer(nil)
131 for i, s := range templateList {
132 buf.Reset()
133
134 t, err := template.New("generic-hook").Funcs(tctx.funcs).Parse(s)
135 if err != nil {
136 return nil, err
137 }
138 if err := t.Execute(buf, tctx); err != nil {
139 return nil, err
140 }
141 templateList[i] = buf.String()
142 }
143 buf.Reset()
144 return templateList, nil
145 }
146
View as plain text