2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28 static AvlNode *mkNode(const void *key, const void *value);
29 static void freeNode(AvlNode *node);
31 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key);
33 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value);
34 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret);
35 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
37 static int sway(AvlNode **p, int sway);
38 static void balance(AvlNode **p, int side);
40 static bool checkBalances(AvlNode *node, int *height);
41 static bool checkOrder(AVL *avl);
42 static size_t countNode(AvlNode *node);
45 * Utility macros for converting between
46 * "balance" values (-1 or 1) and "side" values (0 or 1).
53 #define bal(side) ((side) == 0 ? -1 : 1)
54 #define side(bal) ((bal) == 1 ? 1 : 0)
56 static int sign(int cmp)
65 AVL *avl_new(AvlCompare compare)
67 AVL *avl = malloc(sizeof(*avl));
71 avl->compare = compare;
77 void avl_free(AVL *avl)
83 void *avl_lookup(const AVL *avl, const void *key)
85 AvlNode *found = lookup(avl, avl->root, key);
86 return found ? (void*) found->value : NULL;
89 AvlNode *avl_lookup_node(const AVL *avl, const void *key)
91 return lookup(avl, avl->root, key);
94 size_t avl_count(const AVL *avl)
99 bool avl_insert(AVL *avl, const void *key, const void *value)
101 size_t old_count = avl->count;
102 insert(avl, &avl->root, key, value);
103 return avl->count != old_count;
106 bool avl_remove(AVL *avl, const void *key)
108 AvlNode *node = NULL;
110 remove(avl, &avl->root, key, &node);
120 static AvlNode *mkNode(const void *key, const void *value)
122 AvlNode *node = malloc(sizeof(*node));
124 assert(node != NULL);
134 static void freeNode(AvlNode *node)
137 freeNode(node->lr[0]);
138 freeNode(node->lr[1]);
143 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key)
150 cmp = avl->compare(key, node->key);
153 return lookup(avl, node->lr[0], key);
155 return lookup(avl, node->lr[1], key);
160 * Insert a key/value into a subtree, rebalancing if necessary.
162 * Return true if the subtree's height increased.
164 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value)
167 *p = mkNode(key, value);
172 int cmp = sign(avl->compare(key, node->key));
180 if (!insert(avl, &node->lr[side(cmp)], key, value))
183 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
184 return sway(p, cmp) != 0;
189 * Remove the node matching the given key.
190 * If present, return the removed node through *ret .
191 * The returned node's lr and balance are meaningless.
193 * Return true if the subtree's height decreased.
195 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret)
201 int cmp = sign(avl->compare(key, node->key));
207 if (node->lr[0] != NULL && node->lr[1] != NULL) {
208 AvlNode *replacement;
212 /* Pick a subtree to pull the replacement from such that
213 * this node doesn't have to be rebalanced. */
214 side = node->balance <= 0 ? 0 : 1;
216 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
218 replacement->lr[0] = node->lr[0];
219 replacement->lr[1] = node->lr[1];
220 replacement->balance = node->balance;
226 replacement->balance -= bal(side);
228 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
229 return replacement->balance == 0;
232 if (node->lr[0] != NULL)
240 if (!remove(avl, &node->lr[side(cmp)], key, ret))
243 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
244 return sway(p, -cmp) == 0;
250 * Remove either the left-most (if side == 0) or right-most (if side == 1)
251 * node in a subtree, returning the removed node through *ret .
252 * The returned node's lr and balance are meaningless.
254 * The subtree must not be empty (i.e. *p must not be NULL).
256 * Return true if the subtree's height decreased.
258 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
262 if (node->lr[side] == NULL) {
264 *p = node->lr[1 - side];
268 if (!removeExtremum(&node->lr[side], side, ret))
271 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
272 return sway(p, -bal(side)) == 0;
276 * Rebalance a node if necessary. Think of this function
277 * as a higher-level interface to balance().
279 * sway must be either -1 or 1, and indicates what was added to
280 * the balance of this node by a prior operation.
282 * Return the new balance of the subtree.
284 static int sway(AvlNode **p, int sway)
286 if ((*p)->balance != sway)
287 (*p)->balance += sway;
289 balance(p, side(sway));
291 return (*p)->balance;
295 * Perform tree rotations on an unbalanced node.
297 * side == 0 means the node's balance is -2 .
298 * side == 1 means the node's balance is +2 .
300 static void balance(AvlNode **p, int side)
303 *child = node->lr[side];
304 int opposite = 1 - side;
307 if (child->balance != -bal) {
308 /* Left-left (side == 0) or right-right (side == 1) */
309 node->lr[side] = child->lr[opposite];
310 child->lr[opposite] = node;
313 child->balance -= bal;
314 node->balance = -child->balance;
317 /* Left-right (side == 0) or right-left (side == 1) */
318 AvlNode *grandchild = child->lr[opposite];
320 node->lr[side] = grandchild->lr[opposite];
321 child->lr[opposite] = grandchild->lr[side];
322 grandchild->lr[side] = child;
323 grandchild->lr[opposite] = node;
329 if (grandchild->balance == bal)
330 node->balance = -bal;
331 else if (grandchild->balance == -bal)
332 child->balance = bal;
334 grandchild->balance = 0;
339 /************************* avl_check_invariants() *************************/
341 bool avl_check_invariants(AVL *avl)
345 return checkBalances(avl->root, &dummy)
347 && countNode(avl->root) == avl->count;
350 static bool checkBalances(AvlNode *node, int *height)
355 if (!checkBalances(node->lr[0], &h0))
357 if (!checkBalances(node->lr[1], &h1))
360 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
363 *height = (h0 > h1 ? h0 : h1) + 1;
371 static bool checkOrder(AVL *avl)
374 const void *last = NULL;
375 bool last_set = false;
377 avl_foreach(i, avl) {
378 if (last_set && avl->compare(last, i.key) >= 0)
387 static size_t countNode(AvlNode *node)
390 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
396 /************************* Traversal *************************/
398 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
400 AvlNode *node = avl->root;
402 iter->stack_index = 0;
403 iter->direction = dir;
412 while (node->lr[dir] != NULL) {
413 iter->stack[iter->stack_index++] = node;
414 node = node->lr[dir];
417 iter->key = (void*) node->key;
418 iter->value = (void*) node->value;
422 void avl_iter_next(AvlIter *iter)
424 AvlNode *node = iter->node;
425 AvlDirection dir = iter->direction;
430 node = node->lr[1 - dir];
432 while (node->lr[dir] != NULL) {
433 iter->stack[iter->stack_index++] = node;
434 node = node->lr[dir];
436 } else if (iter->stack_index > 0) {
437 node = iter->stack[--iter->stack_index];
446 iter->key = (void*) node->key;
447 iter->value = (void*) node->value;