1
2
3
4
5
6
7 package mtest
8
9 import (
10 "errors"
11 "fmt"
12
13 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
14 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
15 )
16
17
18 type ReceivedMessage struct {
19 ResponseTo int32
20 RawMessage wiremessage.WireMessage
21 Response bsoncore.Document
22 }
23
24 type receivedMsgParseFn func([]byte) (*ReceivedMessage, error)
25
26 func getReceivedMessageParser(opcode wiremessage.OpCode) (receivedMsgParseFn, bool) {
27 switch opcode {
28 case wiremessage.OpReply:
29 return parseOpReply, true
30 case wiremessage.OpMsg:
31 return parseReceivedOpMsg, true
32 case wiremessage.OpCompressed:
33 return parseReceivedOpCompressed, true
34 default:
35 return nil, false
36 }
37 }
38
39 func parseReceivedMessage(wm []byte) (*ReceivedMessage, error) {
40
41 _, _, responseTo, opcode, remaining, ok := wiremessage.ReadHeader(wm)
42 if !ok {
43 return nil, errors.New("failed to read wiremessage header")
44 }
45
46 parseFn, ok := getReceivedMessageParser(opcode)
47 if !ok {
48 return nil, fmt.Errorf("unknown opcode: %s", opcode)
49 }
50 received, err := parseFn(remaining)
51 if err != nil {
52 return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %w", opcode, err)
53 }
54
55 received.ResponseTo = responseTo
56 received.RawMessage = wm
57 return received, nil
58 }
59
60 func parseOpReply(wm []byte) (*ReceivedMessage, error) {
61 var ok bool
62
63 if _, wm, ok = wiremessage.ReadReplyFlags(wm); !ok {
64 return nil, errors.New("failed to read reply flags")
65 }
66 if _, wm, ok = wiremessage.ReadReplyCursorID(wm); !ok {
67 return nil, errors.New("failed to read cursor ID")
68 }
69 if _, wm, ok = wiremessage.ReadReplyStartingFrom(wm); !ok {
70 return nil, errors.New("failed to read starting from")
71 }
72 if _, wm, ok = wiremessage.ReadReplyNumberReturned(wm); !ok {
73 return nil, errors.New("failed to read number returned")
74 }
75
76 var replyDocuments []bsoncore.Document
77 replyDocuments, wm, ok = wiremessage.ReadReplyDocuments(wm)
78 if !ok {
79 return nil, errors.New("failed to read reply documents")
80 }
81 if len(replyDocuments) == 0 {
82 return nil, errors.New("no documents in response")
83 }
84
85 rm := &ReceivedMessage{
86 Response: replyDocuments[0],
87 }
88 return rm, nil
89 }
90
91 func parseReceivedOpMsg(wm []byte) (*ReceivedMessage, error) {
92 var ok bool
93 var err error
94
95 if _, wm, ok = wiremessage.ReadMsgFlags(wm); !ok {
96 return nil, errors.New("failed to read flags")
97 }
98
99 if wm, err = assertMsgSectionType(wm, wiremessage.SingleDocument); err != nil {
100 return nil, fmt.Errorf("error verifying section type for response document: %w", err)
101 }
102
103 response, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm)
104 if !ok {
105 return nil, errors.New("failed to read response document")
106 }
107 rm := &ReceivedMessage{
108 Response: response,
109 }
110 return rm, nil
111 }
112
113 func parseReceivedOpCompressed(wm []byte) (*ReceivedMessage, error) {
114 originalOpcode, wm, err := parseOpCompressed(wm)
115 if err != nil {
116 return nil, err
117 }
118
119 parser, ok := getReceivedMessageParser(originalOpcode)
120 if !ok {
121 return nil, fmt.Errorf("unknown original opcode %v", originalOpcode)
122 }
123 return parser(wm)
124 }
125
View as plain text