1
2
3 package main
4
5 import (
6 "context"
7 "fmt"
8 "io"
9 "net"
10 "os"
11 "strings"
12 "time"
13 "unsafe"
14
15 "github.com/Microsoft/go-winio"
16 "github.com/containerd/containerd/runtime/v2/task"
17 "github.com/containerd/ttrpc"
18 "github.com/containerd/typeurl"
19 "github.com/gogo/protobuf/proto"
20 "github.com/gogo/protobuf/types"
21 "github.com/pkg/errors"
22 "github.com/sirupsen/logrus"
23 "github.com/urfave/cli"
24 "golang.org/x/sys/windows"
25
26 runhcsopts "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options"
27 "github.com/Microsoft/hcsshim/internal/extendedtask"
28 hcslog "github.com/Microsoft/hcsshim/internal/log"
29 "github.com/Microsoft/hcsshim/internal/shimdiag"
30 "github.com/Microsoft/hcsshim/pkg/octtrpc"
31 )
32
33 var svc *service
34
35 var serveCommand = cli.Command{
36 Name: "serve",
37 Hidden: true,
38 SkipArgReorder: true,
39 Flags: []cli.Flag{
40 cli.StringFlag{
41 Name: "socket",
42 Usage: "the socket path to serve",
43 },
44 cli.BoolFlag{
45 Name: "is-sandbox",
46 Usage: "is the task id a Kubernetes sandbox id",
47 },
48 },
49 Action: func(ctx *cli.Context) error {
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 var lerrs chan error
72
73
74 shimOpts := &runhcsopts.Options{
75 Debug: false,
76 DebugType: runhcsopts.Options_NPIPE,
77 }
78
79
80 newShimOpts, err := readOptions(os.Stdin)
81 if err != nil {
82 return errors.Wrap(err, "failed to read shim options from stdin")
83 } else if newShimOpts != nil {
84
85 shimOpts = newShimOpts
86 }
87
88 if shimOpts.Debug && shimOpts.LogLevel != "" {
89 logrus.Warning("Both Debug and LogLevel specified, Debug will be overridden")
90 }
91
92
93
94 if shimOpts.Debug {
95 logrus.SetLevel(logrus.DebugLevel)
96 }
97
98
99
100 if shimOpts.LogLevel != "" {
101 lvl, err := logrus.ParseLevel(shimOpts.LogLevel)
102 if err != nil {
103 return errors.Wrapf(err, "failed to parse shim log level %q", shimOpts.LogLevel)
104 }
105 logrus.SetLevel(lvl)
106 }
107
108 switch shimOpts.DebugType {
109 case runhcsopts.Options_NPIPE:
110 logrus.SetFormatter(&logrus.TextFormatter{
111 TimestampFormat: hcslog.TimeFormat,
112 FullTimestamp: true,
113 })
114
115
116
117
118
119
120 const logAddrFmt = "\\\\.\\pipe\\containerd-shim-%s-%s-log"
121 logl, err := winio.ListenPipe(fmt.Sprintf(logAddrFmt, namespaceFlag, idFlag), nil)
122 if err != nil {
123 return err
124 }
125 defer logl.Close()
126
127 lerrs = make(chan error, 1)
128 go func() {
129 var cur net.Conn
130 for {
131
132
133
134
135
136
137 new, err := logl.Accept()
138 if err != nil {
139 lerrs <- err
140 return
141 }
142 if cur != nil {
143 cur.Close()
144 }
145 cur = new
146
147
148
149 logrus.SetOutput(cur)
150 }
151 }()
152
153
154 case runhcsopts.Options_FILE:
155 panic("file log output mode is not supported")
156 case runhcsopts.Options_ETW:
157 logrus.SetFormatter(nopFormatter{})
158 logrus.SetOutput(io.Discard)
159 }
160
161 os.Stdin.Close()
162
163
164 if shimOpts.ScrubLogs {
165 hcslog.SetScrubbing(true)
166 }
167
168
169
170 cli.ErrWriter = os.Stdout
171
172 socket := ctx.String("socket")
173 if !strings.HasPrefix(socket, `\\.\pipe`) {
174 return errors.New("socket is required to be pipe address")
175 }
176
177 ttrpcAddress := os.Getenv(ttrpcAddressEnv)
178 ttrpcEventPublisher, err := newEventPublisher(ttrpcAddress, namespaceFlag)
179 if err != nil {
180 return err
181 }
182 defer func() {
183 if err != nil {
184 ttrpcEventPublisher.close()
185 }
186 }()
187
188
189 svc, err = NewService(WithEventPublisher(ttrpcEventPublisher),
190 WithTID(idFlag),
191 WithIsSandbox(ctx.Bool("is-sandbox")))
192 if err != nil {
193 return fmt.Errorf("failed to create new service: %w", err)
194 }
195
196 s, err := ttrpc.NewServer(ttrpc.WithUnaryServerInterceptor(octtrpc.ServerInterceptor()))
197 if err != nil {
198 return err
199 }
200 defer s.Close()
201 task.RegisterTaskService(s, svc)
202 shimdiag.RegisterShimDiagService(s, svc)
203 extendedtask.RegisterExtendedTaskService(s, svc)
204
205 sl, err := winio.ListenPipe(socket, nil)
206 if err != nil {
207 return err
208 }
209 defer sl.Close()
210
211 serrs := make(chan error, 1)
212 defer close(serrs)
213 go func() {
214
215
216
217 if err := trapClosedConnErr(s.Serve(context.Background(), sl)); err != nil {
218 logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure")
219 serrs <- err
220 return
221 }
222 serrs <- nil
223 }()
224
225 select {
226 case err := <-lerrs:
227 return err
228 case err := <-serrs:
229 return err
230 case <-time.After(2 * time.Millisecond):
231
232
233
234
235
236
237
238
239
240
241
242 os.Stdout.Close()
243 }
244
245
246 select {
247 case err = <-serrs:
248
249 case <-svc.Done():
250 if !svc.gracefulShutdown {
251
252
253 return nil
254 }
255
256 sctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
257 defer cancel()
258 err = s.Shutdown(sctx)
259 }
260
261 return err
262 },
263 }
264
265 func trapClosedConnErr(err error) error {
266 if err == nil || strings.Contains(err.Error(), "use of closed network connection") {
267 return nil
268 }
269 return err
270 }
271
272
273
274 func readOptions(r io.Reader) (*runhcsopts.Options, error) {
275 d, err := io.ReadAll(r)
276 if err != nil {
277 return nil, errors.Wrap(err, "failed to read input")
278 }
279 if len(d) > 0 {
280 var a types.Any
281 if err := proto.Unmarshal(d, &a); err != nil {
282 return nil, errors.Wrap(err, "failed unmarshalling into Any")
283 }
284 v, err := typeurl.UnmarshalAny(&a)
285 if err != nil {
286 return nil, errors.Wrap(err, "failed unmarshalling by typeurl")
287 }
288 return v.(*runhcsopts.Options), nil
289 }
290 return nil, nil
291 }
292
293
294
295 func createEvent(event string) (windows.Handle, error) {
296 ev, _ := windows.UTF16PtrFromString(event)
297 sd, err := windows.SecurityDescriptorFromString("D:P(A;;GA;;;BA)(A;;GA;;;SY)")
298 if err != nil {
299 return 0, errors.Wrapf(err, "failed to get security descriptor for event '%s'", event)
300 }
301 var sa windows.SecurityAttributes
302 sa.Length = uint32(unsafe.Sizeof(sa))
303 sa.InheritHandle = 1
304 sa.SecurityDescriptor = sd
305 h, err := windows.CreateEvent(&sa, 0, 0, ev)
306 if h == 0 || err != nil {
307 return 0, errors.Wrapf(err, "failed to create event '%s'", event)
308 }
309 return h, nil
310 }
311
312
313
314 func setupDebuggerEvent() {
315 if os.Getenv("CONTAINERD_SHIM_RUNHCS_V1_WAIT_DEBUGGER") == "" {
316 return
317 }
318 event := "Global\\debugger-" + fmt.Sprint(os.Getpid())
319 handle, err := createEvent(event)
320 if err != nil {
321 return
322 }
323 logrus.WithField("event", event).Info("Halting until signalled")
324 _, _ = windows.WaitForSingleObject(handle, windows.INFINITE)
325 }
326
View as plain text