lib/waiter: Ensure waiters are consistent during waiter_poll
authorJeremy Kerr <jk@ozlabs.org>
Fri, 17 May 2013 01:38:07 +0000 (09:38 +0800)
committerJeremy Kerr <jk@ozlabs.org>
Tue, 21 May 2013 07:29:43 +0000 (15:29 +0800)
We have a bug at the moment: if the waitset's->waiters array is updated
duing waiter_poll() (eg, a client connection is closed, and the client's
callback performs a waiter_remove()), then we may invoke callbacks for
incorrect waiters.

This change uses a consistent waiters array duing execution of
waiter_poll, so that any pollfds returned from poll() will result in
correct callback invocations.

This assumes that a waiter will only ever remove *itself* from the
waitset; otherwise, we may call a free()ed waiter.

Signed-off-by: Jeremy Kerr <jk@ozlabs.org>
lib/waiter/waiter.c

index bb25784bc1b1edbbe4a8428193127ed2836029b9..78ba045f8b18fe5088d8c11422175b4654f87752 100644 (file)
@@ -1,5 +1,6 @@
 
 #include <poll.h>
+#include <stdbool.h>
 #include <string.h>
 #include <assert.h>
 
@@ -18,8 +19,14 @@ struct waiter {
 struct waitset {
        struct waiter   **waiters;
        int             n_waiters;
+       bool            waiters_changed;
+
+       /* These are kept consistent over each call to waiter_poll, as
+        * set->waiters may be updated (by waiters' callbacks calling
+        * waiter_register or waiter_remove) during iteration. */
        struct pollfd   *pollfds;
-       int             n_pollfds;
+       struct waiter   **cur_waiters;
+       int             cur_n_waiters;
 };
 
 struct waitset *waitset_create(void *ctx)
@@ -38,19 +45,22 @@ struct waiter *waiter_register(struct waitset *set, int fd, int events,
 {
        struct waiter **waiters, *waiter;
 
+       waiter = talloc(set->waiters, struct waiter);
+       if (!waiter)
+               return NULL;
+
        waiters = talloc_realloc(set, set->waiters,
                        struct waiter *, set->n_waiters + 1);
 
-       if (!waiters)
+       if (!waiters) {
+               talloc_free(waiter);
                return NULL;
+       }
 
+       set->waiters_changed = true;
        set->waiters = waiters;
        set->n_waiters++;
 
-       waiter = talloc(set->waiters, struct waiter);
-       if (!waiter)
-               return NULL;
-
        set->waiters[set->n_waiters - 1] = waiter;
 
        waiter->set = set;
@@ -79,6 +89,7 @@ void waiter_remove(struct waiter *waiter)
 
        set->waiters = talloc_realloc(set->waiters, set->waiters,
                        struct waiter *, set->n_waiters);
+       set->waiters_changed = true;
 
        talloc_free(waiter);
 }
@@ -87,29 +98,42 @@ int waiter_poll(struct waitset *set)
 {
        int i, rc;
 
-       if (set->n_waiters != set->n_pollfds) {
-               set->pollfds = talloc_realloc(set, set->pollfds,
-                               struct pollfd, set->n_waiters);
-               set->n_pollfds = set->n_waiters;
-       }
+       /* If the waiters have been updated, we need to update our
+        * consistent copy */
+       if (set->waiters_changed) {
+
+               /* We need to reallocate if the count has changes */
+               if (set->cur_n_waiters != set->n_waiters) {
+                       set->cur_waiters = talloc_realloc(set, set->cur_waiters,
+                                       struct waiter *, set->n_waiters);
+                       set->pollfds = talloc_realloc(set, set->pollfds,
+                                       struct pollfd, set->n_waiters);
+                       set->cur_n_waiters = set->n_waiters;
+               }
+
+               /* Populate cur_waiters and pollfds from ->waiters data */
+               for (i = 0; i < set->n_waiters; i++) {
+                       set->pollfds[i].fd = set->waiters[i]->fd;
+                       set->pollfds[i].events = set->waiters[i]->events;
+                       set->pollfds[i].revents = 0;
+                       set->cur_waiters[i] = set->waiters[i];
+               }
 
-       for (i = 0; i < set->n_waiters; i++) {
-               set->pollfds[i].fd = set->waiters[i]->fd;
-               set->pollfds[i].events = set->waiters[i]->events;
-               set->pollfds[i].revents = 0;
+               set->waiters_changed = false;
        }
 
-       rc = poll(set->pollfds, set->n_waiters, -1);
+       rc = poll(set->pollfds, set->cur_n_waiters, -1);
 
        if (rc <= 0)
                return rc;
 
-       for (i = 0; i < set->n_waiters; i++) {
+       for (i = 0; i < set->cur_n_waiters; i++) {
                if (set->pollfds[i].revents) {
-                       rc = set->waiters[i]->callback(set->waiters[i]->arg);
+                       rc = set->cur_waiters[i]->callback(
+                                       set->cur_waiters[i]->arg);
 
                        if (rc)
-                               waiter_remove(set->waiters[i]);
+                               waiter_remove(set->cur_waiters[i]);
                }
        }