...

Source file src/github.com/go-ldap/ldap/v3/conn_test.go

Documentation: github.com/go-ldap/ldap/v3

     1  package ldap
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"runtime"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	ber "github.com/go-asn1-ber/asn1-ber"
    16  )
    17  
    18  func TestUnresponsiveConnection(t *testing.T) {
    19  	// The do-nothing server that accepts requests and does nothing
    20  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    21  	}))
    22  	defer ts.Close()
    23  	c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
    24  	if err != nil {
    25  		t.Fatalf("error connecting to localhost tcp: %v", err)
    26  	}
    27  
    28  	// Create an Ldap connection
    29  	conn := NewConn(c, false)
    30  	conn.SetTimeout(time.Millisecond)
    31  	conn.Start()
    32  	defer conn.Close()
    33  
    34  	// Mock a packet
    35  	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
    36  	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID"))
    37  	bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
    38  	bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
    39  	packet.AppendChild(bindRequest)
    40  
    41  	// Send packet and test response
    42  	msgCtx, err := conn.sendMessage(packet)
    43  	if err != nil {
    44  		t.Fatalf("error sending message: %v", err)
    45  	}
    46  	defer conn.finishMessage(msgCtx)
    47  
    48  	packetResponse, ok := <-msgCtx.responses
    49  	if !ok {
    50  		t.Fatalf("no PacketResponse in response channel")
    51  	}
    52  	_, err = packetResponse.ReadPacket()
    53  	if err == nil {
    54  		t.Fatalf("expected timeout error")
    55  	}
    56  	if !IsErrorWithCode(err, ErrorNetwork) || err.(*Error).Err.Error() != "ldap: connection timed out" {
    57  		t.Fatalf("unexpected error: %v", err)
    58  	}
    59  }
    60  
    61  func TestRequestTimeoutDeadlock(t *testing.T) {
    62  	// The do-nothing server that accepts requests and does nothing
    63  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    64  	}))
    65  	defer ts.Close()
    66  	c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
    67  	if err != nil {
    68  		t.Fatalf("error connecting to localhost tcp: %v", err)
    69  	}
    70  
    71  	// Create an Ldap connection
    72  	conn := NewConn(c, false)
    73  	conn.Start()
    74  	// trigger a race condition on accessing request timeout
    75  	n := 3
    76  	for i := 0; i < n; i++ {
    77  		go func() {
    78  			conn.SetTimeout(time.Millisecond)
    79  		}()
    80  	}
    81  
    82  	// Attempt to close the connection when the message handler is
    83  	// blocked or inactive
    84  	conn.Close()
    85  }
    86  
    87  // TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
    88  // message handler is blocked or inactive.
    89  func TestInvalidStateCloseDeadlock(t *testing.T) {
    90  	// The do-nothing server that accepts requests and does nothing
    91  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    92  	}))
    93  	defer ts.Close()
    94  	c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
    95  	if err != nil {
    96  		t.Fatalf("error connecting to localhost tcp: %v", err)
    97  	}
    98  
    99  	// Create an Ldap connection
   100  	conn := NewConn(c, false)
   101  	conn.SetTimeout(time.Millisecond)
   102  
   103  	// Attempt to close the connection when the message handler is
   104  	// blocked or inactive
   105  	conn.Close()
   106  }
   107  
   108  // TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
   109  // message handler is blocked or inactive.
   110  func TestInvalidStateSendResponseDeadlock(t *testing.T) {
   111  	// Attempt to send a response packet when the message handler is blocked or inactive
   112  	msgCtx := &messageContext{
   113  		id:        0,
   114  		done:      make(chan struct{}),
   115  		responses: make(chan *PacketResponse),
   116  	}
   117  	msgCtx.sendResponse(&PacketResponse{}, time.Millisecond)
   118  }
   119  
   120  // TestFinishMessage tests that we do not enter deadlock when a goroutine makes
   121  // a request but does not handle all responses from the server.
   122  func TestFinishMessage(t *testing.T) {
   123  	ptc := newPacketTranslatorConn()
   124  	defer ptc.Close()
   125  
   126  	conn := NewConn(ptc, false)
   127  	conn.Start()
   128  
   129  	// Test sending 5 different requests in series. Ensure that we can
   130  	// get a response packet from the underlying connection and also
   131  	// ensure that we can gracefully ignore unhandled responses.
   132  	for i := 0; i < 5; i++ {
   133  		t.Logf("serial request %d", i)
   134  		// Create a message and make sure we can receive responses.
   135  		msgCtx := testSendRequest(t, ptc, conn)
   136  		testReceiveResponse(t, ptc, msgCtx)
   137  
   138  		// Send a few unhandled responses and finish the message.
   139  		testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
   140  		t.Logf("serial request %d done", i)
   141  	}
   142  
   143  	// Test sending 5 different requests in parallel.
   144  	var wg sync.WaitGroup
   145  	for i := 0; i < 5; i++ {
   146  		wg.Add(1)
   147  		go func(i int) {
   148  			defer wg.Done()
   149  			t.Logf("parallel request %d", i)
   150  			// Create a message and make sure we can receive responses.
   151  			msgCtx := testSendRequest(t, ptc, conn)
   152  			testReceiveResponse(t, ptc, msgCtx)
   153  
   154  			// Send a few unhandled responses and finish the message.
   155  			testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5)
   156  			t.Logf("parallel request %d done", i)
   157  		}(i)
   158  	}
   159  	wg.Wait()
   160  
   161  	// We cannot run Close() in a defer because t.FailNow() will run it and
   162  	// it will block if the processMessage Loop is in a deadlock.
   163  	conn.Close()
   164  }
   165  
   166  // See: https://github.com/go-ldap/ldap/issues/332
   167  func TestNilConnection(t *testing.T) {
   168  	var conn *Conn
   169  	_, err := conn.Search(&SearchRequest{})
   170  	if err != ErrNilConnection {
   171  		t.Fatalf("expected error to be ErrNilConnection, got %v", err)
   172  	}
   173  }
   174  
   175  func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) {
   176  	var msgID int64
   177  	runWithTimeout(t, time.Second, func() {
   178  		msgID = conn.nextMessageID()
   179  	})
   180  
   181  	requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
   182  	requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID"))
   183  
   184  	var err error
   185  
   186  	runWithTimeout(t, time.Second, func() {
   187  		msgCtx, err = conn.sendMessage(requestPacket)
   188  		if err != nil {
   189  			t.Fatalf("unable to send request message: %s", err)
   190  		}
   191  	})
   192  
   193  	// We should now be able to get this request packet out from the other
   194  	// side.
   195  	runWithTimeout(t, time.Second, func() {
   196  		if _, err = ptc.ReceiveRequest(); err != nil {
   197  			t.Fatalf("unable to receive request packet: %s", err)
   198  		}
   199  	})
   200  
   201  	return msgCtx
   202  }
   203  
   204  func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) {
   205  	// Send a mock response packet.
   206  	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
   207  	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
   208  
   209  	runWithTimeout(t, time.Second, func() {
   210  		if err := ptc.SendResponse(responsePacket); err != nil {
   211  			t.Fatalf("unable to send response packet: %s", err)
   212  		}
   213  	})
   214  
   215  	// We should be able to receive the packet from the connection.
   216  	runWithTimeout(t, time.Second, func() {
   217  		if _, ok := <-msgCtx.responses; !ok {
   218  			t.Fatal("response channel closed")
   219  		}
   220  	})
   221  }
   222  
   223  func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) {
   224  	// Send a mock response packet.
   225  	responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
   226  	responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID"))
   227  
   228  	// Send extra responses but do not attempt to receive them on the
   229  	// client side.
   230  	for i := 0; i < numResponses; i++ {
   231  		runWithTimeout(t, time.Second, func() {
   232  			if err := ptc.SendResponse(responsePacket); err != nil {
   233  				t.Fatalf("unable to send response packet: %s", err)
   234  			}
   235  		})
   236  	}
   237  
   238  	// Finally, attempt to finish this message.
   239  	runWithTimeout(t, time.Second, func() {
   240  		conn.finishMessage(msgCtx)
   241  	})
   242  }
   243  
   244  func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
   245  	done := make(chan struct{})
   246  	go func() {
   247  		f()
   248  		close(done)
   249  	}()
   250  
   251  	select {
   252  	case <-done: // Success!
   253  	case <-time.After(timeout):
   254  		_, file, line, _ := runtime.Caller(1)
   255  		t.Fatalf("%s:%d timed out", file, line)
   256  	}
   257  }
   258  
   259  // packetTranslatorConn is a helpful type which can be used with various tests
   260  // in this package. It implements the net.Conn interface to be used as an
   261  // underlying connection for a *ldap.Conn. Most methods are no-ops but the
   262  // Read() and Write() methods are able to translate ber-encoded packets for
   263  // testing LDAP requests and responses.
   264  //
   265  // Test cases can simulate an LDAP server sending a response by calling the
   266  // SendResponse() method with a ber-encoded LDAP response packet. Test cases
   267  // can simulate an LDAP server receiving a request from a client by calling the
   268  // ReceiveRequest() method which returns a ber-encoded LDAP request packet.
   269  type packetTranslatorConn struct {
   270  	lock     sync.Mutex
   271  	isClosed bool
   272  
   273  	responseCond sync.Cond
   274  	requestCond  sync.Cond
   275  
   276  	responseBuf bytes.Buffer
   277  	requestBuf  bytes.Buffer
   278  }
   279  
   280  var errPacketTranslatorConnClosed = errors.New("connection closed")
   281  
   282  func newPacketTranslatorConn() *packetTranslatorConn {
   283  	conn := &packetTranslatorConn{}
   284  	conn.responseCond = sync.Cond{L: &conn.lock}
   285  	conn.requestCond = sync.Cond{L: &conn.lock}
   286  
   287  	return conn
   288  }
   289  
   290  // Read is called by the reader() loop to receive response packets. It will
   291  // block until there are more packet bytes available or this connection is
   292  // closed.
   293  func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
   294  	c.lock.Lock()
   295  	defer c.lock.Unlock()
   296  
   297  	for !c.isClosed {
   298  		// Attempt to read data from the response buffer. If it fails
   299  		// with an EOF, wait and try again.
   300  		n, err = c.responseBuf.Read(b)
   301  		if err != io.EOF {
   302  			return n, err
   303  		}
   304  
   305  		c.responseCond.Wait()
   306  	}
   307  
   308  	return 0, errPacketTranslatorConnClosed
   309  }
   310  
   311  // SendResponse writes the given response packet to the response buffer for
   312  // this connection, signalling any goroutine waiting to read a response.
   313  func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
   314  	c.lock.Lock()
   315  	defer c.lock.Unlock()
   316  
   317  	if c.isClosed {
   318  		return errPacketTranslatorConnClosed
   319  	}
   320  
   321  	// Signal any goroutine waiting to read a response.
   322  	defer c.responseCond.Broadcast()
   323  
   324  	// Writes to the buffer should always succeed.
   325  	c.responseBuf.Write(packet.Bytes())
   326  
   327  	return nil
   328  }
   329  
   330  // Write is called by the processMessages() loop to send request packets.
   331  func (c *packetTranslatorConn) Write(b []byte) (n int, err error) {
   332  	c.lock.Lock()
   333  	defer c.lock.Unlock()
   334  
   335  	if c.isClosed {
   336  		return 0, errPacketTranslatorConnClosed
   337  	}
   338  
   339  	// Signal any goroutine waiting to read a request.
   340  	defer c.requestCond.Broadcast()
   341  
   342  	// Writes to the buffer should always succeed.
   343  	return c.requestBuf.Write(b)
   344  }
   345  
   346  // ReceiveRequest attempts to read a request packet from this connection. It
   347  // will block until it is able to read a full request packet or until this
   348  // connection is closed.
   349  func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) {
   350  	c.lock.Lock()
   351  	defer c.lock.Unlock()
   352  
   353  	for !c.isClosed {
   354  		// Attempt to parse a request packet from the request buffer.
   355  		// If it fails with an unexpected EOF, wait and try again.
   356  		requestReader := bytes.NewReader(c.requestBuf.Bytes())
   357  		packet, err := ber.ReadPacket(requestReader)
   358  		switch err {
   359  		case io.EOF, io.ErrUnexpectedEOF:
   360  			c.requestCond.Wait()
   361  		case nil:
   362  			// Advance the request buffer by the number of bytes
   363  			// read to decode the request packet.
   364  			c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len())
   365  			return packet, nil
   366  		default:
   367  			return nil, err
   368  		}
   369  	}
   370  
   371  	return nil, errPacketTranslatorConnClosed
   372  }
   373  
   374  // Close closes this connection causing Read() and Write() calls to fail.
   375  func (c *packetTranslatorConn) Close() error {
   376  	c.lock.Lock()
   377  	defer c.lock.Unlock()
   378  
   379  	c.isClosed = true
   380  	c.responseCond.Broadcast()
   381  	c.requestCond.Broadcast()
   382  
   383  	return nil
   384  }
   385  
   386  func (c *packetTranslatorConn) LocalAddr() net.Addr {
   387  	return (*net.TCPAddr)(nil)
   388  }
   389  
   390  func (c *packetTranslatorConn) RemoteAddr() net.Addr {
   391  	return (*net.TCPAddr)(nil)
   392  }
   393  
   394  func (c *packetTranslatorConn) SetDeadline(t time.Time) error {
   395  	return nil
   396  }
   397  
   398  func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error {
   399  	return nil
   400  }
   401  
   402  func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error {
   403  	return nil
   404  }
   405  

View as plain text