1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 package immutable
43
44 import (
45 "fmt"
46 "math/bits"
47 "reflect"
48 "sort"
49 "strings"
50
51 "golang.org/x/exp/constraints"
52 )
53
54
55
56
57
58 type List[T any] struct {
59 root listNode[T]
60 origin int
61 size int
62 }
63
64
65 func NewList[T any](values ...T) *List[T] {
66 l := &List[T]{
67 root: &listLeafNode[T]{},
68 }
69 for _, value := range values {
70 l.append(value, true)
71 }
72 return l
73 }
74
75
76 func (l *List[T]) clone() *List[T] {
77 other := *l
78 return &other
79 }
80
81
82 func (l *List[T]) Len() int {
83 return l.size
84 }
85
86
87 func (l *List[T]) cap() int {
88 return 1 << (l.root.depth() * listNodeBits)
89 }
90
91
92
93 func (l *List[T]) Get(index int) T {
94 if index < 0 || index >= l.size {
95 panic(fmt.Sprintf("immutable.List.Get: index %d out of bounds", index))
96 }
97 return l.root.get(l.origin + index)
98 }
99
100
101
102
103 func (l *List[T]) Set(index int, value T) *List[T] {
104 return l.set(index, value, false)
105 }
106
107 func (l *List[T]) set(index int, value T, mutable bool) *List[T] {
108 if index < 0 || index >= l.size {
109 panic(fmt.Sprintf("immutable.List.Set: index %d out of bounds", index))
110 }
111 other := l
112 if !mutable {
113 other = l.clone()
114 }
115 other.root = other.root.set(l.origin+index, value, mutable)
116 return other
117 }
118
119
120 func (l *List[T]) Append(value T) *List[T] {
121 return l.append(value, false)
122 }
123
124 func (l *List[T]) append(value T, mutable bool) *List[T] {
125 other := l
126 if !mutable {
127 other = l.clone()
128 }
129
130
131 if other.size+other.origin >= l.cap() {
132 newRoot := &listBranchNode[T]{d: other.root.depth() + 1}
133 newRoot.children[0] = other.root
134 other.root = newRoot
135 }
136
137
138 other.size++
139 other.root = other.root.set(other.origin+other.size-1, value, mutable)
140 return other
141 }
142
143
144 func (l *List[T]) Prepend(value T) *List[T] {
145 return l.prepend(value, false)
146 }
147
148 func (l *List[T]) prepend(value T, mutable bool) *List[T] {
149 other := l
150 if !mutable {
151 other = l.clone()
152 }
153
154
155 if other.origin == 0 {
156 newRoot := &listBranchNode[T]{d: other.root.depth() + 1}
157 newRoot.children[listNodeSize-1] = other.root
158 other.root = newRoot
159 other.origin += (listNodeSize - 1) << (other.root.depth() * listNodeBits)
160 }
161
162
163 other.size++
164 other.origin--
165 other.root = other.root.set(other.origin, value, mutable)
166 return other
167 }
168
169
170
171
172
173
174
175
176 func (l *List[T]) Slice(start, end int) *List[T] {
177 return l.slice(start, end, false)
178 }
179
180 func (l *List[T]) slice(start, end int, mutable bool) *List[T] {
181
182 if start < 0 || start > l.size {
183 panic(fmt.Sprintf("immutable.List.Slice: start index %d out of bounds", start))
184 } else if end < 0 || end > l.size {
185 panic(fmt.Sprintf("immutable.List.Slice: end index %d out of bounds", end))
186 } else if start > end {
187 panic(fmt.Sprintf("immutable.List.Slice: invalid slice index: [%d:%d]", start, end))
188 }
189
190
191 if start == 0 && end == l.size {
192 return l
193 }
194
195
196 other := l
197 if !mutable {
198 other = l.clone()
199 }
200
201
202 other.origin = l.origin + start
203 other.size = end - start
204
205
206 for other.root.depth() > 1 {
207 i := (other.origin >> (other.root.depth() * listNodeBits)) & listNodeMask
208 j := ((other.origin + other.size - 1) >> (other.root.depth() * listNodeBits)) & listNodeMask
209 if i != j {
210 break
211 }
212
213
214 other.origin -= i << (other.root.depth() * listNodeBits)
215 other.root = other.root.(*listBranchNode[T]).children[i]
216 }
217
218
219 other.root = other.root.deleteBefore(other.origin, mutable)
220 other.root = other.root.deleteAfter(other.origin+other.size-1, mutable)
221
222 return other
223 }
224
225
226 func (l *List[T]) Iterator() *ListIterator[T] {
227 itr := &ListIterator[T]{list: l}
228 itr.First()
229 return itr
230 }
231
232
233 type ListBuilder[T any] struct {
234 list *List[T]
235 }
236
237
238 func NewListBuilder[T any]() *ListBuilder[T] {
239 return &ListBuilder[T]{list: NewList[T]()}
240 }
241
242
243
244 func (b *ListBuilder[T]) List() *List[T] {
245 assert(b.list != nil, "immutable.ListBuilder.List(): duplicate call to fetch list")
246 list := b.list
247 b.list = nil
248 return list
249 }
250
251
252 func (b *ListBuilder[T]) Len() int {
253 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
254 return b.list.Len()
255 }
256
257
258
259 func (b *ListBuilder[T]) Get(index int) T {
260 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
261 return b.list.Get(index)
262 }
263
264
265
266
267 func (b *ListBuilder[T]) Set(index int, value T) {
268 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
269 b.list = b.list.set(index, value, true)
270 }
271
272
273 func (b *ListBuilder[T]) Append(value T) {
274 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
275 b.list = b.list.append(value, true)
276 }
277
278
279 func (b *ListBuilder[T]) Prepend(value T) {
280 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
281 b.list = b.list.prepend(value, true)
282 }
283
284
285
286 func (b *ListBuilder[T]) Slice(start, end int) {
287 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
288 b.list = b.list.slice(start, end, true)
289 }
290
291
292 func (b *ListBuilder[T]) Iterator() *ListIterator[T] {
293 assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation")
294 return b.list.Iterator()
295 }
296
297
298 const (
299 listNodeBits = 5
300 listNodeSize = 1 << listNodeBits
301 listNodeMask = listNodeSize - 1
302 )
303
304
305 type listNode[T any] interface {
306 depth() uint
307 get(index int) T
308 set(index int, v T, mutable bool) listNode[T]
309
310 containsBefore(index int) bool
311 containsAfter(index int) bool
312
313 deleteBefore(index int, mutable bool) listNode[T]
314 deleteAfter(index int, mutable bool) listNode[T]
315 }
316
317
318 func newListNode[T any](depth uint) listNode[T] {
319 if depth == 0 {
320 return &listLeafNode[T]{}
321 }
322 return &listBranchNode[T]{d: depth}
323 }
324
325
326 type listBranchNode[T any] struct {
327 d uint
328 children [listNodeSize]listNode[T]
329 }
330
331
332 func (n *listBranchNode[T]) depth() uint { return n.d }
333
334
335 func (n *listBranchNode[T]) get(index int) T {
336 idx := (index >> (n.d * listNodeBits)) & listNodeMask
337 return n.children[idx].get(index)
338 }
339
340
341 func (n *listBranchNode[T]) set(index int, v T, mutable bool) listNode[T] {
342 idx := (index >> (n.d * listNodeBits)) & listNodeMask
343
344
345 child := n.children[idx]
346 if child == nil {
347 child = newListNode[T](n.depth() - 1)
348 }
349
350
351 var other *listBranchNode[T]
352 if mutable {
353 other = n
354 } else {
355 tmp := *n
356 other = &tmp
357 }
358 other.children[idx] = child.set(index, v, mutable)
359 return other
360 }
361
362
363 func (n *listBranchNode[T]) containsBefore(index int) bool {
364 idx := (index >> (n.d * listNodeBits)) & listNodeMask
365
366
367 for i := 0; i < idx; i++ {
368 if n.children[i] != nil {
369 return true
370 }
371 }
372
373
374 if n.children[idx] != nil && n.children[idx].containsBefore(index) {
375 return true
376 }
377 return false
378 }
379
380
381 func (n *listBranchNode[T]) containsAfter(index int) bool {
382 idx := (index >> (n.d * listNodeBits)) & listNodeMask
383
384
385 for i := idx + 1; i < len(n.children); i++ {
386 if n.children[i] != nil {
387 return true
388 }
389 }
390
391
392 if n.children[idx] != nil && n.children[idx].containsAfter(index) {
393 return true
394 }
395 return false
396 }
397
398
399 func (n *listBranchNode[T]) deleteBefore(index int, mutable bool) listNode[T] {
400
401 if !n.containsBefore(index) {
402 return n
403 }
404
405
406 idx := (index >> (n.d * listNodeBits)) & listNodeMask
407
408 var other *listBranchNode[T]
409 if mutable {
410 other = n
411 for i := 0; i < idx; i++ {
412 n.children[i] = nil
413 }
414 } else {
415 other = &listBranchNode[T]{d: n.d}
416 copy(other.children[idx:][:], n.children[idx:][:])
417 }
418
419 if other.children[idx] != nil {
420 other.children[idx] = other.children[idx].deleteBefore(index, mutable)
421 }
422 return other
423 }
424
425
426 func (n *listBranchNode[T]) deleteAfter(index int, mutable bool) listNode[T] {
427
428 if !n.containsAfter(index) {
429 return n
430 }
431
432
433 idx := (index >> (n.d * listNodeBits)) & listNodeMask
434
435 var other *listBranchNode[T]
436 if mutable {
437 other = n
438 for i := idx + 1; i < len(n.children); i++ {
439 n.children[i] = nil
440 }
441 } else {
442 other = &listBranchNode[T]{d: n.d}
443 copy(other.children[:idx+1], n.children[:idx+1])
444 }
445
446 if other.children[idx] != nil {
447 other.children[idx] = other.children[idx].deleteAfter(index, mutable)
448 }
449 return other
450 }
451
452
453 type listLeafNode[T any] struct {
454 children [listNodeSize]T
455
456 occupied uint32
457 }
458
459
460 func (n *listLeafNode[T]) depth() uint { return 0 }
461
462
463 func (n *listLeafNode[T]) get(index int) T {
464 return n.children[index&listNodeMask]
465 }
466
467
468 func (n *listLeafNode[T]) set(index int, v T, mutable bool) listNode[T] {
469 idx := index & listNodeMask
470 var other *listLeafNode[T]
471 if mutable {
472 other = n
473 } else {
474 tmp := *n
475 other = &tmp
476 }
477 other.children[idx] = v
478 other.occupied |= 1 << idx
479 return other
480 }
481
482
483 func (n *listLeafNode[T]) containsBefore(index int) bool {
484 idx := index & listNodeMask
485 return bits.TrailingZeros32(n.occupied) < idx
486 }
487
488
489 func (n *listLeafNode[T]) containsAfter(index int) bool {
490 idx := index & listNodeMask
491 lastSetPos := 31 - bits.LeadingZeros32(n.occupied)
492 return lastSetPos > idx
493 }
494
495
496 func (n *listLeafNode[T]) deleteBefore(index int, mutable bool) listNode[T] {
497 if !n.containsBefore(index) {
498 return n
499 }
500
501 idx := index & listNodeMask
502 var other *listLeafNode[T]
503 if mutable {
504 other = n
505 var empty T
506 for i := 0; i < idx; i++ {
507 other.children[i] = empty
508 }
509 } else {
510 other = &listLeafNode[T]{occupied: n.occupied}
511 copy(other.children[idx:][:], n.children[idx:][:])
512 }
513
514 other.occupied &= ^((1 << idx) - 1)
515 return other
516 }
517
518
519 func (n *listLeafNode[T]) deleteAfter(index int, mutable bool) listNode[T] {
520 if !n.containsAfter(index) {
521 return n
522 }
523
524 idx := index & listNodeMask
525 var other *listLeafNode[T]
526 if mutable {
527 other = n
528 var empty T
529 for i := idx + 1; i < len(n.children); i++ {
530 other.children[i] = empty
531 }
532 } else {
533 other = &listLeafNode[T]{occupied: n.occupied}
534 copy(other.children[:idx+1][:], n.children[:idx+1][:])
535 }
536
537 other.occupied &= (1 << (idx + 1)) - 1
538 return other
539 }
540
541
542 type ListIterator[T any] struct {
543 list *List[T]
544 index int
545
546 stack [32]listIteratorElem[T]
547 depth int
548 }
549
550
551 func (itr *ListIterator[T]) Done() bool {
552 return itr.index < 0 || itr.index >= itr.list.Len()
553 }
554
555
556
557 func (itr *ListIterator[T]) First() {
558 if itr.list.Len() != 0 {
559 itr.Seek(0)
560 }
561 }
562
563
564
565 func (itr *ListIterator[T]) Last() {
566 if n := itr.list.Len(); n != 0 {
567 itr.Seek(n - 1)
568 }
569 }
570
571
572
573
574 func (itr *ListIterator[T]) Seek(index int) {
575
576 if index < 0 || index >= itr.list.Len() {
577 panic(fmt.Sprintf("immutable.ListIterator.Seek: index %d out of bounds", index))
578 }
579 itr.index = index
580
581
582 itr.stack[0] = listIteratorElem[T]{node: itr.list.root}
583 itr.depth = 0
584 itr.seek(index)
585 }
586
587
588
589 func (itr *ListIterator[T]) Next() (index int, value T) {
590
591 var empty T
592 if itr.Done() {
593 return -1, empty
594 }
595
596
597 elem := &itr.stack[itr.depth]
598 index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index]
599
600
601 itr.index++
602 if itr.Done() {
603 return index, value
604 }
605
606
607 for ; itr.depth > 0 && itr.stack[itr.depth].index >= listNodeSize-1; itr.depth-- {
608 }
609
610
611 itr.seek(itr.index)
612
613 return index, value
614 }
615
616
617
618 func (itr *ListIterator[T]) Prev() (index int, value T) {
619
620 var empty T
621 if itr.Done() {
622 return -1, empty
623 }
624
625
626 elem := &itr.stack[itr.depth]
627 index, value = itr.index, elem.node.(*listLeafNode[T]).children[elem.index]
628
629
630 itr.index--
631 if itr.Done() {
632 return index, value
633 }
634
635
636 for ; itr.depth > 0 && itr.stack[itr.depth].index == 0; itr.depth-- {
637 }
638
639
640 itr.seek(itr.index)
641
642 return index, value
643 }
644
645
646
647 func (itr *ListIterator[T]) seek(index int) {
648
649 for {
650 elem := &itr.stack[itr.depth]
651 elem.index = ((itr.list.origin + index) >> (elem.node.depth() * listNodeBits)) & listNodeMask
652
653 switch node := elem.node.(type) {
654 case *listBranchNode[T]:
655 child := node.children[elem.index]
656 itr.stack[itr.depth+1] = listIteratorElem[T]{node: child}
657 itr.depth++
658 case *listLeafNode[T]:
659 return
660 }
661 }
662 }
663
664
665 type listIteratorElem[T any] struct {
666 node listNode[T]
667 index int
668 }
669
670
671 const (
672 maxArrayMapSize = 8
673 maxBitmapIndexedSize = 16
674 )
675
676
677 const (
678 mapNodeBits = 5
679 mapNodeSize = 1 << mapNodeBits
680 mapNodeMask = mapNodeSize - 1
681 )
682
683
684
685
686
687 type Map[K, V any] struct {
688 size int
689 root mapNode[K, V]
690 hasher Hasher[K]
691 }
692
693
694
695
696 func NewMap[K, V any](hasher Hasher[K]) *Map[K, V] {
697 return &Map[K, V]{
698 hasher: hasher,
699 }
700 }
701
702
703
704
705
706 func NewMapOf[K comparable, V any](hasher Hasher[K], entries map[K]V) *Map[K, V] {
707 m := &Map[K, V]{
708 hasher: hasher,
709 }
710 for k, v := range entries {
711 m.set(k, v, true)
712 }
713 return m
714 }
715
716
717 func (m *Map[K, V]) Len() int {
718 return m.size
719 }
720
721
722 func (m *Map[K, V]) clone() *Map[K, V] {
723 other := *m
724 return &other
725 }
726
727
728
729
730 func (m *Map[K, V]) Get(key K) (value V, ok bool) {
731 var empty V
732 if m.root == nil {
733 return empty, false
734 }
735 keyHash := m.hasher.Hash(key)
736 return m.root.get(key, 0, keyHash, m.hasher)
737 }
738
739
740
741
742
743 func (m *Map[K, V]) Set(key K, value V) *Map[K, V] {
744 return m.set(key, value, false)
745 }
746
747 func (m *Map[K, V]) set(key K, value V, mutable bool) *Map[K, V] {
748
749 hasher := m.hasher
750 if hasher == nil {
751 hasher = NewHasher(key)
752 }
753
754
755 other := m
756 if !mutable {
757 other = m.clone()
758 }
759 other.hasher = hasher
760
761
762 if m.root == nil {
763 other.size = 1
764 other.root = &mapArrayNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}}
765 return other
766 }
767
768
769
770 var resized bool
771 other.root = m.root.set(key, value, 0, hasher.Hash(key), hasher, mutable, &resized)
772 if resized {
773 other.size++
774 }
775 return other
776 }
777
778
779
780 func (m *Map[K, V]) Delete(key K) *Map[K, V] {
781 return m.delete(key, false)
782 }
783
784 func (m *Map[K, V]) delete(key K, mutable bool) *Map[K, V] {
785
786 if m.root == nil {
787 return m
788 }
789
790
791 var resized bool
792 newRoot := m.root.delete(key, 0, m.hasher.Hash(key), m.hasher, mutable, &resized)
793 if !resized {
794 return m
795 }
796
797
798 other := m
799 if !mutable {
800 other = m.clone()
801 }
802
803
804 other.size = m.size - 1
805 other.root = newRoot
806 return other
807 }
808
809
810 func (m *Map[K, V]) Iterator() *MapIterator[K, V] {
811 itr := &MapIterator[K, V]{m: m}
812 itr.First()
813 return itr
814 }
815
816
817 type MapBuilder[K, V any] struct {
818 m *Map[K, V]
819 }
820
821
822 func NewMapBuilder[K, V any](hasher Hasher[K]) *MapBuilder[K, V] {
823 return &MapBuilder[K, V]{m: NewMap[K, V](hasher)}
824 }
825
826
827
828 func (b *MapBuilder[K, V]) Map() *Map[K, V] {
829 assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map")
830 m := b.m
831 b.m = nil
832 return m
833 }
834
835
836 func (b *MapBuilder[K, V]) Len() int {
837 assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
838 return b.m.Len()
839 }
840
841
842 func (b *MapBuilder[K, V]) Get(key K) (value V, ok bool) {
843 assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
844 return b.m.Get(key)
845 }
846
847
848 func (b *MapBuilder[K, V]) Set(key K, value V) {
849 assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
850 b.m = b.m.set(key, value, true)
851 }
852
853
854 func (b *MapBuilder[K, V]) Delete(key K) {
855 assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
856 b.m = b.m.delete(key, true)
857 }
858
859
860 func (b *MapBuilder[K, V]) Iterator() *MapIterator[K, V] {
861 assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation")
862 return b.m.Iterator()
863 }
864
865
866 type mapNode[K, V any] interface {
867 get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool)
868 set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V]
869 delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V]
870 }
871
872 var _ mapNode[string, any] = (*mapArrayNode[string, any])(nil)
873 var _ mapNode[string, any] = (*mapBitmapIndexedNode[string, any])(nil)
874 var _ mapNode[string, any] = (*mapHashArrayNode[string, any])(nil)
875 var _ mapNode[string, any] = (*mapValueNode[string, any])(nil)
876 var _ mapNode[string, any] = (*mapHashCollisionNode[string, any])(nil)
877
878
879 type mapLeafNode[K, V any] interface {
880 mapNode[K, V]
881 keyHashValue() uint32
882 }
883
884 var _ mapLeafNode[string, any] = (*mapValueNode[string, any])(nil)
885 var _ mapLeafNode[string, any] = (*mapHashCollisionNode[string, any])(nil)
886
887
888
889
890 type mapArrayNode[K, V any] struct {
891 entries []mapEntry[K, V]
892 }
893
894
895 func (n *mapArrayNode[K, V]) indexOf(key K, h Hasher[K]) int {
896 for i := range n.entries {
897 if h.Equal(n.entries[i].key, key) {
898 return i
899 }
900 }
901 return -1
902 }
903
904
905 func (n *mapArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
906 i := n.indexOf(key, h)
907 if i == -1 {
908 return value, false
909 }
910 return n.entries[i].value, true
911 }
912
913
914
915 func (n *mapArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
916 idx := n.indexOf(key, h)
917
918
919 if idx == -1 {
920 *resized = true
921 }
922
923
924
925 if idx == -1 && len(n.entries) >= maxArrayMapSize {
926 var node mapNode[K, V] = newMapValueNode(h.Hash(key), key, value)
927 for _, entry := range n.entries {
928 node = node.set(entry.key, entry.value, 0, h.Hash(entry.key), h, false, resized)
929 }
930 return node
931 }
932
933
934 if mutable {
935 if idx != -1 {
936 n.entries[idx] = mapEntry[K, V]{key, value}
937 } else {
938 n.entries = append(n.entries, mapEntry[K, V]{key, value})
939 }
940 return n
941 }
942
943
944
945 var other mapArrayNode[K, V]
946 if idx != -1 {
947 other.entries = make([]mapEntry[K, V], len(n.entries))
948 copy(other.entries, n.entries)
949 other.entries[idx] = mapEntry[K, V]{key, value}
950 } else {
951 other.entries = make([]mapEntry[K, V], len(n.entries)+1)
952 copy(other.entries, n.entries)
953 other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value}
954 }
955 return &other
956 }
957
958
959
960 func (n *mapArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
961 idx := n.indexOf(key, h)
962
963
964 if idx == -1 {
965 return n
966 }
967 *resized = true
968
969
970 if len(n.entries) == 1 {
971 return nil
972 }
973
974
975 if mutable {
976 copy(n.entries[idx:], n.entries[idx+1:])
977 n.entries[len(n.entries)-1] = mapEntry[K, V]{}
978 n.entries = n.entries[:len(n.entries)-1]
979 return n
980 }
981
982
983 other := &mapArrayNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)}
984 copy(other.entries[:idx], n.entries[:idx])
985 copy(other.entries[idx:], n.entries[idx+1:])
986 return other
987 }
988
989
990
991
992 type mapBitmapIndexedNode[K, V any] struct {
993 bitmap uint32
994 nodes []mapNode[K, V]
995 }
996
997
998 func (n *mapBitmapIndexedNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
999 bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
1000 if (n.bitmap & bit) == 0 {
1001 return value, false
1002 }
1003 child := n.nodes[bits.OnesCount32(n.bitmap&(bit-1))]
1004 return child.get(key, shift+mapNodeBits, keyHash, h)
1005 }
1006
1007
1008
1009 func (n *mapBitmapIndexedNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1010
1011 keyHashFrag := (keyHash >> shift) & mapNodeMask
1012
1013
1014 bit := uint32(1) << keyHashFrag
1015 exists := (n.bitmap & bit) != 0
1016
1017
1018 if !exists {
1019 *resized = true
1020 }
1021
1022
1023 idx := bits.OnesCount32(n.bitmap & (bit - 1))
1024
1025
1026
1027 var newNode mapNode[K, V]
1028 if exists {
1029 newNode = n.nodes[idx].set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized)
1030 } else {
1031 newNode = newMapValueNode(keyHash, key, value)
1032 }
1033
1034
1035
1036 if !exists && len(n.nodes) > maxBitmapIndexedSize {
1037 var other mapHashArrayNode[K, V]
1038 for i := uint(0); i < uint(len(other.nodes)); i++ {
1039 if n.bitmap&(uint32(1)<<i) != 0 {
1040 other.nodes[i] = n.nodes[other.count]
1041 other.count++
1042 }
1043 }
1044 other.nodes[keyHashFrag] = newNode
1045 other.count++
1046 return &other
1047 }
1048
1049
1050 if mutable {
1051 if exists {
1052 n.nodes[idx] = newNode
1053 } else {
1054 n.bitmap |= bit
1055 n.nodes = append(n.nodes, nil)
1056 copy(n.nodes[idx+1:], n.nodes[idx:])
1057 n.nodes[idx] = newNode
1058 }
1059 return n
1060 }
1061
1062
1063
1064 other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap | bit}
1065 if exists {
1066 other.nodes = make([]mapNode[K, V], len(n.nodes))
1067 copy(other.nodes, n.nodes)
1068 other.nodes[idx] = newNode
1069 } else {
1070 other.nodes = make([]mapNode[K, V], len(n.nodes)+1)
1071 copy(other.nodes, n.nodes[:idx])
1072 other.nodes[idx] = newNode
1073 copy(other.nodes[idx+1:], n.nodes[idx:])
1074 }
1075 return other
1076 }
1077
1078
1079
1080
1081 func (n *mapBitmapIndexedNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1082 bit := uint32(1) << ((keyHash >> shift) & mapNodeMask)
1083
1084
1085 if (n.bitmap & bit) == 0 {
1086 return n
1087 }
1088
1089
1090 idx := bits.OnesCount32(n.bitmap & (bit - 1))
1091
1092
1093 child := n.nodes[idx]
1094 newChild := child.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized)
1095
1096
1097 if !*resized {
1098 return n
1099 }
1100
1101
1102 if newChild == nil {
1103
1104 if len(n.nodes) == 1 {
1105 return nil
1106 }
1107
1108
1109 if mutable {
1110 n.bitmap ^= bit
1111 copy(n.nodes[idx:], n.nodes[idx+1:])
1112 n.nodes[len(n.nodes)-1] = nil
1113 n.nodes = n.nodes[:len(n.nodes)-1]
1114 return n
1115 }
1116
1117
1118 other := &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap ^ bit, nodes: make([]mapNode[K, V], len(n.nodes)-1)}
1119 copy(other.nodes[:idx], n.nodes[:idx])
1120 copy(other.nodes[idx:], n.nodes[idx+1:])
1121 return other
1122 }
1123
1124
1125 other := n
1126 if !mutable {
1127 other = &mapBitmapIndexedNode[K, V]{bitmap: n.bitmap, nodes: make([]mapNode[K, V], len(n.nodes))}
1128 copy(other.nodes, n.nodes)
1129 }
1130
1131
1132 other.nodes[idx] = newChild
1133 return other
1134 }
1135
1136
1137
1138 type mapHashArrayNode[K, V any] struct {
1139 count uint
1140 nodes [mapNodeSize]mapNode[K, V]
1141 }
1142
1143
1144 func (n *mapHashArrayNode[K, V]) clone() *mapHashArrayNode[K, V] {
1145 other := *n
1146 return &other
1147 }
1148
1149
1150 func (n *mapHashArrayNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
1151 node := n.nodes[(keyHash>>shift)&mapNodeMask]
1152 if node == nil {
1153 return value, false
1154 }
1155 return node.get(key, shift+mapNodeBits, keyHash, h)
1156 }
1157
1158
1159 func (n *mapHashArrayNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1160 idx := (keyHash >> shift) & mapNodeMask
1161 node := n.nodes[idx]
1162
1163
1164
1165 var newNode mapNode[K, V]
1166 if node == nil {
1167 *resized = true
1168 newNode = newMapValueNode(keyHash, key, value)
1169 } else {
1170 newNode = node.set(key, value, shift+mapNodeBits, keyHash, h, mutable, resized)
1171 }
1172
1173
1174 other := n
1175 if !mutable {
1176 other = n.clone()
1177 }
1178
1179
1180 if node == nil {
1181 other.count++
1182 }
1183 other.nodes[idx] = newNode
1184 return other
1185 }
1186
1187
1188
1189
1190 func (n *mapHashArrayNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1191 idx := (keyHash >> shift) & mapNodeMask
1192 node := n.nodes[idx]
1193
1194
1195 if node == nil {
1196 return n
1197 }
1198
1199
1200 newNode := node.delete(key, shift+mapNodeBits, keyHash, h, mutable, resized)
1201 if !*resized {
1202 return n
1203 }
1204
1205
1206 if newNode == nil && n.count <= maxBitmapIndexedSize {
1207 other := &mapBitmapIndexedNode[K, V]{nodes: make([]mapNode[K, V], 0, n.count-1)}
1208 for i, child := range n.nodes {
1209 if child != nil && uint32(i) != idx {
1210 other.bitmap |= 1 << uint(i)
1211 other.nodes = append(other.nodes, child)
1212 }
1213 }
1214 return other
1215 }
1216
1217
1218 other := n
1219 if !mutable {
1220 other = n.clone()
1221 }
1222
1223
1224 other.nodes[idx] = newNode
1225 if newNode == nil {
1226 other.count--
1227 }
1228 return other
1229 }
1230
1231
1232
1233
1234 type mapValueNode[K, V any] struct {
1235 keyHash uint32
1236 key K
1237 value V
1238 }
1239
1240
1241 func newMapValueNode[K, V any](keyHash uint32, key K, value V) *mapValueNode[K, V] {
1242 return &mapValueNode[K, V]{
1243 keyHash: keyHash,
1244 key: key,
1245 value: value,
1246 }
1247 }
1248
1249
1250 func (n *mapValueNode[K, V]) keyHashValue() uint32 {
1251 return n.keyHash
1252 }
1253
1254
1255 func (n *mapValueNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
1256 if !h.Equal(n.key, key) {
1257 return value, false
1258 }
1259 return n.value, true
1260 }
1261
1262
1263
1264
1265
1266 func (n *mapValueNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1267
1268 if h.Equal(n.key, key) {
1269
1270 if mutable {
1271 n.value = value
1272 return n
1273 }
1274
1275 return newMapValueNode(n.keyHash, key, value)
1276 }
1277
1278 *resized = true
1279
1280
1281 if n.keyHash != keyHash {
1282 return mergeIntoNode[K, V](n, shift, keyHash, key, value)
1283 }
1284
1285
1286 return &mapHashCollisionNode[K, V]{keyHash: keyHash, entries: []mapEntry[K, V]{
1287 {key: n.key, value: n.value},
1288 {key: key, value: value},
1289 }}
1290 }
1291
1292
1293 func (n *mapValueNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1294
1295 if !h.Equal(n.key, key) {
1296 return n
1297 }
1298
1299
1300 *resized = true
1301 return nil
1302 }
1303
1304
1305
1306 type mapHashCollisionNode[K, V any] struct {
1307 keyHash uint32
1308 entries []mapEntry[K, V]
1309 }
1310
1311
1312 func (n *mapHashCollisionNode[K, V]) keyHashValue() uint32 {
1313 return n.keyHash
1314 }
1315
1316
1317
1318 func (n *mapHashCollisionNode[K, V]) indexOf(key K, h Hasher[K]) int {
1319 for i := range n.entries {
1320 if h.Equal(n.entries[i].key, key) {
1321 return i
1322 }
1323 }
1324 return -1
1325 }
1326
1327
1328 func (n *mapHashCollisionNode[K, V]) get(key K, shift uint, keyHash uint32, h Hasher[K]) (value V, ok bool) {
1329 for i := range n.entries {
1330 if h.Equal(n.entries[i].key, key) {
1331 return n.entries[i].value, true
1332 }
1333 }
1334 return value, false
1335 }
1336
1337
1338 func (n *mapHashCollisionNode[K, V]) set(key K, value V, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1339
1340 if n.keyHash != keyHash {
1341 *resized = true
1342 return mergeIntoNode[K, V](n, shift, keyHash, key, value)
1343 }
1344
1345
1346 if mutable {
1347 if idx := n.indexOf(key, h); idx == -1 {
1348 *resized = true
1349 n.entries = append(n.entries, mapEntry[K, V]{key, value})
1350 } else {
1351 n.entries[idx] = mapEntry[K, V]{key, value}
1352 }
1353 return n
1354 }
1355
1356
1357
1358 other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash}
1359 if idx := n.indexOf(key, h); idx == -1 {
1360 *resized = true
1361 other.entries = make([]mapEntry[K, V], len(n.entries)+1)
1362 copy(other.entries, n.entries)
1363 other.entries[len(other.entries)-1] = mapEntry[K, V]{key, value}
1364 } else {
1365 other.entries = make([]mapEntry[K, V], len(n.entries))
1366 copy(other.entries, n.entries)
1367 other.entries[idx] = mapEntry[K, V]{key, value}
1368 }
1369 return other
1370 }
1371
1372
1373
1374
1375 func (n *mapHashCollisionNode[K, V]) delete(key K, shift uint, keyHash uint32, h Hasher[K], mutable bool, resized *bool) mapNode[K, V] {
1376 idx := n.indexOf(key, h)
1377
1378
1379 if idx == -1 {
1380 return n
1381 }
1382
1383
1384 *resized = true
1385
1386
1387 if len(n.entries) == 2 {
1388 return &mapValueNode[K, V]{
1389 keyHash: n.keyHash,
1390 key: n.entries[idx^1].key,
1391 value: n.entries[idx^1].value,
1392 }
1393 }
1394
1395
1396 if mutable {
1397 copy(n.entries[idx:], n.entries[idx+1:])
1398 n.entries[len(n.entries)-1] = mapEntry[K, V]{}
1399 n.entries = n.entries[:len(n.entries)-1]
1400 return n
1401 }
1402
1403
1404 other := &mapHashCollisionNode[K, V]{keyHash: n.keyHash, entries: make([]mapEntry[K, V], len(n.entries)-1)}
1405 copy(other.entries[:idx], n.entries[:idx])
1406 copy(other.entries[idx:], n.entries[idx+1:])
1407 return other
1408 }
1409
1410
1411
1412 func mergeIntoNode[K, V any](node mapLeafNode[K, V], shift uint, keyHash uint32, key K, value V) mapNode[K, V] {
1413 idx1 := (node.keyHashValue() >> shift) & mapNodeMask
1414 idx2 := (keyHash >> shift) & mapNodeMask
1415
1416
1417 other := &mapBitmapIndexedNode[K, V]{bitmap: (1 << idx1) | (1 << idx2)}
1418 if idx1 == idx2 {
1419 other.nodes = []mapNode[K, V]{mergeIntoNode(node, shift+mapNodeBits, keyHash, key, value)}
1420 } else {
1421 if newNode := newMapValueNode(keyHash, key, value); idx1 < idx2 {
1422 other.nodes = []mapNode[K, V]{node, newNode}
1423 } else {
1424 other.nodes = []mapNode[K, V]{newNode, node}
1425 }
1426 }
1427 return other
1428 }
1429
1430
1431 type mapEntry[K, V any] struct {
1432 key K
1433 value V
1434 }
1435
1436
1437
1438 type MapIterator[K, V any] struct {
1439 m *Map[K, V]
1440
1441 stack [32]mapIteratorElem[K, V]
1442 depth int
1443 }
1444
1445
1446 func (itr *MapIterator[K, V]) Done() bool {
1447 return itr.depth == -1
1448 }
1449
1450
1451 func (itr *MapIterator[K, V]) First() {
1452
1453 if itr.m.root == nil {
1454 itr.depth = -1
1455 return
1456 }
1457
1458
1459 itr.stack[0] = mapIteratorElem[K, V]{node: itr.m.root}
1460 itr.depth = 0
1461 itr.first()
1462 }
1463
1464
1465 func (itr *MapIterator[K, V]) Next() (key K, value V, ok bool) {
1466
1467 if itr.Done() {
1468 return key, value, false
1469 }
1470
1471
1472 elem := &itr.stack[itr.depth]
1473 switch node := elem.node.(type) {
1474 case *mapArrayNode[K, V]:
1475 entry := &node.entries[elem.index]
1476 key, value = entry.key, entry.value
1477 case *mapValueNode[K, V]:
1478 key, value = node.key, node.value
1479 case *mapHashCollisionNode[K, V]:
1480 entry := &node.entries[elem.index]
1481 key, value = entry.key, entry.value
1482 }
1483
1484
1485
1486 itr.next()
1487 return key, value, true
1488 }
1489
1490
1491 func (itr *MapIterator[K, V]) next() {
1492 for ; itr.depth >= 0; itr.depth-- {
1493 elem := &itr.stack[itr.depth]
1494
1495 switch node := elem.node.(type) {
1496 case *mapArrayNode[K, V]:
1497 if elem.index < len(node.entries)-1 {
1498 elem.index++
1499 return
1500 }
1501
1502 case *mapBitmapIndexedNode[K, V]:
1503 if elem.index < len(node.nodes)-1 {
1504 elem.index++
1505 itr.stack[itr.depth+1].node = node.nodes[elem.index]
1506 itr.depth++
1507 itr.first()
1508 return
1509 }
1510
1511 case *mapHashArrayNode[K, V]:
1512 for i := elem.index + 1; i < len(node.nodes); i++ {
1513 if node.nodes[i] != nil {
1514 elem.index = i
1515 itr.stack[itr.depth+1].node = node.nodes[elem.index]
1516 itr.depth++
1517 itr.first()
1518 return
1519 }
1520 }
1521
1522 case *mapValueNode[K, V]:
1523 continue
1524
1525 case *mapHashCollisionNode[K, V]:
1526 if elem.index < len(node.entries)-1 {
1527 elem.index++
1528 return
1529 }
1530 }
1531 }
1532 }
1533
1534
1535
1536 func (itr *MapIterator[K, V]) first() {
1537 for ; ; itr.depth++ {
1538 elem := &itr.stack[itr.depth]
1539
1540 switch node := elem.node.(type) {
1541 case *mapBitmapIndexedNode[K, V]:
1542 elem.index = 0
1543 itr.stack[itr.depth+1].node = node.nodes[0]
1544
1545 case *mapHashArrayNode[K, V]:
1546 for i := 0; i < len(node.nodes); i++ {
1547 if node.nodes[i] != nil {
1548 elem.index = i
1549 itr.stack[itr.depth+1].node = node.nodes[i]
1550 break
1551 }
1552 }
1553
1554 default:
1555 elem.index = 0
1556 return
1557 }
1558 }
1559 }
1560
1561
1562 type mapIteratorElem[K, V any] struct {
1563 node mapNode[K, V]
1564 index int
1565 }
1566
1567
1568 const (
1569 sortedMapNodeSize = 32
1570 )
1571
1572
1573
1574
1575
1576 type SortedMap[K, V any] struct {
1577 size int
1578 root sortedMapNode[K, V]
1579 comparer Comparer[K]
1580 }
1581
1582
1583
1584
1585 func NewSortedMap[K, V any](comparer Comparer[K]) *SortedMap[K, V] {
1586 return &SortedMap[K, V]{
1587 comparer: comparer,
1588 }
1589 }
1590
1591
1592
1593
1594
1595 func NewSortedMapOf[K comparable, V any](comparer Comparer[K], entries map[K]V) *SortedMap[K, V] {
1596 m := &SortedMap[K, V]{
1597 comparer: comparer,
1598 }
1599 for k, v := range entries {
1600 m.set(k, v, true)
1601 }
1602 return m
1603 }
1604
1605
1606 func (m *SortedMap[K, V]) Len() int {
1607 return m.size
1608 }
1609
1610
1611
1612 func (m *SortedMap[K, V]) Get(key K) (V, bool) {
1613 if m.root == nil {
1614 var v V
1615 return v, false
1616 }
1617 return m.root.get(key, m.comparer)
1618 }
1619
1620
1621 func (m *SortedMap[K, V]) Set(key K, value V) *SortedMap[K, V] {
1622 return m.set(key, value, false)
1623 }
1624
1625 func (m *SortedMap[K, V]) set(key K, value V, mutable bool) *SortedMap[K, V] {
1626
1627 comparer := m.comparer
1628 if comparer == nil {
1629 comparer = NewComparer(key)
1630 }
1631
1632
1633 other := m
1634 if !mutable {
1635 other = m.clone()
1636 }
1637 other.comparer = comparer
1638
1639
1640 if m.root == nil {
1641 other.size = 1
1642 other.root = &sortedMapLeafNode[K, V]{entries: []mapEntry[K, V]{{key: key, value: value}}}
1643 return other
1644 }
1645
1646
1647
1648 var resized bool
1649 newRoot, splitNode := m.root.set(key, value, comparer, mutable, &resized)
1650 if splitNode != nil {
1651 newRoot = newSortedMapBranchNode(newRoot, splitNode)
1652 }
1653
1654
1655 other.size = m.size
1656 other.root = newRoot
1657 if resized {
1658 other.size++
1659 }
1660 return other
1661 }
1662
1663
1664
1665 func (m *SortedMap[K, V]) Delete(key K) *SortedMap[K, V] {
1666 return m.delete(key, false)
1667 }
1668
1669 func (m *SortedMap[K, V]) delete(key K, mutable bool) *SortedMap[K, V] {
1670
1671 if m.root == nil {
1672 return m
1673 }
1674
1675
1676 var resized bool
1677 newRoot := m.root.delete(key, m.comparer, mutable, &resized)
1678 if !resized {
1679 return m
1680 }
1681
1682
1683 other := m
1684 if !mutable {
1685 other = m.clone()
1686 }
1687
1688
1689 other.size = m.size - 1
1690 other.root = newRoot
1691 return other
1692 }
1693
1694
1695 func (m *SortedMap[K, V]) clone() *SortedMap[K, V] {
1696 other := *m
1697 return &other
1698 }
1699
1700
1701 func (m *SortedMap[K, V]) Iterator() *SortedMapIterator[K, V] {
1702 itr := &SortedMapIterator[K, V]{m: m}
1703 itr.First()
1704 return itr
1705 }
1706
1707
1708 type SortedMapBuilder[K, V any] struct {
1709 m *SortedMap[K, V]
1710 }
1711
1712
1713 func NewSortedMapBuilder[K, V any](comparer Comparer[K]) *SortedMapBuilder[K, V] {
1714 return &SortedMapBuilder[K, V]{m: NewSortedMap[K, V](comparer)}
1715 }
1716
1717
1718
1719 func (b *SortedMapBuilder[K, V]) Map() *SortedMap[K, V] {
1720 assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map")
1721 m := b.m
1722 b.m = nil
1723 return m
1724 }
1725
1726
1727 func (b *SortedMapBuilder[K, V]) Len() int {
1728 assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
1729 return b.m.Len()
1730 }
1731
1732
1733 func (b *SortedMapBuilder[K, V]) Get(key K) (value V, ok bool) {
1734 assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
1735 return b.m.Get(key)
1736 }
1737
1738
1739 func (b *SortedMapBuilder[K, V]) Set(key K, value V) {
1740 assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
1741 b.m = b.m.set(key, value, true)
1742 }
1743
1744
1745 func (b *SortedMapBuilder[K, V]) Delete(key K) {
1746 assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
1747 b.m = b.m.delete(key, true)
1748 }
1749
1750
1751 func (b *SortedMapBuilder[K, V]) Iterator() *SortedMapIterator[K, V] {
1752 assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation")
1753 return b.m.Iterator()
1754 }
1755
1756
1757 type sortedMapNode[K, V any] interface {
1758 minKey() K
1759 indexOf(key K, c Comparer[K]) int
1760 get(key K, c Comparer[K]) (value V, ok bool)
1761 set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V])
1762 delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V]
1763 }
1764
1765 var _ sortedMapNode[string, any] = (*sortedMapBranchNode[string, any])(nil)
1766 var _ sortedMapNode[string, any] = (*sortedMapLeafNode[string, any])(nil)
1767
1768
1769 type sortedMapBranchNode[K, V any] struct {
1770 elems []sortedMapBranchElem[K, V]
1771 }
1772
1773
1774 func newSortedMapBranchNode[K, V any](children ...sortedMapNode[K, V]) *sortedMapBranchNode[K, V] {
1775
1776 elems := make([]sortedMapBranchElem[K, V], len(children))
1777 for i, child := range children {
1778 elems[i] = sortedMapBranchElem[K, V]{
1779 key: child.minKey(),
1780 node: child,
1781 }
1782 }
1783
1784 return &sortedMapBranchNode[K, V]{elems: elems}
1785 }
1786
1787
1788 func (n *sortedMapBranchNode[K, V]) minKey() K {
1789 return n.elems[0].node.minKey()
1790 }
1791
1792
1793 func (n *sortedMapBranchNode[K, V]) indexOf(key K, c Comparer[K]) int {
1794 if idx := sort.Search(len(n.elems), func(i int) bool { return c.Compare(n.elems[i].key, key) == 1 }); idx > 0 {
1795 return idx - 1
1796 }
1797 return 0
1798 }
1799
1800
1801 func (n *sortedMapBranchNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) {
1802 idx := n.indexOf(key, c)
1803 return n.elems[idx].node.get(key, c)
1804 }
1805
1806
1807 func (n *sortedMapBranchNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) {
1808 idx := n.indexOf(key, c)
1809
1810
1811 newNode, splitNode := n.elems[idx].node.set(key, value, c, mutable, resized)
1812
1813
1814 if mutable {
1815 n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode}
1816 if splitNode != nil {
1817 n.elems = append(n.elems, sortedMapBranchElem[K, V]{})
1818 copy(n.elems[idx+1:], n.elems[idx:])
1819 n.elems[idx+1] = sortedMapBranchElem[K, V]{key: splitNode.minKey(), node: splitNode}
1820 }
1821
1822
1823 if len(n.elems) > sortedMapNodeSize {
1824 splitIdx := len(n.elems) / 2
1825 newNode := &sortedMapBranchNode[K, V]{elems: n.elems[:splitIdx:splitIdx]}
1826 splitNode := &sortedMapBranchNode[K, V]{elems: n.elems[splitIdx:]}
1827 return newNode, splitNode
1828 }
1829 return n, nil
1830 }
1831
1832
1833
1834 var other sortedMapBranchNode[K, V]
1835 if splitNode == nil {
1836 other.elems = make([]sortedMapBranchElem[K, V], len(n.elems))
1837 copy(other.elems, n.elems)
1838 other.elems[idx] = sortedMapBranchElem[K, V]{
1839 key: newNode.minKey(),
1840 node: newNode,
1841 }
1842 } else {
1843 other.elems = make([]sortedMapBranchElem[K, V], len(n.elems)+1)
1844 copy(other.elems[:idx], n.elems[:idx])
1845 copy(other.elems[idx+1:], n.elems[idx:])
1846 other.elems[idx] = sortedMapBranchElem[K, V]{
1847 key: newNode.minKey(),
1848 node: newNode,
1849 }
1850 other.elems[idx+1] = sortedMapBranchElem[K, V]{
1851 key: splitNode.minKey(),
1852 node: splitNode,
1853 }
1854 }
1855
1856
1857 if len(other.elems) > sortedMapNodeSize {
1858 splitIdx := len(other.elems) / 2
1859 newNode := &sortedMapBranchNode[K, V]{elems: other.elems[:splitIdx:splitIdx]}
1860 splitNode := &sortedMapBranchNode[K, V]{elems: other.elems[splitIdx:]}
1861 return newNode, splitNode
1862 }
1863
1864
1865 return &other, nil
1866 }
1867
1868
1869
1870 func (n *sortedMapBranchNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] {
1871 idx := n.indexOf(key, c)
1872
1873
1874 newNode := n.elems[idx].node.delete(key, c, mutable, resized)
1875 if !*resized {
1876 return n
1877 }
1878
1879
1880 if newNode == nil {
1881
1882 if len(n.elems) == 1 {
1883 return nil
1884 }
1885
1886
1887 if mutable {
1888 copy(n.elems[idx:], n.elems[idx+1:])
1889 n.elems[len(n.elems)-1] = sortedMapBranchElem[K, V]{}
1890 n.elems = n.elems[:len(n.elems)-1]
1891 return n
1892 }
1893
1894
1895 other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems)-1)}
1896 copy(other.elems[:idx], n.elems[:idx])
1897 copy(other.elems[idx:], n.elems[idx+1:])
1898 return other
1899 }
1900
1901
1902 if mutable {
1903 n.elems[idx] = sortedMapBranchElem[K, V]{key: newNode.minKey(), node: newNode}
1904 return n
1905 }
1906
1907
1908 other := &sortedMapBranchNode[K, V]{elems: make([]sortedMapBranchElem[K, V], len(n.elems))}
1909 copy(other.elems, n.elems)
1910 other.elems[idx] = sortedMapBranchElem[K, V]{
1911 key: newNode.minKey(),
1912 node: newNode,
1913 }
1914 return other
1915 }
1916
1917 type sortedMapBranchElem[K, V any] struct {
1918 key K
1919 node sortedMapNode[K, V]
1920 }
1921
1922
1923 type sortedMapLeafNode[K, V any] struct {
1924 entries []mapEntry[K, V]
1925 }
1926
1927
1928 func (n *sortedMapLeafNode[K, V]) minKey() K {
1929 return n.entries[0].key
1930 }
1931
1932
1933 func (n *sortedMapLeafNode[K, V]) indexOf(key K, c Comparer[K]) int {
1934 return sort.Search(len(n.entries), func(i int) bool {
1935 return c.Compare(n.entries[i].key, key) != -1
1936 })
1937 }
1938
1939
1940 func (n *sortedMapLeafNode[K, V]) get(key K, c Comparer[K]) (value V, ok bool) {
1941 idx := n.indexOf(key, c)
1942
1943
1944 if idx == len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 {
1945 return value, false
1946 }
1947
1948
1949 return n.entries[idx].value, true
1950 }
1951
1952
1953
1954 func (n *sortedMapLeafNode[K, V]) set(key K, value V, c Comparer[K], mutable bool, resized *bool) (sortedMapNode[K, V], sortedMapNode[K, V]) {
1955
1956 idx := n.indexOf(key, c)
1957 exists := idx < len(n.entries) && c.Compare(n.entries[idx].key, key) == 0
1958
1959
1960 if mutable {
1961 if !exists {
1962 *resized = true
1963 n.entries = append(n.entries, mapEntry[K, V]{})
1964 copy(n.entries[idx+1:], n.entries[idx:])
1965 }
1966 n.entries[idx] = mapEntry[K, V]{key: key, value: value}
1967
1968
1969 if len(n.entries) > sortedMapNodeSize {
1970 splitIdx := len(n.entries) / 2
1971 newNode := &sortedMapLeafNode[K, V]{entries: n.entries[:splitIdx:splitIdx]}
1972 splitNode := &sortedMapLeafNode[K, V]{entries: n.entries[splitIdx:]}
1973 return newNode, splitNode
1974 }
1975 return n, nil
1976 }
1977
1978
1979
1980 var newEntries []mapEntry[K, V]
1981 if exists {
1982 newEntries = make([]mapEntry[K, V], len(n.entries))
1983 copy(newEntries, n.entries)
1984 newEntries[idx] = mapEntry[K, V]{key: key, value: value}
1985 } else {
1986 *resized = true
1987 newEntries = make([]mapEntry[K, V], len(n.entries)+1)
1988 copy(newEntries[:idx], n.entries[:idx])
1989 newEntries[idx] = mapEntry[K, V]{key: key, value: value}
1990 copy(newEntries[idx+1:], n.entries[idx:])
1991 }
1992
1993
1994 if len(newEntries) > sortedMapNodeSize {
1995 splitIdx := len(newEntries) / 2
1996 newNode := &sortedMapLeafNode[K, V]{entries: newEntries[:splitIdx:splitIdx]}
1997 splitNode := &sortedMapLeafNode[K, V]{entries: newEntries[splitIdx:]}
1998 return newNode, splitNode
1999 }
2000
2001
2002 return &sortedMapLeafNode[K, V]{entries: newEntries}, nil
2003 }
2004
2005
2006
2007 func (n *sortedMapLeafNode[K, V]) delete(key K, c Comparer[K], mutable bool, resized *bool) sortedMapNode[K, V] {
2008 idx := n.indexOf(key, c)
2009
2010
2011 if idx >= len(n.entries) || c.Compare(n.entries[idx].key, key) != 0 {
2012 return n
2013 }
2014 *resized = true
2015
2016
2017 if len(n.entries) == 1 {
2018 return nil
2019 }
2020
2021
2022 if mutable {
2023 copy(n.entries[idx:], n.entries[idx+1:])
2024 n.entries[len(n.entries)-1] = mapEntry[K, V]{}
2025 n.entries = n.entries[:len(n.entries)-1]
2026 return n
2027 }
2028
2029
2030 other := &sortedMapLeafNode[K, V]{entries: make([]mapEntry[K, V], len(n.entries)-1)}
2031 copy(other.entries[:idx], n.entries[:idx])
2032 copy(other.entries[idx:], n.entries[idx+1:])
2033 return other
2034 }
2035
2036
2037
2038 type SortedMapIterator[K, V any] struct {
2039 m *SortedMap[K, V]
2040
2041 stack [32]sortedMapIteratorElem[K, V]
2042 depth int
2043 }
2044
2045
2046 func (itr *SortedMapIterator[K, V]) Done() bool {
2047 return itr.depth == -1
2048 }
2049
2050
2051 func (itr *SortedMapIterator[K, V]) First() {
2052 if itr.m.root == nil {
2053 itr.depth = -1
2054 return
2055 }
2056 itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
2057 itr.depth = 0
2058 itr.first()
2059 }
2060
2061
2062 func (itr *SortedMapIterator[K, V]) Last() {
2063 if itr.m.root == nil {
2064 itr.depth = -1
2065 return
2066 }
2067 itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
2068 itr.depth = 0
2069 itr.last()
2070 }
2071
2072
2073
2074
2075 func (itr *SortedMapIterator[K, V]) Seek(key K) {
2076 if itr.m.root == nil {
2077 itr.depth = -1
2078 return
2079 }
2080 itr.stack[0] = sortedMapIteratorElem[K, V]{node: itr.m.root}
2081 itr.depth = 0
2082 itr.seek(key)
2083 }
2084
2085
2086
2087 func (itr *SortedMapIterator[K, V]) Next() (key K, value V, ok bool) {
2088
2089 if itr.Done() {
2090 return key, value, false
2091 }
2092
2093
2094 leafElem := &itr.stack[itr.depth]
2095 leafNode := leafElem.node.(*sortedMapLeafNode[K, V])
2096 leafEntry := &leafNode.entries[leafElem.index]
2097 key, value = leafEntry.key, leafEntry.value
2098
2099
2100 itr.next()
2101
2102
2103 return key, value, true
2104 }
2105
2106
2107 func (itr *SortedMapIterator[K, V]) next() {
2108 for ; itr.depth >= 0; itr.depth-- {
2109 elem := &itr.stack[itr.depth]
2110
2111 switch node := elem.node.(type) {
2112 case *sortedMapLeafNode[K, V]:
2113 if elem.index < len(node.entries)-1 {
2114 elem.index++
2115 return
2116 }
2117 case *sortedMapBranchNode[K, V]:
2118 if elem.index < len(node.elems)-1 {
2119 elem.index++
2120 itr.stack[itr.depth+1].node = node.elems[elem.index].node
2121 itr.depth++
2122 itr.first()
2123 return
2124 }
2125 }
2126 }
2127 }
2128
2129
2130
2131 func (itr *SortedMapIterator[K, V]) Prev() (key K, value V, ok bool) {
2132
2133 if itr.Done() {
2134 return key, value, false
2135 }
2136
2137
2138 leafElem := &itr.stack[itr.depth]
2139 leafNode := leafElem.node.(*sortedMapLeafNode[K, V])
2140 leafEntry := &leafNode.entries[leafElem.index]
2141 key, value = leafEntry.key, leafEntry.value
2142
2143 itr.prev()
2144 return key, value, true
2145 }
2146
2147
2148 func (itr *SortedMapIterator[K, V]) prev() {
2149 for ; itr.depth >= 0; itr.depth-- {
2150 elem := &itr.stack[itr.depth]
2151
2152 switch node := elem.node.(type) {
2153 case *sortedMapLeafNode[K, V]:
2154 if elem.index > 0 {
2155 elem.index--
2156 return
2157 }
2158 case *sortedMapBranchNode[K, V]:
2159 if elem.index > 0 {
2160 elem.index--
2161 itr.stack[itr.depth+1].node = node.elems[elem.index].node
2162 itr.depth++
2163 itr.last()
2164 return
2165 }
2166 }
2167 }
2168 }
2169
2170
2171
2172 func (itr *SortedMapIterator[K, V]) first() {
2173 for {
2174 elem := &itr.stack[itr.depth]
2175 elem.index = 0
2176
2177 switch node := elem.node.(type) {
2178 case *sortedMapBranchNode[K, V]:
2179 itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
2180 itr.depth++
2181 case *sortedMapLeafNode[K, V]:
2182 return
2183 }
2184 }
2185 }
2186
2187
2188
2189 func (itr *SortedMapIterator[K, V]) last() {
2190 for {
2191 elem := &itr.stack[itr.depth]
2192
2193 switch node := elem.node.(type) {
2194 case *sortedMapBranchNode[K, V]:
2195 elem.index = len(node.elems) - 1
2196 itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
2197 itr.depth++
2198 case *sortedMapLeafNode[K, V]:
2199 elem.index = len(node.entries) - 1
2200 return
2201 }
2202 }
2203 }
2204
2205
2206
2207 func (itr *SortedMapIterator[K, V]) seek(key K) {
2208 for {
2209 elem := &itr.stack[itr.depth]
2210 elem.index = elem.node.indexOf(key, itr.m.comparer)
2211
2212 switch node := elem.node.(type) {
2213 case *sortedMapBranchNode[K, V]:
2214 itr.stack[itr.depth+1] = sortedMapIteratorElem[K, V]{node: node.elems[elem.index].node}
2215 itr.depth++
2216 case *sortedMapLeafNode[K, V]:
2217 if elem.index == len(node.entries) {
2218 itr.next()
2219 }
2220 return
2221 }
2222 }
2223 }
2224
2225
2226 type sortedMapIteratorElem[K, V any] struct {
2227 node sortedMapNode[K, V]
2228 index int
2229 }
2230
2231
2232 type Hasher[K any] interface {
2233
2234 Hash(key K) uint32
2235
2236
2237 Equal(a, b K) bool
2238 }
2239
2240
2241 func NewHasher[K any](key K) Hasher[K] {
2242
2243 switch (any(key)).(type) {
2244 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string:
2245 return &defaultHasher[K]{}
2246 }
2247
2248
2249
2250 switch reflect.TypeOf(key).Kind() {
2251 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
2252 return &reflectHasher[K]{}
2253 }
2254
2255
2256
2257 panic(fmt.Sprintf("immutable.NewHasher: must set hasher for %T type", key))
2258 }
2259
2260
2261 func hashString(value string) uint32 {
2262 var hash uint32
2263 for i, value := 0, value; i < len(value); i++ {
2264 hash = 31*hash + uint32(value[i])
2265 }
2266 return hash
2267 }
2268
2269
2270 type reflectHasher[K any] struct{}
2271
2272
2273 func (h *reflectHasher[K]) Hash(key K) uint32 {
2274 switch reflect.TypeOf(key).Kind() {
2275 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
2276 return hashUint64(uint64(reflect.ValueOf(key).Int()))
2277 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
2278 return hashUint64(reflect.ValueOf(key).Uint())
2279 case reflect.String:
2280 var hash uint32
2281 s := reflect.ValueOf(key).String()
2282 for i := 0; i < len(s); i++ {
2283 hash = 31*hash + uint32(s[i])
2284 }
2285 return hash
2286 }
2287 panic(fmt.Sprintf("immutable.reflectHasher.Hash: reflectHasher does not support %T type", key))
2288 }
2289
2290
2291
2292 func (h *reflectHasher[K]) Equal(a, b K) bool {
2293 switch reflect.TypeOf(a).Kind() {
2294 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
2295 return reflect.ValueOf(a).Int() == reflect.ValueOf(b).Int()
2296 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
2297 return reflect.ValueOf(a).Uint() == reflect.ValueOf(b).Uint()
2298 case reflect.String:
2299 return reflect.ValueOf(a).String() == reflect.ValueOf(b).String()
2300 }
2301 panic(fmt.Sprintf("immutable.reflectHasher.Equal: reflectHasher does not support %T type", a))
2302
2303 }
2304
2305
2306 func hashUint64(value uint64) uint32 {
2307 hash := value
2308 for value > 0xffffffff {
2309 value /= 0xffffffff
2310 hash ^= value
2311 }
2312 return uint32(hash)
2313 }
2314
2315
2316 type defaultHasher[K any] struct{}
2317
2318
2319 func (h *defaultHasher[K]) Hash(key K) uint32 {
2320 switch x := (any(key)).(type) {
2321 case int:
2322 return hashUint64(uint64(x))
2323 case int8:
2324 return hashUint64(uint64(x))
2325 case int16:
2326 return hashUint64(uint64(x))
2327 case int32:
2328 return hashUint64(uint64(x))
2329 case int64:
2330 return hashUint64(uint64(x))
2331 case uint:
2332 return hashUint64(uint64(x))
2333 case uint8:
2334 return hashUint64(uint64(x))
2335 case uint16:
2336 return hashUint64(uint64(x))
2337 case uint32:
2338 return hashUint64(uint64(x))
2339 case uint64:
2340 return hashUint64(uint64(x))
2341 case uintptr:
2342 return hashUint64(uint64(x))
2343 case string:
2344 return hashString(x)
2345 }
2346 panic(fmt.Sprintf("immutable.defaultHasher.Hash: must set comparer for %T type", key))
2347 }
2348
2349
2350
2351 func (h *defaultHasher[K]) Equal(a, b K) bool {
2352 return any(a) == any(b)
2353 }
2354
2355
2356 type Comparer[K any] interface {
2357
2358
2359 Compare(a, b K) int
2360 }
2361
2362
2363
2364
2365 func NewComparer[K any](key K) Comparer[K] {
2366
2367 switch (any(key)).(type) {
2368 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, string:
2369 return &defaultComparer[K]{}
2370 }
2371
2372
2373 switch reflect.TypeOf(key).Kind() {
2374 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
2375 return &reflectComparer[K]{}
2376 }
2377
2378
2379 panic(fmt.Sprintf("immutable.NewComparer: must set comparer for %T type", key))
2380 }
2381
2382
2383 type defaultComparer[K any] struct{}
2384
2385
2386
2387 func (c *defaultComparer[K]) Compare(i K, j K) int {
2388 switch x := (any(i)).(type) {
2389 case int:
2390 return defaultCompare(x, (any(j)).(int))
2391 case int8:
2392 return defaultCompare(x, (any(j)).(int8))
2393 case int16:
2394 return defaultCompare(x, (any(j)).(int16))
2395 case int32:
2396 return defaultCompare(x, (any(j)).(int32))
2397 case int64:
2398 return defaultCompare(x, (any(j)).(int64))
2399 case uint:
2400 return defaultCompare(x, (any(j)).(uint))
2401 case uint8:
2402 return defaultCompare(x, (any(j)).(uint8))
2403 case uint16:
2404 return defaultCompare(x, (any(j)).(uint16))
2405 case uint32:
2406 return defaultCompare(x, (any(j)).(uint32))
2407 case uint64:
2408 return defaultCompare(x, (any(j)).(uint64))
2409 case uintptr:
2410 return defaultCompare(x, (any(j)).(uintptr))
2411 case string:
2412 return defaultCompare(x, (any(j)).(string))
2413 }
2414 panic(fmt.Sprintf("immutable.defaultComparer: must set comparer for %T type", i))
2415 }
2416
2417
2418
2419 func defaultCompare[K constraints.Ordered](i, j K) int {
2420 if i < j {
2421 return -1
2422 } else if i > j {
2423 return 1
2424 }
2425 return 0
2426 }
2427
2428
2429 type reflectComparer[K any] struct{}
2430
2431
2432
2433 func (c *reflectComparer[K]) Compare(a, b K) int {
2434 switch reflect.TypeOf(a).Kind() {
2435 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
2436 if i, j := reflect.ValueOf(a).Int(), reflect.ValueOf(b).Int(); i < j {
2437 return -1
2438 } else if i > j {
2439 return 1
2440 }
2441 return 0
2442 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
2443 if i, j := reflect.ValueOf(a).Uint(), reflect.ValueOf(b).Uint(); i < j {
2444 return -1
2445 } else if i > j {
2446 return 1
2447 }
2448 return 0
2449 case reflect.String:
2450 return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String())
2451 }
2452 panic(fmt.Sprintf("immutable.reflectComparer.Compare: must set comparer for %T type", a))
2453 }
2454
2455 func assert(condition bool, message string) {
2456 if !condition {
2457 panic(message)
2458 }
2459 }
2460
View as plain text