1
2
3
4 package bridge
5
6 import (
7 "context"
8 "encoding/base64"
9 "encoding/binary"
10 "encoding/hex"
11 "encoding/json"
12 "fmt"
13 "io"
14 "os"
15 "strconv"
16 "sync"
17 "sync/atomic"
18 "time"
19
20 "github.com/pkg/errors"
21 "github.com/sirupsen/logrus"
22 "go.opencensus.io/trace"
23 "go.opencensus.io/trace/tracestate"
24
25 "github.com/Microsoft/hcsshim/internal/guest/gcserr"
26 "github.com/Microsoft/hcsshim/internal/guest/prot"
27 "github.com/Microsoft/hcsshim/internal/guest/runtime/hcsv2"
28 "github.com/Microsoft/hcsshim/internal/log"
29 "github.com/Microsoft/hcsshim/internal/oc"
30 )
31
32
33
34 func UnknownMessage(r *Request) (RequestResponse, error) {
35 return nil, gcserr.WrapHresult(errors.Errorf("bridge: function not supported, header type: %v", r.Header.Type), gcserr.HrNotImpl)
36 }
37
38
39
40 func UnknownMessageHandler() Handler {
41 return HandlerFunc(UnknownMessage)
42 }
43
44
45 type Handler interface {
46 ServeMsg(*Request) (RequestResponse, error)
47 }
48
49
50 type HandlerFunc func(*Request) (RequestResponse, error)
51
52
53 func (f HandlerFunc) ServeMsg(r *Request) (RequestResponse, error) {
54 return f(r)
55 }
56
57
58
59 type Mux struct {
60 mu sync.Mutex
61 m map[prot.MessageIdentifier]map[prot.ProtocolVersion]Handler
62 }
63
64
65 func NewBridgeMux() *Mux {
66 return &Mux{m: make(map[prot.MessageIdentifier]map[prot.ProtocolVersion]Handler)}
67 }
68
69
70 func (mux *Mux) Handle(id prot.MessageIdentifier, ver prot.ProtocolVersion, handler Handler) {
71 mux.mu.Lock()
72 defer mux.mu.Unlock()
73
74 if handler == nil {
75 panic("bridge: nil handler")
76 }
77
78 if _, ok := mux.m[id]; !ok {
79 mux.m[id] = make(map[prot.ProtocolVersion]Handler)
80 }
81
82 if _, ok := mux.m[id][ver]; ok {
83 logrus.WithFields(logrus.Fields{
84 "message-type": id.String(),
85 "protocol-version": ver,
86 }).Warn("opengcs::bridge - overwriting bridge handler")
87 }
88
89 mux.m[id][ver] = handler
90 }
91
92
93 func (mux *Mux) HandleFunc(id prot.MessageIdentifier, ver prot.ProtocolVersion, handler func(*Request) (RequestResponse, error)) {
94 if handler == nil {
95 panic("bridge: nil handler func")
96 }
97
98 mux.Handle(id, ver, HandlerFunc(handler))
99 }
100
101
102 func (mux *Mux) Handler(r *Request) Handler {
103 mux.mu.Lock()
104 defer mux.mu.Unlock()
105
106 if r == nil {
107 panic("bridge: nil request to handler")
108 }
109
110 var m map[prot.ProtocolVersion]Handler
111 var ok bool
112 if m, ok = mux.m[r.Header.Type]; !ok {
113 return UnknownMessageHandler()
114 }
115
116 var h Handler
117 if h, ok = m[r.Version]; !ok {
118 return UnknownMessageHandler()
119 }
120
121 return h
122 }
123
124
125
126 func (mux *Mux) ServeMsg(r *Request) (RequestResponse, error) {
127 h := mux.Handler(r)
128 return h.ServeMsg(r)
129 }
130
131
132 type Request struct {
133
134 Context context.Context
135
136
137 Header *prot.MessageHeader
138
139 ContainerID string
140
141 ActivityID string
142
143
144 Message []byte
145
146
147 Version prot.ProtocolVersion
148 }
149
150
151 type RequestResponse interface {
152 Base() *prot.MessageResponseBase
153 }
154
155 type bridgeResponse struct {
156
157 ctx context.Context
158 header *prot.MessageHeader
159 response interface{}
160 }
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175 type Bridge struct {
176
177 Handler Handler
178
179 EnableV4 bool
180
181
182
183 responseChan chan bridgeResponse
184
185 hostState *hcsv2.Host
186
187 quitChan chan bool
188
189 hasQuitPending uint32
190
191 protVer prot.ProtocolVersion
192 }
193
194
195
196
197 func (b *Bridge) AssignHandlers(mux *Mux, host *hcsv2.Host) {
198 b.hostState = host
199
200
201
202 if b.EnableV4 {
203 mux.HandleFunc(prot.ComputeSystemNegotiateProtocolV1, prot.PvInvalid, b.negotiateProtocolV2)
204 }
205
206 if b.EnableV4 {
207
208 mux.HandleFunc(prot.ComputeSystemStartV1, prot.PvV4, b.startContainerV2)
209 mux.HandleFunc(prot.ComputeSystemCreateV1, prot.PvV4, b.createContainerV2)
210 mux.HandleFunc(prot.ComputeSystemExecuteProcessV1, prot.PvV4, b.execProcessV2)
211 mux.HandleFunc(prot.ComputeSystemShutdownForcedV1, prot.PvV4, b.killContainerV2)
212 mux.HandleFunc(prot.ComputeSystemShutdownGracefulV1, prot.PvV4, b.shutdownContainerV2)
213 mux.HandleFunc(prot.ComputeSystemSignalProcessV1, prot.PvV4, b.signalProcessV2)
214 mux.HandleFunc(prot.ComputeSystemGetPropertiesV1, prot.PvV4, b.getPropertiesV2)
215 mux.HandleFunc(prot.ComputeSystemWaitForProcessV1, prot.PvV4, b.waitOnProcessV2)
216 mux.HandleFunc(prot.ComputeSystemResizeConsoleV1, prot.PvV4, b.resizeConsoleV2)
217 mux.HandleFunc(prot.ComputeSystemModifySettingsV1, prot.PvV4, b.modifySettingsV2)
218 mux.HandleFunc(prot.ComputeSystemDumpStacksV1, prot.PvV4, b.dumpStacksV2)
219 mux.HandleFunc(prot.ComputeSystemDeleteContainerStateV1, prot.PvV4, b.deleteContainerStateV2)
220 }
221 }
222
223
224
225
226 func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser) error {
227 requestChan := make(chan *Request)
228 requestErrChan := make(chan error)
229 b.responseChan = make(chan bridgeResponse)
230 responseErrChan := make(chan error)
231 b.quitChan = make(chan bool)
232
233 defer close(b.quitChan)
234 defer bridgeOut.Close()
235 defer close(responseErrChan)
236 defer close(b.responseChan)
237 defer close(requestChan)
238 defer close(requestErrChan)
239 defer bridgeIn.Close()
240
241
242 go func() {
243 var recverr error
244 for {
245 if atomic.LoadUint32(&b.hasQuitPending) == 0 {
246 header := &prot.MessageHeader{}
247 if err := binary.Read(bridgeIn, binary.LittleEndian, header); err != nil {
248 if err == io.ErrUnexpectedEOF || err == os.ErrClosed {
249 break
250 }
251 recverr = errors.Wrap(err, "bridge: failed reading message header")
252 break
253 }
254 message := make([]byte, header.Size-prot.MessageHeaderSize)
255 if _, err := io.ReadFull(bridgeIn, message); err != nil {
256 if err == io.ErrUnexpectedEOF || err == os.ErrClosed {
257 break
258 }
259 recverr = errors.Wrap(err, "bridge: failed reading message payload")
260 break
261 }
262
263 base := prot.MessageBase{}
264
265
266
267
268 _ = json.Unmarshal(message, &base)
269
270 var ctx context.Context
271 var span *trace.Span
272 if base.OpenCensusSpanContext != nil {
273 sc := trace.SpanContext{}
274 if bytes, err := hex.DecodeString(base.OpenCensusSpanContext.TraceID); err == nil {
275 copy(sc.TraceID[:], bytes)
276 }
277 if bytes, err := hex.DecodeString(base.OpenCensusSpanContext.SpanID); err == nil {
278 copy(sc.SpanID[:], bytes)
279 }
280 sc.TraceOptions = trace.TraceOptions(base.OpenCensusSpanContext.TraceOptions)
281 if base.OpenCensusSpanContext.Tracestate != "" {
282 if bytes, err := base64.StdEncoding.DecodeString(base.OpenCensusSpanContext.Tracestate); err == nil {
283 var entries []tracestate.Entry
284 if err := json.Unmarshal(bytes, &entries); err == nil {
285 if ts, err := tracestate.New(nil, entries...); err == nil {
286 sc.Tracestate = ts
287 }
288 }
289 }
290 }
291 ctx, span = oc.StartSpanWithRemoteParent(
292 context.Background(),
293 "opengcs::bridge::request",
294 sc,
295 oc.WithServerSpanKind,
296 )
297 } else {
298 ctx, span = oc.StartSpan(
299 context.Background(),
300 "opengcs::bridge::request",
301 oc.WithServerSpanKind,
302 )
303 }
304
305 span.AddAttributes(
306 trace.Int64Attribute("message-id", int64(header.ID)),
307 trace.StringAttribute("message-type", header.Type.String()),
308 trace.StringAttribute("activityID", base.ActivityID),
309 trace.StringAttribute("cid", base.ContainerID))
310
311 entry := log.G(ctx)
312 if entry.Logger.GetLevel() >= logrus.DebugLevel {
313 s := string(message)
314 switch header.Type {
315 case prot.ComputeSystemCreateV1:
316 b, err := log.ScrubBridgeCreate(message)
317 s = string(b)
318 if err != nil {
319 entry.WithError(err).Warning("could not scrub bridge payload")
320 }
321 }
322 entry.WithField("message", s).Debug("request read message")
323 }
324 requestChan <- &Request{
325 Context: ctx,
326 Header: header,
327 ContainerID: base.ContainerID,
328 ActivityID: base.ActivityID,
329 Message: message,
330 Version: b.protVer,
331 }
332 }
333 }
334 requestErrChan <- recverr
335 }()
336
337 go func() {
338 for req := range requestChan {
339 go func(r *Request) {
340 br := bridgeResponse{
341 ctx: r.Context,
342 header: &prot.MessageHeader{
343 Type: prot.GetResponseIdentifier(r.Header.Type),
344 ID: r.Header.ID,
345 },
346 }
347 resp, err := b.Handler.ServeMsg(r)
348 if resp == nil {
349 resp = &prot.MessageResponseBase{}
350 }
351 resp.Base().ActivityID = r.ActivityID
352 if err != nil {
353 span := trace.FromContext(r.Context)
354 if span != nil {
355 oc.SetSpanStatus(span, err)
356 }
357 setErrorForResponseBase(resp.Base(), err)
358 }
359 br.response = resp
360 b.responseChan <- br
361 }(req)
362 }
363 }()
364
365 go func() {
366 var resperr error
367 for resp := range b.responseChan {
368 responseBytes, err := json.Marshal(resp.response)
369 if err != nil {
370 resperr = errors.Wrapf(err, "bridge: failed to marshal JSON for response \"%v\"", resp.response)
371 break
372 }
373 resp.header.Size = uint32(len(responseBytes) + prot.MessageHeaderSize)
374 if err := binary.Write(bridgeOut, binary.LittleEndian, resp.header); err != nil {
375 resperr = errors.Wrap(err, "bridge: failed writing message header")
376 break
377 }
378
379 if _, err := bridgeOut.Write(responseBytes); err != nil {
380 resperr = errors.Wrap(err, "bridge: failed writing message payload")
381 break
382 }
383
384 s := trace.FromContext(resp.ctx)
385 if s != nil {
386 log.G(resp.ctx).WithField("message", string(responseBytes)).Debug("request write response")
387 s.End()
388 }
389 }
390 responseErrChan <- resperr
391 }()
392
393 select {
394 case err := <-requestErrChan:
395 return err
396 case err := <-responseErrChan:
397 return err
398 case <-b.quitChan:
399
400
401 atomic.StoreUint32(&b.hasQuitPending, 1)
402
403
404
405 var err error
406 select {
407 case err = <-requestErrChan:
408 case <-time.After(time.Second * 5):
409
410 if cerr := bridgeIn.Close(); cerr != nil {
411 err = errors.Wrap(cerr, "bridge: failed to close bridgeIn")
412 }
413 <-requestErrChan
414 }
415 <-responseErrChan
416 return err
417 }
418 }
419
420
421 func (b *Bridge) PublishNotification(n *prot.ContainerNotification) {
422 ctx, span := oc.StartSpan(context.Background(),
423 "opengcs::bridge::PublishNotification",
424 oc.WithClientSpanKind)
425 span.AddAttributes(trace.StringAttribute("notification", fmt.Sprintf("%+v", n)))
426
427
428
429 resp := bridgeResponse{
430 ctx: ctx,
431 header: &prot.MessageHeader{
432 Type: prot.ComputeSystemNotificationV1,
433 ID: 0,
434 },
435 response: n,
436 }
437 b.responseChan <- resp
438 }
439
440
441
442 func setErrorForResponseBase(response *prot.MessageResponseBase, errForResponse error) {
443 errorMessage := errForResponse.Error()
444 stackString := ""
445 fileName := ""
446 lineNumber := -1
447 functionName := ""
448 if stack := gcserr.BaseStackTrace(errForResponse); stack != nil {
449 bottomFrame := stack[0]
450 stackString = fmt.Sprintf("%+v", stack)
451 fileName = fmt.Sprintf("%s", bottomFrame)
452 lineNumberStr := fmt.Sprintf("%d", bottomFrame)
453 var err error
454 lineNumber, err = strconv.Atoi(lineNumberStr)
455 if err != nil {
456 logrus.WithFields(logrus.Fields{
457 "line-number": lineNumberStr,
458 logrus.ErrorKey: err,
459 }).Error("opengcs::bridge::setErrorForResponseBase - failed to parse line number, using -1 instead")
460 lineNumber = -1
461 }
462 functionName = fmt.Sprintf("%n", bottomFrame)
463 }
464 hresult, err := gcserr.GetHresult(errForResponse)
465 if err != nil {
466
467 hresult = gcserr.HrFail
468 }
469 response.Result = int32(hresult)
470 response.ErrorMessage = errorMessage
471 newRecord := prot.ErrorRecord{
472 Result: int32(hresult),
473 Message: errorMessage,
474 StackTrace: stackString,
475 ModuleName: "gcs",
476 FileName: fileName,
477 Line: uint32(lineNumber),
478 FunctionName: functionName,
479 }
480 response.ErrorRecords = append(response.ErrorRecords, newRecord)
481 }
482
View as plain text