1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 package btree
51
52 import (
53 "sort"
54 "sync"
55 )
56
57
58 type Key interface{}
59
60
61 type Value interface{}
62
63
64 type item struct {
65 key Key
66 value Value
67 }
68
69 type lessFunc func(interface{}, interface{}) bool
70
71
72
73
74
75
76
77
78
79
80 func New(degree int, less func(interface{}, interface{}) bool) *BTree {
81 if degree <= 1 {
82 panic("bad degree")
83 }
84 return &BTree{
85 degree: degree,
86 less: less,
87 cow: ©OnWriteContext{},
88 }
89 }
90
91
92 type items []item
93
94
95
96 func (s *items) insertAt(index int, m item) {
97 *s = append(*s, item{})
98 if index < len(*s) {
99 copy((*s)[index+1:], (*s)[index:])
100 }
101 (*s)[index] = m
102 }
103
104
105
106 func (s *items) removeAt(index int) item {
107 m := (*s)[index]
108 copy((*s)[index:], (*s)[index+1:])
109 (*s)[len(*s)-1] = item{}
110 *s = (*s)[:len(*s)-1]
111 return m
112 }
113
114
115 func (s *items) pop() item {
116 index := len(*s) - 1
117 out := (*s)[index]
118 (*s)[index] = item{}
119 *s = (*s)[:index]
120 return out
121 }
122
123 var nilItems = make(items, 16)
124
125
126
127 func (s *items) truncate(index int) {
128 var toClear items
129 *s, toClear = (*s)[:index], (*s)[index:]
130 for len(toClear) > 0 {
131 toClear = toClear[copy(toClear, nilItems):]
132 }
133 }
134
135
136
137
138 func (s items) find(k Key, less lessFunc) (index int, found bool) {
139 i := sort.Search(len(s), func(i int) bool { return less(k, s[i].key) })
140
141 if i > 0 && !less(s[i-1].key, k) {
142 return i - 1, true
143 }
144 return i, false
145 }
146
147
148 type children []*node
149
150
151
152 func (s *children) insertAt(index int, n *node) {
153 *s = append(*s, nil)
154 if index < len(*s) {
155 copy((*s)[index+1:], (*s)[index:])
156 }
157 (*s)[index] = n
158 }
159
160
161
162 func (s *children) removeAt(index int) *node {
163 n := (*s)[index]
164 copy((*s)[index:], (*s)[index+1:])
165 (*s)[len(*s)-1] = nil
166 *s = (*s)[:len(*s)-1]
167 return n
168 }
169
170
171 func (s *children) pop() (out *node) {
172 index := len(*s) - 1
173 out = (*s)[index]
174 (*s)[index] = nil
175 *s = (*s)[:index]
176 return
177 }
178
179 var nilChildren = make(children, 16)
180
181
182
183 func (s *children) truncate(index int) {
184 var toClear children
185 *s, toClear = (*s)[:index], (*s)[index:]
186 for len(toClear) > 0 {
187 toClear = toClear[copy(toClear, nilChildren):]
188 }
189 }
190
191
192
193
194
195
196 type node struct {
197 items items
198 children children
199 size int
200 cow *copyOnWriteContext
201 }
202
203 func (n *node) computeSize() int {
204 sz := len(n.items)
205 for _, c := range n.children {
206 sz += c.size
207 }
208 return sz
209 }
210
211 func (n *node) mutableFor(cow *copyOnWriteContext) *node {
212 if n.cow == cow {
213 return n
214 }
215 out := cow.newNode()
216 if cap(out.items) >= len(n.items) {
217 out.items = out.items[:len(n.items)]
218 } else {
219 out.items = make(items, len(n.items), cap(n.items))
220 }
221 copy(out.items, n.items)
222
223 if cap(out.children) >= len(n.children) {
224 out.children = out.children[:len(n.children)]
225 } else {
226 out.children = make(children, len(n.children), cap(n.children))
227 }
228 copy(out.children, n.children)
229 out.size = n.size
230 return out
231 }
232
233 func (n *node) mutableChild(i int) *node {
234 c := n.children[i].mutableFor(n.cow)
235 n.children[i] = c
236 return c
237 }
238
239
240
241
242 func (n *node) split(i int) (item, *node) {
243 item := n.items[i]
244 next := n.cow.newNode()
245 next.items = append(next.items, n.items[i+1:]...)
246 n.items.truncate(i)
247 if len(n.children) > 0 {
248 next.children = append(next.children, n.children[i+1:]...)
249 n.children.truncate(i + 1)
250 }
251 n.size = n.computeSize()
252 next.size = next.computeSize()
253 return item, next
254 }
255
256
257
258 func (n *node) maybeSplitChild(i, maxItems int) bool {
259 if len(n.children[i].items) < maxItems {
260 return false
261 }
262 first := n.mutableChild(i)
263 item, second := first.split(maxItems / 2)
264 n.items.insertAt(i, item)
265 n.children.insertAt(i+1, second)
266
267 return true
268 }
269
270
271
272
273
274
275 func (n *node) insert(m item, maxItems int, less lessFunc, computeIndex bool) (old Value, present bool, idx int) {
276 i, found := n.items.find(m.key, less)
277 if found {
278 out := n.items[i]
279 n.items[i] = m
280 if computeIndex {
281 idx = n.itemIndex(i)
282 }
283 return out.value, true, idx
284 }
285 if len(n.children) == 0 {
286 n.items.insertAt(i, m)
287 n.size++
288 return old, false, i
289 }
290 if n.maybeSplitChild(i, maxItems) {
291 inTree := n.items[i]
292 switch {
293 case less(m.key, inTree.key):
294
295 case less(inTree.key, m.key):
296 i++
297 default:
298 out := n.items[i]
299 n.items[i] = m
300 if computeIndex {
301 idx = n.itemIndex(i)
302 }
303 return out.value, true, idx
304 }
305 }
306 old, present, idx = n.mutableChild(i).insert(m, maxItems, less, computeIndex)
307 if !present {
308 n.size++
309 }
310 if computeIndex {
311 idx += n.partialSize(i)
312 }
313 return old, present, idx
314 }
315
316
317
318
319 func (n *node) get(k Key, computeIndex bool, less lessFunc) (item, bool, int) {
320 i, found := n.items.find(k, less)
321 if found {
322 return n.items[i], true, n.itemIndex(i)
323 }
324 if len(n.children) > 0 {
325 m, found, idx := n.children[i].get(k, computeIndex, less)
326 if computeIndex && found {
327 idx += n.partialSize(i)
328 }
329 return m, found, idx
330 }
331 return item{}, false, -1
332 }
333
334
335 func (n *node) itemIndex(i int) int {
336 if len(n.children) == 0 {
337 return i
338 }
339
340
341 return n.partialSize(i+1) - 1
342 }
343
344
345 func (n *node) partialSize(i int) int {
346 var sz int
347 for j, c := range n.children {
348 if j == i {
349 break
350 }
351 sz += c.size + 1
352 }
353 return sz
354 }
355
356
357 func (n *node) cursorStackForKey(k Key, cs cursorStack, less lessFunc) (cursorStack, bool, int) {
358 i, found := n.items.find(k, less)
359 cs.push(cursor{n, i})
360 idx := i
361 if found {
362 if len(n.children) > 0 {
363 idx = n.partialSize(i+1) - 1
364 }
365 return cs, true, idx
366 }
367 if len(n.children) > 0 {
368 cs, found, idx := n.children[i].cursorStackForKey(k, cs, less)
369 return cs, found, idx + n.partialSize(i)
370 }
371 return cs, false, idx
372 }
373
374
375
376 func (n *node) at(i int) item {
377 if len(n.children) == 0 {
378 return n.items[i]
379 }
380 for j, c := range n.children {
381 if i < c.size {
382 return c.at(i)
383 }
384 i -= c.size
385 if i == 0 {
386 return n.items[j]
387 }
388 i--
389 }
390 panic("impossible")
391 }
392
393
394
395 func (n *node) cursorStackForIndex(i int, cs cursorStack) cursorStack {
396 if len(n.children) == 0 {
397 return cs.push(cursor{n, i})
398 }
399 for j, c := range n.children {
400 if i < c.size {
401 return c.cursorStackForIndex(i, cs.push(cursor{n, j}))
402 }
403 i -= c.size
404 if i == 0 {
405 return cs.push(cursor{n, j})
406 }
407 i--
408 }
409 panic("impossible")
410 }
411
412
413 type toRemove int
414
415 const (
416 removeItem toRemove = iota
417 removeMin
418 removeMax
419 )
420
421
422 func (n *node) remove(key Key, minItems int, typ toRemove, less lessFunc) (item, bool) {
423 var i int
424 var found bool
425 switch typ {
426 case removeMax:
427 if len(n.children) == 0 {
428 n.size--
429 return n.items.pop(), true
430
431 }
432 i = len(n.items)
433 case removeMin:
434 if len(n.children) == 0 {
435 n.size--
436 return n.items.removeAt(0), true
437 }
438 i = 0
439 case removeItem:
440 i, found = n.items.find(key, less)
441 if len(n.children) == 0 {
442 if found {
443 n.size--
444 return n.items.removeAt(i), true
445 }
446 return item{}, false
447 }
448 default:
449 panic("invalid type")
450 }
451
452 if len(n.children[i].items) <= minItems {
453 return n.growChildAndRemove(i, key, minItems, typ, less)
454 }
455 child := n.mutableChild(i)
456
457
458
459 if found {
460
461
462 out := n.items[i]
463
464
465
466 n.items[i], _ = child.remove(nil, minItems, removeMax, less)
467 n.size--
468 return out, true
469 }
470
471
472 m, removed := child.remove(key, minItems, typ, less)
473 if removed {
474 n.size--
475 }
476 return m, removed
477 }
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503 func (n *node) growChildAndRemove(i int, key Key, minItems int, typ toRemove, less lessFunc) (item, bool) {
504 if i > 0 && len(n.children[i-1].items) > minItems {
505
506 child := n.mutableChild(i)
507 stealFrom := n.mutableChild(i - 1)
508 stolenItem := stealFrom.items.pop()
509 stealFrom.size--
510 child.items.insertAt(0, n.items[i-1])
511 child.size++
512 n.items[i-1] = stolenItem
513 if len(stealFrom.children) > 0 {
514 c := stealFrom.children.pop()
515 stealFrom.size -= c.size
516 child.children.insertAt(0, c)
517 child.size += c.size
518 }
519 } else if i < len(n.items) && len(n.children[i+1].items) > minItems {
520
521 child := n.mutableChild(i)
522 stealFrom := n.mutableChild(i + 1)
523 stolenItem := stealFrom.items.removeAt(0)
524 stealFrom.size--
525 child.items = append(child.items, n.items[i])
526 child.size++
527 n.items[i] = stolenItem
528 if len(stealFrom.children) > 0 {
529 c := stealFrom.children.removeAt(0)
530 stealFrom.size -= c.size
531 child.children = append(child.children, c)
532 child.size += c.size
533 }
534 } else {
535 if i >= len(n.items) {
536 i--
537 }
538 child := n.mutableChild(i)
539
540 mergeItem := n.items.removeAt(i)
541 mergeChild := n.children.removeAt(i + 1)
542 child.items = append(child.items, mergeItem)
543 child.items = append(child.items, mergeChild.items...)
544 child.children = append(child.children, mergeChild.children...)
545 child.size = child.computeSize()
546 n.cow.freeNode(mergeChild)
547 }
548 return n.remove(key, minItems, typ, less)
549 }
550
551
552
553
554
555
556
557
558 type BTree struct {
559 degree int
560 less lessFunc
561 root *node
562 cow *copyOnWriteContext
563 }
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579 type copyOnWriteContext struct{ byte }
580
581
582
583
584
585
586
587
588
589
590
591
592 func (t *BTree) Clone() *BTree {
593
594
595
596
597
598 cow1, cow2 := *t.cow, *t.cow
599 out := *t
600 t.cow = &cow1
601 out.cow = &cow2
602 return &out
603 }
604
605
606 func (t *BTree) maxItems() int {
607 return t.degree*2 - 1
608 }
609
610
611
612 func (t *BTree) minItems() int {
613 return t.degree - 1
614 }
615
616 var nodePool = sync.Pool{New: func() interface{} { return new(node) }}
617
618 func (c *copyOnWriteContext) newNode() *node {
619 n := nodePool.Get().(*node)
620 n.cow = c
621 return n
622 }
623
624 func (c *copyOnWriteContext) freeNode(n *node) {
625 if n.cow == c {
626
627 n.items.truncate(0)
628 n.children.truncate(0)
629 n.cow = nil
630 nodePool.Put(n)
631 }
632 }
633
634
635
636
637
638 func (t *BTree) Set(k Key, v Value) (old Value, present bool) {
639 old, present, _ = t.set(k, v, false)
640 return old, present
641 }
642
643
644
645 func (t *BTree) SetWithIndex(k Key, v Value) (old Value, present bool, index int) {
646 return t.set(k, v, true)
647 }
648
649 func (t *BTree) set(k Key, v Value, computeIndex bool) (old Value, present bool, idx int) {
650 if t.root == nil {
651 t.root = t.cow.newNode()
652 t.root.items = append(t.root.items, item{k, v})
653 t.root.size = 1
654 return old, false, 0
655 }
656 t.root = t.root.mutableFor(t.cow)
657 if len(t.root.items) >= t.maxItems() {
658 sz := t.root.size
659 item2, second := t.root.split(t.maxItems() / 2)
660 oldroot := t.root
661 t.root = t.cow.newNode()
662 t.root.items = append(t.root.items, item2)
663 t.root.children = append(t.root.children, oldroot, second)
664 t.root.size = sz
665 }
666
667 return t.root.insert(item{k, v}, t.maxItems(), t.less, computeIndex)
668 }
669
670
671
672 func (t *BTree) Delete(k Key) (Value, bool) {
673 m, removed := t.deleteItem(k, removeItem)
674 return m.value, removed
675 }
676
677
678
679 func (t *BTree) DeleteMin() (Key, Value) {
680 item, _ := t.deleteItem(nil, removeMin)
681 return item.key, item.value
682 }
683
684
685
686 func (t *BTree) DeleteMax() (Key, Value) {
687 item, _ := t.deleteItem(nil, removeMax)
688 return item.key, item.value
689 }
690
691 func (t *BTree) deleteItem(key Key, typ toRemove) (item, bool) {
692 if t.root == nil || len(t.root.items) == 0 {
693 return item{}, false
694 }
695 t.root = t.root.mutableFor(t.cow)
696 out, removed := t.root.remove(key, t.minItems(), typ, t.less)
697 if len(t.root.items) == 0 && len(t.root.children) > 0 {
698 oldroot := t.root
699 t.root = t.root.children[0]
700 t.cow.freeNode(oldroot)
701 }
702 return out, removed
703 }
704
705
706
707
708
709 func (t *BTree) Get(k Key) Value {
710 var z Value
711 if t.root == nil {
712 return z
713 }
714 item, ok, _ := t.root.get(k, false, t.less)
715 if !ok {
716 return z
717 }
718 return item.value
719 }
720
721
722
723 func (t *BTree) GetWithIndex(k Key) (Value, int) {
724 var z Value
725 if t.root == nil {
726 return z, -1
727 }
728 item, _, index := t.root.get(k, true, t.less)
729 return item.value, index
730 }
731
732
733
734 func (t *BTree) At(i int) (Key, Value) {
735 if i < 0 || i >= t.Len() {
736 panic("btree: index out of range")
737 }
738 item := t.root.at(i)
739 return item.key, item.value
740 }
741
742
743 func (t *BTree) Has(k Key) bool {
744 if t.root == nil {
745 return false
746 }
747 _, ok, _ := t.root.get(k, false, t.less)
748 return ok
749 }
750
751
752
753 func (t *BTree) Min() (Key, Value) {
754 var k Key
755 var v Value
756 if t.root == nil {
757 return k, v
758 }
759 n := t.root
760 for len(n.children) > 0 {
761 n = n.children[0]
762 }
763 if len(n.items) == 0 {
764 return k, v
765 }
766 return n.items[0].key, n.items[0].value
767 }
768
769
770
771 func (t *BTree) Max() (Key, Value) {
772 var k Key
773 var v Value
774 if t.root == nil {
775 return k, v
776 }
777 n := t.root
778 for len(n.children) > 0 {
779 n = n.children[len(n.children)-1]
780 }
781 if len(n.items) == 0 {
782 return k, v
783 }
784 m := n.items[len(n.items)-1]
785 return m.key, m.value
786 }
787
788
789 func (t *BTree) Len() int {
790 if t.root == nil {
791 return 0
792 }
793 return t.root.size
794 }
795
796
797
798
799 func (t *BTree) Before(k Key) *Iterator {
800 if t.root == nil {
801 return &Iterator{}
802 }
803 var cs cursorStack
804 cs, found, idx := t.root.cursorStackForKey(k, cs, t.less)
805
806
807
808
809
810 var stay bool
811 top := cs[len(cs)-1]
812 if found {
813 stay = true
814 } else if top.index < len(top.node.items) {
815 stay = true
816 } else {
817 idx--
818 }
819 return &Iterator{
820 cursors: cs,
821 stay: stay,
822 descending: false,
823 Index: idx,
824 }
825 }
826
827
828
829
830 func (t *BTree) After(k Key) *Iterator {
831 if t.root == nil {
832 return &Iterator{}
833 }
834 var cs cursorStack
835 cs, found, idx := t.root.cursorStackForKey(k, cs, t.less)
836
837
838
839
840 return &Iterator{
841 cursors: cs,
842 stay: found,
843 descending: true,
844 Index: idx,
845 }
846 }
847
848
849
850
851
852 func (t *BTree) BeforeIndex(i int) *Iterator {
853 return t.indexIterator(i, false)
854 }
855
856
857
858
859
860 func (t *BTree) AfterIndex(i int) *Iterator {
861 return t.indexIterator(i, true)
862 }
863
864 func (t *BTree) indexIterator(i int, descending bool) *Iterator {
865 if i < 0 || i > t.Len() {
866 panic("btree: index out of range")
867 }
868 if i == t.Len() {
869 return &Iterator{}
870 }
871 var cs cursorStack
872 return &Iterator{
873 cursors: t.root.cursorStackForIndex(i, cs),
874 stay: true,
875 descending: descending,
876 Index: i,
877 }
878 }
879
880
881 type Iterator struct {
882 Key Key
883 Value Value
884
885
886 Index int
887
888 cursors cursorStack
889 stay bool
890 descending bool
891 }
892
893
894
895
896
897
898 func (it *Iterator) Next() bool {
899 var more bool
900 switch {
901 case len(it.cursors) == 0:
902 more = false
903 case it.stay:
904 it.stay = false
905 more = true
906 case it.descending:
907 more = it.dec()
908 default:
909 more = it.inc()
910 }
911 if !more {
912 return false
913 }
914 top := it.cursors[len(it.cursors)-1]
915 item := top.node.items[top.index]
916 it.Key = item.key
917 it.Value = item.value
918 return true
919 }
920
921
922 func (it *Iterator) inc() bool {
923
924
925
926
927
928 it.Index++
929
930
931
932 top := it.cursors.incTop(1)
933 for len(top.node.children) > 0 {
934 top = cursor{top.node.children[top.index], 0}
935 it.cursors.push(top)
936 }
937
938
939 for top.index >= len(top.node.items) {
940
941 it.cursors.pop()
942
943 if it.cursors.empty() {
944 return false
945 }
946 top = it.cursors.top()
947
948
949 }
950
951 return true
952 }
953
954 func (it *Iterator) dec() bool {
955
956 it.Index--
957 top := it.cursors.top()
958
959
960 for len(top.node.children) > 0 {
961 c := top.node.children[top.index]
962 top = cursor{c, len(c.items)}
963 it.cursors.push(top)
964 }
965 top = it.cursors.incTop(-1)
966
967
968 for top.index < 0 {
969
970 it.cursors.pop()
971
972 if it.cursors.empty() {
973 return false
974 }
975
976
977
978 top = it.cursors.incTop(-1)
979 }
980 return true
981 }
982
983
984
985
986
987
988
989 type cursor struct {
990 node *node
991 index int
992 }
993
994
995 type cursorStack []cursor
996
997 func (s *cursorStack) push(c cursor) cursorStack {
998 *s = append(*s, c)
999 return *s
1000 }
1001
1002 func (s *cursorStack) pop() cursor {
1003 last := len(*s) - 1
1004 t := (*s)[last]
1005 *s = (*s)[:last]
1006 return t
1007 }
1008
1009 func (s *cursorStack) top() cursor {
1010 return (*s)[len(*s)-1]
1011 }
1012
1013 func (s *cursorStack) empty() bool {
1014 return len(*s) == 0
1015 }
1016
1017
1018 func (s *cursorStack) incTop(n int) cursor {
1019 (*s)[len(*s)-1].index += n
1020 return s.top()
1021 }
1022
View as plain text