]> git.ozlabs.org Git - ccan/blob - ccan/avl/avl.c
build_assert: relicense to public domain.
[ccan] / ccan / avl / avl.c
1 /*
2  * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
3  * 
4  * Permission is hereby granted, free of charge, to any person obtaining a copy
5  * of this software and associated documentation files (the "Software"), to deal
6  * in the Software without restriction, including without limitation the rights
7  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8  * copies of the Software, and to permit persons to whom the Software is
9  * furnished to do so, subject to the following conditions:
10  * 
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  * 
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20  * THE SOFTWARE.
21  */
22
23 #include "avl.h"
24
25 #include <assert.h>
26 #include <stdlib.h>
27
28 static AvlNode *mkNode(const void *key, const void *value);
29 static void freeNode(AvlNode *node);
30
31 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key);
32
33 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value);
34 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret);
35 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
36
37 static int sway(AvlNode **p, int sway);
38 static void balance(AvlNode **p, int side);
39
40 static bool checkBalances(AvlNode *node, int *height);
41 static bool checkOrder(AVL *avl);
42 static size_t countNode(AvlNode *node);
43
44 /*
45  * Utility macros for converting between
46  * "balance" values (-1 or 1) and "side" values (0 or 1).
47  *
48  * bal(0)   == -1
49  * bal(1)   == +1
50  * side(-1) == 0
51  * side(+1) == 1
52  */
53 #define bal(side) ((side) == 0 ? -1 : 1)
54 #define side(bal) ((bal)  == 1 ?  1 : 0)
55
56 static int sign(int cmp)
57 {
58         if (cmp < 0)
59                 return -1;
60         if (cmp == 0)
61                 return 0;
62         return 1;
63 }
64
65 AVL *avl_new(AvlCompare compare)
66 {
67         AVL *avl = malloc(sizeof(*avl));
68         
69         assert(avl != NULL);
70         
71         avl->compare = compare;
72         avl->root = NULL;
73         avl->count = 0;
74         return avl;
75 }
76
77 void avl_free(AVL *avl)
78 {
79         freeNode(avl->root);
80         free(avl);
81 }
82
83 void *avl_lookup(const AVL *avl, const void *key)
84 {
85         AvlNode *found = lookup(avl, avl->root, key);
86         return found ? (void*) found->value : NULL;
87 }
88
89 AvlNode *avl_lookup_node(const AVL *avl, const void *key)
90 {
91         return lookup(avl, avl->root, key);
92 }
93
94 size_t avl_count(const AVL *avl)
95 {
96         return avl->count;
97 }
98
99 bool avl_insert(AVL *avl, const void *key, const void *value)
100 {
101         size_t old_count = avl->count;
102         insert(avl, &avl->root, key, value);
103         return avl->count != old_count;
104 }
105
106 bool avl_remove(AVL *avl, const void *key)
107 {
108         AvlNode *node = NULL;
109         
110         remove(avl, &avl->root, key, &node);
111         
112         if (node == NULL) {
113                 return false;
114         } else {
115                 free(node);
116                 return true;
117         }
118 }
119
120 static AvlNode *mkNode(const void *key, const void *value)
121 {
122         AvlNode *node = malloc(sizeof(*node));
123         
124         assert(node != NULL);
125         
126         node->key = key;
127         node->value = value;
128         node->lr[0] = NULL;
129         node->lr[1] = NULL;
130         node->balance = 0;
131         return node;
132 }
133
134 static void freeNode(AvlNode *node)
135 {
136         if (node) {
137                 freeNode(node->lr[0]);
138                 freeNode(node->lr[1]);
139                 free(node);
140         }
141 }
142
143 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key)
144 {
145         int cmp;
146         
147         if (node == NULL)
148                 return NULL;
149         
150         cmp = avl->compare(key, node->key);
151         
152         if (cmp < 0)
153                 return lookup(avl, node->lr[0], key);
154         if (cmp > 0)
155                 return lookup(avl, node->lr[1], key);
156         return node;
157 }
158
159 /*
160  * Insert a key/value into a subtree, rebalancing if necessary.
161  *
162  * Return true if the subtree's height increased.
163  */
164 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value)
165 {
166         if (*p == NULL) {
167                 *p = mkNode(key, value);
168                 avl->count++;
169                 return true;
170         } else {
171                 AvlNode *node = *p;
172                 int      cmp  = sign(avl->compare(key, node->key));
173                 
174                 if (cmp == 0) {
175                         node->key = key;
176                         node->value = value;
177                         return false;
178                 }
179                 
180                 if (!insert(avl, &node->lr[side(cmp)], key, value))
181                         return false;
182                 
183                 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
184                 return sway(p, cmp) != 0;
185         }
186 }
187
188 /*
189  * Remove the node matching the given key.
190  * If present, return the removed node through *ret .
191  * The returned node's lr and balance are meaningless.
192  *
193  * Return true if the subtree's height decreased.
194  */
195 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret)
196 {
197         if (*p == NULL) {
198                 return false;
199         } else {
200                 AvlNode *node = *p;
201                 int      cmp  = sign(avl->compare(key, node->key));
202                 
203                 if (cmp == 0) {
204                         *ret = node;
205                         avl->count--;
206                         
207                         if (node->lr[0] != NULL && node->lr[1] != NULL) {
208                                 AvlNode *replacement;
209                                 int      side;
210                                 bool     shrunk;
211                                 
212                                 /* Pick a subtree to pull the replacement from such that
213                                  * this node doesn't have to be rebalanced. */
214                                 side = node->balance <= 0 ? 0 : 1;
215                                 
216                                 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
217                                 
218                                 replacement->lr[0]   = node->lr[0];
219                                 replacement->lr[1]   = node->lr[1];
220                                 replacement->balance = node->balance;
221                                 *p = replacement;
222                                 
223                                 if (!shrunk)
224                                         return false;
225                                 
226                                 replacement->balance -= bal(side);
227                                 
228                                 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
229                                 return replacement->balance == 0;
230                         }
231                         
232                         if (node->lr[0] != NULL)
233                                 *p = node->lr[0];
234                         else
235                                 *p = node->lr[1];
236                         
237                         return true;
238                         
239                 } else {
240                         if (!remove(avl, &node->lr[side(cmp)], key, ret))
241                                 return false;
242                         
243                         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
244                         return sway(p, -cmp) == 0;
245                 }
246         }
247 }
248
249 /*
250  * Remove either the left-most (if side == 0) or right-most (if side == 1)
251  * node in a subtree, returning the removed node through *ret .
252  * The returned node's lr and balance are meaningless.
253  *
254  * The subtree must not be empty (i.e. *p must not be NULL).
255  *
256  * Return true if the subtree's height decreased.
257  */
258 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
259 {
260         AvlNode *node = *p;
261         
262         if (node->lr[side] == NULL) {
263                 *ret = node;
264                 *p = node->lr[1 - side];
265                 return true;
266         }
267         
268         if (!removeExtremum(&node->lr[side], side, ret))
269                 return false;
270         
271         /* If tree's balance became 0, it means the tree's height shrank due to removal. */
272         return sway(p, -bal(side)) == 0;
273 }
274
275 /*
276  * Rebalance a node if necessary.  Think of this function
277  * as a higher-level interface to balance().
278  *
279  * sway must be either -1 or 1, and indicates what was added to
280  * the balance of this node by a prior operation.
281  *
282  * Return the new balance of the subtree.
283  */
284 static int sway(AvlNode **p, int sway)
285 {
286         if ((*p)->balance != sway)
287                 (*p)->balance += sway;
288         else
289                 balance(p, side(sway));
290         
291         return (*p)->balance;
292 }
293
294 /*
295  * Perform tree rotations on an unbalanced node.
296  *
297  * side == 0 means the node's balance is -2 .
298  * side == 1 means the node's balance is +2 .
299  */
300 static void balance(AvlNode **p, int side)
301 {
302         AvlNode  *node  = *p,
303                  *child = node->lr[side];
304         int opposite    = 1 - side;
305         int bal         = bal(side);
306         
307         if (child->balance != -bal) {
308                 /* Left-left (side == 0) or right-right (side == 1) */
309                 node->lr[side]      = child->lr[opposite];
310                 child->lr[opposite] = node;
311                 *p = child;
312                 
313                 child->balance -= bal;
314                 node->balance = -child->balance;
315                 
316         } else {
317                 /* Left-right (side == 0) or right-left (side == 1) */
318                 AvlNode *grandchild = child->lr[opposite];
319                 
320                 node->lr[side]           = grandchild->lr[opposite];
321                 child->lr[opposite]      = grandchild->lr[side];
322                 grandchild->lr[side]     = child;
323                 grandchild->lr[opposite] = node;
324                 *p = grandchild;
325                 
326                 node->balance       = 0;
327                 child->balance      = 0;
328                 
329                 if (grandchild->balance == bal)
330                         node->balance  = -bal;
331                 else if (grandchild->balance == -bal)
332                         child->balance = bal;
333                 
334                 grandchild->balance = 0;
335         }
336 }
337
338
339 /************************* avl_check_invariants() *************************/
340
341 bool avl_check_invariants(AVL *avl)
342 {
343         int    dummy;
344         
345         return checkBalances(avl->root, &dummy)
346             && checkOrder(avl)
347             && countNode(avl->root) == avl->count;
348 }
349
350 static bool checkBalances(AvlNode *node, int *height)
351 {
352         if (node) {
353                 int h0, h1;
354                 
355                 if (!checkBalances(node->lr[0], &h0))
356                         return false;
357                 if (!checkBalances(node->lr[1], &h1))
358                         return false;
359                 
360                 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
361                         return false;
362                 
363                 *height = (h0 > h1 ? h0 : h1) + 1;
364                 return true;
365         } else {
366                 *height = 0;
367                 return true;
368         }
369 }
370
371 static bool checkOrder(AVL *avl)
372 {
373         AvlIter     i;
374         const void *last     = NULL;
375         bool        last_set = false;
376         
377         avl_foreach(i, avl) {
378                 if (last_set && avl->compare(last, i.key) >= 0)
379                         return false;
380                 last     = i.key;
381                 last_set = true;
382         }
383         
384         return true;
385 }
386
387 static size_t countNode(AvlNode *node)
388 {
389         if (node)
390                 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
391         else
392                 return 0;
393 }
394
395
396 /************************* Traversal *************************/
397
398 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
399 {
400         AvlNode *node = avl->root;
401         
402         iter->stack_index = 0;
403         iter->direction   = dir;
404         
405         if (node == NULL) {
406                 iter->key      = NULL;
407                 iter->value    = NULL;
408                 iter->node     = NULL;
409                 return;
410         }
411         
412         while (node->lr[dir] != NULL) {
413                 iter->stack[iter->stack_index++] = node;
414                 node = node->lr[dir];
415         }
416         
417         iter->key   = (void*) node->key;
418         iter->value = (void*) node->value;
419         iter->node  = node;
420 }
421
422 void avl_iter_next(AvlIter *iter)
423 {
424         AvlNode     *node = iter->node;
425         AvlDirection dir  = iter->direction;
426         
427         if (node == NULL)
428                 return;
429         
430         node = node->lr[1 - dir];
431         if (node != NULL) {
432                 while (node->lr[dir] != NULL) {
433                         iter->stack[iter->stack_index++] = node;
434                         node = node->lr[dir];
435                 }
436         } else if (iter->stack_index > 0) {
437                 node = iter->stack[--iter->stack_index];
438         } else {
439                 iter->key      = NULL;
440                 iter->value    = NULL;
441                 iter->node     = NULL;
442                 return;
443         }
444         
445         iter->node  = node;
446         iter->key   = (void*) node->key;
447         iter->value = (void*) node->value;
448 }