]> git.ozlabs.org Git - ccan/blobdiff - ccan/generator/generator.h
generator: Rewrite to use coroutine module
[ccan] / ccan / generator / generator.h
index 6b2bd92fe526411a56d254969d4d8b3e6f35ad55..7122f554dd8987cde3ed83d7a03db9a380a36bec 100644 (file)
@@ -3,10 +3,6 @@
 #define CCAN_GENERATOR_H
 #include "config.h"
 
-#if !HAVE_UCONTEXT
-#error Generators require working ucontext.h functions
-#endif
-
 #if !HAVE_TYPEOF
 #error Generators require typeof
 #endif
 #include <assert.h>
 #include <stddef.h>
 #include <stdbool.h>
-#include <ucontext.h>
 
-#include <ccan/ptrint/ptrint.h>
-#include <ccan/build_assert/build_assert.h>
 #include <ccan/cppmagic/cppmagic.h>
+#include <ccan/compiler/compiler.h>
+#include <ccan/coroutine/coroutine.h>
+
+#if !COROUTINE_AVAILABLE
+#error Generators require coroutines
+#endif
 
 /*
  * Internals - included just for the use of inlines and macros
  */
 
 struct generator_ {
-       ucontext_t gen;
-       ucontext_t caller;
+       struct coroutine_state gen;
+       struct coroutine_state caller;
        bool complete;
        void *base;
 };
@@ -40,20 +39,18 @@ static inline struct generator_ *generator_state_(const void *ret)
        return (struct generator_ *)ret - 1;
 }
 
+static inline void *generator_argp_(const void *ret)
+{
+       return generator_state_(ret)->base;
+}
+
 struct generator_incomplete_;
 
 #define generator_rtype_(gen_)                 \
        typeof((*(gen_))((struct generator_incomplete_ *)NULL))
 
-#if HAVE_POINTER_SAFE_MAKECONTEXT
-#define generator_wrapper_args_()      void *ret
-#else
-#define generator_wrapper_args_()      int lo, int hi
-#endif
-typedef void generator_wrapper_(generator_wrapper_args_());
-
-void *generator_new_(generator_wrapper_ *fn, size_t retsize);
-void generator_free_(void *ret);
+void *generator_new_(void (*fn)(void *), size_t retsize);
+void generator_free_(void *ret, size_t retsize);
 
 /*
  * API
@@ -77,8 +74,8 @@ void generator_free_(void *ret);
  * Example:
  *     generator_declare(count_to_3, int);
  */
-#define generator_declare(name_, rtype_)       \
-       generator_t(rtype_) name_(void)
+#define generator_declare(name_, rtype_, ...)  \
+       generator_t(rtype_) name_(generator_parms_outer_(__VA_ARGS__))
 
 /**
  * generator_def - define a generator function
@@ -97,32 +94,56 @@ void generator_free_(void *ret);
  *             generator_yield(3);
  *     }
  */
-#define generator_def_(name_, rtype_, storage_)                                \
-       static void name_##_generator_(rtype_ *ret_);                   \
-       static void name_##_generator__(generator_wrapper_args_())      \
+#define generator_parm_(t_, n_)                        t_ n_
+#define generator_parms_(...)                                          \
+       CPPMAGIC_2MAP(generator_parm_, __VA_ARGS__)
+#define generator_parms_inner_(...)                                    \
+       CPPMAGIC_IFELSE(CPPMAGIC_NONEMPTY(__VA_ARGS__))                 \
+               (, generator_parms_(__VA_ARGS__))()
+#define generator_parms_outer_(...)                                    \
+       CPPMAGIC_IFELSE(CPPMAGIC_NONEMPTY(__VA_ARGS__)) \
+               (generator_parms_(__VA_ARGS__))(void)
+#define generator_argfield_(t_, n_)            t_ n_;
+#define generator_argstruct_(...)                                      \
+       struct {                                                        \
+               CPPMAGIC_JOIN(, CPPMAGIC_2MAP(generator_argfield_,      \
+                                             __VA_ARGS__))             \
+       }
+#define generator_arg_unpack_(t_, n_)          args->n_
+#define generator_args_unpack_(...)            \
+       CPPMAGIC_IFELSE(CPPMAGIC_NONEMPTY(__VA_ARGS__))                 \
+               (, CPPMAGIC_2MAP(generator_arg_unpack_, __VA_ARGS__))()
+#define generator_arg_pack_(t_, n_)            args->n_ = n_
+#define generator_args_pack_(...)                                      \
+       CPPMAGIC_JOIN(;, CPPMAGIC_2MAP(generator_arg_pack_, __VA_ARGS__))
+#define generator_def_(name_, rtype_, storage_, ...)                   \
+       static void name_##_generator_(rtype_ *ret_                     \
+                                      generator_parms_inner_(__VA_ARGS__)); \
+       static void name_##_generator__(void *ret)                      \
        {                                                               \
                struct generator_ *gen;                                 \
-               CPPMAGIC_IFELSE(HAVE_POINTER_SAFE_MAKECONTEXT)          \
-                       ()                                              \
-                       (ptrdiff_t hilo = ((ptrdiff_t)hi << (8*sizeof(int))) \
-                               + (ptrdiff_t)lo;                        \
-                       rtype_ *ret = (rtype_ *)int2ptr(hilo);          \
-                       BUILD_ASSERT(sizeof(struct generator_ *)        \
-                                    <= 2*sizeof(int));)                \
+               UNNEEDED generator_argstruct_(__VA_ARGS__) *args;       \
                gen = generator_state_(ret);                            \
-               name_##_generator_(ret);                                \
+               args = generator_argp_(ret);                            \
+               name_##_generator_(ret generator_args_unpack_(__VA_ARGS__)); \
                gen->complete = true;                                   \
-               setcontext(&gen->caller);                               \
+               coroutine_jump(&gen->caller);                           \
                assert(0);                                              \
        }                                                               \
-       storage_ generator_t(rtype_) name_(void)                        \
+       storage_ generator_t(rtype_)                                    \
+       name_(generator_parms_outer_(__VA_ARGS__))                      \
        {                                                               \
-               return generator_new_(name_##_generator__,              \
-                                     sizeof(rtype_));                  \
+               generator_t(rtype_) gen = generator_new_(name_##_generator__, \
+                                                        sizeof(rtype_)); \
+               UNNEEDED generator_argstruct_(__VA_ARGS__) *args =      \
+                       generator_argp_(gen);                           \
+               generator_args_pack_(__VA_ARGS__);                      \
+               return gen;                                             \
        }                                                               \
-       static void name_##_generator_(rtype_ *ret_)
-#define generator_def(name_, rtype_)           \
-       generator_def_(name_, rtype_, )
+       static void name_##_generator_(rtype_ *ret_                     \
+                                      generator_parms_inner_(__VA_ARGS__))
+#define generator_def(name_, rtype_, ...)      \
+       generator_def_(name_, rtype_, , __VA_ARGS__)
 
 /**
  * generator_def_static - define a private / local generator function
@@ -132,8 +153,8 @@ void generator_free_(void *ret);
  * As generator_def, but the resulting generator function will be
  * local to this module.
  */
-#define generator_def_static(name_, rtype_)    \
-       generator_def_(name_, rtype_, static)
+#define generator_def_static(name_, rtype_, ...)       \
+       generator_def_(name_, rtype_, static, __VA_ARGS__)
 
 /**
  * generator_yield - yield (return) a value from a generator
@@ -147,10 +168,8 @@ void generator_free_(void *ret);
 #define generator_yield(val_)                                          \
        do {                                                            \
                struct generator_ *gen_ = generator_state_(ret_);       \
-               int rc;                                                 \
                *(ret_) = (val_);                                       \
-               rc = swapcontext(&gen_->gen, &gen_->caller);            \
-               assert(rc == 0);                                        \
+               coroutine_switch(&gen_->gen, &gen_->caller);            \
        } while (0)
 
 /**
@@ -165,13 +184,11 @@ void generator_free_(void *ret);
 static inline void *generator_next_(void *ret_)
 {
        struct generator_ *gen = generator_state_(ret_);
-       int rc;
 
        if (gen->complete)
                return NULL;
 
-       rc = swapcontext(&gen->caller, &gen->gen);
-       assert(rc == 0);
+       coroutine_switch(&gen->caller, &gen->gen);
 
        return gen->complete ? NULL : ret_;
 }
@@ -197,6 +214,7 @@ static inline void *generator_next_(void *ret_)
        })
 
 #define generator_free(gen_)                                   \
-       generator_free_((generator_rtype_(gen_) *)(gen_))
+       generator_free_((generator_rtype_(gen_) *)(gen_),       \
+                       sizeof(generator_rtype_(gen_)))
 
 #endif /* CCAN_GENERATOR_H */