]> git.ozlabs.org Git - ccan/blob - ccan/tal/talloc/talloc.c
fbe9b38474e923610bdcfefb39afee113d5ebcdc
[ccan] / ccan / tal / talloc / talloc.c
1 /* Licensed under LGPL - see LICENSE file for details */
2 #include <ccan/tal/talloc/talloc.h>
3 #include <ccan/take/take.h>
4 #include <errno.h>
5 #include <assert.h>
6
7 static void (*errorfn)(const char *msg) = (void *)abort;
8
9 static void COLD call_error(const char *msg)
10 {
11         errorfn(msg);
12 }
13
14 static void *error_on_null(void *p, const char *msg)
15 {
16         if (!p)
17                 call_error(msg);
18         return p;
19 }
20
21 void *tal_talloc_(const tal_t *ctx, size_t bytes, bool clear,
22                   const char *label)
23 {
24         void *ret;
25
26         if (clear)
27                 ret = _talloc_zero(ctx, bytes, label);
28         else
29                 ret = talloc_named_const(ctx, bytes, label);
30
31         return error_on_null(ret, "allocation failure");
32 }
33
34 void *tal_talloc_arr_(const tal_t *ctx, size_t bytes, size_t count, bool clear,
35                       const char *label)
36 {
37         void *ret;
38
39         if (clear)
40                 ret = _talloc_zero_array(ctx, bytes, count, label);
41         else
42                 ret = _talloc_array(ctx, bytes, count, label);
43
44         return error_on_null(ret, "array allocation failure");
45 }
46
47 void *tal_talloc_free_(const tal_t *ctx)
48 {
49         int saved_errno = errno;
50         talloc_free((void *)ctx);
51         errno = saved_errno;
52         return NULL;
53 }
54
55 bool tal_talloc_set_name_(tal_t *ctx, const char *name, bool literal)
56 {
57         if (!literal) {
58                 name = talloc_strdup(ctx, name);
59                 if (!name) {
60                         call_error("set_name allocation failure");
61                         return false;
62                 }
63         }
64         talloc_set_name_const(ctx, name);
65         return true;
66 }
67
68 const char *tal_talloc_name_(const tal_t *ctx)
69 {
70         const char *p = talloc_get_name(ctx);
71         if (p && unlikely(strcmp(p, "UNNAMED") == 0))
72                 p = NULL;
73         return p;
74 }
75
76 static bool adjust_size(size_t *size, size_t count)
77 {
78         /* Multiplication wrap */
79         if (count && unlikely(*size * count / *size != count))
80                 goto overflow;
81
82         *size *= count;
83
84         /* Make sure we don't wrap adding header. */
85         if (*size + 1024 < 1024)
86                 goto overflow;
87         return true;
88 overflow:
89         call_error("allocation size overflow");
90         return false;
91 }
92
93 void *tal_talloc_dup_(const tal_t *ctx, const void *p, size_t size,
94                       size_t n, size_t extra, const char *label)
95 {
96         void *ret;
97         size_t nbytes = size;
98
99         if (!adjust_size(&nbytes, n)) {
100                 if (taken(p))
101                         tal_free(p);
102                 return NULL;
103         }
104
105         /* Beware addition overflow! */
106         if (n + extra < n) {
107                 call_error("dup size overflow");
108                 if (taken(p))
109                         tal_free(p);
110                 return NULL;
111         }
112
113         if (taken(p)) {
114                 if (unlikely(!p))
115                         return NULL;
116                 if (unlikely(!tal_talloc_resize_((void **)&p, size, n + extra)))
117                         return tal_free(p);
118                 if (unlikely(!tal_steal(ctx, p)))
119                         return tal_free(p);
120                 return (void *)p;
121         }
122
123         ret = tal_talloc_arr_(ctx, size, n + extra, false, label);
124         if (ret)
125                 memcpy(ret, p, nbytes);
126         return ret;
127 }
128
129 bool tal_talloc_resize_(tal_t **ctxp, size_t size, size_t count)
130 {
131         tal_t *newp;
132
133         if (unlikely(count == 0)) {
134                 /* Don't free it! */
135                 newp = talloc_size(talloc_parent(*ctxp), 0);
136                 if (!newp) {
137                         call_error("Resize failure");
138                         return false;
139                 }
140                 talloc_free(*ctxp);
141                 *ctxp = newp;
142                 return true;
143         }
144
145         /* count is unsigned, not size_t, so check for overflow here! */
146         if ((unsigned)count != count) {
147                 call_error("Resize overflos");
148                 return false;
149         }
150
151         newp = _talloc_realloc_array(NULL, *ctxp, size, count, NULL);
152         if (!newp) {
153                 call_error("Resize failure");
154                 return false;
155         }
156         *ctxp = newp;
157         return true;
158 }
159
160 bool tal_talloc_expand_(tal_t **ctxp, const void *src, size_t size, size_t count)
161 {
162         bool ret = false;
163         size_t old_count = talloc_get_size(*ctxp) / size;
164
165         /* Check for additive overflow */
166         if (old_count + count < count) {
167                 call_error("dup size overflow");
168                 goto out;
169         }
170
171         /* Don't point src inside thing we're expanding! */
172         assert(src < *ctxp
173                || (char *)src >= (char *)(*ctxp) + (size * old_count));
174
175         if (!tal_talloc_resize_(ctxp, size, old_count + count))
176                 goto out;
177
178         memcpy((char *)*ctxp + size * old_count, src, count * size);
179         ret = true;
180
181 out:
182         if (taken(src))
183                 tal_free(src);
184         return ret;
185 }
186
187 /* Sucky inline hash table implementation, to avoid deps. */
188 #define HTABLE_BITS 10
189 struct destructor {
190         struct destructor *next;
191         const tal_t *ctx;
192         void (*destroy)(void *me);
193 };
194 static struct destructor *destr_hash[1 << HTABLE_BITS];
195
196 static unsigned int hash_ptr(const void *p)
197 {
198         unsigned long h = (unsigned long)p / sizeof(void *);
199
200         return (h ^ (h >> HTABLE_BITS)) & ((1 << HTABLE_BITS) - 1);
201 }
202
203 static int tal_talloc_destroy(const tal_t *ctx)
204 {
205         struct destructor **d = &destr_hash[hash_ptr(ctx)];
206         while (*d) {
207                 if ((*d)->ctx == ctx) {
208                         struct destructor *this = *d;
209                         this->destroy((void *)ctx);
210                         *d = this->next;
211                         talloc_free(this);
212                 }
213         }
214         return 0;
215 }
216
217 bool tal_talloc_add_destructor_(const tal_t *ctx, void (*destroy)(void *me))
218 {
219         struct destructor *d = talloc(ctx, struct destructor);
220         if (!d)
221                 return false;
222
223         d->next = destr_hash[hash_ptr(ctx)];
224         d->ctx = ctx;
225         d->destroy = destroy;
226         destr_hash[hash_ptr(ctx)] = d;
227         talloc_set_destructor(ctx, tal_talloc_destroy);
228         return true;
229 }
230
231 bool tal_talloc_del_destructor_(const tal_t *ctx, void (*destroy)(void *me))
232 {
233         struct destructor **d = &destr_hash[hash_ptr(ctx)];
234
235         while (*d) {
236                 if ((*d)->ctx == ctx && (*d)->destroy == destroy) {
237                         struct destructor *this = *d;
238                         *d = this->next;
239                         talloc_free(this);
240                         return true;
241                 }
242                 d = &(*d)->next;
243         }
244         return false;
245 }
246
247 void tal_talloc_set_backend_(void *(*alloc_fn)(size_t size),
248                              void *(*resize_fn)(void *, size_t size),
249                              void (*free_fn)(void *),
250                              void (*error_fn)(const char *msg))
251 {
252         assert(!alloc_fn);
253         assert(!resize_fn);
254         assert(!free_fn);
255         errorfn = error_fn;
256         talloc_set_abort_fn(error_fn);
257 }
258
259 bool tal_talloc_check_(const tal_t *ctx, const char *errorstr)
260 {
261         /* We can't really check, but this iterates (and may abort). */
262         return !ctx || talloc_total_blocks(ctx) >= 1;
263 }