lib/waiter: Ensure waiters are consistent during waiter_poll
[petitboot] / lib / waiter / waiter.c
1
2 #include <poll.h>
3 #include <stdbool.h>
4 #include <string.h>
5 #include <assert.h>
6
7 #include <talloc/talloc.h>
8
9 #include "waiter.h"
10
11 struct waiter {
12         struct waitset  *set;
13         int             fd;
14         int             events;
15         waiter_cb       callback;
16         void            *arg;
17 };
18
19 struct waitset {
20         struct waiter   **waiters;
21         int             n_waiters;
22         bool            waiters_changed;
23
24         /* These are kept consistent over each call to waiter_poll, as
25          * set->waiters may be updated (by waiters' callbacks calling
26          * waiter_register or waiter_remove) during iteration. */
27         struct pollfd   *pollfds;
28         struct waiter   **cur_waiters;
29         int             cur_n_waiters;
30 };
31
32 struct waitset *waitset_create(void *ctx)
33 {
34         struct waitset *set = talloc_zero(ctx, struct waitset);
35         return set;
36 }
37
38 void waitset_destroy(struct waitset *set)
39 {
40         talloc_free(set);
41 }
42
43 struct waiter *waiter_register(struct waitset *set, int fd, int events,
44                 waiter_cb callback, void *arg)
45 {
46         struct waiter **waiters, *waiter;
47
48         waiter = talloc(set->waiters, struct waiter);
49         if (!waiter)
50                 return NULL;
51
52         waiters = talloc_realloc(set, set->waiters,
53                         struct waiter *, set->n_waiters + 1);
54
55         if (!waiters) {
56                 talloc_free(waiter);
57                 return NULL;
58         }
59
60         set->waiters_changed = true;
61         set->waiters = waiters;
62         set->n_waiters++;
63
64         set->waiters[set->n_waiters - 1] = waiter;
65
66         waiter->set = set;
67         waiter->fd = fd;
68         waiter->events = events;
69         waiter->callback = callback;
70         waiter->arg = arg;
71
72         return waiter;
73 }
74
75 void waiter_remove(struct waiter *waiter)
76 {
77         struct waitset *set = waiter->set;
78         int i;
79
80         for (i = 0; i < set->n_waiters; i++)
81                 if (set->waiters[i] == waiter)
82                         break;
83
84         assert(i < set->n_waiters);
85
86         set->n_waiters--;
87         memmove(&set->waiters[i], &set->waiters[i+1],
88                 (set->n_waiters - i) * sizeof(set->waiters[0]));
89
90         set->waiters = talloc_realloc(set->waiters, set->waiters,
91                         struct waiter *, set->n_waiters);
92         set->waiters_changed = true;
93
94         talloc_free(waiter);
95 }
96
97 int waiter_poll(struct waitset *set)
98 {
99         int i, rc;
100
101         /* If the waiters have been updated, we need to update our
102          * consistent copy */
103         if (set->waiters_changed) {
104
105                 /* We need to reallocate if the count has changes */
106                 if (set->cur_n_waiters != set->n_waiters) {
107                         set->cur_waiters = talloc_realloc(set, set->cur_waiters,
108                                         struct waiter *, set->n_waiters);
109                         set->pollfds = talloc_realloc(set, set->pollfds,
110                                         struct pollfd, set->n_waiters);
111                         set->cur_n_waiters = set->n_waiters;
112                 }
113
114                 /* Populate cur_waiters and pollfds from ->waiters data */
115                 for (i = 0; i < set->n_waiters; i++) {
116                         set->pollfds[i].fd = set->waiters[i]->fd;
117                         set->pollfds[i].events = set->waiters[i]->events;
118                         set->pollfds[i].revents = 0;
119                         set->cur_waiters[i] = set->waiters[i];
120                 }
121
122                 set->waiters_changed = false;
123         }
124
125         rc = poll(set->pollfds, set->cur_n_waiters, -1);
126
127         if (rc <= 0)
128                 return rc;
129
130         for (i = 0; i < set->cur_n_waiters; i++) {
131                 if (set->pollfds[i].revents) {
132                         rc = set->cur_waiters[i]->callback(
133                                         set->cur_waiters[i]->arg);
134
135                         if (rc)
136                                 waiter_remove(set->cur_waiters[i]);
137                 }
138         }
139
140         return 0;
141 }