1 package ldap
2
3 import (
4 "bytes"
5 enchex "encoding/hex"
6 "errors"
7 "fmt"
8 "sort"
9 "strings"
10
11 ber "github.com/go-asn1-ber/asn1-ber"
12 )
13
14
15 type AttributeTypeAndValue struct {
16
17 Type string
18
19 Value string
20 }
21
22
23
24 func (a *AttributeTypeAndValue) String() string {
25 return strings.ToLower(a.Type) + "=" + a.encodeValue()
26 }
27
28 func (a *AttributeTypeAndValue) encodeValue() string {
29
30
31 value := a.Value
32
33 encodedBuf := bytes.Buffer{}
34
35 escapeChar := func(c byte) {
36 encodedBuf.WriteByte('\\')
37 encodedBuf.WriteByte(c)
38 }
39
40 escapeHex := func(c byte) {
41 encodedBuf.WriteByte('\\')
42 encodedBuf.WriteString(enchex.EncodeToString([]byte{c}))
43 }
44
45 for i := 0; i < len(value); i++ {
46 char := value[i]
47 if i == 0 && char == ' ' || char == '#' {
48
49 escapeChar(char)
50 continue
51 }
52 if i == len(value)-1 && char == ' ' {
53
54 escapeChar(char)
55 continue
56 }
57
58 switch char {
59 case '"', '+', ',', ';', '<', '>', '\\':
60
61 escapeChar(char)
62 continue
63 }
64
65 if char < ' ' || char > '~' {
66
67
68
69 escapeHex(char)
70 continue
71 }
72
73
74 encodedBuf.WriteByte(char)
75 }
76
77 return encodedBuf.String()
78 }
79
80
81 type RelativeDN struct {
82 Attributes []*AttributeTypeAndValue
83 }
84
85
86
87 func (r *RelativeDN) String() string {
88 attrs := make([]string, len(r.Attributes))
89 for i := range r.Attributes {
90 attrs[i] = r.Attributes[i].String()
91 }
92 sort.Strings(attrs)
93 return strings.Join(attrs, "+")
94 }
95
96
97 type DN struct {
98 RDNs []*RelativeDN
99 }
100
101
102
103 func (d *DN) String() string {
104 rdns := make([]string, len(d.RDNs))
105 for i := range d.RDNs {
106 rdns[i] = d.RDNs[i].String()
107 }
108 return strings.Join(rdns, ",")
109 }
110
111
112
113 func ParseDN(str string) (*DN, error) {
114 dn := new(DN)
115 dn.RDNs = make([]*RelativeDN, 0)
116 rdn := new(RelativeDN)
117 rdn.Attributes = make([]*AttributeTypeAndValue, 0)
118 buffer := bytes.Buffer{}
119 attribute := new(AttributeTypeAndValue)
120 escaping := false
121
122 unescapedTrailingSpaces := 0
123 stringFromBuffer := func() string {
124 s := buffer.String()
125 s = s[0 : len(s)-unescapedTrailingSpaces]
126 buffer.Reset()
127 unescapedTrailingSpaces = 0
128 return s
129 }
130
131 for i := 0; i < len(str); i++ {
132 char := str[i]
133 switch {
134 case escaping:
135 unescapedTrailingSpaces = 0
136 escaping = false
137 switch char {
138 case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
139 buffer.WriteByte(char)
140 continue
141 }
142
143 if len(str) == i+1 {
144 return nil, errors.New("got corrupted escaped character")
145 }
146
147 dst := []byte{0}
148 n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
149 if err != nil {
150 return nil, fmt.Errorf("failed to decode escaped character: %s", err)
151 } else if n != 1 {
152 return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n)
153 }
154 buffer.WriteByte(dst[0])
155 i++
156 case char == '\\':
157 unescapedTrailingSpaces = 0
158 escaping = true
159 case char == '=' && attribute.Type == "":
160 attribute.Type = stringFromBuffer()
161
162
163
164 if len(str) > i+1 && str[i+1] == '#' {
165 i += 2
166 index := strings.IndexAny(str[i:], ",+")
167 var data string
168 if index > 0 {
169 data = str[i : i+index]
170 } else {
171 data = str[i:]
172 }
173 rawBER, err := enchex.DecodeString(data)
174 if err != nil {
175 return nil, fmt.Errorf("failed to decode BER encoding: %s", err)
176 }
177 packet, err := ber.DecodePacketErr(rawBER)
178 if err != nil {
179 return nil, fmt.Errorf("failed to decode BER packet: %s", err)
180 }
181 buffer.WriteString(packet.Data.String())
182 i += len(data) - 1
183 }
184 case char == ',' || char == '+' || char == ';':
185
186 if len(attribute.Type) == 0 {
187 return nil, errors.New("incomplete type, value pair")
188 }
189 attribute.Value = stringFromBuffer()
190 rdn.Attributes = append(rdn.Attributes, attribute)
191 attribute = new(AttributeTypeAndValue)
192 if char == ',' || char == ';' {
193 dn.RDNs = append(dn.RDNs, rdn)
194 rdn = new(RelativeDN)
195 rdn.Attributes = make([]*AttributeTypeAndValue, 0)
196 }
197 case char == ' ' && buffer.Len() == 0:
198
199 continue
200 default:
201 if char == ' ' {
202
203 unescapedTrailingSpaces++
204 } else {
205
206 unescapedTrailingSpaces = 0
207 }
208 buffer.WriteByte(char)
209 }
210 }
211 if buffer.Len() > 0 {
212 if len(attribute.Type) == 0 {
213 return nil, errors.New("DN ended with incomplete type, value pair")
214 }
215 attribute.Value = stringFromBuffer()
216 rdn.Attributes = append(rdn.Attributes, attribute)
217 dn.RDNs = append(dn.RDNs, rdn)
218 }
219 return dn, nil
220 }
221
222
223
224
225 func (d *DN) Equal(other *DN) bool {
226 if len(d.RDNs) != len(other.RDNs) {
227 return false
228 }
229 for i := range d.RDNs {
230 if !d.RDNs[i].Equal(other.RDNs[i]) {
231 return false
232 }
233 }
234 return true
235 }
236
237
238
239
240
241 func (d *DN) AncestorOf(other *DN) bool {
242 if len(d.RDNs) >= len(other.RDNs) {
243 return false
244 }
245
246 otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
247 for i := range d.RDNs {
248 if !d.RDNs[i].Equal(otherRDNs[i]) {
249 return false
250 }
251 }
252 return true
253 }
254
255
256
257
258
259
260 func (r *RelativeDN) Equal(other *RelativeDN) bool {
261 if len(r.Attributes) != len(other.Attributes) {
262 return false
263 }
264 return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
265 }
266
267 func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
268 for _, attr := range attrs {
269 found := false
270 for _, myattr := range r.Attributes {
271 if myattr.Equal(attr) {
272 found = true
273 break
274 }
275 }
276 if !found {
277 return false
278 }
279 }
280 return true
281 }
282
283
284
285 func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
286 return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
287 }
288
289
290
291
292
293 func (d *DN) EqualFold(other *DN) bool {
294 if len(d.RDNs) != len(other.RDNs) {
295 return false
296 }
297 for i := range d.RDNs {
298 if !d.RDNs[i].EqualFold(other.RDNs[i]) {
299 return false
300 }
301 }
302 return true
303 }
304
305
306
307 func (d *DN) AncestorOfFold(other *DN) bool {
308 if len(d.RDNs) >= len(other.RDNs) {
309 return false
310 }
311
312 otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
313 for i := range d.RDNs {
314 if !d.RDNs[i].EqualFold(otherRDNs[i]) {
315 return false
316 }
317 }
318 return true
319 }
320
321
322
323 func (r *RelativeDN) EqualFold(other *RelativeDN) bool {
324 if len(r.Attributes) != len(other.Attributes) {
325 return false
326 }
327 return r.hasAllAttributesFold(other.Attributes) && other.hasAllAttributesFold(r.Attributes)
328 }
329
330 func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
331 for _, attr := range attrs {
332 found := false
333 for _, myattr := range r.Attributes {
334 if myattr.EqualFold(attr) {
335 found = true
336 break
337 }
338 }
339 if !found {
340 return false
341 }
342 }
343 return true
344 }
345
346
347
348 func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
349 return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
350 }
351
View as plain text