RDMA++
Loading...
Searching...
No Matches
task.h
1#pragma once
2
3#include <cassert>
4#include <coroutine>
5#include <exception>
6#include <future>
7#include <utility>
8
9namespace rdmapp {
10
11template <class T> class value_returner {
12public:
13 std::promise<T> promise_;
14 void return_value(T &&value) { promise_.set_value(std::forward<T>(value)); }
15};
16
17template <> class value_returner<void> {
18public:
19 std::promise<void> promise_;
20 void return_void() { promise_.set_value(); }
21};
22
23template <class T, class CoroutineHandle>
24struct promise_base : public value_returner<T> {
25 std::suspend_never initial_suspend() { return {}; }
26 auto final_suspend() noexcept {
27 struct awaiter {
28 std::coroutine_handle<> release_detached_;
29 bool await_ready() noexcept { return false; }
30 std::coroutine_handle<>
31 await_suspend(CoroutineHandle suspended) noexcept {
32 if (suspended.promise().continuation_) {
33 return suspended.promise().continuation_;
34 } else {
35 if (release_detached_) {
36 release_detached_.destroy();
37 }
38 return std::noop_coroutine();
39 }
40 }
41 void await_resume() noexcept {}
42 };
43 return awaiter{release_detached_};
44 }
45
46 std::coroutine_handle<> continuation_;
47 std::coroutine_handle<> release_detached_;
48};
49
50template <class T> struct task {
52 : public promise_base<T, std::coroutine_handle<promise_type>> {
53 task<T> get_return_object() {
54 return std::coroutine_handle<promise_type>::from_promise(*this);
55 }
56 void unhandled_exception() {
57 this->promise_.set_exception(std::current_exception());
58 }
59 promise_type() : future_(this->promise_.get_future()) {}
60 std::future<T> &get_future() { return future_; }
61 void set_detached_task(std::coroutine_handle<promise_type> h) {
62 this->release_detached_ = h;
63 }
64 std::future<T> future_;
65 };
66
67 struct task_awaiter {
68 std::coroutine_handle<promise_type> h_;
69 task_awaiter(std::coroutine_handle<promise_type> h) : h_(h) {}
70 bool await_ready() { return h_.done(); }
71 auto await_suspend(std::coroutine_handle<> suspended) {
72 h_.promise().continuation_ = suspended;
73 }
74 auto await_resume() { return h_.promise().future_.get(); }
75 };
76
77 using coroutine_handle_type = std::coroutine_handle<promise_type>;
78
79 auto operator co_await() const { return task_awaiter(h_); }
80
81 ~task() {
82 if (!detached_) {
83 if (!h_.done()) {
84 h_.promise().set_detached_task(h_);
85 get_future().wait();
86 } else {
87 h_.destroy();
88 }
89 }
90 }
91 task(task &&other)
92 : h_(std::exchange(other.h_, nullptr)),
93 detached_(std::exchange(other.detached_, true)) {}
94 task(coroutine_handle_type h) : h_(h), detached_(false) {}
95 coroutine_handle_type h_;
96 bool detached_;
97 operator coroutine_handle_type() const { return h_; }
98 std::future<T> &get_future() const { return h_.promise().get_future(); }
99 void detach() {
100 assert(!detached_);
101 h_.promise().set_detached_task(h_);
102 detached_ = true;
103 }
104};
105
106} // namespace rdmapp
Definition task.h:11
Definition task.h:24
Definition task.h:52
Definition task.h:67
Definition task.h:50