1
2
3 package gcs
4
5 import (
6 "bufio"
7 "bytes"
8 "context"
9 "encoding/binary"
10 "encoding/json"
11 "errors"
12 "fmt"
13 "io"
14 "net"
15 "sync"
16 "time"
17
18 "github.com/sirupsen/logrus"
19 "golang.org/x/sys/windows"
20
21 "github.com/Microsoft/hcsshim/internal/log"
22 )
23
24 const (
25 hdrSize = 16
26 hdrOffType = 0
27 hdrOffSize = 4
28 hdrOffID = 8
29
30
31
32
33 maxMsgSize = 0x10000
34 )
35
36 type requestMessage interface {
37 Base() *requestBase
38 }
39
40 type responseMessage interface {
41 Base() *responseBase
42 }
43
44
45 type rpc struct {
46 proc rpcProc
47 id int64
48 req requestMessage
49 resp responseMessage
50 brdgErr error
51 ch chan struct{}
52 }
53
54
55
56 type bridge struct {
57
58 Timeout time.Duration
59
60 mu sync.Mutex
61 nextID int64
62 rpcs map[int64]*rpc
63 conn io.ReadWriteCloser
64 rpcCh chan *rpc
65 notify notifyFunc
66 closed bool
67 log *logrus.Entry
68 brdgErr error
69 waitCh chan struct{}
70 }
71
72 var errBridgeClosed = fmt.Errorf("bridge closed: %w", net.ErrClosed)
73
74 const (
75
76 bridgeFailureTimeout = time.Minute * 5
77 )
78
79 type notifyFunc func(*containerNotification) error
80
81
82
83
84 func newBridge(conn io.ReadWriteCloser, notify notifyFunc, log *logrus.Entry) *bridge {
85 return &bridge{
86 conn: conn,
87 rpcs: make(map[int64]*rpc),
88 rpcCh: make(chan *rpc),
89 waitCh: make(chan struct{}),
90 notify: notify,
91 log: log,
92 Timeout: bridgeFailureTimeout,
93 }
94 }
95
96
97 func (brdg *bridge) Start() {
98 go brdg.recvLoopRoutine()
99 go brdg.sendLoop()
100 }
101
102
103
104 func (brdg *bridge) kill(err error) {
105 brdg.mu.Lock()
106 if brdg.closed {
107 brdg.mu.Unlock()
108 if err != nil {
109 brdg.log.WithError(err).Warn("bridge error, already terminated")
110 }
111 return
112 }
113 brdg.closed = true
114 brdg.mu.Unlock()
115 brdg.brdgErr = err
116 if err != nil {
117 brdg.log.WithError(err).Error("bridge forcibly terminating")
118 } else {
119 brdg.log.Debug("bridge terminating")
120 }
121 brdg.conn.Close()
122 close(brdg.waitCh)
123 }
124
125
126
127 func (brdg *bridge) Close() error {
128 brdg.kill(nil)
129 return brdg.brdgErr
130 }
131
132
133
134 func (brdg *bridge) Wait() error {
135 <-brdg.waitCh
136 return brdg.brdgErr
137 }
138
139
140
141
142 func (brdg *bridge) AsyncRPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage) (*rpc, error) {
143 call := &rpc{
144 ch: make(chan struct{}),
145 proc: proc,
146 req: req,
147 resp: resp,
148 }
149 if err := ctx.Err(); err != nil {
150 return nil, err
151 }
152
153 select {
154 case brdg.rpcCh <- call:
155 return call, nil
156 case <-brdg.waitCh:
157 err := brdg.brdgErr
158 if err == nil {
159 err = errBridgeClosed
160 }
161 return nil, err
162 case <-ctx.Done():
163 return nil, ctx.Err()
164 }
165 }
166
167 func (call *rpc) complete(err error) {
168 call.brdgErr = err
169 close(call.ch)
170 }
171
172 type rpcError struct {
173 result int32
174 message string
175 }
176
177 func (err *rpcError) Error() string {
178 msg := err.message
179 if msg == "" {
180 msg = windows.Errno(err.result).Error()
181 }
182 return "guest RPC failure: " + msg
183 }
184
185 func (err *rpcError) Unwrap() error {
186 return windows.Errno(err.result)
187 }
188
189
190
191 func (call *rpc) Err() error {
192 if call.brdgErr != nil {
193 return call.brdgErr
194 }
195 resp := call.resp.Base()
196 if resp.Result == 0 {
197 return nil
198 }
199 return &rpcError{result: resp.Result, message: resp.ErrorMessage}
200 }
201
202
203 func (call *rpc) Done() bool {
204 select {
205 case <-call.ch:
206 return true
207 default:
208 return false
209 }
210 }
211
212
213 func (call *rpc) Wait() {
214 <-call.ch
215 }
216
217
218
219
220
221
222
223 func (brdg *bridge) RPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage, allowCancel bool) error {
224 call, err := brdg.AsyncRPC(ctx, proc, req, resp)
225 if err != nil {
226 return err
227 }
228 var ctxDone <-chan struct{}
229 if allowCancel {
230
231 ctxDone = ctx.Done()
232 }
233 t := time.NewTimer(brdg.Timeout)
234 defer t.Stop()
235 select {
236 case <-call.ch:
237 return call.Err()
238 case <-ctxDone:
239 brdg.log.WithField("reason", ctx.Err()).Warn("ignoring response to bridge message")
240 return ctx.Err()
241 case <-t.C:
242 brdg.kill(errors.New("message timeout"))
243 <-call.ch
244 return call.Err()
245 }
246 }
247
248 func (brdg *bridge) recvLoopRoutine() {
249 brdg.kill(brdg.recvLoop())
250
251 brdg.mu.Lock()
252 rpcs := brdg.rpcs
253 brdg.rpcs = nil
254 brdg.mu.Unlock()
255 for _, call := range rpcs {
256 call.complete(errBridgeClosed)
257 }
258 }
259
260 func readMessage(r io.Reader) (int64, msgType, []byte, error) {
261 var h [hdrSize]byte
262 _, err := io.ReadFull(r, h[:])
263 if err != nil {
264 return 0, 0, nil, err
265 }
266 typ := msgType(binary.LittleEndian.Uint32(h[hdrOffType:]))
267 n := binary.LittleEndian.Uint32(h[hdrOffSize:])
268 id := int64(binary.LittleEndian.Uint64(h[hdrOffID:]))
269 if n < hdrSize || n > maxMsgSize {
270 return 0, 0, nil, fmt.Errorf("invalid message size %d", n)
271 }
272 n -= hdrSize
273 b := make([]byte, n)
274 _, err = io.ReadFull(r, b)
275 if err != nil {
276 if err == io.EOF {
277 err = io.ErrUnexpectedEOF
278 }
279 return 0, 0, nil, err
280 }
281 return id, typ, b, nil
282 }
283
284 func isLocalDisconnectError(err error) bool {
285 return errors.Is(err, windows.WSAECONNABORTED)
286 }
287
288 func (brdg *bridge) recvLoop() error {
289 br := bufio.NewReader(brdg.conn)
290 for {
291 id, typ, b, err := readMessage(br)
292 if err != nil {
293 if err == io.EOF || isLocalDisconnectError(err) {
294 return nil
295 }
296 return fmt.Errorf("bridge read failed: %s", err)
297 }
298 brdg.log.WithFields(logrus.Fields{
299 "payload": string(b),
300 "type": typ.String(),
301 "message-id": id}).Debug("bridge receive")
302 switch typ & msgTypeMask {
303 case msgTypeResponse:
304
305 brdg.mu.Lock()
306 call := brdg.rpcs[id]
307 delete(brdg.rpcs, id)
308 brdg.mu.Unlock()
309 if call == nil {
310 return fmt.Errorf("bridge received unknown rpc response for id %d, type %s", id, typ)
311 }
312 err := json.Unmarshal(b, call.resp)
313 if err != nil {
314 err = fmt.Errorf("bridge response unmarshal failed: %s", err)
315 } else if resp := call.resp.Base(); resp.Result != 0 {
316 for _, rec := range resp.ErrorRecords {
317 brdg.log.WithFields(logrus.Fields{
318 "message-id": id,
319 "result": rec.Result,
320 "result-message": windows.Errno(rec.Result).Error(),
321 "error-message": rec.Message,
322 "stack": rec.StackTrace,
323 "module": rec.ModuleName,
324 "file": rec.FileName,
325 "line": rec.Line,
326 "function": rec.FunctionName,
327 }).Error("bridge RPC error record")
328 }
329 }
330 call.complete(err)
331 if err != nil {
332 return err
333 }
334
335 case msgTypeNotify:
336 if typ != notifyContainer|msgTypeNotify {
337 return fmt.Errorf("bridge received unknown unknown notification message %s", typ)
338 }
339 var ntf containerNotification
340 ntf.ResultInfo.Value = &json.RawMessage{}
341 err := json.Unmarshal(b, &ntf)
342 if err != nil {
343 return fmt.Errorf("bridge response unmarshal failed: %s", err)
344 }
345 err = brdg.notify(&ntf)
346 if err != nil {
347 return fmt.Errorf("bridge notification failed: %s", err)
348 }
349 default:
350 return fmt.Errorf("bridge received unknown unknown message type %s", typ)
351 }
352 }
353 }
354
355 func (brdg *bridge) sendLoop() {
356 var buf bytes.Buffer
357 enc := json.NewEncoder(&buf)
358 enc.SetEscapeHTML(false)
359 for {
360 select {
361 case <-brdg.waitCh:
362
363 return
364 case call := <-brdg.rpcCh:
365 err := brdg.sendRPC(&buf, enc, call)
366 if err != nil {
367 brdg.kill(err)
368 return
369 }
370 }
371 }
372 }
373
374 func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgType, id int64, req interface{}) error {
375
376 var h [hdrSize]byte
377 binary.LittleEndian.PutUint32(h[hdrOffType:], uint32(typ))
378 binary.LittleEndian.PutUint64(h[hdrOffID:], uint64(id))
379 buf.Write(h[:])
380 err := enc.Encode(req)
381 if err != nil {
382 return fmt.Errorf("bridge encode: %s", err)
383 }
384
385 binary.LittleEndian.PutUint32(buf.Bytes()[hdrOffSize:], uint32(buf.Len()))
386
387 if brdg.log.Logger.GetLevel() >= logrus.DebugLevel {
388 b := buf.Bytes()[hdrSize:]
389 switch typ {
390
391 case msgType(rpcCreate) | msgTypeRequest:
392 b, err = log.ScrubBridgeCreate(b)
393 case msgType(rpcExecuteProcess) | msgTypeRequest:
394 b, err = log.ScrubBridgeExecProcess(b)
395 }
396 if err != nil {
397 brdg.log.WithError(err).Warning("could not scrub bridge payload")
398 }
399 brdg.log.WithFields(logrus.Fields{
400 "payload": string(b),
401 "type": typ.String(),
402 "message-id": id}).Debug("bridge send")
403 }
404
405
406 _, err = buf.WriteTo(brdg.conn)
407 if err != nil {
408 return fmt.Errorf("bridge write: %s", err)
409 }
410 return nil
411 }
412
413 func (brdg *bridge) sendRPC(buf *bytes.Buffer, enc *json.Encoder, call *rpc) error {
414
415 brdg.mu.Lock()
416 if brdg.rpcs == nil {
417 brdg.mu.Unlock()
418 call.complete(errBridgeClosed)
419 return nil
420 }
421 id := brdg.nextID
422 call.id = id
423 brdg.rpcs[id] = call
424 brdg.nextID++
425 brdg.mu.Unlock()
426 typ := msgType(call.proc) | msgTypeRequest
427 err := brdg.writeMessage(buf, enc, typ, id, call.req)
428 if err != nil {
429
430 brdg.mu.Lock()
431 if brdg.rpcs[id] == nil {
432 call = nil
433 }
434 delete(brdg.rpcs, id)
435 brdg.mu.Unlock()
436 if call != nil {
437 call.complete(err)
438 } else {
439 brdg.log.WithError(err).Error("bridge write failed but call is already complete")
440 }
441 return err
442 }
443 return nil
444 }
445
View as plain text