ccecda7e32ad7a08d5ab36bb458d1519a61024b6
[ccan] / ccan / avl / avl.c
1 /*
2  * Copyright (c) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
3  *
4  * Permission to use, copy, modify, and/or distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  *
8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15  */
16
17 #include "avl.h"
18
19 #include <assert.h>
20 #include <stdlib.h>
21
22 static AvlNode *mkNode(const void *key, const void *value);
23 static void freeNode(AvlNode *node);
24
25 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key);
26
27 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value);
28 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret);
29 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
30
31 static int sway(AvlNode **p, int sway);
32 static void balance(AvlNode **p, int side);
33
34 static bool checkBalances(AvlNode *node, int *height);
35 static bool checkOrder(AVL *avl);
36 static size_t countNode(AvlNode *node);
37
38 /*
39  * Utility macros for converting between
40  * "balance" values (-1 or 1) and "side" values (0 or 1).
41  *
42  * bal(0)   == -1
43  * bal(1)   == +1
44  * side(-1) == 0
45  * side(+1) == 1
46  */
47 #define bal(side) ((side) == 0 ? -1 : 1)
48 #define side(bal) ((bal)  == 1 ?  1 : 0)
49
50 static int sign(int cmp)
51 {
52         if (cmp < 0)
53                 return -1;
54         if (cmp == 0)
55                 return 0;
56         return 1;
57 }
58
59 AVL *avl_new(AvlCompare compare)
60 {
61         AVL *avl = malloc(sizeof(*avl));
62         
63         assert(avl != NULL);
64         
65         avl->compare = compare;
66         avl->root = NULL;
67         avl->count = 0;
68         return avl;
69 }
70
71 void avl_free(AVL *avl)
72 {
73         freeNode(avl->root);
74         free(avl);
75 }
76
77 void *avl_lookup(const AVL *avl, const void *key)
78 {
79         AvlNode *found = lookup(avl, avl->root, key);
80         return found ? (void*) found->value : NULL;
81 }
82
83 AvlNode *avl_lookup_node(const AVL *avl, const void *key)
84 {
85         return lookup(avl, avl->root, key);
86 }
87
88 size_t avl_count(const AVL *avl)
89 {
90         return avl->count;
91 }
92
93 bool avl_insert(AVL *avl, const void *key, const void *value)
94 {
95         size_t old_count = avl->count;
96         insert(avl, &avl->root, key, value);
97         return avl->count != old_count;
98 }
99
100 bool avl_remove(AVL *avl, const void *key)
101 {
102         AvlNode *node = NULL;
103         
104         remove(avl, &avl->root, key, &node);
105         
106         if (node == NULL) {
107                 return false;
108         } else {
109                 free(node);
110                 return true;
111         }
112 }
113
114 static AvlNode *mkNode(const void *key, const void *value)
115 {
116         AvlNode *node = malloc(sizeof(*node));
117         
118         assert(node != NULL);
119         
120         node->key = key;
121         node->value = value;
122         node->lr[0] = NULL;
123         node->lr[1] = NULL;
124         node->balance = 0;
125         return node;
126 }
127
128 static void freeNode(AvlNode *node)
129 {
130         if (node) {
131                 freeNode(node->lr[0]);
132                 freeNode(node->lr[1]);
133                 free(node);
134         }
135 }
136
137 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key)
138 {
139         int cmp;
140         
141         if (node == NULL)
142                 return NULL;
143         
144         cmp = avl->compare(key, node->key);
145         
146         if (cmp < 0)
147                 return lookup(avl, node->lr[0], key);
148         if (cmp > 0)
149                 return lookup(avl, node->lr[1], key);
150         return node;
151 }
152
153 /*
154  * Insert a key/value into a subtree, rebalancing if necessary.
155  *
156  * Return true if the subtree's height increased.
157  */
158 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value)
159 {
160         if (*p == NULL) {
161                 *p = mkNode(key, value);
162                 avl->count++;
163                 return true;
164         } else {
165                 AvlNode *node = *p;
166                 int      cmp  = sign(avl->compare(key, node->key));
167                 
168                 if (cmp == 0) {
169                         node->key = key;
170                         node->value = value;
171                         return false;
172                 }
173                 
174                 if (!insert(avl, &node->lr[side(cmp)], key, value))
175                         return false;
176                 
177                 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
178                 return sway(p, cmp) != 0;
179         }
180 }
181
182 /*
183  * Remove the node matching the given key.
184  * If present, return the removed node through *ret .
185  * The returned node's lr and balance are meaningless.
186  *
187  * Return true if the subtree's height decreased.
188  */
189 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret)
190 {
191         if (*p == NULL) {
192                 return false;
193         } else {
194                 AvlNode *node = *p;
195                 int      cmp  = sign(avl->compare(key, node->key));
196                 
197                 if (cmp == 0) {
198                         *ret = node;
199                         avl->count--;
200                         
201                         if (node->lr[0] != NULL && node->lr[1] != NULL) {
202                                 AvlNode *replacement;
203                                 int      side;
204                                 bool     shrunk;
205                                 
206                                 /* Pick a subtree to pull the replacement from such that
207                                  * this node doesn't have to be rebalanced. */
208                                 side = node->balance <= 0 ? 0 : 1;
209                                 
210                                 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
211                                 
212                                 replacement->lr[0]   = node->lr[0];
213                                 replacement->lr[1]   = node->lr[1];
214                                 replacement->balance = node->balance;
215                                 *p = replacement;
216                                 
217                                 if (!shrunk)
218                                         return false;
219                                 
220                                 replacement->balance -= bal(side);
221                                 
222                                 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
223                                 return replacement->balance == 0;
224                         }
225                         
226                         if (node->lr[0] != NULL)
227                                 *p = node->lr[0];
228                         else
229                                 *p = node->lr[1];
230                         
231                         return true;
232                         
233                 } else {
234                         if (!remove(avl, &node->lr[side(cmp)], key, ret))
235                                 return false;
236                         
237                         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
238                         return sway(p, -cmp) == 0;
239                 }
240         }
241 }
242
243 /*
244  * Remove either the left-most (if side == 0) or right-most (if side == 1)
245  * node in a subtree, returning the removed node through *ret .
246  * The returned node's lr and balance are meaningless.
247  *
248  * The subtree must not be empty (i.e. *p must not be NULL).
249  *
250  * Return true if the subtree's height decreased.
251  */
252 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
253 {
254         AvlNode *node = *p;
255         
256         if (node->lr[side] == NULL) {
257                 *ret = node;
258                 *p = node->lr[1 - side];
259                 return true;
260         }
261         
262         if (!removeExtremum(&node->lr[side], side, ret))
263                 return false;
264         
265         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
266         return sway(p, -bal(side)) == 0;
267 }
268
269 /*
270  * Rebalance a node if necessary.  Think of this function
271  * as a higher-level interface to balance().
272  *
273  * sway must be either -1 or 1, and indicates what was added to
274  * the balance of this node by a prior operation.
275  *
276  * Return the new balance of the subtree.
277  */
278 static int sway(AvlNode **p, int sway)
279 {
280         if ((*p)->balance != sway)
281                 (*p)->balance += sway;
282         else
283                 balance(p, side(sway));
284         
285         return (*p)->balance;
286 }
287
288 /*
289  * Perform tree rotations on an unbalanced node.
290  *
291  * side == 0 means the node's balance is -2 .
292  * side == 1 means the node's balance is +2 .
293  */
294 static void balance(AvlNode **p, int side)
295 {
296         AvlNode  *node  = *p,
297                  *child = node->lr[side];
298         int opposite    = 1 - side;
299         int bal         = bal(side);
300         
301         if (child->balance != -bal) {
302                 /* Left-left (side == 0) or right-right (side == 1) */
303                 node->lr[side]      = child->lr[opposite];
304                 child->lr[opposite] = node;
305                 *p = child;
306                 
307                 child->balance -= bal;
308                 node->balance = -child->balance;
309                 
310         } else {
311                 /* Left-right (side == 0) or right-left (side == 1) */
312                 AvlNode *grandchild = child->lr[opposite];
313                 
314                 node->lr[side]           = grandchild->lr[opposite];
315                 child->lr[opposite]      = grandchild->lr[side];
316                 grandchild->lr[side]     = child;
317                 grandchild->lr[opposite] = node;
318                 *p = grandchild;
319                 
320                 node->balance       = 0;
321                 child->balance      = 0;
322                 
323                 if (grandchild->balance == bal)
324                         node->balance  = -bal;
325                 else if (grandchild->balance == -bal)
326                         child->balance = bal;
327                 
328                 grandchild->balance = 0;
329         }
330 }
331
332
333 /************************* avl_check_invariants() *************************/
334
335 bool avl_check_invariants(AVL *avl)
336 {
337         int    dummy;
338         
339         return checkBalances(avl->root, &dummy)
340             && checkOrder(avl)
341             && countNode(avl->root) == avl->count;
342 }
343
344 static bool checkBalances(AvlNode *node, int *height)
345 {
346         if (node) {
347                 int h0, h1;
348                 
349                 if (!checkBalances(node->lr[0], &h0))
350                         return false;
351                 if (!checkBalances(node->lr[1], &h1))
352                         return false;
353                 
354                 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
355                         return false;
356                 
357                 *height = (h0 > h1 ? h0 : h1) + 1;
358                 return true;
359         } else {
360                 *height = 0;
361                 return true;
362         }
363 }
364
365 static bool checkOrder(AVL *avl)
366 {
367         AvlIter     i;
368         const void *last     = NULL;
369         bool        last_set = false;
370         
371         avl_foreach(i, avl) {
372                 if (last_set && avl->compare(last, i.key) >= 0)
373                         return false;
374                 last     = i.key;
375                 last_set = true;
376         }
377         
378         return true;
379 }
380
381 static size_t countNode(AvlNode *node)
382 {
383         if (node)
384                 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
385         else
386                 return 0;
387 }
388
389
390 /************************* Traversal *************************/
391
392 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
393 {
394         AvlNode *node = avl->root;
395         
396         iter->stack_index = 0;
397         iter->direction   = dir;
398         
399         if (node == NULL) {
400                 iter->key      = NULL;
401                 iter->value    = NULL;
402                 iter->node     = NULL;
403                 return;
404         }
405         
406         while (node->lr[dir] != NULL) {
407                 iter->stack[iter->stack_index++] = node;
408                 node = node->lr[dir];
409         }
410         
411         iter->key   = (void*) node->key;
412         iter->value = (void*) node->value;
413         iter->node  = node;
414 }
415
416 void avl_iter_next(AvlIter *iter)
417 {
418         AvlNode     *node = iter->node;
419         AvlDirection dir  = iter->direction;
420         
421         if (node == NULL)
422                 return;
423         
424         node = node->lr[1 - dir];
425         if (node != NULL) {
426                 while (node->lr[dir] != NULL) {
427                         iter->stack[iter->stack_index++] = node;
428                         node = node->lr[dir];
429                 }
430         } else if (iter->stack_index > 0) {
431                 node = iter->stack[--iter->stack_index];
432         } else {
433                 iter->key      = NULL;
434                 iter->value    = NULL;
435                 iter->node     = NULL;
436                 return;
437         }
438         
439         iter->node  = node;
440         iter->key   = (void*) node->key;
441         iter->value = (void*) node->value;
442 }