]> git.ozlabs.org Git - ccan/blob - ccan/tal/talloc/talloc.c
ab96ff78a7e878b1c630df847ef27316e501998d
[ccan] / ccan / tal / talloc / talloc.c
1 /* Licensed under LGPL - see LICENSE file for details */
2 #include <ccan/tal/talloc/talloc.h>
3 #include <errno.h>
4 #include <assert.h>
5
6 static void (*errorfn)(const char *msg) = (void *)abort;
7
8 static void COLD call_error(const char *msg)
9 {
10         errorfn(msg);
11 }
12
13 static void *error_on_null(void *p, const char *msg)
14 {
15         if (!p)
16                 call_error(msg);
17         return p;
18 }
19
20 void *tal_talloc_(const tal_t *ctx, size_t bytes, bool clear,
21                   const char *label)
22 {
23         void *ret;
24
25         if (clear)
26                 ret = _talloc_zero(ctx, bytes, label);
27         else
28                 ret = talloc_named_const(ctx, bytes, label);
29
30         return error_on_null(ret, "allocation failure");
31 }
32
33 void *tal_talloc_arr_(const tal_t *ctx, size_t bytes, size_t count, bool clear,
34                       const char *label)
35 {
36         void *ret;
37
38         if (clear)
39                 ret = _talloc_zero_array(ctx, bytes, count, label);
40         else
41                 ret = _talloc_array(ctx, bytes, count, label);
42
43         return error_on_null(ret, "array allocation failure");
44 }
45
46 void *tal_talloc_free_(const tal_t *ctx)
47 {
48         int saved_errno = errno;
49         talloc_free((void *)ctx);
50         errno = saved_errno;
51         return NULL;
52 }
53
54 bool tal_talloc_set_name_(tal_t *ctx, const char *name, bool literal)
55 {
56         if (!literal) {
57                 name = talloc_strdup(ctx, name);
58                 if (!name) {
59                         call_error("set_name allocation failure");
60                         return false;
61                 }
62         }
63         talloc_set_name_const(ctx, name);
64         return true;
65 }
66
67 const char *tal_talloc_name_(const tal_t *ctx)
68 {
69         const char *p = talloc_get_name(ctx);
70         if (p && unlikely(strcmp(p, "UNNAMED") == 0))
71                 p = NULL;
72         return p;
73 }
74
75 static bool adjust_size(size_t *size, size_t count)
76 {
77         /* Multiplication wrap */
78         if (count && unlikely(*size * count / *size != count))
79                 goto overflow;
80
81         *size *= count;
82
83         /* Make sure we don't wrap adding header. */
84         if (*size + 1024 < 1024)
85                 goto overflow;
86         return true;
87 overflow:
88         call_error("allocation size overflow");
89         return false;
90 }
91
92 void *tal_talloc_dup_(const tal_t *ctx, const void *p, size_t size,
93                       size_t n, size_t extra, const char *label)
94 {
95         void *ret;
96         size_t nbytes = size;
97
98         if (!adjust_size(&nbytes, n)) {
99                 if (taken(p))
100                         tal_free(p);
101                 return NULL;
102         }
103
104         /* Beware addition overflow! */
105         if (n + extra < n) {
106                 call_error("dup size overflow");
107                 if (taken(p))
108                         tal_free(p);
109                 return NULL;
110         }
111
112         if (taken(p)) {
113                 if (unlikely(!p))
114                         return NULL;
115                 if (unlikely(!tal_talloc_resize_((void **)&p, size, n + extra)))
116                         return tal_free(p);
117                 if (unlikely(!tal_steal(ctx, p)))
118                         return tal_free(p);
119                 return (void *)p;
120         }
121
122         ret = tal_talloc_arr_(ctx, size, n + extra, false, label);
123         if (ret)
124                 memcpy(ret, p, nbytes);
125         return ret;
126 }
127
128 bool tal_talloc_resize_(tal_t **ctxp, size_t size, size_t count)
129 {
130         tal_t *newp;
131
132         if (unlikely(count == 0)) {
133                 /* Don't free it! */
134                 newp = talloc_size(talloc_parent(*ctxp), 0);
135                 if (!newp) {
136                         call_error("Resize failure");
137                         return false;
138                 }
139                 talloc_free(*ctxp);
140                 *ctxp = newp;
141                 return true;
142         }
143
144         /* count is unsigned, not size_t, so check for overflow here! */
145         if ((unsigned)count != count) {
146                 call_error("Resize overflos");
147                 return false;
148         }
149
150         newp = _talloc_realloc_array(NULL, *ctxp, size, count, NULL);
151         if (!newp) {
152                 call_error("Resize failure");
153                 return false;
154         }
155         *ctxp = newp;
156         return true;
157 }
158
159 bool tal_talloc_expand_(tal_t **ctxp, const void *src, size_t size, size_t count)
160 {
161         bool ret = false;
162         size_t old_count = talloc_get_size(*ctxp) / size;
163
164         /* Check for additive overflow */
165         if (old_count + count < count) {
166                 call_error("dup size overflow");
167                 goto out;
168         }
169
170         /* Don't point src inside thing we're expanding! */
171         assert(src < *ctxp
172                || (char *)src >= (char *)(*ctxp) + (size * old_count));
173
174         if (!tal_talloc_resize_(ctxp, size, old_count + count))
175                 goto out;
176
177         memcpy((char *)*ctxp + size * old_count, src, count * size);
178         ret = true;
179
180 out:
181         if (taken(src))
182                 tal_free(src);
183         return ret;
184 }
185
186 /* Sucky inline hash table implementation, to avoid deps. */
187 #define HTABLE_BITS 10
188 struct destructor {
189         struct destructor *next;
190         const tal_t *ctx;
191         void (*destroy)(void *me);
192 };
193 static struct destructor *destr_hash[1 << HTABLE_BITS];
194
195 static unsigned int hash_ptr(const void *p)
196 {
197         unsigned long h = (unsigned long)p / sizeof(void *);
198
199         return (h ^ (h >> HTABLE_BITS)) & ((1 << HTABLE_BITS) - 1);
200 }
201
202 static int tal_talloc_destroy(const tal_t *ctx)
203 {
204         struct destructor **d = &destr_hash[hash_ptr(ctx)];
205         while (*d) {
206                 if ((*d)->ctx == ctx) {
207                         struct destructor *this = *d;
208                         this->destroy((void *)ctx);
209                         *d = this->next;
210                         talloc_free(this);
211                 }
212         }
213         return 0;
214 }
215
216 bool tal_talloc_add_destructor_(const tal_t *ctx, void (*destroy)(void *me))
217 {
218         struct destructor *d = talloc(ctx, struct destructor);
219         if (!d)
220                 return false;
221
222         d->next = destr_hash[hash_ptr(ctx)];
223         d->ctx = ctx;
224         d->destroy = destroy;
225         destr_hash[hash_ptr(ctx)] = d;
226         talloc_set_destructor(ctx, tal_talloc_destroy);
227         return true;
228 }
229
230 bool tal_talloc_del_destructor_(const tal_t *ctx, void (*destroy)(void *me))
231 {
232         struct destructor **d = &destr_hash[hash_ptr(ctx)];
233
234         while (*d) {
235                 if ((*d)->ctx == ctx && (*d)->destroy == destroy) {
236                         struct destructor *this = *d;
237                         *d = this->next;
238                         talloc_free(this);
239                         return true;
240                 }
241                 d = &(*d)->next;
242         }
243         return false;
244 }
245
246 void tal_talloc_set_backend_(void *(*alloc_fn)(size_t size),
247                              void *(*resize_fn)(void *, size_t size),
248                              void (*free_fn)(void *),
249                              void (*error_fn)(const char *msg))
250 {
251         assert(!alloc_fn);
252         assert(!resize_fn);
253         assert(!free_fn);
254         errorfn = error_fn;
255         talloc_set_abort_fn(error_fn);
256 }
257
258 bool tal_talloc_check_(const tal_t *ctx, const char *errorstr)
259 {
260         /* We can't really check, but this iterates (and may abort). */
261         return !ctx || talloc_total_blocks(ctx) >= 1;
262 }