htable: avoid branch in calculating perfect bit.
[ccan] / ccan / htable / htable.c
1 /* Licensed under LGPLv2+ - see LICENSE file for details */
2 #include <ccan/htable/htable.h>
3 #include <ccan/compiler/compiler.h>
4 #include <stdlib.h>
5 #include <stdio.h>
6 #include <limits.h>
7 #include <stdbool.h>
8 #include <assert.h>
9 #include <string.h>
10
11 /* We use 0x1 as deleted marker. */
12 #define HTABLE_DELETED (0x1)
13
14 /* perfect_bitnum 63 means there's no perfect bitnum */
15 #define NO_PERFECT_BIT (sizeof(uintptr_t) * CHAR_BIT - 1)
16
17 static void *htable_default_alloc(struct htable *ht, size_t len)
18 {
19         return calloc(len, 1);
20 }
21
22 static void htable_default_free(struct htable *ht, void *p)
23 {
24         free(p);
25 }
26
27 static void *(*htable_alloc)(struct htable *, size_t) = htable_default_alloc;
28 static void (*htable_free)(struct htable *, void *) = htable_default_free;
29
30 void htable_set_allocator(void *(*alloc)(struct htable *, size_t len),
31                           void (*free)(struct htable *, void *p))
32 {
33         if (!alloc)
34                 alloc = htable_default_alloc;
35         if (!free)
36                 free = htable_default_free;
37         htable_alloc = alloc;
38         htable_free = free;
39 }
40
41 /* We clear out the bits which are always the same, and put metadata there. */
42 static inline uintptr_t get_extra_ptr_bits(const struct htable *ht,
43                                            uintptr_t e)
44 {
45         return e & ht->common_mask;
46 }
47
48 static inline void *get_raw_ptr(const struct htable *ht, uintptr_t e)
49 {
50         return (void *)((e & ~ht->common_mask) | ht->common_bits);
51 }
52
53 static inline uintptr_t make_hval(const struct htable *ht,
54                                   const void *p, uintptr_t bits)
55 {
56         return ((uintptr_t)p & ~ht->common_mask) | bits;
57 }
58
59 static inline bool entry_is_valid(uintptr_t e)
60 {
61         return e > HTABLE_DELETED;
62 }
63
64 static inline uintptr_t ht_perfect_mask(const struct htable *ht)
65 {
66         return (uintptr_t)2 << ht->perfect_bitnum;
67 }
68
69 static inline uintptr_t get_hash_ptr_bits(const struct htable *ht,
70                                           size_t hash)
71 {
72         /* Shuffling the extra bits (as specified in mask) down the
73          * end is quite expensive.  But the lower bits are redundant, so
74          * we fold the value first. */
75         return (hash ^ (hash >> ht->bits))
76                 & ht->common_mask & ~ht_perfect_mask(ht);
77 }
78
79 void htable_init(struct htable *ht,
80                  size_t (*rehash)(const void *elem, void *priv), void *priv)
81 {
82         struct htable empty = HTABLE_INITIALIZER(empty, NULL, NULL);
83         *ht = empty;
84         ht->rehash = rehash;
85         ht->priv = priv;
86         ht->table = &ht->common_bits;
87 }
88
89 static inline size_t ht_max(const struct htable *ht)
90 {
91         return ((size_t)3 << ht->bits) / 4;
92 }
93
94 static inline size_t ht_max_with_deleted(const struct htable *ht)
95 {
96         return ((size_t)9 << ht->bits) / 10;
97 }
98
99 bool htable_init_sized(struct htable *ht,
100                        size_t (*rehash)(const void *, void *),
101                        void *priv, size_t expect)
102 {
103         htable_init(ht, rehash, priv);
104
105         /* Don't go insane with sizing. */
106         for (ht->bits = 1; ((size_t)3 << ht->bits) / 4 < expect; ht->bits++) {
107                 if (ht->bits == 30)
108                         break;
109         }
110
111         ht->table = htable_alloc(ht, sizeof(size_t) << ht->bits);
112         if (!ht->table) {
113                 ht->table = &ht->common_bits;
114                 return false;
115         }
116         (void)htable_debug(ht, HTABLE_LOC);
117         return true;
118 }
119         
120 void htable_clear(struct htable *ht)
121 {
122         if (ht->table != &ht->common_bits)
123                 htable_free(ht, (void *)ht->table);
124         htable_init(ht, ht->rehash, ht->priv);
125 }
126
127 bool htable_copy_(struct htable *dst, const struct htable *src)
128 {
129         uintptr_t *htable = htable_alloc(dst, sizeof(size_t) << src->bits);
130
131         if (!htable)
132                 return false;
133
134         *dst = *src;
135         dst->table = htable;
136         memcpy(dst->table, src->table, sizeof(size_t) << src->bits);
137         return true;
138 }
139
140 static size_t hash_bucket(const struct htable *ht, size_t h)
141 {
142         return h & ((1 << ht->bits)-1);
143 }
144
145 static void *htable_val(const struct htable *ht,
146                         struct htable_iter *i, size_t hash, uintptr_t perfect)
147 {
148         uintptr_t h2 = get_hash_ptr_bits(ht, hash) | perfect;
149
150         while (ht->table[i->off]) {
151                 if (ht->table[i->off] != HTABLE_DELETED) {
152                         if (get_extra_ptr_bits(ht, ht->table[i->off]) == h2)
153                                 return get_raw_ptr(ht, ht->table[i->off]);
154                 }
155                 i->off = (i->off + 1) & ((1 << ht->bits)-1);
156                 h2 &= ~perfect;
157         }
158         return NULL;
159 }
160
161 void *htable_firstval_(const struct htable *ht,
162                        struct htable_iter *i, size_t hash)
163 {
164         i->off = hash_bucket(ht, hash);
165         return htable_val(ht, i, hash, ht_perfect_mask(ht));
166 }
167
168 void *htable_nextval_(const struct htable *ht,
169                       struct htable_iter *i, size_t hash)
170 {
171         i->off = (i->off + 1) & ((1 << ht->bits)-1);
172         return htable_val(ht, i, hash, 0);
173 }
174
175 void *htable_first_(const struct htable *ht, struct htable_iter *i)
176 {
177         for (i->off = 0; i->off < (size_t)1 << ht->bits; i->off++) {
178                 if (entry_is_valid(ht->table[i->off]))
179                         return get_raw_ptr(ht, ht->table[i->off]);
180         }
181         return NULL;
182 }
183
184 void *htable_next_(const struct htable *ht, struct htable_iter *i)
185 {
186         for (i->off++; i->off < (size_t)1 << ht->bits; i->off++) {
187                 if (entry_is_valid(ht->table[i->off]))
188                         return get_raw_ptr(ht, ht->table[i->off]);
189         }
190         return NULL;
191 }
192
193 void *htable_prev_(const struct htable *ht, struct htable_iter *i)
194 {
195         for (;;) {
196                 if (!i->off)
197                         return NULL;
198                 i->off --;
199                 if (entry_is_valid(ht->table[i->off]))
200                         return get_raw_ptr(ht, ht->table[i->off]);
201         }
202 }
203
204 /* This does not expand the hash table, that's up to caller. */
205 static void ht_add(struct htable *ht, const void *new, size_t h)
206 {
207         size_t i;
208         uintptr_t perfect = ht_perfect_mask(ht);
209
210         i = hash_bucket(ht, h);
211
212         while (entry_is_valid(ht->table[i])) {
213                 perfect = 0;
214                 i = (i + 1) & ((1 << ht->bits)-1);
215         }
216         ht->table[i] = make_hval(ht, new, get_hash_ptr_bits(ht, h)|perfect);
217 }
218
219 static COLD bool double_table(struct htable *ht)
220 {
221         unsigned int i;
222         size_t oldnum = (size_t)1 << ht->bits;
223         uintptr_t *oldtable, e;
224
225         oldtable = ht->table;
226         ht->table = htable_alloc(ht, sizeof(size_t) << (ht->bits+1));
227         if (!ht->table) {
228                 ht->table = oldtable;
229                 return false;
230         }
231         ht->bits++;
232
233         /* If we lost our "perfect bit", get it back now. */
234         if (ht->perfect_bitnum == NO_PERFECT_BIT && ht->common_mask) {
235                 for (i = 0; i < sizeof(ht->common_mask) * CHAR_BIT; i++) {
236                         if (ht->common_mask & ((size_t)2 << i)) {
237                                 ht->perfect_bitnum = i;
238                                 break;
239                         }
240                 }
241         }
242
243         if (oldtable != &ht->common_bits) {
244                 for (i = 0; i < oldnum; i++) {
245                         if (entry_is_valid(e = oldtable[i])) {
246                                 void *p = get_raw_ptr(ht, e);
247                                 ht_add(ht, p, ht->rehash(p, ht->priv));
248                         }
249                 }
250                 htable_free(ht, oldtable);
251         }
252         ht->deleted = 0;
253
254         (void)htable_debug(ht, HTABLE_LOC);
255         return true;
256 }
257
258 static COLD void rehash_table(struct htable *ht)
259 {
260         size_t start, i;
261         uintptr_t e, perfect = ht_perfect_mask(ht);
262
263         /* Beware wrap cases: we need to start from first empty bucket. */
264         for (start = 0; ht->table[start]; start++);
265
266         for (i = 0; i < (size_t)1 << ht->bits; i++) {
267                 size_t h = (i + start) & ((1 << ht->bits)-1);
268                 e = ht->table[h];
269                 if (!e)
270                         continue;
271                 if (e == HTABLE_DELETED)
272                         ht->table[h] = 0;
273                 else if (!(e & perfect)) {
274                         void *p = get_raw_ptr(ht, e);
275                         ht->table[h] = 0;
276                         ht_add(ht, p, ht->rehash(p, ht->priv));
277                 }
278         }
279         ht->deleted = 0;
280         (void)htable_debug(ht, HTABLE_LOC);
281 }
282
283 /* We stole some bits, now we need to put them back... */
284 static COLD void update_common(struct htable *ht, const void *p)
285 {
286         unsigned int i;
287         uintptr_t maskdiff, bitsdiff;
288
289         if (ht->elems == 0) {
290                 /* Always reveal one bit of the pointer in the bucket,
291                  * so it's not zero or HTABLE_DELETED (1), even if
292                  * hash happens to be 0.  Assumes (void *)1 is not a
293                  * valid pointer. */
294                 for (i = sizeof(uintptr_t)*CHAR_BIT - 1; i > 0; i--) {
295                         if ((uintptr_t)p & ((uintptr_t)1 << i))
296                                 break;
297                 }
298
299                 ht->common_mask = ~((uintptr_t)1 << i);
300                 ht->common_bits = ((uintptr_t)p & ht->common_mask);
301                 ht->perfect_bitnum = 0;
302                 (void)htable_debug(ht, HTABLE_LOC);
303                 return;
304         }
305
306         /* Find bits which are unequal to old common set. */
307         maskdiff = ht->common_bits ^ ((uintptr_t)p & ht->common_mask);
308
309         /* These are the bits which go there in existing entries. */
310         bitsdiff = ht->common_bits & maskdiff;
311
312         for (i = 0; i < (size_t)1 << ht->bits; i++) {
313                 if (!entry_is_valid(ht->table[i]))
314                         continue;
315                 /* Clear the bits no longer in the mask, set them as
316                  * expected. */
317                 ht->table[i] &= ~maskdiff;
318                 ht->table[i] |= bitsdiff;
319         }
320
321         /* Take away those bits from our mask, bits and perfect bit. */
322         ht->common_mask &= ~maskdiff;
323         ht->common_bits &= ~maskdiff;
324         if (ht_perfect_mask(ht) & maskdiff)
325                 ht->perfect_bitnum = NO_PERFECT_BIT;
326         (void)htable_debug(ht, HTABLE_LOC);
327 }
328
329 bool htable_add_(struct htable *ht, size_t hash, const void *p)
330 {
331         if (ht->elems+1 > ht_max(ht) && !double_table(ht))
332                 return false;
333         if (ht->elems+1 + ht->deleted > ht_max_with_deleted(ht))
334                 rehash_table(ht);
335         assert(p);
336         if (((uintptr_t)p & ht->common_mask) != ht->common_bits)
337                 update_common(ht, p);
338
339         ht_add(ht, p, hash);
340         ht->elems++;
341         return true;
342 }
343
344 bool htable_del_(struct htable *ht, size_t h, const void *p)
345 {
346         struct htable_iter i;
347         void *c;
348
349         for (c = htable_firstval(ht,&i,h); c; c = htable_nextval(ht,&i,h)) {
350                 if (c == p) {
351                         htable_delval(ht, &i);
352                         return true;
353                 }
354         }
355         return false;
356 }
357
358 void htable_delval_(struct htable *ht, struct htable_iter *i)
359 {
360         assert(i->off < (size_t)1 << ht->bits);
361         assert(entry_is_valid(ht->table[i->off]));
362
363         ht->elems--;
364         ht->table[i->off] = HTABLE_DELETED;
365         ht->deleted++;
366 }
367
368 struct htable *htable_check(const struct htable *ht, const char *abortstr)
369 {
370         void *p;
371         struct htable_iter i;
372         size_t n = 0;
373
374         /* Use non-DEBUG versions here, to avoid infinite recursion with
375          * CCAN_HTABLE_DEBUG! */
376         for (p = htable_first_(ht, &i); p; p = htable_next_(ht, &i)) {
377                 struct htable_iter i2;
378                 void *c;
379                 size_t h = ht->rehash(p, ht->priv);
380                 bool found = false;
381
382                 n++;
383
384                 /* Open-code htable_get to avoid CCAN_HTABLE_DEBUG */
385                 for (c = htable_firstval_(ht, &i2, h);
386                      c;
387                      c = htable_nextval_(ht, &i2, h)) {
388                         if (c == p) {
389                                 found = true;
390                                 break;
391                         }
392                 }
393
394                 if (!found) {
395                         if (abortstr) {
396                                 fprintf(stderr,
397                                         "%s: element %p in position %zu"
398                                         " cannot find itself\n",
399                                         abortstr, p, i.off);
400                                 abort();
401                         }
402                         return NULL;
403                 }
404         }
405         if (n != ht->elems) {
406                 if (abortstr) {
407                         fprintf(stderr,
408                                 "%s: found %zu elems, expected %zu\n",
409                                 abortstr, n, ht->elems);
410                         abort();
411                 }
412                 return NULL;
413         }
414
415         return (struct htable *)ht;
416 }