...
1
16
17 package bitmask
18
19 import (
20 "fmt"
21 "math/bits"
22 "strconv"
23 )
24
25
26 type BitMask interface {
27 Add(bits ...int) error
28 Remove(bits ...int) error
29 And(masks ...BitMask)
30 Or(masks ...BitMask)
31 Clear()
32 Fill()
33 IsEqual(mask BitMask) bool
34 IsEmpty() bool
35 IsSet(bit int) bool
36 AnySet(bits []int) bool
37 IsNarrowerThan(mask BitMask) bool
38 IsLessThan(mask BitMask) bool
39 IsGreaterThan(mask BitMask) bool
40 String() string
41 Count() int
42 GetBits() []int
43 }
44
45 type bitMask uint64
46
47
48 func NewEmptyBitMask() BitMask {
49 s := bitMask(0)
50 return &s
51 }
52
53
54 func NewBitMask(bits ...int) (BitMask, error) {
55 s := bitMask(0)
56 err := (&s).Add(bits...)
57 if err != nil {
58 return nil, err
59 }
60 return &s, nil
61 }
62
63
64 func (s *bitMask) Add(bits ...int) error {
65 mask := *s
66 for _, i := range bits {
67 if i < 0 || i >= 64 {
68 return fmt.Errorf("bit number must be in range 0-63")
69 }
70 mask |= 1 << uint64(i)
71 }
72 *s = mask
73 return nil
74 }
75
76
77 func (s *bitMask) Remove(bits ...int) error {
78 mask := *s
79 for _, i := range bits {
80 if i < 0 || i >= 64 {
81 return fmt.Errorf("bit number must be in range 0-63")
82 }
83 mask &^= 1 << uint64(i)
84 }
85 *s = mask
86 return nil
87 }
88
89
90 func (s *bitMask) And(masks ...BitMask) {
91 for _, m := range masks {
92 *s &= *m.(*bitMask)
93 }
94 }
95
96
97 func (s *bitMask) Or(masks ...BitMask) {
98 for _, m := range masks {
99 *s |= *m.(*bitMask)
100 }
101 }
102
103
104 func (s *bitMask) Clear() {
105 *s = 0
106 }
107
108
109 func (s *bitMask) Fill() {
110 *s = bitMask(^uint64(0))
111 }
112
113
114 func (s *bitMask) IsEmpty() bool {
115 return *s == 0
116 }
117
118
119 func (s *bitMask) IsSet(bit int) bool {
120 if bit < 0 || bit >= 64 {
121 return false
122 }
123 return (*s & (1 << uint64(bit))) > 0
124 }
125
126
127 func (s *bitMask) AnySet(bits []int) bool {
128 for _, b := range bits {
129 if s.IsSet(b) {
130 return true
131 }
132 }
133 return false
134 }
135
136
137 func (s *bitMask) IsEqual(mask BitMask) bool {
138 return *s == *mask.(*bitMask)
139 }
140
141
142
143
144
145
146 func (s *bitMask) IsNarrowerThan(mask BitMask) bool {
147 if s.Count() == mask.Count() {
148 return s.IsLessThan(mask)
149 }
150 return s.Count() < mask.Count()
151 }
152
153
154 func (s *bitMask) IsLessThan(mask BitMask) bool {
155 return *s < *mask.(*bitMask)
156 }
157
158
159 func (s *bitMask) IsGreaterThan(mask BitMask) bool {
160 return *s > *mask.(*bitMask)
161 }
162
163
164 func (s *bitMask) String() string {
165 grouping := 2
166 for shift := 64 - grouping; shift > 0; shift -= grouping {
167 if *s > (1 << uint(shift)) {
168 return fmt.Sprintf("%0"+strconv.Itoa(shift+grouping)+"b", *s)
169 }
170 }
171 return fmt.Sprintf("%0"+strconv.Itoa(grouping)+"b", *s)
172 }
173
174
175 func (s *bitMask) Count() int {
176 return bits.OnesCount64(uint64(*s))
177 }
178
179
180 func (s *bitMask) GetBits() []int {
181 var bits []int
182 for i := uint64(0); i < 64; i++ {
183 if (*s & (1 << i)) > 0 {
184 bits = append(bits, int(i))
185 }
186 }
187 return bits
188 }
189
190
191 func And(first BitMask, masks ...BitMask) BitMask {
192 s := *first.(*bitMask)
193 s.And(masks...)
194 return &s
195 }
196
197
198 func Or(first BitMask, masks ...BitMask) BitMask {
199 s := *first.(*bitMask)
200 s.Or(masks...)
201 return &s
202 }
203
204
205
206 func IterateBitMasks(bits []int, callback func(BitMask)) {
207 var iterate func(bits, accum []int, size int)
208 iterate = func(bits, accum []int, size int) {
209 if len(accum) == size {
210 mask, _ := NewBitMask(accum...)
211 callback(mask)
212 return
213 }
214 for i := range bits {
215 iterate(bits[i+1:], append(accum, bits[i]), size)
216 }
217 }
218
219 for i := 1; i <= len(bits); i++ {
220 iterate(bits, []int{}, i)
221 }
222 }
223
View as plain text