UCX++
awaitable.h
1#pragma once
2
3#include <atomic>
4#include <cassert>
5#include <coroutine>
6#include <cstddef>
7#include <cstdint>
8#include <functional>
9#include <memory>
10#include <thread>
11#include <type_traits>
12#include <ucs/type/status.h>
13
14#include <ucp/api/ucp.h>
15
16#include "ucxpp/error.h"
17
18#include "ucxpp/detail/debug.h"
19
20namespace ucxpp {
21
23protected:
24 std::coroutine_handle<> h_;
25 ucs_status_t status_;
26 base_awaitable() : h_(nullptr), status_(UCS_OK) {}
27 bool check_request_ready(ucs_status_ptr_t request) {
28 if (UCS_PTR_IS_PTR(request)) [[unlikely]] {
29 status_ = UCS_INPROGRESS;
30 return false;
31 } else if (UCS_PTR_IS_ERR(request)) [[unlikely]] {
32 status_ = UCS_PTR_STATUS(request);
33 UCXPP_LOG_ERROR("%s", ::ucs_status_string(status_));
34 return true;
35 }
36
37 status_ = UCS_OK;
38 return true;
39 }
40};
41
42/* Common awaitable class for send-like callbacks */
43template <class Derived> class send_awaitable : public base_awaitable {
44public:
45 static void send_cb(void *request, ucs_status_t status, void *user_data) {
46 auto self = reinterpret_cast<Derived *>(user_data);
47 self->status_ = status;
48 ::ucp_request_free(request);
49 self->h_.resume();
50 }
51
52 ucp_request_param_t build_param() {
53 ucp_request_param_t send_param;
54 send_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
55 UCP_OP_ATTR_FIELD_USER_DATA |
56 UCP_OP_ATTR_FLAG_MULTI_SEND;
57 send_param.cb.send = &send_cb;
58 send_param.user_data = this;
59 return send_param;
60 }
61
62 bool await_suspend(std::coroutine_handle<> h) {
63 h_ = h;
64 return status_ == UCS_INPROGRESS;
65 }
66
67 void await_resume() const { check_ucs_status(status_, "operation failed"); }
68};
69
70class stream_send_awaitable : public send_awaitable<stream_send_awaitable> {
71 ucp_ep_h ep_;
72 void const *buffer_;
73 size_t length_;
74 friend class send_awaitable;
75
76public:
77 stream_send_awaitable(ucp_ep_h ep, void const *buffer, size_t length)
78 : ep_(ep), buffer_(buffer), length_(length) {}
79
80 bool await_ready() noexcept {
81 auto send_param = build_param();
82 auto request = ::ucp_stream_send_nbx(ep_, buffer_, length_, &send_param);
83 return check_request_ready(request);
84 }
85};
86
87class tag_send_awaitable : public send_awaitable<tag_send_awaitable> {
88 ucp_ep_h ep_;
89 ucp_tag_t tag_;
90 void const *buffer_;
91 size_t length_;
92 friend class send_awaitable;
93
94public:
95 tag_send_awaitable(ucp_ep_h ep, void const *buffer, size_t length,
96 ucp_tag_t tag)
97 : ep_(ep), tag_(tag), buffer_(buffer), length_(length) {}
98
99 bool await_ready() noexcept {
100 auto send_param = build_param();
101 auto request = ::ucp_tag_send_nbx(ep_, buffer_, length_, tag_, &send_param);
102 return check_request_ready(request);
103 }
104};
105
106class rma_put_awaitable : public send_awaitable<rma_put_awaitable> {
107 ucp_ep_h ep_;
108 void const *buffer_;
109 size_t length_;
110 uint64_t remote_addr_;
111 ucp_rkey_h rkey_;
112 friend class send_awaitable;
113
114public:
115 rma_put_awaitable(ucp_ep_h ep, void const *buffer, size_t length,
116 uint64_t remote_addr, ucp_rkey_h rkey)
117 : ep_(ep), buffer_(buffer), length_(length), remote_addr_(remote_addr),
118 rkey_(rkey) {}
119
120 bool await_ready() noexcept {
121 auto send_param = build_param();
122 auto request =
123 ::ucp_put_nbx(ep_, buffer_, length_, remote_addr_, rkey_, &send_param);
124 return check_request_ready(request);
125 }
126};
127
128class rma_get_awaitable : public send_awaitable<rma_get_awaitable> {
129 ucp_ep_h ep_;
130 void *buffer_;
131 size_t length_;
132 uint64_t remote_addr_;
133 ucp_rkey_h rkey_;
134 friend class send_awaitable;
135
136public:
137 rma_get_awaitable(ucp_ep_h ep, void *buffer, size_t length,
138 uint64_t remote_addr, ucp_rkey_h rkey)
139 : ep_(ep), buffer_(buffer), length_(length), remote_addr_(remote_addr),
140 rkey_(rkey) {}
141
142 bool await_ready() noexcept {
143 auto send_param = build_param();
144 auto request =
145 ::ucp_get_nbx(ep_, buffer_, length_, remote_addr_, rkey_, &send_param);
146 return check_request_ready(request);
147 }
148};
149
150template <class T>
151class rma_atomic_awaitable : public send_awaitable<rma_atomic_awaitable<T>> {
152 static_assert(sizeof(T) == 4 || sizeof(T) == 8, "Only 4-byte and 8-byte "
153 "integers are supported");
154 ucp_ep_h ep_;
155 ucp_atomic_op_t const op_;
156 void const *buffer_;
157 uint64_t remote_addr_;
158 ucp_rkey_h rkey_;
159 void *reply_buffer_;
160 friend class send_awaitable<rma_atomic_awaitable<T>>;
161
162public:
163 rma_atomic_awaitable(ucp_ep_h ep, ucp_atomic_op_t const op,
164 void const *buffer, uint64_t remote_addr,
165 ucp_rkey_h rkey, void *reply_buffer = nullptr)
166 : ep_(ep), op_(op), buffer_(buffer), remote_addr_(remote_addr),
167 rkey_(rkey), reply_buffer_(reply_buffer) {
168 if (op == UCP_ATOMIC_OP_SWAP || op == UCP_ATOMIC_OP_CSWAP) {
169 assert(reply_buffer != nullptr);
170 }
171 }
172
173 bool await_ready() noexcept {
174 auto send_param = this->build_param();
175 send_param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
176 send_param.datatype = ucp_dt_make_contig(sizeof(T));
177 if (reply_buffer_ != nullptr) {
178 send_param.op_attr_mask |= UCP_OP_ATTR_FIELD_REPLY_BUFFER;
179 send_param.reply_buffer = reply_buffer_;
180 }
181 auto request = ::ucp_atomic_op_nbx(ep_, op_, buffer_, 1, remote_addr_,
182 rkey_, &send_param);
183 return this->check_request_ready(request);
184 }
185};
186
187/* These awaitables are not on "hot" path so they can hold a shared_ptr */
188class endpoint;
189class ep_flush_awaitable : public send_awaitable<ep_flush_awaitable> {
190 std::shared_ptr<endpoint const> endpoint_;
191 friend class send_awaitable;
192
193public:
194 ep_flush_awaitable(std::shared_ptr<endpoint const> endpoint);
195 bool await_ready() noexcept;
196};
197
198class ep_close_awaitable : public send_awaitable<ep_close_awaitable> {
199 std::shared_ptr<endpoint> endpoint_;
200 friend class send_awaitable;
201
202public:
203 ep_close_awaitable(std::shared_ptr<endpoint> endpoint);
204 bool await_ready() noexcept;
205};
206
207class worker;
208class worker_flush_awaitable : public send_awaitable<worker_flush_awaitable> {
209 std::shared_ptr<worker> worker_;
210 friend class send_awaitable;
211
212public:
213 worker_flush_awaitable(std::shared_ptr<worker> worker);
214
215 bool await_ready() noexcept;
216};
217
218/* Common awaitable class for stream-recv-like callbacks */
220private:
221 ucp_ep_h ep_;
222 ucp_worker_h worker_;
223 size_t received_;
224 void *buffer_;
225 size_t length_;
226 void *request_;
227
228public:
229 stream_recv_awaitable(ucp_ep_h ep, void *buffer, size_t length)
230 : ep_(ep), received_(0), buffer_(buffer), length_(length) {}
231
232 stream_recv_awaitable(ucp_ep_h ep, ucp_worker_h worker, void *buffer,
233 size_t length, stream_recv_awaitable *&cancel)
234 : ep_(ep), worker_(worker), received_(0), buffer_(buffer),
235 length_(length) {
236 cancel = this;
237 }
238
239 static void stream_recv_cb(void *request, ucs_status_t status,
240 size_t received, void *user_data) {
241 auto self = reinterpret_cast<stream_recv_awaitable *>(user_data);
242 self->status_ = status;
243 self->received_ = received;
244 ::ucp_request_free(request);
245 self->h_.resume();
246 }
247
248 bool await_ready() noexcept {
249 ucp_request_param_t stream_recv_param;
250 stream_recv_param.op_attr_mask =
251 UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA;
252 stream_recv_param.cb.recv_stream = &stream_recv_cb;
253 stream_recv_param.user_data = this;
254 auto request = ::ucp_stream_recv_nbx(ep_, buffer_, length_, &received_,
255 &stream_recv_param);
256
257 if (!check_request_ready(request)) {
258 request_ = request;
259 return false;
260 }
261
262 return true;
263 }
264
265 bool await_suspend(std::coroutine_handle<> h) {
266 h_ = h;
267 return status_ == UCS_INPROGRESS;
268 }
269
270 size_t await_resume() const {
271 check_ucs_status(status_, "operation failed");
272 return received_;
273 }
274
275 void cancel() {
276 if (request_ != nullptr) {
277 ::ucp_request_cancel(worker_, request_);
278 }
279 }
280};
281
282/* Common awaitable class for tag-recv-like callbacks */
284
285private:
286 ucp_worker_h worker_;
287 void *request_;
288 void *buffer_;
289 size_t length_;
290 ucp_tag_t tag_;
291 ucp_tag_t tag_mask_;
292 ucp_tag_recv_info_t recv_info_;
293
294public:
295 tag_recv_awaitable(ucp_worker_h worker, void *buffer, size_t length,
296 ucp_tag_t tag, ucp_tag_t tag_mask)
297 : worker_(worker), request_(nullptr), buffer_(buffer), length_(length),
298 tag_(tag), tag_mask_(tag_mask) {}
299
300 tag_recv_awaitable(ucp_worker_h worker, void *buffer, size_t length,
301 ucp_tag_t tag, ucp_tag_t tag_mask,
302 tag_recv_awaitable *&cancel)
303 : tag_recv_awaitable(worker, buffer, length, tag, tag_mask) {
304 cancel = this;
305 }
306
307 static void tag_recv_cb(void *request, ucs_status_t status,
308 ucp_tag_recv_info_t const *tag_info,
309 void *user_data) {
310 auto self = reinterpret_cast<tag_recv_awaitable *>(user_data);
311 self->status_ = status;
312 self->recv_info_.length = tag_info->length;
313 self->recv_info_.sender_tag = tag_info->sender_tag;
314 ::ucp_request_free(request);
315 self->h_.resume();
316 }
317
318 bool await_ready() noexcept {
319 ucp_request_param_t tag_recv_param;
320 tag_recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
321 UCP_OP_ATTR_FIELD_USER_DATA |
322 UCP_OP_ATTR_FIELD_RECV_INFO;
323 tag_recv_param.cb.recv = &tag_recv_cb;
324 tag_recv_param.user_data = this;
325 tag_recv_param.recv_info.tag_info = &recv_info_;
326
327 auto request = ::ucp_tag_recv_nbx(worker_, buffer_, length_, tag_,
328 tag_mask_, &tag_recv_param);
329
330 if (!check_request_ready(request)) {
331 request_ = request;
332 return false;
333 }
334 return true;
335 }
336
337 bool await_suspend(std::coroutine_handle<> h) {
338 h_ = h;
339 return status_ == UCS_INPROGRESS;
340 }
341
342 std::pair<size_t, ucp_tag_t> await_resume() {
343 request_ = nullptr;
344 check_ucs_status(status_, "error in ucp_tag_recv_nbx");
345 return std::make_pair(recv_info_.length, recv_info_.sender_tag);
346 }
347
348 void cancel() {
349 if (request_) {
350 ::ucp_request_cancel(worker_, request_);
351 request_ = nullptr;
352 }
353 }
354};
355
356} // namespace ucxpp
Definition: awaitable.h:22
Abstraction for a UCX endpoint.
Definition: endpoint.h:27
Definition: awaitable.h:198
Definition: awaitable.h:189
Definition: awaitable.h:151
Definition: awaitable.h:128
Definition: awaitable.h:106
Definition: awaitable.h:43
Definition: awaitable.h:219
Definition: awaitable.h:70
Definition: awaitable.h:283
Definition: awaitable.h:87
Definition: awaitable.h:208
Abstraction for a UCX worker.
Definition: worker.h:20