UCX++
task.h
1#pragma once
2
3#include <cassert>
4#include <coroutine>
5#include <exception>
6#include <functional>
7#include <future>
8#include <utility>
9
10#include "ucxpp/detail/debug.h"
11
12namespace ucxpp {
13
14template <class T> class value_returner {
15public:
16 std::promise<T> promise_;
17 void return_value(T &&value) { promise_.set_value(std::forward<T>(value)); }
18};
19
20template <> class value_returner<void> {
21public:
22 std::promise<void> promise_;
23 void return_void() { promise_.set_value(); }
24};
25
26template <class T, class CoroutineHandle>
27struct promise_base : public value_returner<T> {
28 std::suspend_never initial_suspend() { return {}; }
29 auto final_suspend() noexcept {
30 struct awaiter {
31 std::coroutine_handle<> release_detached_;
32 bool await_ready() noexcept { return false; }
33 std::coroutine_handle<>
34 await_suspend(CoroutineHandle suspended) noexcept {
35 if (suspended.promise().continuation_) {
36 return suspended.promise().continuation_;
37 } else {
38 if (release_detached_) {
39 release_detached_.destroy();
40 }
41 return std::noop_coroutine();
42 }
43 }
44 void await_resume() noexcept {}
45 };
46 return awaiter{release_detached_};
47 }
48
49 std::coroutine_handle<> continuation_;
50 std::coroutine_handle<> release_detached_;
51};
52
53template <class T> struct task {
55 : public promise_base<T, std::coroutine_handle<promise_type>> {
56 task<T> get_return_object() {
57 return std::coroutine_handle<promise_type>::from_promise(*this);
58 }
59 void unhandled_exception() {
60 this->promise_.set_exception(std::current_exception());
61 }
62 promise_type() : future_(this->promise_.get_future()) {}
63 std::future<T> &get_future() { return future_; }
64 void set_detached_task(std::coroutine_handle<promise_type> h) {
65 this->release_detached_ = h;
66 }
67 std::future<T> future_;
68 };
69
70 struct task_awaiter {
71 std::coroutine_handle<promise_type> h_;
72 task_awaiter(std::coroutine_handle<promise_type> h) : h_(h) {}
73 bool await_ready() { return h_.done(); }
74 auto await_suspend(std::coroutine_handle<> suspended) {
75 h_.promise().continuation_ = suspended;
76 }
77 auto await_resume() { return h_.promise().future_.get(); }
78 };
79
80 using coroutine_handle_type = std::coroutine_handle<promise_type>;
81
82 auto operator co_await() const { return task_awaiter(h_); }
83
84 ~task() {
85 if (!detached_) {
86 if (!h_.done()) {
87 h_.promise().set_detached_task(h_);
88 get_future().get();
89 } else {
90 h_.destroy();
91 }
92 }
93 }
94 task(task &&other)
95 : h_(std::exchange(other.h_, nullptr)),
96 detached_(std::exchange(other.detached_, true)) {}
97 task(coroutine_handle_type h) : h_(h), detached_(false) {}
98 coroutine_handle_type h_;
99 bool detached_;
100 operator coroutine_handle_type() const { return h_; }
101 std::future<T> &get_future() const { return h_.promise().get_future(); }
102 void detach() {
103 assert(!detached_);
104 h_.promise().set_detached_task(h_);
105 detached_ = true;
106 }
107};
108
109} // namespace ucxpp
Definition: task.h:14
Definition: task.h:27
Definition: task.h:55
Definition: task.h:70
Definition: task.h:53