1 package ldap
2
3 import (
4 "bytes"
5 "errors"
6 "io"
7 "net"
8 "net/http"
9 "net/http/httptest"
10 "runtime"
11 "sync"
12 "testing"
13 "time"
14
15 ber "github.com/go-asn1-ber/asn1-ber"
16 )
17
18 func TestUnresponsiveConnection(t *testing.T) {
19
20 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21 }))
22 defer ts.Close()
23 c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
24 if err != nil {
25 t.Fatalf("error connecting to localhost tcp: %v", err)
26 }
27
28
29 conn := NewConn(c, false)
30 conn.SetTimeout(time.Millisecond)
31 conn.Start()
32 defer conn.Close()
33
34
35 packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
36 packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
37 bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
38 bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
39 packet.AppendChild(bindRequest)
40
41
42 msgCtx, err := conn.sendMessage(packet)
43 if err != nil {
44 t.Fatalf("error sending message: %v", err)
45 }
46 defer conn.finishMessage(msgCtx)
47
48 packetResponse, ok := <-msgCtx.responses
49 if !ok {
50 t.Fatalf("no PacketResponse in response channel")
51 }
52 _, err = packetResponse.ReadPacket()
53 if err == nil {
54 t.Fatalf("expected timeout error")
55 }
56 if !IsErrorWithCode(err, ErrorNetwork) || err.(*Error).Err.Error() != "ldap: connection timed out" {
57 t.Fatalf("unexpected error: %v", err)
58 }
59 }
60
61 func TestRequestTimeoutDeadlock(t *testing.T) {
62
63 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64 }))
65 defer ts.Close()
66 c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
67 if err != nil {
68 t.Fatalf("error connecting to localhost tcp: %v", err)
69 }
70
71
72 conn := NewConn(c, false)
73 conn.Start()
74
75 n := 3
76 for i := 0; i < n; i++ {
77 go func() {
78 conn.SetTimeout(time.Millisecond)
79 }()
80 }
81
82
83
84 conn.Close()
85 }
86
87
88
89 func TestInvalidStateCloseDeadlock(t *testing.T) {
90
91 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
92 }))
93 defer ts.Close()
94 c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
95 if err != nil {
96 t.Fatalf("error connecting to localhost tcp: %v", err)
97 }
98
99
100 conn := NewConn(c, false)
101 conn.SetTimeout(time.Millisecond)
102
103
104
105 conn.Close()
106 }
107
108
109
110 func TestInvalidStateSendResponseDeadlock(t *testing.T) {
111
112 msgCtx := &messageContext{
113 id: 0,
114 done: make(chan struct{}),
115 responses: make(chan *PacketResponse),
116 }
117 msgCtx.sendResponse(&PacketResponse{}, time.Millisecond)
118 }
119
120
121
122 func TestFinishMessage(t *testing.T) {
123 ptc := newPacketTranslatorConn()
124 defer ptc.Close()
125
126 conn := NewConn(ptc, false)
127 conn.Start()
128
129
130
131
132 for i := 0; i < 5; i++ {
133 t.Logf("serial request %d", i)
134
135 msgCtx := testSendRequest(t, ptc, conn)
136 testReceiveResponse(t, ptc, msgCtx)
137
138
139 testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
140 t.Logf("serial request %d done", i)
141 }
142
143
144 var wg sync.WaitGroup
145 for i := 0; i < 5; i++ {
146 wg.Add(1)
147 go func(i int) {
148 defer wg.Done()
149 t.Logf("parallel request %d", i)
150
151 msgCtx := testSendRequest(t, ptc, conn)
152 testReceiveResponse(t, ptc, msgCtx)
153
154
155 testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
156 t.Logf("parallel request %d done", i)
157 }(i)
158 }
159 wg.Wait()
160
161
162
163 conn.Close()
164 }
165
166
167 func TestNilConnection(t *testing.T) {
168 var conn *Conn
169 _, err := conn.Search(&SearchRequest{})
170 if err != ErrNilConnection {
171 t.Fatalf("expected error to be ErrNilConnection, got %v", err)
172 }
173 }
174
175 func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
176 var msgID int64
177 runWithTimeout(t, time.Second, func() {
178 msgID = conn.nextMessageID()
179 })
180
181 requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
182 requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
183
184 var err error
185
186 runWithTimeout(t, time.Second, func() {
187 msgCtx, err = conn.sendMessage(requestPacket)
188 if err != nil {
189 t.Fatalf("unable to send request message: %s", err)
190 }
191 })
192
193
194
195 runWithTimeout(t, time.Second, func() {
196 if _, err = ptc.ReceiveRequest(); err != nil {
197 t.Fatalf("unable to receive request packet: %s", err)
198 }
199 })
200
201 return msgCtx
202 }
203
204 func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
205
206 responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
207 responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
208
209 runWithTimeout(t, time.Second, func() {
210 if err := ptc.SendResponse(responsePacket); err != nil {
211 t.Fatalf("unable to send response packet: %s", err)
212 }
213 })
214
215
216 runWithTimeout(t, time.Second, func() {
217 if _, ok := <-msgCtx.responses; !ok {
218 t.Fatal("response channel closed")
219 }
220 })
221 }
222
223 func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
224
225 responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
226 responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
227
228
229
230 for i := 0; i < numResponses; i++ {
231 runWithTimeout(t, time.Second, func() {
232 if err := ptc.SendResponse(responsePacket); err != nil {
233 t.Fatalf("unable to send response packet: %s", err)
234 }
235 })
236 }
237
238
239 runWithTimeout(t, time.Second, func() {
240 conn.finishMessage(msgCtx)
241 })
242 }
243
244 func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
245 done := make(chan struct{})
246 go func() {
247 f()
248 close(done)
249 }()
250
251 select {
252 case <-done:
253 case <-time.After(timeout):
254 _, file, line, _ := runtime.Caller(1)
255 t.Fatalf("%s:%d timed out", file, line)
256 }
257 }
258
259
260
261
262
263
264
265
266
267
268
269 type packetTranslatorConn struct {
270 lock sync.Mutex
271 isClosed bool
272
273 responseCond sync.Cond
274 requestCond sync.Cond
275
276 responseBuf bytes.Buffer
277 requestBuf bytes.Buffer
278 }
279
280 var errPacketTranslatorConnClosed = errors.New("connection closed")
281
282 func newPacketTranslatorConn() *packetTranslatorConn {
283 conn := &packetTranslatorConn{}
284 conn.responseCond = sync.Cond{L: &conn.lock}
285 conn.requestCond = sync.Cond{L: &conn.lock}
286
287 return conn
288 }
289
290
291
292
293 func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
294 c.lock.Lock()
295 defer c.lock.Unlock()
296
297 for !c.isClosed {
298
299
300 n, err = c.responseBuf.Read(b)
301 if err != io.EOF {
302 return n, err
303 }
304
305 c.responseCond.Wait()
306 }
307
308 return 0, errPacketTranslatorConnClosed
309 }
310
311
312
313 func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
314 c.lock.Lock()
315 defer c.lock.Unlock()
316
317 if c.isClosed {
318 return errPacketTranslatorConnClosed
319 }
320
321
322 defer c.responseCond.Broadcast()
323
324
325 c.responseBuf.Write(packet.Bytes())
326
327 return nil
328 }
329
330
331 func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
332 c.lock.Lock()
333 defer c.lock.Unlock()
334
335 if c.isClosed {
336 return 0, errPacketTranslatorConnClosed
337 }
338
339
340 defer c.requestCond.Broadcast()
341
342
343 return c.requestBuf.Write(b)
344 }
345
346
347
348
349 func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
350 c.lock.Lock()
351 defer c.lock.Unlock()
352
353 for !c.isClosed {
354
355
356 requestReader := bytes.NewReader(c.requestBuf.Bytes())
357 packet, err := ber.ReadPacket(requestReader)
358 switch err {
359 case io.EOF, io.ErrUnexpectedEOF:
360 c.requestCond.Wait()
361 case nil:
362
363
364 c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
365 return packet, nil
366 default:
367 return nil, err
368 }
369 }
370
371 return nil, errPacketTranslatorConnClosed
372 }
373
374
375 func (c *packetTranslatorConn) Close() error {
376 c.lock.Lock()
377 defer c.lock.Unlock()
378
379 c.isClosed = true
380 c.responseCond.Broadcast()
381 c.requestCond.Broadcast()
382
383 return nil
384 }
385
386 func (c *packetTranslatorConn) LocalAddr() net.Addr {
387 return (*net.TCPAddr)(nil)
388 }
389
390 func (c *packetTranslatorConn) RemoteAddr() net.Addr {
391 return (*net.TCPAddr)(nil)
392 }
393
394 func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
395 return nil
396 }
397
398 func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
399 return nil
400 }
401
402 func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
403 return nil
404 }
405
View as plain text