1 package ldap
2
3 import (
4 "errors"
5 "io"
6 "net"
7 "strings"
8 "testing"
9 "time"
10
11 ber "github.com/go-asn1-ber/asn1-ber"
12 )
13
14
15 func TestNilPacket(t *testing.T) {
16
17 err := GetLDAPError(nil)
18 if !IsErrorWithCode(err, ErrorUnexpectedResponse) {
19 t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err)
20 }
21
22
23 kids := []*ber.Packet{
24 {},
25 nil,
26 }
27 pack := &ber.Packet{Children: kids}
28 err = GetLDAPError(pack)
29
30 if !IsErrorWithCode(err, ErrorUnexpectedResponse) {
31 t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err)
32 }
33 }
34
35
36
37 func TestConnReadErr(t *testing.T) {
38 conn := &signalErrConn{
39 signals: make(chan error),
40 }
41
42 ldapConn := NewConn(conn, false)
43 ldapConn.Start()
44
45
46 searchReq := NewSearchRequest("dc=example,dc=com", ScopeWholeSubtree, DerefAlways, 0, 0, false, "(objectClass=*)", nil, nil)
47
48 expectedError := errors.New("this is the error you are looking for")
49
50
51 time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError })
52
53
54
55
56 _, err := ldapConn.Search(searchReq)
57 if err == nil || !strings.Contains(err.Error(), expectedError.Error()) {
58 t.Errorf("not the expected error: %s", err)
59 }
60 }
61
62
63 func TestGetLDAPError(t *testing.T) {
64 diagnosticMessage := "Detailed error message"
65 bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
66 bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode"))
67 bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "dc=example,dc=org", "matchedDN"))
68 bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, diagnosticMessage, "diagnosticMessage"))
69 packet := ber.NewSequence("LDAPMessage")
70 packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID"))
71 packet.AppendChild(bindResponse)
72 err := GetLDAPError(packet)
73 if err == nil {
74 t.Errorf("Did not get error response")
75 }
76
77 ldapError := err.(*Error)
78 if ldapError.ResultCode != LDAPResultInvalidCredentials {
79 t.Errorf("Got incorrect error code in LDAP error; got %v, expected %v", ldapError.ResultCode, LDAPResultInvalidCredentials)
80 }
81 if ldapError.Err.Error() != diagnosticMessage {
82 t.Errorf("Got incorrect error message in LDAP error; got %v, expected %v", ldapError.Err.Error(), diagnosticMessage)
83 }
84 }
85
86
87
88 func TestGetLDAPErrorInvalidResponse(t *testing.T) {
89 bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
90 bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "dc=example,dc=org", "matchedDN"))
91 bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode"))
92 bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode"))
93 packet := ber.NewSequence("LDAPMessage")
94 packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID"))
95 packet.AppendChild(bindResponse)
96 err := GetLDAPError(packet)
97 if err == nil {
98 t.Errorf("Did not get error response")
99 }
100
101 ldapError := err.(*Error)
102 if ldapError.ResultCode != ErrorNetwork {
103 t.Errorf("Got incorrect error code in LDAP error; got %v, expected %v", ldapError.ResultCode, ErrorNetwork)
104 }
105 }
106
107 func TestErrorIs(t *testing.T) {
108 err := NewError(ErrorNetwork, io.EOF)
109 if !errors.Is(err, io.EOF) {
110 t.Errorf("Expected an io.EOF error: %v", err)
111 }
112 }
113
114 func TestErrorAs(t *testing.T) {
115 var netErr net.InvalidAddrError = "invalid addr"
116 err := NewError(ErrorNetwork, netErr)
117
118 var target net.InvalidAddrError
119 ok := errors.As(err, &target)
120 if !ok {
121 t.Error("Expected an InvalidAddrError")
122 }
123 }
124
125
126 func TestGetLDAPErrorSuccess(t *testing.T) {
127 bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
128 bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "resultCode"))
129 bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN"))
130 bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "diagnosticMessage"))
131 packet := ber.NewSequence("LDAPMessage")
132 packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID"))
133 packet.AppendChild(bindResponse)
134 err := GetLDAPError(packet)
135 if err != nil {
136 t.Errorf("Successful responses should not produce an error, but got: %v", err)
137 }
138 }
139
140
141
142
143
144 type signalErrConn struct {
145 signals chan error
146 }
147
148
149
150 func (c *signalErrConn) Read(b []byte) (n int, err error) {
151 return 0, <-c.signals
152 }
153
154 func (c *signalErrConn) Write(b []byte) (n int, err error) {
155 return len(b), nil
156 }
157
158 func (c *signalErrConn) Close() error {
159 close(c.signals)
160 return nil
161 }
162
163 func (c *signalErrConn) LocalAddr() net.Addr {
164 return (*net.TCPAddr)(nil)
165 }
166
167 func (c *signalErrConn) RemoteAddr() net.Addr {
168 return (*net.TCPAddr)(nil)
169 }
170
171 func (c *signalErrConn) SetDeadline(t time.Time) error {
172 return nil
173 }
174
175 func (c *signalErrConn) SetReadDeadline(t time.Time) error {
176 return nil
177 }
178
179 func (c *signalErrConn) SetWriteDeadline(t time.Time) error {
180 return nil
181 }
182
View as plain text