1 package ldap
2
3 import (
4 "bufio"
5 "context"
6 "crypto/tls"
7 "errors"
8 "fmt"
9 "net"
10 "net/url"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 ber "github.com/go-asn1-ber/asn1-ber"
16 )
17
18 const (
19
20 MessageQuit = 0
21
22 MessageRequest = 1
23
24 MessageResponse = 2
25
26 MessageFinish = 3
27
28 MessageTimeout = 4
29 )
30
31 const (
32
33 DefaultLdapPort = "389"
34
35 DefaultLdapsPort = "636"
36 )
37
38
39 type PacketResponse struct {
40
41 Packet *ber.Packet
42
43 Error error
44 }
45
46
47 func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
48 if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
49 return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
50 }
51 return pr.Packet, pr.Error
52 }
53
54 type messageContext struct {
55 id int64
56
57 done chan struct{}
58
59 responses chan *PacketResponse
60 }
61
62
63
64 func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
65 timeoutCtx := context.Background()
66 if timeout > 0 {
67 var cancelFunc context.CancelFunc
68 timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
69 defer cancelFunc()
70 }
71 select {
72 case msgCtx.responses <- packet:
73
74 case <-msgCtx.done:
75
76
77 case <-timeoutCtx.Done():
78
79 }
80 }
81
82 type messagePacket struct {
83 Op int
84 MessageID int64
85 Packet *ber.Packet
86 Context *messageContext
87 }
88
89 type sendMessageFlags uint
90
91 const (
92 startTLS sendMessageFlags = 1 << iota
93 )
94
95
96 type Conn struct {
97
98
99
100 requestTimeout int64
101 conn net.Conn
102 isTLS bool
103 closing uint32
104 closeErr atomic.Value
105 isStartingTLS bool
106 Debug debugging
107 chanConfirm chan struct{}
108 messageContexts map[int64]*messageContext
109 chanMessage chan *messagePacket
110 chanMessageID chan int64
111 wgClose sync.WaitGroup
112 outstandingRequests uint
113 messageMutex sync.Mutex
114
115 err error
116 }
117
118 var _ Client = &Conn{}
119
120
121
122
123
124
125 var DefaultTimeout = 60 * time.Second
126
127
128 type DialOpt func(*DialContext)
129
130
131 func DialWithDialer(d *net.Dialer) DialOpt {
132 return func(dc *DialContext) {
133 dc.dialer = d
134 }
135 }
136
137
138 func DialWithTLSConfig(tc *tls.Config) DialOpt {
139 return func(dc *DialContext) {
140 dc.tlsConfig = tc
141 }
142 }
143
144
145
146
147 func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
148 return func(dc *DialContext) {
149 dc.tlsConfig = tlsConfig
150 dc.dialer = dialer
151 }
152 }
153
154
155 type DialContext struct {
156 dialer *net.Dialer
157 tlsConfig *tls.Config
158 }
159
160 func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
161 if u.Scheme == "ldapi" {
162 if u.Path == "" || u.Path == "/" {
163 u.Path = "/var/run/slapd/ldapi"
164 }
165 return dc.dialer.Dial("unix", u.Path)
166 }
167
168 host, port, err := net.SplitHostPort(u.Host)
169 if err != nil {
170
171 host = u.Host
172 port = ""
173 }
174
175 switch u.Scheme {
176 case "cldap":
177 if port == "" {
178 port = DefaultLdapPort
179 }
180 return dc.dialer.Dial("udp", net.JoinHostPort(host, port))
181 case "ldap":
182 if port == "" {
183 port = DefaultLdapPort
184 }
185 return dc.dialer.Dial("tcp", net.JoinHostPort(host, port))
186 case "ldaps":
187 if port == "" {
188 port = DefaultLdapsPort
189 }
190 return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig)
191 }
192
193 return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
194 }
195
196
197
198
199 func Dial(network, addr string) (*Conn, error) {
200 c, err := net.DialTimeout(network, addr, DefaultTimeout)
201 if err != nil {
202 return nil, NewError(ErrorNetwork, err)
203 }
204 conn := NewConn(c, false)
205 conn.Start()
206 return conn, nil
207 }
208
209
210
211
212 func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
213 c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
214 if err != nil {
215 return nil, NewError(ErrorNetwork, err)
216 }
217 conn := NewConn(c, true)
218 conn.Start()
219 return conn, nil
220 }
221
222
223
224
225
226 func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
227 u, err := url.Parse(addr)
228 if err != nil {
229 return nil, NewError(ErrorNetwork, err)
230 }
231
232 var dc DialContext
233 for _, opt := range opts {
234 opt(&dc)
235 }
236 if dc.dialer == nil {
237 dc.dialer = &net.Dialer{Timeout: DefaultTimeout}
238 }
239
240 c, err := dc.dial(u)
241 if err != nil {
242 return nil, NewError(ErrorNetwork, err)
243 }
244
245 conn := NewConn(c, u.Scheme == "ldaps")
246 conn.Start()
247 return conn, nil
248 }
249
250
251 func NewConn(conn net.Conn, isTLS bool) *Conn {
252 l := &Conn{
253 conn: conn,
254 chanConfirm: make(chan struct{}),
255 chanMessageID: make(chan int64),
256 chanMessage: make(chan *messagePacket, 10),
257 messageContexts: map[int64]*messageContext{},
258 requestTimeout: 0,
259 isTLS: isTLS,
260 }
261 l.wgClose.Add(1)
262 return l
263 }
264
265
266 func (l *Conn) Start() {
267 go l.reader()
268 go l.processMessages()
269 }
270
271
272 func (l *Conn) IsClosing() bool {
273 return atomic.LoadUint32(&l.closing) == 1
274 }
275
276
277 func (l *Conn) setClosing() bool {
278 return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
279 }
280
281
282 func (l *Conn) Close() (err error) {
283 l.messageMutex.Lock()
284 defer l.messageMutex.Unlock()
285
286 if l.setClosing() {
287 l.Debug.Printf("Sending quit message and waiting for confirmation")
288 l.chanMessage <- &messagePacket{Op: MessageQuit}
289
290 timeoutCtx := context.Background()
291 if l.getTimeout() > 0 {
292 var cancelFunc context.CancelFunc
293 timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
294 defer cancelFunc()
295 }
296 select {
297 case <-l.chanConfirm:
298
299 case <-timeoutCtx.Done():
300
301 }
302
303 close(l.chanMessage)
304
305 l.Debug.Printf("Closing network connection")
306 err = l.conn.Close()
307 l.wgClose.Done()
308 }
309 l.wgClose.Wait()
310
311 return err
312 }
313
314
315 func (l *Conn) SetTimeout(timeout time.Duration) {
316 atomic.StoreInt64(&l.requestTimeout, int64(timeout))
317 }
318
319 func (l *Conn) getTimeout() int64 {
320 return atomic.LoadInt64(&l.requestTimeout)
321 }
322
323
324 func (l *Conn) nextMessageID() int64 {
325 if messageID, ok := <-l.chanMessageID; ok {
326 return messageID
327 }
328 return 0
329 }
330
331
332
333 func (l *Conn) GetLastError() error {
334 l.messageMutex.Lock()
335 defer l.messageMutex.Unlock()
336 return l.err
337 }
338
339
340 func (l *Conn) StartTLS(config *tls.Config) error {
341 if l.isTLS {
342 return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
343 }
344
345 packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
346 packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
347 request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
348 request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
349 packet.AppendChild(request)
350 l.Debug.PrintPacket(packet)
351
352 msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
353 if err != nil {
354 return err
355 }
356 defer l.finishMessage(msgCtx)
357
358 l.Debug.Printf("%d: waiting for response", msgCtx.id)
359
360 packetResponse, ok := <-msgCtx.responses
361 if !ok {
362 return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
363 }
364 packet, err = packetResponse.ReadPacket()
365 l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
366 if err != nil {
367 return err
368 }
369
370 if l.Debug {
371 if err := addLDAPDescriptions(packet); err != nil {
372 l.Close()
373 return err
374 }
375 l.Debug.PrintPacket(packet)
376 }
377
378 if err := GetLDAPError(packet); err == nil {
379 conn := tls.Client(l.conn, config)
380
381 if connErr := conn.Handshake(); connErr != nil {
382 l.Close()
383 return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
384 }
385
386 l.isTLS = true
387 l.conn = conn
388 } else {
389 return err
390 }
391 go l.reader()
392
393 return nil
394 }
395
396
397
398
399 func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
400 tc, ok := l.conn.(*tls.Conn)
401 if !ok {
402 return
403 }
404 return tc.ConnectionState(), true
405 }
406
407 func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
408 return l.sendMessageWithFlags(packet, 0)
409 }
410
411 func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
412 if l.IsClosing() {
413 return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
414 }
415 l.messageMutex.Lock()
416 l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
417 if l.isStartingTLS {
418 l.messageMutex.Unlock()
419 return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
420 }
421 if flags&startTLS != 0 {
422 if l.outstandingRequests != 0 {
423 l.messageMutex.Unlock()
424 return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
425 }
426 l.isStartingTLS = true
427 }
428 l.outstandingRequests++
429
430 l.messageMutex.Unlock()
431
432 responses := make(chan *PacketResponse)
433 messageID := packet.Children[0].Value.(int64)
434 message := &messagePacket{
435 Op: MessageRequest,
436 MessageID: messageID,
437 Packet: packet,
438 Context: &messageContext{
439 id: messageID,
440 done: make(chan struct{}),
441 responses: responses,
442 },
443 }
444 if !l.sendProcessMessage(message) {
445 if l.IsClosing() {
446 return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
447 }
448 return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason"))
449 }
450 return message.Context, nil
451 }
452
453 func (l *Conn) finishMessage(msgCtx *messageContext) {
454 close(msgCtx.done)
455
456 if l.IsClosing() {
457 return
458 }
459
460 l.messageMutex.Lock()
461 l.outstandingRequests--
462 if l.isStartingTLS {
463 l.isStartingTLS = false
464 }
465 l.messageMutex.Unlock()
466
467 message := &messagePacket{
468 Op: MessageFinish,
469 MessageID: msgCtx.id,
470 }
471 l.sendProcessMessage(message)
472 }
473
474 func (l *Conn) sendProcessMessage(message *messagePacket) bool {
475 l.messageMutex.Lock()
476 defer l.messageMutex.Unlock()
477 if l.IsClosing() {
478 return false
479 }
480 l.chanMessage <- message
481 return true
482 }
483
484 func (l *Conn) processMessages() {
485 defer func() {
486 if err := recover(); err != nil {
487 l.err = fmt.Errorf("ldap: recovered panic in processMessages: %v", err)
488 }
489 for messageID, msgCtx := range l.messageContexts {
490
491
492 if l.IsClosing() && l.closeErr.Load() != nil {
493 msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
494 }
495 l.Debug.Printf("Closing channel for MessageID %d", messageID)
496 close(msgCtx.responses)
497 delete(l.messageContexts, messageID)
498 }
499 close(l.chanMessageID)
500 close(l.chanConfirm)
501 }()
502
503 var messageID int64 = 1
504 for {
505 select {
506 case l.chanMessageID <- messageID:
507 messageID++
508 case message := <-l.chanMessage:
509 switch message.Op {
510 case MessageQuit:
511 l.Debug.Printf("Shutting down - quit message received")
512 return
513 case MessageRequest:
514
515 l.Debug.Printf("Sending message %d", message.MessageID)
516
517 buf := message.Packet.Bytes()
518 _, err := l.conn.Write(buf)
519 if err != nil {
520 l.Debug.Printf("Error Sending Message: %s", err.Error())
521 message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
522 close(message.Context.responses)
523 break
524 }
525
526
527
528 l.messageContexts[message.MessageID] = message.Context
529
530
531 requestTimeout := l.getTimeout()
532 if requestTimeout > 0 {
533 go func() {
534 timer := time.NewTimer(time.Duration(requestTimeout))
535 defer func() {
536 if err := recover(); err != nil {
537 l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
538 }
539
540 timer.Stop()
541 }()
542
543 select {
544 case <-timer.C:
545 timeoutMessage := &messagePacket{
546 Op: MessageTimeout,
547 MessageID: message.MessageID,
548 }
549 l.sendProcessMessage(timeoutMessage)
550 case <-message.Context.done:
551 }
552 }()
553 }
554 case MessageResponse:
555 l.Debug.Printf("Receiving message %d", message.MessageID)
556 if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
557 msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
558 } else {
559 l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
560 l.Debug.PrintPacket(message.Packet)
561 }
562 case MessageTimeout:
563
564
565 if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
566 l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
567 msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
568 delete(l.messageContexts, message.MessageID)
569 close(msgCtx.responses)
570 }
571 case MessageFinish:
572 l.Debug.Printf("Finished message %d", message.MessageID)
573 if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
574 delete(l.messageContexts, message.MessageID)
575 close(msgCtx.responses)
576 }
577 }
578 }
579 }
580 }
581
582 func (l *Conn) reader() {
583 cleanstop := false
584 defer func() {
585 if err := recover(); err != nil {
586 l.err = fmt.Errorf("ldap: recovered panic in reader: %v", err)
587 }
588 if !cleanstop {
589 l.Close()
590 }
591 }()
592
593 bufConn := bufio.NewReader(l.conn)
594 for {
595 if cleanstop {
596 l.Debug.Printf("reader clean stopping (without closing the connection)")
597 return
598 }
599 packet, err := ber.ReadPacket(bufConn)
600 if err != nil {
601
602 if !l.IsClosing() {
603 l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
604 l.Debug.Printf("reader error: %s", err)
605 }
606 return
607 }
608 if err := addLDAPDescriptions(packet); err != nil {
609 l.Debug.Printf("descriptions error: %s", err)
610 }
611 if len(packet.Children) == 0 {
612 l.Debug.Printf("Received bad ldap packet")
613 continue
614 }
615 l.messageMutex.Lock()
616 if l.isStartingTLS {
617 cleanstop = true
618 }
619 l.messageMutex.Unlock()
620 message := &messagePacket{
621 Op: MessageResponse,
622 MessageID: packet.Children[0].Value.(int64),
623 Packet: packet,
624 }
625 if !l.sendProcessMessage(message) {
626 return
627 }
628 }
629 }
630
View as plain text