...

Source file src/github.com/go-ldap/ldap/v3/conn.go

Documentation: github.com/go-ldap/ldap/v3

     1  package ldap
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/url"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	ber "github.com/go-asn1-ber/asn1-ber"
    16  )
    17  
    18  const (
    19  	// MessageQuit causes the processMessages loop to exit
    20  	MessageQuit = 0
    21  	// MessageRequest sends a request to the server
    22  	MessageRequest = 1
    23  	// MessageResponse receives a response from the server
    24  	MessageResponse = 2
    25  	// MessageFinish indicates the client considers a particular message ID to be finished
    26  	MessageFinish = 3
    27  	// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
    28  	MessageTimeout = 4
    29  )
    30  
    31  const (
    32  	// DefaultLdapPort default ldap port for pure TCP connection
    33  	DefaultLdapPort = "389"
    34  	// DefaultLdapsPort default ldap port for SSL connection
    35  	DefaultLdapsPort = "636"
    36  )
    37  
    38  // PacketResponse contains the packet or error encountered reading a response
    39  type PacketResponse struct {
    40  	// Packet is the packet read from the server
    41  	Packet *ber.Packet
    42  	// Error is an error encountered while reading
    43  	Error error
    44  }
    45  
    46  // ReadPacket returns the packet or an error
    47  func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
    48  	if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
    49  		return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
    50  	}
    51  	return pr.Packet, pr.Error
    52  }
    53  
    54  type messageContext struct {
    55  	id int64
    56  	// close(done) should only be called from finishMessage()
    57  	done chan struct{}
    58  	// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
    59  	responses chan *PacketResponse
    60  }
    61  
    62  // sendResponse should only be called within the processMessages() loop which
    63  // is also responsible for closing the responses channel.
    64  func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
    65  	timeoutCtx := context.Background()
    66  	if timeout > 0 {
    67  		var cancelFunc context.CancelFunc
    68  		timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
    69  		defer cancelFunc()
    70  	}
    71  	select {
    72  	case msgCtx.responses <- packet:
    73  		// Successfully sent packet to message handler.
    74  	case <-msgCtx.done:
    75  		// The request handler is done and will not receive more
    76  		// packets.
    77  	case <-timeoutCtx.Done():
    78  		// The timeout was reached before the packet was sent.
    79  	}
    80  }
    81  
    82  type messagePacket struct {
    83  	Op        int
    84  	MessageID int64
    85  	Packet    *ber.Packet
    86  	Context   *messageContext
    87  }
    88  
    89  type sendMessageFlags uint
    90  
    91  const (
    92  	startTLS sendMessageFlags = 1 << iota
    93  )
    94  
    95  // Conn represents an LDAP Connection
    96  type Conn struct {
    97  	// requestTimeout is loaded atomically
    98  	// so we need to ensure 64-bit alignment on 32-bit platforms.
    99  	// https://github.com/go-ldap/ldap/pull/199
   100  	requestTimeout      int64
   101  	conn                net.Conn
   102  	isTLS               bool
   103  	closing             uint32
   104  	closeErr            atomic.Value
   105  	isStartingTLS       bool
   106  	Debug               debugging
   107  	chanConfirm         chan struct{}
   108  	messageContexts     map[int64]*messageContext
   109  	chanMessage         chan *messagePacket
   110  	chanMessageID       chan int64
   111  	wgClose             sync.WaitGroup
   112  	outstandingRequests uint
   113  	messageMutex        sync.Mutex
   114  
   115  	err error
   116  }
   117  
   118  var _ Client = &Conn{}
   119  
   120  // DefaultTimeout is a package-level variable that sets the timeout value
   121  // used for the Dial and DialTLS methods.
   122  //
   123  // WARNING: since this is a package-level variable, setting this value from
   124  // multiple places will probably result in undesired behaviour.
   125  var DefaultTimeout = 60 * time.Second
   126  
   127  // DialOpt configures DialContext.
   128  type DialOpt func(*DialContext)
   129  
   130  // DialWithDialer updates net.Dialer in DialContext.
   131  func DialWithDialer(d *net.Dialer) DialOpt {
   132  	return func(dc *DialContext) {
   133  		dc.dialer = d
   134  	}
   135  }
   136  
   137  // DialWithTLSConfig updates tls.Config in DialContext.
   138  func DialWithTLSConfig(tc *tls.Config) DialOpt {
   139  	return func(dc *DialContext) {
   140  		dc.tlsConfig = tc
   141  	}
   142  }
   143  
   144  // DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
   145  // specify a net.Dialer to for example define a timeout or a custom resolver.
   146  // @deprecated Use DialWithDialer and DialWithTLSConfig instead
   147  func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
   148  	return func(dc *DialContext) {
   149  		dc.tlsConfig = tlsConfig
   150  		dc.dialer = dialer
   151  	}
   152  }
   153  
   154  // DialContext contains necessary parameters to dial the given ldap URL.
   155  type DialContext struct {
   156  	dialer    *net.Dialer
   157  	tlsConfig *tls.Config
   158  }
   159  
   160  func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
   161  	if u.Scheme == "ldapi" {
   162  		if u.Path == "" || u.Path == "/" {
   163  			u.Path = "/var/run/slapd/ldapi"
   164  		}
   165  		return dc.dialer.Dial("unix", u.Path)
   166  	}
   167  
   168  	host, port, err := net.SplitHostPort(u.Host)
   169  	if err != nil {
   170  		// we assume that error is due to missing port
   171  		host = u.Host
   172  		port = ""
   173  	}
   174  
   175  	switch u.Scheme {
   176  	case "cldap":
   177  		if port == "" {
   178  			port = DefaultLdapPort
   179  		}
   180  		return dc.dialer.Dial("udp", net.JoinHostPort(host, port))
   181  	case "ldap":
   182  		if port == "" {
   183  			port = DefaultLdapPort
   184  		}
   185  		return dc.dialer.Dial("tcp", net.JoinHostPort(host, port))
   186  	case "ldaps":
   187  		if port == "" {
   188  			port = DefaultLdapsPort
   189  		}
   190  		return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig)
   191  	}
   192  
   193  	return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
   194  }
   195  
   196  // Dial connects to the given address on the given network using net.Dial
   197  // and then returns a new Conn for the connection.
   198  // @deprecated Use DialURL instead.
   199  func Dial(network, addr string) (*Conn, error) {
   200  	c, err := net.DialTimeout(network, addr, DefaultTimeout)
   201  	if err != nil {
   202  		return nil, NewError(ErrorNetwork, err)
   203  	}
   204  	conn := NewConn(c, false)
   205  	conn.Start()
   206  	return conn, nil
   207  }
   208  
   209  // DialTLS connects to the given address on the given network using tls.Dial
   210  // and then returns a new Conn for the connection.
   211  // @deprecated Use DialURL instead.
   212  func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
   213  	c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
   214  	if err != nil {
   215  		return nil, NewError(ErrorNetwork, err)
   216  	}
   217  	conn := NewConn(c, true)
   218  	conn.Start()
   219  	return conn, nil
   220  }
   221  
   222  // DialURL connects to the given ldap URL.
   223  // The following schemas are supported: ldap://, ldaps://, ldapi://,
   224  // and cldap:// (RFC1798, deprecated but used by Active Directory).
   225  // On success a new Conn for the connection is returned.
   226  func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
   227  	u, err := url.Parse(addr)
   228  	if err != nil {
   229  		return nil, NewError(ErrorNetwork, err)
   230  	}
   231  
   232  	var dc DialContext
   233  	for _, opt := range opts {
   234  		opt(&dc)
   235  	}
   236  	if dc.dialer == nil {
   237  		dc.dialer = &net.Dialer{Timeout: DefaultTimeout}
   238  	}
   239  
   240  	c, err := dc.dial(u)
   241  	if err != nil {
   242  		return nil, NewError(ErrorNetwork, err)
   243  	}
   244  
   245  	conn := NewConn(c, u.Scheme == "ldaps")
   246  	conn.Start()
   247  	return conn, nil
   248  }
   249  
   250  // NewConn returns a new Conn using conn for network I/O.
   251  func NewConn(conn net.Conn, isTLS bool) *Conn {
   252  	l := &Conn{
   253  		conn:            conn,
   254  		chanConfirm:     make(chan struct{}),
   255  		chanMessageID:   make(chan int64),
   256  		chanMessage:     make(chan *messagePacket, 10),
   257  		messageContexts: map[int64]*messageContext{},
   258  		requestTimeout:  0,
   259  		isTLS:           isTLS,
   260  	}
   261  	l.wgClose.Add(1)
   262  	return l
   263  }
   264  
   265  // Start initializes goroutines to read responses and process messages
   266  func (l *Conn) Start() {
   267  	go l.reader()
   268  	go l.processMessages()
   269  }
   270  
   271  // IsClosing returns whether or not we're currently closing.
   272  func (l *Conn) IsClosing() bool {
   273  	return atomic.LoadUint32(&l.closing) == 1
   274  }
   275  
   276  // setClosing sets the closing value to true
   277  func (l *Conn) setClosing() bool {
   278  	return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
   279  }
   280  
   281  // Close closes the connection.
   282  func (l *Conn) Close() (err error) {
   283  	l.messageMutex.Lock()
   284  	defer l.messageMutex.Unlock()
   285  
   286  	if l.setClosing() {
   287  		l.Debug.Printf("Sending quit message and waiting for confirmation")
   288  		l.chanMessage <- &messagePacket{Op: MessageQuit}
   289  
   290  		timeoutCtx := context.Background()
   291  		if l.getTimeout() > 0 {
   292  			var cancelFunc context.CancelFunc
   293  			timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
   294  			defer cancelFunc()
   295  		}
   296  		select {
   297  		case <-l.chanConfirm:
   298  			// Confirmation was received.
   299  		case <-timeoutCtx.Done():
   300  			// The timeout was reached before confirmation was received.
   301  		}
   302  
   303  		close(l.chanMessage)
   304  
   305  		l.Debug.Printf("Closing network connection")
   306  		err = l.conn.Close()
   307  		l.wgClose.Done()
   308  	}
   309  	l.wgClose.Wait()
   310  
   311  	return err
   312  }
   313  
   314  // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
   315  func (l *Conn) SetTimeout(timeout time.Duration) {
   316  	atomic.StoreInt64(&l.requestTimeout, int64(timeout))
   317  }
   318  
   319  func (l *Conn) getTimeout() int64 {
   320  	return atomic.LoadInt64(&l.requestTimeout)
   321  }
   322  
   323  // Returns the next available messageID
   324  func (l *Conn) nextMessageID() int64 {
   325  	if messageID, ok := <-l.chanMessageID; ok {
   326  		return messageID
   327  	}
   328  	return 0
   329  }
   330  
   331  // GetLastError returns the last recorded error from goroutines like processMessages and reader.
   332  // Only the last recorded error will be returned.
   333  func (l *Conn) GetLastError() error {
   334  	l.messageMutex.Lock()
   335  	defer l.messageMutex.Unlock()
   336  	return l.err
   337  }
   338  
   339  // StartTLS sends the command to start a TLS session and then creates a new TLS Client
   340  func (l *Conn) StartTLS(config *tls.Config) error {
   341  	if l.isTLS {
   342  		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
   343  	}
   344  
   345  	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
   346  	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
   347  	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
   348  	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
   349  	packet.AppendChild(request)
   350  	l.Debug.PrintPacket(packet)
   351  
   352  	msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
   353  	if err != nil {
   354  		return err
   355  	}
   356  	defer l.finishMessage(msgCtx)
   357  
   358  	l.Debug.Printf("%d: waiting for response", msgCtx.id)
   359  
   360  	packetResponse, ok := <-msgCtx.responses
   361  	if !ok {
   362  		return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
   363  	}
   364  	packet, err = packetResponse.ReadPacket()
   365  	l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
   366  	if err != nil {
   367  		return err
   368  	}
   369  
   370  	if l.Debug {
   371  		if err := addLDAPDescriptions(packet); err != nil {
   372  			l.Close()
   373  			return err
   374  		}
   375  		l.Debug.PrintPacket(packet)
   376  	}
   377  
   378  	if err := GetLDAPError(packet); err == nil {
   379  		conn := tls.Client(l.conn, config)
   380  
   381  		if connErr := conn.Handshake(); connErr != nil {
   382  			l.Close()
   383  			return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
   384  		}
   385  
   386  		l.isTLS = true
   387  		l.conn = conn
   388  	} else {
   389  		return err
   390  	}
   391  	go l.reader()
   392  
   393  	return nil
   394  }
   395  
   396  // TLSConnectionState returns the client's TLS connection state.
   397  // The return values are their zero values if StartTLS did
   398  // not succeed.
   399  func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
   400  	tc, ok := l.conn.(*tls.Conn)
   401  	if !ok {
   402  		return
   403  	}
   404  	return tc.ConnectionState(), true
   405  }
   406  
   407  func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
   408  	return l.sendMessageWithFlags(packet, 0)
   409  }
   410  
   411  func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
   412  	if l.IsClosing() {
   413  		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
   414  	}
   415  	l.messageMutex.Lock()
   416  	l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
   417  	if l.isStartingTLS {
   418  		l.messageMutex.Unlock()
   419  		return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
   420  	}
   421  	if flags&startTLS != 0 {
   422  		if l.outstandingRequests != 0 {
   423  			l.messageMutex.Unlock()
   424  			return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
   425  		}
   426  		l.isStartingTLS = true
   427  	}
   428  	l.outstandingRequests++
   429  
   430  	l.messageMutex.Unlock()
   431  
   432  	responses := make(chan *PacketResponse)
   433  	messageID := packet.Children[0].Value.(int64)
   434  	message := &messagePacket{
   435  		Op:        MessageRequest,
   436  		MessageID: messageID,
   437  		Packet:    packet,
   438  		Context: &messageContext{
   439  			id:        messageID,
   440  			done:      make(chan struct{}),
   441  			responses: responses,
   442  		},
   443  	}
   444  	if !l.sendProcessMessage(message) {
   445  		if l.IsClosing() {
   446  			return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
   447  		}
   448  		return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason"))
   449  	}
   450  	return message.Context, nil
   451  }
   452  
   453  func (l *Conn) finishMessage(msgCtx *messageContext) {
   454  	close(msgCtx.done)
   455  
   456  	if l.IsClosing() {
   457  		return
   458  	}
   459  
   460  	l.messageMutex.Lock()
   461  	l.outstandingRequests--
   462  	if l.isStartingTLS {
   463  		l.isStartingTLS = false
   464  	}
   465  	l.messageMutex.Unlock()
   466  
   467  	message := &messagePacket{
   468  		Op:        MessageFinish,
   469  		MessageID: msgCtx.id,
   470  	}
   471  	l.sendProcessMessage(message)
   472  }
   473  
   474  func (l *Conn) sendProcessMessage(message *messagePacket) bool {
   475  	l.messageMutex.Lock()
   476  	defer l.messageMutex.Unlock()
   477  	if l.IsClosing() {
   478  		return false
   479  	}
   480  	l.chanMessage <- message
   481  	return true
   482  }
   483  
   484  func (l *Conn) processMessages() {
   485  	defer func() {
   486  		if err := recover(); err != nil {
   487  			l.err = fmt.Errorf("ldap: recovered panic in processMessages: %v", err)
   488  		}
   489  		for messageID, msgCtx := range l.messageContexts {
   490  			// If we are closing due to an error, inform anyone who
   491  			// is waiting about the error.
   492  			if l.IsClosing() && l.closeErr.Load() != nil {
   493  				msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
   494  			}
   495  			l.Debug.Printf("Closing channel for MessageID %d", messageID)
   496  			close(msgCtx.responses)
   497  			delete(l.messageContexts, messageID)
   498  		}
   499  		close(l.chanMessageID)
   500  		close(l.chanConfirm)
   501  	}()
   502  
   503  	var messageID int64 = 1
   504  	for {
   505  		select {
   506  		case l.chanMessageID <- messageID:
   507  			messageID++
   508  		case message := <-l.chanMessage:
   509  			switch message.Op {
   510  			case MessageQuit:
   511  				l.Debug.Printf("Shutting down - quit message received")
   512  				return
   513  			case MessageRequest:
   514  				// Add to message list and write to network
   515  				l.Debug.Printf("Sending message %d", message.MessageID)
   516  
   517  				buf := message.Packet.Bytes()
   518  				_, err := l.conn.Write(buf)
   519  				if err != nil {
   520  					l.Debug.Printf("Error Sending Message: %s", err.Error())
   521  					message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
   522  					close(message.Context.responses)
   523  					break
   524  				}
   525  
   526  				// Only add to messageContexts if we were able to
   527  				// successfully write the message.
   528  				l.messageContexts[message.MessageID] = message.Context
   529  
   530  				// Add timeout if defined
   531  				requestTimeout := l.getTimeout()
   532  				if requestTimeout > 0 {
   533  					go func() {
   534  						timer := time.NewTimer(time.Duration(requestTimeout))
   535  						defer func() {
   536  							if err := recover(); err != nil {
   537  								l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
   538  							}
   539  
   540  							timer.Stop()
   541  						}()
   542  
   543  						select {
   544  						case <-timer.C:
   545  							timeoutMessage := &messagePacket{
   546  								Op:        MessageTimeout,
   547  								MessageID: message.MessageID,
   548  							}
   549  							l.sendProcessMessage(timeoutMessage)
   550  						case <-message.Context.done:
   551  						}
   552  					}()
   553  				}
   554  			case MessageResponse:
   555  				l.Debug.Printf("Receiving message %d", message.MessageID)
   556  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   557  					msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
   558  				} else {
   559  					l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
   560  					l.Debug.PrintPacket(message.Packet)
   561  				}
   562  			case MessageTimeout:
   563  				// Handle the timeout by closing the channel
   564  				// All reads will return immediately
   565  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   566  					l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
   567  					msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
   568  					delete(l.messageContexts, message.MessageID)
   569  					close(msgCtx.responses)
   570  				}
   571  			case MessageFinish:
   572  				l.Debug.Printf("Finished message %d", message.MessageID)
   573  				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
   574  					delete(l.messageContexts, message.MessageID)
   575  					close(msgCtx.responses)
   576  				}
   577  			}
   578  		}
   579  	}
   580  }
   581  
   582  func (l *Conn) reader() {
   583  	cleanstop := false
   584  	defer func() {
   585  		if err := recover(); err != nil {
   586  			l.err = fmt.Errorf("ldap: recovered panic in reader: %v", err)
   587  		}
   588  		if !cleanstop {
   589  			l.Close()
   590  		}
   591  	}()
   592  
   593  	bufConn := bufio.NewReader(l.conn)
   594  	for {
   595  		if cleanstop {
   596  			l.Debug.Printf("reader clean stopping (without closing the connection)")
   597  			return
   598  		}
   599  		packet, err := ber.ReadPacket(bufConn)
   600  		if err != nil {
   601  			// A read error is expected here if we are closing the connection...
   602  			if !l.IsClosing() {
   603  				l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
   604  				l.Debug.Printf("reader error: %s", err)
   605  			}
   606  			return
   607  		}
   608  		if err := addLDAPDescriptions(packet); err != nil {
   609  			l.Debug.Printf("descriptions error: %s", err)
   610  		}
   611  		if len(packet.Children) == 0 {
   612  			l.Debug.Printf("Received bad ldap packet")
   613  			continue
   614  		}
   615  		l.messageMutex.Lock()
   616  		if l.isStartingTLS {
   617  			cleanstop = true
   618  		}
   619  		l.messageMutex.Unlock()
   620  		message := &messagePacket{
   621  			Op:        MessageResponse,
   622  			MessageID: packet.Children[0].Value.(int64),
   623  			Packet:    packet,
   624  		}
   625  		if !l.sendProcessMessage(message) {
   626  			return
   627  		}
   628  	}
   629  }
   630  

View as plain text