12#include <ucs/type/status.h>
14#include <ucp/api/ucp.h>
16#include "ucxpp/error.h"
18#include "ucxpp/detail/debug.h"
24 std::coroutine_handle<> h_;
27 bool check_request_ready(ucs_status_ptr_t request) {
28 if (UCS_PTR_IS_PTR(request)) [[unlikely]] {
29 status_ = UCS_INPROGRESS;
31 }
else if (UCS_PTR_IS_ERR(request)) [[unlikely]] {
32 status_ = UCS_PTR_STATUS(request);
33 UCXPP_LOG_ERROR(
"%s", ::ucs_status_string(status_));
45 static void send_cb(
void *request, ucs_status_t status,
void *user_data) {
46 auto self =
reinterpret_cast<Derived *
>(user_data);
47 self->status_ = status;
48 ::ucp_request_free(request);
52 ucp_request_param_t build_param() {
53 ucp_request_param_t send_param;
54 send_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
55 UCP_OP_ATTR_FIELD_USER_DATA |
56 UCP_OP_ATTR_FLAG_MULTI_SEND;
57 send_param.cb.send = &send_cb;
58 send_param.user_data =
this;
62 bool await_suspend(std::coroutine_handle<> h) {
64 return status_ == UCS_INPROGRESS;
67 void await_resume()
const { check_ucs_status(status_,
"operation failed"); }
78 : ep_(ep), buffer_(buffer), length_(length) {}
80 bool await_ready()
noexcept {
81 auto send_param = build_param();
82 auto request = ::ucp_stream_send_nbx(ep_, buffer_, length_, &send_param);
83 return check_request_ready(request);
97 : ep_(ep), tag_(tag), buffer_(buffer), length_(length) {}
99 bool await_ready()
noexcept {
100 auto send_param = build_param();
101 auto request = ::ucp_tag_send_nbx(ep_, buffer_, length_, tag_, &send_param);
102 return check_request_ready(request);
110 uint64_t remote_addr_;
116 uint64_t remote_addr, ucp_rkey_h rkey)
117 : ep_(ep), buffer_(buffer), length_(length), remote_addr_(remote_addr),
120 bool await_ready()
noexcept {
121 auto send_param = build_param();
123 ::ucp_put_nbx(ep_, buffer_, length_, remote_addr_, rkey_, &send_param);
124 return check_request_ready(request);
132 uint64_t remote_addr_;
138 uint64_t remote_addr, ucp_rkey_h rkey)
139 : ep_(ep), buffer_(buffer), length_(length), remote_addr_(remote_addr),
142 bool await_ready()
noexcept {
143 auto send_param = build_param();
145 ::ucp_get_nbx(ep_, buffer_, length_, remote_addr_, rkey_, &send_param);
146 return check_request_ready(request);
152 static_assert(
sizeof(T) == 4 ||
sizeof(T) == 8,
"Only 4-byte and 8-byte "
153 "integers are supported");
155 ucp_atomic_op_t
const op_;
157 uint64_t remote_addr_;
164 void const *buffer, uint64_t remote_addr,
165 ucp_rkey_h rkey,
void *reply_buffer =
nullptr)
166 : ep_(ep), op_(op), buffer_(buffer), remote_addr_(remote_addr),
167 rkey_(rkey), reply_buffer_(reply_buffer) {
168 if (op == UCP_ATOMIC_OP_SWAP || op == UCP_ATOMIC_OP_CSWAP) {
169 assert(reply_buffer !=
nullptr);
173 bool await_ready()
noexcept {
174 auto send_param = this->build_param();
175 send_param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE;
176 send_param.datatype = ucp_dt_make_contig(
sizeof(T));
177 if (reply_buffer_ !=
nullptr) {
178 send_param.op_attr_mask |= UCP_OP_ATTR_FIELD_REPLY_BUFFER;
179 send_param.reply_buffer = reply_buffer_;
181 auto request = ::ucp_atomic_op_nbx(ep_, op_, buffer_, 1, remote_addr_,
183 return this->check_request_ready(request);
190 std::shared_ptr<endpoint const> endpoint_;
195 bool await_ready()
noexcept;
199 std::shared_ptr<endpoint> endpoint_;
204 bool await_ready()
noexcept;
209 std::shared_ptr<worker> worker_;
215 bool await_ready()
noexcept;
222 ucp_worker_h worker_;
230 : ep_(ep), received_(0), buffer_(buffer), length_(length) {}
234 : ep_(ep), worker_(
worker), received_(0), buffer_(buffer),
239 static void stream_recv_cb(
void *request, ucs_status_t status,
240 size_t received,
void *user_data) {
242 self->status_ = status;
243 self->received_ = received;
244 ::ucp_request_free(request);
248 bool await_ready()
noexcept {
249 ucp_request_param_t stream_recv_param;
250 stream_recv_param.op_attr_mask =
251 UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA;
252 stream_recv_param.cb.recv_stream = &stream_recv_cb;
253 stream_recv_param.user_data =
this;
254 auto request = ::ucp_stream_recv_nbx(ep_, buffer_, length_, &received_,
257 if (!check_request_ready(request)) {
265 bool await_suspend(std::coroutine_handle<> h) {
267 return status_ == UCS_INPROGRESS;
270 size_t await_resume()
const {
271 check_ucs_status(status_,
"operation failed");
276 if (request_ !=
nullptr) {
277 ::ucp_request_cancel(worker_, request_);
286 ucp_worker_h worker_;
292 ucp_tag_recv_info_t recv_info_;
296 ucp_tag_t tag, ucp_tag_t tag_mask)
297 : worker_(
worker), request_(
nullptr), buffer_(buffer), length_(length),
298 tag_(tag), tag_mask_(tag_mask) {}
301 ucp_tag_t tag, ucp_tag_t tag_mask,
307 static void tag_recv_cb(
void *request, ucs_status_t status,
308 ucp_tag_recv_info_t
const *tag_info,
311 self->status_ = status;
312 self->recv_info_.length = tag_info->length;
313 self->recv_info_.sender_tag = tag_info->sender_tag;
314 ::ucp_request_free(request);
318 bool await_ready()
noexcept {
319 ucp_request_param_t tag_recv_param;
320 tag_recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
321 UCP_OP_ATTR_FIELD_USER_DATA |
322 UCP_OP_ATTR_FIELD_RECV_INFO;
323 tag_recv_param.cb.recv = &tag_recv_cb;
324 tag_recv_param.user_data =
this;
325 tag_recv_param.recv_info.tag_info = &recv_info_;
327 auto request = ::ucp_tag_recv_nbx(worker_, buffer_, length_, tag_,
328 tag_mask_, &tag_recv_param);
330 if (!check_request_ready(request)) {
337 bool await_suspend(std::coroutine_handle<> h) {
339 return status_ == UCS_INPROGRESS;
342 std::pair<size_t, ucp_tag_t> await_resume() {
344 check_ucs_status(status_,
"error in ucp_tag_recv_nbx");
345 return std::make_pair(recv_info_.length, recv_info_.sender_tag);
350 ::ucp_request_cancel(worker_, request_);
Definition: awaitable.h:22
Abstraction for a UCX endpoint.
Definition: endpoint.h:27
Definition: awaitable.h:198
Definition: awaitable.h:189
Definition: awaitable.h:151
Definition: awaitable.h:128
Definition: awaitable.h:106
Definition: awaitable.h:43
Definition: awaitable.h:219
Definition: awaitable.h:70
Definition: awaitable.h:283
Definition: awaitable.h:87
Definition: awaitable.h:208
Abstraction for a UCX worker.
Definition: worker.h:20