1
38 package bitset
39
40 import (
41 "bufio"
42 "bytes"
43 "encoding/base64"
44 "encoding/binary"
45 "encoding/json"
46 "errors"
47 "fmt"
48 "io"
49 "strconv"
50 )
51
52
53 const wordSize = uint(64)
54
55
56 const log2WordSize = uint(6)
57
58
59 const allBits uint64 = 0xffffffffffffffff
60
61
62 var binaryOrder binary.ByteOrder = binary.BigEndian
63
64
65 var base64Encoding = base64.URLEncoding
66
67
68 func Base64StdEncoding() { base64Encoding = base64.StdEncoding }
69
70
71 func LittleEndian() { binaryOrder = binary.LittleEndian }
72
73
74 type BitSet struct {
75 length uint
76 set []uint64
77 }
78
79
80 type Error string
81
82
83 func (b *BitSet) safeSet() []uint64 {
84 if b.set == nil {
85 b.set = make([]uint64, wordsNeeded(0))
86 }
87 return b.set
88 }
89
90
91 func From(buf []uint64) *BitSet {
92 return &BitSet{uint(len(buf)) * 64, buf}
93 }
94
95
96 func (b *BitSet) Bytes() []uint64 {
97 return b.set
98 }
99
100
101 func wordsNeeded(i uint) int {
102 if i > (Cap() - wordSize + 1) {
103 return int(Cap() >> log2WordSize)
104 }
105 return int((i + (wordSize - 1)) >> log2WordSize)
106 }
107
108
109 func New(length uint) (bset *BitSet) {
110 defer func() {
111 if r := recover(); r != nil {
112 bset = &BitSet{
113 0,
114 make([]uint64, 0),
115 }
116 }
117 }()
118
119 bset = &BitSet{
120 length,
121 make([]uint64, wordsNeeded(length)),
122 }
123
124 return bset
125 }
126
127
128 func Cap() uint {
129 return ^uint(0)
130 }
131
132
133
134 func (b *BitSet) Len() uint {
135 return b.length
136 }
137
138
139 func (b *BitSet) extendSetMaybe(i uint) {
140 if i >= b.length {
141 if i >= Cap() {
142 panic("You are exceeding the capacity")
143 }
144 nsize := wordsNeeded(i + 1)
145 if b.set == nil {
146 b.set = make([]uint64, nsize)
147 } else if cap(b.set) >= nsize {
148 b.set = b.set[:nsize]
149 } else if len(b.set) < nsize {
150 newset := make([]uint64, nsize, 2*nsize)
151 copy(newset, b.set)
152 b.set = newset
153 }
154 b.length = i + 1
155 }
156 }
157
158
159 func (b *BitSet) Test(i uint) bool {
160 if i >= b.length {
161 return false
162 }
163 return b.set[i>>log2WordSize]&(1<<(i&(wordSize-1))) != 0
164 }
165
166
167
168
169
170
171
172 func (b *BitSet) Set(i uint) *BitSet {
173 b.extendSetMaybe(i)
174 b.set[i>>log2WordSize] |= 1 << (i & (wordSize - 1))
175 return b
176 }
177
178
179 func (b *BitSet) Clear(i uint) *BitSet {
180 if i >= b.length {
181 return b
182 }
183 b.set[i>>log2WordSize] &^= 1 << (i & (wordSize - 1))
184 return b
185 }
186
187
188
189
190
191
192 func (b *BitSet) SetTo(i uint, value bool) *BitSet {
193 if value {
194 return b.Set(i)
195 }
196 return b.Clear(i)
197 }
198
199
200
201
202
203
204 func (b *BitSet) Flip(i uint) *BitSet {
205 if i >= b.length {
206 return b.Set(i)
207 }
208 b.set[i>>log2WordSize] ^= 1 << (i & (wordSize - 1))
209 return b
210 }
211
212
213
214
215
216
217 func (b *BitSet) FlipRange(start, end uint) *BitSet {
218 if start >= end {
219 return b
220 }
221
222 b.extendSetMaybe(end - 1)
223 var startWord uint = start >> log2WordSize
224 var endWord uint = end >> log2WordSize
225 b.set[startWord] ^= ^(^uint64(0) << (start & (wordSize - 1)))
226 for i := startWord; i < endWord; i++ {
227 b.set[i] = ^b.set[i]
228 }
229 b.set[endWord] ^= ^uint64(0) >> (-end & (wordSize - 1))
230 return b
231 }
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247 func (b *BitSet) Shrink(lastbitindex uint) *BitSet {
248 length := lastbitindex + 1
249 idx := wordsNeeded(length)
250 if idx > len(b.set) {
251 return b
252 }
253 shrunk := make([]uint64, idx)
254 copy(shrunk, b.set[:idx])
255 b.set = shrunk
256 b.length = length
257 b.set[idx-1] &= (allBits >> (uint64(64) - uint64(length&(wordSize-1))))
258 return b
259 }
260
261
262
263 func (b *BitSet) Compact() *BitSet {
264 idx := len(b.set) - 1
265 for ; idx >= 0 && b.set[idx] == 0; idx-- {
266 }
267 newlength := uint((idx + 1) << log2WordSize)
268 if newlength >= b.length {
269 return b
270 }
271 if newlength > 0 {
272 return b.Shrink(newlength - 1)
273 }
274
275 return b.Shrink(63)
276 }
277
278
279
280
281
282
283
284
285 func (b *BitSet) InsertAt(idx uint) *BitSet {
286 insertAtElement := (idx >> log2WordSize)
287
288
289 if b.isLenExactMultiple() {
290 b.set = append(b.set, uint64(0))
291 }
292
293 var i uint
294 for i = uint(len(b.set) - 1); i > insertAtElement; i-- {
295
296 b.set[i] <<= 1
297
298
299
300 b.set[i] |= (b.set[i-1] & 0x8000000000000000) >> 63
301 }
302
303
304
305 dataMask := ^(uint64(1)<<uint64(idx&(wordSize-1)) - 1)
306
307
308 data := b.set[i] & dataMask
309
310
311 b.set[i] &= ^dataMask
312
313
314 b.set[i] |= data << 1
315
316
317 b.length++
318
319 return b
320 }
321
322
323 func (b *BitSet) String() string {
324
325 var buffer bytes.Buffer
326 start := []byte("{")
327 buffer.Write(start)
328 counter := 0
329 i, e := b.NextSet(0)
330 for e {
331 counter = counter + 1
332
333 if counter > 0x40000 {
334 buffer.WriteString("...")
335 break
336 }
337 buffer.WriteString(strconv.FormatInt(int64(i), 10))
338 i, e = b.NextSet(i + 1)
339 if e {
340 buffer.WriteString(",")
341 }
342 }
343 buffer.WriteString("}")
344 return buffer.String()
345 }
346
347
348
349
350
351
352
353 func (b *BitSet) DeleteAt(i uint) *BitSet {
354
355 deleteAtElement := i >> log2WordSize
356
357
358
359 dataMask := ^((uint64(1) << (i & (wordSize - 1))) - 1)
360
361
362 data := b.set[deleteAtElement] & dataMask
363
364
365 b.set[deleteAtElement] &= ^dataMask
366
367
368
369 b.set[deleteAtElement] |= (data >> 1) & dataMask
370
371
372
373
374 for i := int(deleteAtElement) + 1; i < len(b.set); i++ {
375 b.set[i-1] |= (b.set[i] & 1) << 63
376 b.set[i] >>= 1
377 }
378
379 b.length = b.length - 1
380
381 return b
382 }
383
384
385
386
387
388
389
390
391 func (b *BitSet) NextSet(i uint) (uint, bool) {
392 x := int(i >> log2WordSize)
393 if x >= len(b.set) {
394 return 0, false
395 }
396 w := b.set[x]
397 w = w >> (i & (wordSize - 1))
398 if w != 0 {
399 return i + trailingZeroes64(w), true
400 }
401 x = x + 1
402 for x < len(b.set) {
403 if b.set[x] != 0 {
404 return uint(x)*wordSize + trailingZeroes64(b.set[x]), true
405 }
406 x = x + 1
407
408 }
409 return 0, false
410 }
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434 func (b *BitSet) NextSetMany(i uint, buffer []uint) (uint, []uint) {
435 myanswer := buffer
436 capacity := cap(buffer)
437 x := int(i >> log2WordSize)
438 if x >= len(b.set) || capacity == 0 {
439 return 0, myanswer[:0]
440 }
441 skip := i & (wordSize - 1)
442 word := b.set[x] >> skip
443 myanswer = myanswer[:capacity]
444 size := int(0)
445 for word != 0 {
446 r := trailingZeroes64(word)
447 t := word & ((^word) + 1)
448 myanswer[size] = r + i
449 size++
450 if size == capacity {
451 goto End
452 }
453 word = word ^ t
454 }
455 x++
456 for idx, word := range b.set[x:] {
457 for word != 0 {
458 r := trailingZeroes64(word)
459 t := word & ((^word) + 1)
460 myanswer[size] = r + (uint(x+idx) << 6)
461 size++
462 if size == capacity {
463 goto End
464 }
465 word = word ^ t
466 }
467 }
468 End:
469 if size > 0 {
470 return myanswer[size-1], myanswer[:size]
471 }
472 return 0, myanswer[:0]
473 }
474
475
476
477
478 func (b *BitSet) NextClear(i uint) (uint, bool) {
479 x := int(i >> log2WordSize)
480 if x >= len(b.set) {
481 return 0, false
482 }
483 w := b.set[x]
484 w = w >> (i & (wordSize - 1))
485 wA := allBits >> (i & (wordSize - 1))
486 index := i + trailingZeroes64(^w)
487 if w != wA && index < b.length {
488 return index, true
489 }
490 x++
491 for x < len(b.set) {
492 index = uint(x)*wordSize + trailingZeroes64(^b.set[x])
493 if b.set[x] != allBits && index < b.length {
494 return index, true
495 }
496 x++
497 }
498 return 0, false
499 }
500
501
502 func (b *BitSet) ClearAll() *BitSet {
503 if b != nil && b.set != nil {
504 for i := range b.set {
505 b.set[i] = 0
506 }
507 }
508 return b
509 }
510
511
512 func (b *BitSet) wordCount() int {
513 return len(b.set)
514 }
515
516
517 func (b *BitSet) Clone() *BitSet {
518 c := New(b.length)
519 if b.set != nil {
520 copy(c.set, b.set)
521 }
522 return c
523 }
524
525
526
527
528 func (b *BitSet) Copy(c *BitSet) (count uint) {
529 if c == nil {
530 return
531 }
532 if b.set != nil {
533 copy(c.set, b.set)
534 }
535 count = c.length
536 if b.length < c.length {
537 count = b.length
538 }
539 return
540 }
541
542
543
544 func (b *BitSet) Count() uint {
545 if b != nil && b.set != nil {
546 return uint(popcntSlice(b.set))
547 }
548 return 0
549 }
550
551
552
553
554 func (b *BitSet) Equal(c *BitSet) bool {
555 if c == nil || b == nil {
556 return c == b
557 }
558 if b.length != c.length {
559 return false
560 }
561 if b.length == 0 {
562 return true
563 }
564
565
566 for p, v := range b.set {
567 if c.set[p] != v {
568 return false
569 }
570 }
571 return true
572 }
573
574 func panicIfNull(b *BitSet) {
575 if b == nil {
576 panic(Error("BitSet must not be null"))
577 }
578 }
579
580
581
582 func (b *BitSet) Difference(compare *BitSet) (result *BitSet) {
583 panicIfNull(b)
584 panicIfNull(compare)
585 result = b.Clone()
586 l := int(compare.wordCount())
587 if l > int(b.wordCount()) {
588 l = int(b.wordCount())
589 }
590 for i := 0; i < l; i++ {
591 result.set[i] = b.set[i] &^ compare.set[i]
592 }
593 return
594 }
595
596
597 func (b *BitSet) DifferenceCardinality(compare *BitSet) uint {
598 panicIfNull(b)
599 panicIfNull(compare)
600 l := int(compare.wordCount())
601 if l > int(b.wordCount()) {
602 l = int(b.wordCount())
603 }
604 cnt := uint64(0)
605 cnt += popcntMaskSlice(b.set[:l], compare.set[:l])
606 cnt += popcntSlice(b.set[l:])
607 return uint(cnt)
608 }
609
610
611
612 func (b *BitSet) InPlaceDifference(compare *BitSet) {
613 panicIfNull(b)
614 panicIfNull(compare)
615 l := int(compare.wordCount())
616 if l > int(b.wordCount()) {
617 l = int(b.wordCount())
618 }
619 for i := 0; i < l; i++ {
620 b.set[i] &^= compare.set[i]
621 }
622 }
623
624
625
626 func sortByLength(a *BitSet, b *BitSet) (ap *BitSet, bp *BitSet) {
627 if a.length <= b.length {
628 ap, bp = a, b
629 } else {
630 ap, bp = b, a
631 }
632 return
633 }
634
635
636
637 func (b *BitSet) Intersection(compare *BitSet) (result *BitSet) {
638 panicIfNull(b)
639 panicIfNull(compare)
640 b, compare = sortByLength(b, compare)
641 result = New(b.length)
642 for i, word := range b.set {
643 result.set[i] = word & compare.set[i]
644 }
645 return
646 }
647
648
649 func (b *BitSet) IntersectionCardinality(compare *BitSet) uint {
650 panicIfNull(b)
651 panicIfNull(compare)
652 b, compare = sortByLength(b, compare)
653 cnt := popcntAndSlice(b.set, compare.set)
654 return uint(cnt)
655 }
656
657
658
659
660 func (b *BitSet) InPlaceIntersection(compare *BitSet) {
661 panicIfNull(b)
662 panicIfNull(compare)
663 l := int(compare.wordCount())
664 if l > int(b.wordCount()) {
665 l = int(b.wordCount())
666 }
667 for i := 0; i < l; i++ {
668 b.set[i] &= compare.set[i]
669 }
670 for i := l; i < len(b.set); i++ {
671 b.set[i] = 0
672 }
673 if compare.length > 0 {
674 b.extendSetMaybe(compare.length - 1)
675 }
676 }
677
678
679
680 func (b *BitSet) Union(compare *BitSet) (result *BitSet) {
681 panicIfNull(b)
682 panicIfNull(compare)
683 b, compare = sortByLength(b, compare)
684 result = compare.Clone()
685 for i, word := range b.set {
686 result.set[i] = word | compare.set[i]
687 }
688 return
689 }
690
691
692
693 func (b *BitSet) UnionCardinality(compare *BitSet) uint {
694 panicIfNull(b)
695 panicIfNull(compare)
696 b, compare = sortByLength(b, compare)
697 cnt := popcntOrSlice(b.set, compare.set)
698 if len(compare.set) > len(b.set) {
699 cnt += popcntSlice(compare.set[len(b.set):])
700 }
701 return uint(cnt)
702 }
703
704
705
706 func (b *BitSet) InPlaceUnion(compare *BitSet) {
707 panicIfNull(b)
708 panicIfNull(compare)
709 l := int(compare.wordCount())
710 if l > int(b.wordCount()) {
711 l = int(b.wordCount())
712 }
713 if compare.length > 0 {
714 b.extendSetMaybe(compare.length - 1)
715 }
716 for i := 0; i < l; i++ {
717 b.set[i] |= compare.set[i]
718 }
719 if len(compare.set) > l {
720 for i := l; i < len(compare.set); i++ {
721 b.set[i] = compare.set[i]
722 }
723 }
724 }
725
726
727
728 func (b *BitSet) SymmetricDifference(compare *BitSet) (result *BitSet) {
729 panicIfNull(b)
730 panicIfNull(compare)
731 b, compare = sortByLength(b, compare)
732
733 result = compare.Clone()
734 for i, word := range b.set {
735 result.set[i] = word ^ compare.set[i]
736 }
737 return
738 }
739
740
741 func (b *BitSet) SymmetricDifferenceCardinality(compare *BitSet) uint {
742 panicIfNull(b)
743 panicIfNull(compare)
744 b, compare = sortByLength(b, compare)
745 cnt := popcntXorSlice(b.set, compare.set)
746 if len(compare.set) > len(b.set) {
747 cnt += popcntSlice(compare.set[len(b.set):])
748 }
749 return uint(cnt)
750 }
751
752
753
754 func (b *BitSet) InPlaceSymmetricDifference(compare *BitSet) {
755 panicIfNull(b)
756 panicIfNull(compare)
757 l := int(compare.wordCount())
758 if l > int(b.wordCount()) {
759 l = int(b.wordCount())
760 }
761 if compare.length > 0 {
762 b.extendSetMaybe(compare.length - 1)
763 }
764 for i := 0; i < l; i++ {
765 b.set[i] ^= compare.set[i]
766 }
767 if len(compare.set) > l {
768 for i := l; i < len(compare.set); i++ {
769 b.set[i] = compare.set[i]
770 }
771 }
772 }
773
774
775 func (b *BitSet) isLenExactMultiple() bool {
776 return b.length%wordSize == 0
777 }
778
779
780 func (b *BitSet) cleanLastWord() {
781 if !b.isLenExactMultiple() {
782 b.set[len(b.set)-1] &= allBits >> (wordSize - b.length%wordSize)
783 }
784 }
785
786
787 func (b *BitSet) Complement() (result *BitSet) {
788 panicIfNull(b)
789 result = New(b.length)
790 for i, word := range b.set {
791 result.set[i] = ^word
792 }
793 result.cleanLastWord()
794 return
795 }
796
797
798
799 func (b *BitSet) All() bool {
800 panicIfNull(b)
801 return b.Count() == b.length
802 }
803
804
805
806 func (b *BitSet) None() bool {
807 panicIfNull(b)
808 if b != nil && b.set != nil {
809 for _, word := range b.set {
810 if word > 0 {
811 return false
812 }
813 }
814 return true
815 }
816 return true
817 }
818
819
820 func (b *BitSet) Any() bool {
821 panicIfNull(b)
822 return !b.None()
823 }
824
825
826 func (b *BitSet) IsSuperSet(other *BitSet) bool {
827 for i, e := other.NextSet(0); e; i, e = other.NextSet(i + 1) {
828 if !b.Test(i) {
829 return false
830 }
831 }
832 return true
833 }
834
835
836 func (b *BitSet) IsStrictSuperSet(other *BitSet) bool {
837 return b.Count() > other.Count() && b.IsSuperSet(other)
838 }
839
840
841 func (b *BitSet) DumpAsBits() string {
842 if b.set == nil {
843 return "."
844 }
845 buffer := bytes.NewBufferString("")
846 i := len(b.set) - 1
847 for ; i >= 0; i-- {
848 fmt.Fprintf(buffer, "%064b.", b.set[i])
849 }
850 return buffer.String()
851 }
852
853
854 func (b *BitSet) BinaryStorageSize() int {
855 return binary.Size(uint64(0)) + binary.Size(b.set)
856 }
857
858
859 func (b *BitSet) WriteTo(stream io.Writer) (int64, error) {
860 length := uint64(b.length)
861
862
863 err := binary.Write(stream, binaryOrder, length)
864 if err != nil {
865 return 0, err
866 }
867
868
869 err = binary.Write(stream, binaryOrder, b.set)
870 return int64(b.BinaryStorageSize()), err
871 }
872
873
874 func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
875 var length uint64
876
877
878 err := binary.Read(stream, binaryOrder, &length)
879 if err != nil {
880 return 0, err
881 }
882 newset := New(uint(length))
883
884 if uint64(newset.length) != length {
885 return 0, errors.New("unmarshalling error: type mismatch")
886 }
887
888
889 err = binary.Read(stream, binaryOrder, newset.set)
890 if err != nil {
891 return 0, err
892 }
893
894 *b = *newset
895 return int64(b.BinaryStorageSize()), nil
896 }
897
898
899 func (b *BitSet) MarshalBinary() ([]byte, error) {
900 var buf bytes.Buffer
901 writer := bufio.NewWriter(&buf)
902
903 _, err := b.WriteTo(writer)
904 if err != nil {
905 return []byte{}, err
906 }
907
908 err = writer.Flush()
909
910 return buf.Bytes(), err
911 }
912
913
914 func (b *BitSet) UnmarshalBinary(data []byte) error {
915 buf := bytes.NewReader(data)
916 reader := bufio.NewReader(buf)
917
918 _, err := b.ReadFrom(reader)
919
920 return err
921 }
922
923
924 func (b *BitSet) MarshalJSON() ([]byte, error) {
925 buffer := bytes.NewBuffer(make([]byte, 0, b.BinaryStorageSize()))
926 _, err := b.WriteTo(buffer)
927 if err != nil {
928 return nil, err
929 }
930
931
932 return json.Marshal(base64Encoding.EncodeToString(buffer.Bytes()))
933 }
934
935
936 func (b *BitSet) UnmarshalJSON(data []byte) error {
937
938 var s string
939 err := json.Unmarshal(data, &s)
940 if err != nil {
941 return err
942 }
943
944
945 buf, err := base64Encoding.DecodeString(s)
946 if err != nil {
947 return err
948 }
949
950 _, err = b.ReadFrom(bytes.NewReader(buf))
951 return err
952 }
953
View as plain text