1
2
3
4
5
6
7 package drivertest
8
9 import (
10 "context"
11 "errors"
12
13 "go.mongodb.org/mongo-driver/mongo/address"
14 "go.mongodb.org/mongo-driver/mongo/description"
15 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
16 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
17 )
18
19
20
21 type ChannelConn struct {
22 WriteErr error
23 Written chan []byte
24 ReadResp chan []byte
25 ReadErr chan error
26 Desc description.Server
27 }
28
29
30 func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error {
31
32 b := make([]byte, len(wm))
33 copy(b, wm)
34 select {
35 case c.Written <- b:
36 case <-ctx.Done():
37 return ctx.Err()
38 default:
39 c.WriteErr = errors.New("could not write wiremessage to written channel")
40 }
41 return c.WriteErr
42 }
43
44
45 func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) {
46 var wm []byte
47 var err error
48 select {
49 case wm = <-c.ReadResp:
50 case err = <-c.ReadErr:
51 case <-ctx.Done():
52 err = ctx.Err()
53 }
54 return wm, err
55 }
56
57
58 func (c *ChannelConn) Description() description.Server { return c.Desc }
59
60
61 func (c *ChannelConn) Close() error {
62 return nil
63 }
64
65
66 func (c *ChannelConn) ID() string {
67 return "faked"
68 }
69
70
71
72 func (c *ChannelConn) DriverConnectionID() uint64 {
73 return 0
74 }
75
76
77 func (c *ChannelConn) ServerConnectionID() *int64 {
78 serverConnectionID := int64(42)
79 return &serverConnectionID
80 }
81
82
83 func (c *ChannelConn) Address() address.Address { return address.Address("0.0.0.0") }
84
85
86 func (c *ChannelConn) Stale() bool {
87 return false
88 }
89
90
91 func MakeReply(doc bsoncore.Document) []byte {
92 var dst []byte
93 idx, dst := wiremessage.AppendHeaderStart(dst, 10, 9, wiremessage.OpReply)
94 dst = wiremessage.AppendReplyFlags(dst, 0)
95 dst = wiremessage.AppendReplyCursorID(dst, 0)
96 dst = wiremessage.AppendReplyStartingFrom(dst, 0)
97 dst = wiremessage.AppendReplyNumberReturned(dst, 1)
98 dst = append(dst, doc...)
99 return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
100 }
101
102
103 func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
104 var ok bool
105 _, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
106 if !ok {
107 return nil, errors.New("could not read header")
108 }
109 _, wm, ok = wiremessage.ReadQueryFlags(wm)
110 if !ok {
111 return nil, errors.New("could not read flags")
112 }
113 _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
114 if !ok {
115 return nil, errors.New("could not read fullCollectionName")
116 }
117 _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
118 if !ok {
119 return nil, errors.New("could not read numberToSkip")
120 }
121 _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
122 if !ok {
123 return nil, errors.New("could not read numberToReturn")
124 }
125
126 var query bsoncore.Document
127 query, wm, ok = wiremessage.ReadQueryQuery(wm)
128 if !ok {
129 return nil, errors.New("could not read query")
130 }
131 return query, nil
132 }
133
134
135 func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
136 var ok bool
137 _, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
138 if !ok {
139 return nil, errors.New("could not read header")
140 }
141
142 _, wm, ok = wiremessage.ReadMsgFlags(wm)
143 if !ok {
144 return nil, errors.New("could not read flags")
145 }
146 _, wm, ok = wiremessage.ReadMsgSectionType(wm)
147 if !ok {
148 return nil, errors.New("could not read section type")
149 }
150
151 cmdDoc, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm)
152 if !ok {
153 return nil, errors.New("could not read command document")
154 }
155 return cmdDoc, nil
156 }
157
View as plain text