2 * Copyright (c) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission to use, copy, modify, and/or distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
22 static AvlNode *mkNode(const void *key, const void *value);
23 static void freeNode(AvlNode *node);
25 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key);
27 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value);
28 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret);
29 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
31 static int sway(AvlNode **p, int sway);
32 static void balance(AvlNode **p, int side);
34 static bool checkBalances(AvlNode *node, int *height);
35 static bool checkOrder(AVL *avl);
36 static size_t countNode(AvlNode *node);
39 * Utility macros for converting between
40 * "balance" values (-1 or 1) and "side" values (0 or 1).
47 #define bal(side) ((side) == 0 ? -1 : 1)
48 #define side(bal) ((bal) == 1 ? 1 : 0)
50 static int sign(int cmp)
59 AVL *avl_new(AvlCompare compare)
61 AVL *avl = malloc(sizeof(*avl));
65 avl->compare = compare;
71 void avl_free(AVL *avl)
77 void *avl_lookup(const AVL *avl, const void *key)
79 AvlNode *found = lookup(avl, avl->root, key);
80 return found ? (void*) found->value : NULL;
83 AvlNode *avl_lookup_node(const AVL *avl, const void *key)
85 return lookup(avl, avl->root, key);
88 size_t avl_count(const AVL *avl)
93 bool avl_insert(AVL *avl, const void *key, const void *value)
95 size_t old_count = avl->count;
96 insert(avl, &avl->root, key, value);
97 return avl->count != old_count;
100 bool avl_remove(AVL *avl, const void *key)
102 AvlNode *node = NULL;
104 remove(avl, &avl->root, key, &node);
114 static AvlNode *mkNode(const void *key, const void *value)
116 AvlNode *node = malloc(sizeof(*node));
118 assert(node != NULL);
128 static void freeNode(AvlNode *node)
131 freeNode(node->lr[0]);
132 freeNode(node->lr[1]);
137 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key)
144 cmp = avl->compare(key, node->key);
147 return lookup(avl, node->lr[0], key);
149 return lookup(avl, node->lr[1], key);
154 * Insert a key/value into a subtree, rebalancing if necessary.
156 * Return true if the subtree's height increased.
158 static bool insert(AVL *avl, AvlNode **p, const void *key, const void *value)
161 *p = mkNode(key, value);
166 int cmp = sign(avl->compare(key, node->key));
174 if (!insert(avl, &node->lr[side(cmp)], key, value))
177 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
178 return sway(p, cmp) != 0;
183 * Remove the node matching the given key.
184 * If present, return the removed node through *ret .
185 * The returned node's lr and balance are meaningless.
187 * Return true if the subtree's height decreased.
189 static bool remove(AVL *avl, AvlNode **p, const void *key, AvlNode **ret)
195 int cmp = sign(avl->compare(key, node->key));
201 if (node->lr[0] != NULL && node->lr[1] != NULL) {
202 AvlNode *replacement;
206 /* Pick a subtree to pull the replacement from such that
207 * this node doesn't have to be rebalanced. */
208 side = node->balance <= 0 ? 0 : 1;
210 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
212 replacement->lr[0] = node->lr[0];
213 replacement->lr[1] = node->lr[1];
214 replacement->balance = node->balance;
220 replacement->balance -= bal(side);
222 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
223 return replacement->balance == 0;
226 if (node->lr[0] != NULL)
234 if (!remove(avl, &node->lr[side(cmp)], key, ret))
237 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
238 return sway(p, -cmp) == 0;
244 * Remove either the left-most (if side == 0) or right-most (if side == 1)
245 * node in a subtree, returning the removed node through *ret .
246 * The returned node's lr and balance are meaningless.
248 * The subtree must not be empty (i.e. *p must not be NULL).
250 * Return true if the subtree's height decreased.
252 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
256 if (node->lr[side] == NULL) {
258 *p = node->lr[1 - side];
262 if (!removeExtremum(&node->lr[side], side, ret))
265 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
266 return sway(p, -bal(side)) == 0;
270 * Rebalance a node if necessary. Think of this function
271 * as a higher-level interface to balance().
273 * sway must be either -1 or 1, and indicates what was added to
274 * the balance of this node by a prior operation.
276 * Return the new balance of the subtree.
278 static int sway(AvlNode **p, int sway)
280 if ((*p)->balance != sway)
281 (*p)->balance += sway;
283 balance(p, side(sway));
285 return (*p)->balance;
289 * Perform tree rotations on an unbalanced node.
291 * side == 0 means the node's balance is -2 .
292 * side == 1 means the node's balance is +2 .
294 static void balance(AvlNode **p, int side)
297 *child = node->lr[side];
298 int opposite = 1 - side;
301 if (child->balance != -bal) {
302 /* Left-left (side == 0) or right-right (side == 1) */
303 node->lr[side] = child->lr[opposite];
304 child->lr[opposite] = node;
307 child->balance -= bal;
308 node->balance = -child->balance;
311 /* Left-right (side == 0) or right-left (side == 1) */
312 AvlNode *grandchild = child->lr[opposite];
314 node->lr[side] = grandchild->lr[opposite];
315 child->lr[opposite] = grandchild->lr[side];
316 grandchild->lr[side] = child;
317 grandchild->lr[opposite] = node;
323 if (grandchild->balance == bal)
324 node->balance = -bal;
325 else if (grandchild->balance == -bal)
326 child->balance = bal;
328 grandchild->balance = 0;
333 /************************* avl_check_invariants() *************************/
335 bool avl_check_invariants(AVL *avl)
339 return checkBalances(avl->root, &dummy)
341 && countNode(avl->root) == avl->count;
344 static bool checkBalances(AvlNode *node, int *height)
349 if (!checkBalances(node->lr[0], &h0))
351 if (!checkBalances(node->lr[1], &h1))
354 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
357 *height = (h0 > h1 ? h0 : h1) + 1;
365 static bool checkOrder(AVL *avl)
368 const void *last = NULL;
369 bool last_set = false;
371 avl_foreach(i, avl) {
372 if (last_set && avl->compare(last, i.key) >= 0)
381 static size_t countNode(AvlNode *node)
384 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
390 /************************* Traversal *************************/
392 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
394 AvlNode *node = avl->root;
396 iter->stack_index = 0;
397 iter->direction = dir;
406 while (node->lr[dir] != NULL) {
407 iter->stack[iter->stack_index++] = node;
408 node = node->lr[dir];
411 iter->key = (void*) node->key;
412 iter->value = (void*) node->value;
416 void avl_iter_next(AvlIter *iter)
418 AvlNode *node = iter->node;
419 AvlDirection dir = iter->direction;
424 node = node->lr[1 - dir];
426 while (node->lr[dir] != NULL) {
427 iter->stack[iter->stack_index++] = node;
428 node = node->lr[dir];
430 } else if (iter->stack_index > 0) {
431 node = iter->stack[--iter->stack_index];
440 iter->key = (void*) node->key;
441 iter->value = (void*) node->value;