Add A-star module
[ccan] / ccan / a_star / a_star.c
1 /*
2         Copyright (C) 2016 Stephen M. Cameron
3         Author: Stephen M. Cameron
4
5         This file is part of Spacenerds In Space.
6
7         Spacenerds in Space is free software; you can redistribute it and/or modify
8         it under the terms of the GNU General Public License as published by
9         the Free Software Foundation; either version 2 of the License, or
10         (at your option) any later version.
11
12         Spacenerds in Space is distributed in the hope that it will be useful,
13         but WITHOUT ANY WARRANTY; without even the implied warranty of
14         MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15         GNU General Public License for more details.
16
17         You should have received a copy of the GNU General Public License
18         along with Spacenerds in Space; if not, write to the Free Software
19         Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
20 */
21 #include "a_star.h"
22
23 #include <stdio.h>
24 #include <string.h>
25 #include <stdlib.h>
26 #include <assert.h>
27
28 struct nodeset {
29         int nmembers;
30         int maxmembers;
31         __extension__ void *node[0];
32 };
33
34 static int nodeset_empty(struct nodeset *n)
35 {
36         return (n->nmembers == 0);
37 }
38
39 static void nodeset_add_node(struct nodeset *n, void *node)
40 {
41         int i;
42
43         for (i = 0; i < n->nmembers; i++) {
44                 if (n->node[i] == node)
45                         return;
46         }
47         assert(n->nmembers < n->maxmembers);
48         n->node[n->nmembers] = node;
49         n->nmembers++;
50 }
51
52 static void nodeset_remove_node(struct nodeset *n, void *node)
53 {
54         int i;
55
56         for (i = 0; i < n->nmembers; i++) {
57                 if (n->node[i] != node)
58                         continue;
59                 if (i == n->nmembers - 1) {
60                         n->node[i] = NULL;
61                         n->nmembers--;
62                         return;
63                 }
64                 n->node[i] = n->node[n->nmembers - 1];
65                 n->nmembers--;
66                 n->node[n->nmembers] = NULL;
67                 return;
68         }
69 }
70
71 static int nodeset_contains_node(struct nodeset *n, void *node)
72 {
73         int i;
74
75         for (i = 0; i < n->nmembers; i++)
76                 if (n->node[i] == node)
77                         return 1;
78         return 0;
79 }
80
81 static struct nodeset *nodeset_new(int maxnodes)
82 {
83         struct nodeset *n;
84
85         n = malloc(sizeof(*n) + maxnodes * sizeof(void *));
86         memset(n, 0, sizeof(*n) + maxnodes * sizeof(void *));
87         n->maxmembers = maxnodes;
88         return n;
89 }
90
91 struct node_pair {
92         void *from, *to;
93 };
94
95 struct node_map {
96         int nelements;
97         __extension__ struct node_pair p[0];
98 };
99
100 struct score_entry {
101         void *node;
102         float score;
103 };
104
105 struct score_map {
106         int nelements;
107         __extension__ struct score_entry s[0];
108 };
109
110 static float score_map_get_score(struct score_map *m, void *node)
111 {
112         int i;
113
114         for (i = 0; i < m->nelements; i++)
115                 if (m->s[i].node == node)
116                         return m->s[i].score;
117         assert(0);
118 }
119
120 static void *lowest_score(struct nodeset *candidates, struct score_map *s)
121 {
122
123         int i;
124         float score, lowest_score;
125         void *lowest = NULL;
126
127         for (i = 0; i < candidates->nmembers; i++) {
128                 score = score_map_get_score(s, candidates->node[i]);
129                 if (lowest != NULL && score > lowest_score)
130                         continue;
131                 lowest = candidates->node[i];
132                 lowest_score = score;
133         }
134         return lowest;
135 }
136
137 static struct score_map *score_map_new(int maxnodes)
138 {
139         struct score_map *s;
140
141         s = malloc(sizeof(*s) + sizeof(struct score_entry) * maxnodes);
142         memset(s, 0, sizeof(*s) + sizeof(struct score_entry) * maxnodes);
143         s->nelements = maxnodes;
144         return s;
145 }
146
147 static void score_map_add_score(struct score_map *s, void *node, float score)
148 {
149         int i;
150
151         for (i = 0; i < s->nelements; i++) {
152                 if (s->s[i].node != node)
153                         continue;
154                 s->s[i].score = score;
155                 return;
156         }
157         for (i = 0; i < s->nelements; i++) {
158                 if (s->s[i].node != NULL)
159                         continue;
160                 s->s[i].node = node;
161                 s->s[i].score = score;
162                 return;
163         }
164         assert(0);
165 }
166
167 static struct node_map *node_map_new(int maxnodes)
168 {
169         struct node_map *n;
170
171         n = malloc(sizeof(*n) + sizeof(struct node_pair) * maxnodes);
172         memset(n, 0, sizeof(*n) + sizeof(struct node_pair) * maxnodes);
173         n->nelements = maxnodes;
174         return n;
175 }
176
177 static void node_map_set_from(struct node_map *n, void *to, void *from)
178 {
179         int i;
180
181         for (i = 0; i < n->nelements; i++) {
182                 if (n->p[i].to != to)
183                         continue;
184                 n->p[i].from = from;
185                 return;
186         }
187         /* didn't find it, pick a NULL entry */
188         for (i = 0; i < n->nelements; i++) {
189                 if (n->p[i].to != NULL)
190                         continue;
191                 n->p[i].to = to;
192                 n->p[i].from = from;
193                 return;
194         }
195         assert(0); /* should never get here */
196 }
197
198 static void *node_map_get_from(struct node_map *n, void *to)
199 {
200         int i;
201
202         for (i = 0; i < n->nelements; i++)
203                 if (n->p[i].to == to)
204                         return n->p[i].from;
205         return NULL;
206 }
207
208 static void reconstruct_path(struct node_map *came_from, void *current, void ***path, int *nodecount, int maxnodes)
209 {
210         int i;
211         void **p = malloc(sizeof(*p) * maxnodes);
212         memset(p, 0, sizeof(*p) * maxnodes);
213
214         for (i = 0; i < came_from->nelements; i++)
215                 if (came_from->p[i].to == NULL)
216                         break;
217         p[0] = current;
218         i = 1;
219         while ((current = node_map_get_from(came_from, current))) {
220                 p[i] = current;
221                 i++;
222         }
223         *nodecount = i;
224         *path = p;
225 }
226
227 struct a_star_path *a_star(void *context, void *start, void *goal,
228                                 int maxnodes,
229                                 a_star_node_cost_fn distance,
230                                 a_star_node_cost_fn cost_estimate,
231                                 a_star_neighbor_iterator_fn nth_neighbor)
232 {
233         struct nodeset *openset, *closedset;
234         struct node_map *came_from;
235         struct score_map *gscore, *fscore;
236         void *neighbor, *current;
237         float tentative_gscore;
238         int i, n;
239         void **answer = NULL;
240         int answer_count = 0;
241         struct a_star_path *return_value;
242
243         closedset = nodeset_new(maxnodes);
244         openset = nodeset_new(maxnodes);
245         came_from = node_map_new(maxnodes);
246         gscore = score_map_new(maxnodes);
247         fscore = score_map_new(maxnodes);
248
249         nodeset_add_node(openset, start);
250         score_map_add_score(gscore, start, 0.0);
251         score_map_add_score(fscore, start, cost_estimate(context, start, goal));
252
253         while (!nodeset_empty(openset)) {
254                 current = lowest_score(openset, fscore);
255                 if (current == goal) {
256                         reconstruct_path(came_from, current, &answer, &answer_count, maxnodes);
257                         break;
258                 }
259                 nodeset_remove_node(openset, current);
260                 nodeset_add_node(closedset, current);
261                 n = 0;
262                 while ((neighbor = nth_neighbor(context, current, n))) {
263                         n++;
264                         if (nodeset_contains_node(closedset, neighbor))
265                                 continue;
266                         tentative_gscore = score_map_get_score(gscore, current) + distance(context, current, neighbor);
267                         if (!nodeset_contains_node(openset, neighbor))
268                                 nodeset_add_node(openset, neighbor);
269                         else if (tentative_gscore >= score_map_get_score(gscore, neighbor))
270                                 continue;
271                         node_map_set_from(came_from, neighbor, current);
272                         score_map_add_score(gscore, neighbor, tentative_gscore);
273                         score_map_add_score(fscore, neighbor,
274                                         score_map_get_score(gscore, neighbor) +
275                                                 cost_estimate(context, neighbor, goal));
276                 }
277         }
278         free(closedset);
279         free(openset);
280         free(came_from);
281         free(gscore);
282         free(fscore);
283         if (answer_count == 0) {
284                 return_value = NULL;
285         } else {
286                 return_value = malloc(sizeof(*return_value) + sizeof(return_value->path[0]) * answer_count);
287                 return_value->node_count = answer_count;
288                 for (i = 0; i < answer_count; i++) {
289                         return_value->path[answer_count - i - 1] = answer[i];
290                 }
291         }
292         free(answer);
293         return return_value;
294 }