Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
socket_server.cpp
Go to the documentation of this file.
2#include <cerrno>
3#include <cstdint>
4#include <cstring>
5#include <fcntl.h>
6#include <span>
7#include <string>
8#include <sys/socket.h>
9#include <sys/types.h>
10#include <sys/un.h>
11#include <unistd.h>
12#include <utility>
13
14// Platform-specific event notification includes
15#ifdef __APPLE__
16#include <sys/event.h> // kqueue on macOS/BSD
17#else
18#include <sys/epoll.h> // epoll on Linux
19#endif
20
21namespace bb::ipc {
22
23SocketServer::SocketServer(std::string socket_path, int initial_max_clients)
24 : socket_path_(std::move(socket_path))
25 , initial_max_clients_(initial_max_clients)
26{
27 const size_t reserve_size = initial_max_clients > 0 ? static_cast<size_t>(initial_max_clients) : 10;
28 client_fds_.reserve(reserve_size);
29 recv_buffers_.reserve(reserve_size);
30}
31
36
41
43{
44 // Close all client connections
45 for (int fd : client_fds_) {
46 if (fd >= 0) {
47 ::close(fd);
48 }
49 }
50 client_fds_.clear();
51 fd_to_client_id_.clear();
52 num_clients_ = 0;
53
54 if (fd_ >= 0) {
55 ::close(fd_);
56 fd_ = -1;
57 }
58
59 if (listen_fd_ >= 0) {
61 listen_fd_ = -1;
62 }
63
64 // Clean up socket file
65 ::unlink(socket_path_.c_str());
66}
67
69{
70 // Look for existing free slot
71 for (size_t i = 0; i < client_fds_.size(); i++) {
72 if (client_fds_[i] == -1) {
73 return static_cast<int>(i);
74 }
75 }
76
77 // No free slot found, allocate new one at end
78 return static_cast<int>(client_fds_.size());
79}
80
81bool SocketServer::send(int client_id, const void* data, size_t len)
82{
83 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size() ||
84 client_fds_[static_cast<size_t>(client_id)] < 0) {
85 errno = EINVAL;
86 return false;
87 }
88
89 int fd = client_fds_[static_cast<size_t>(client_id)];
90
91 // Send length prefix (4 bytes)
92 auto msg_len = static_cast<uint32_t>(len);
93 ssize_t n = ::send(fd, &msg_len, sizeof(msg_len), 0);
94 if (n < 0 || static_cast<size_t>(n) != sizeof(msg_len)) {
95 return false;
96 }
97
98 // Send message data
99 n = ::send(fd, data, len, 0);
100 if (n < 0) {
101 return false;
102 }
103 const auto bytes_sent = static_cast<size_t>(n);
104 return bytes_sent == len;
105}
106
107void SocketServer::release(int client_id, size_t message_size)
108{
109 // No-op for sockets - message already consumed from kernel buffer during receive()
110 (void)client_id;
111 (void)message_size;
112}
113
115{
116 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size() ||
117 client_fds_[static_cast<size_t>(client_id)] < 0) {
118 return {};
119 }
120
121 int fd = client_fds_[static_cast<size_t>(client_id)];
122 const auto client_idx = static_cast<size_t>(client_id);
123
124 // Ensure buffers are sized for this client
125 if (client_idx >= recv_buffers_.size()) {
126 recv_buffers_.resize(client_idx + 1);
127 }
128
129 // Read length prefix (4 bytes) - must loop until all bytes received (MSG_WAITALL unreliable on macOS)
130 uint32_t msg_len = 0;
131 size_t total_read = 0;
132 while (total_read < sizeof(msg_len)) {
133 ssize_t n = ::recv(fd, reinterpret_cast<uint8_t*>(&msg_len) + total_read, sizeof(msg_len) - total_read, 0);
134 if (n < 0) {
135 if (errno == EINTR) {
136 continue; // Interrupted, retry
137 }
138 return {};
139 }
140 if (n == 0) {
141 // Client disconnected
142 disconnect_client(client_id);
143 return {};
144 }
145 total_read += static_cast<size_t>(n);
146 }
147
148 // Resize buffer if needed to fit length prefix + message
149 size_t total_size = sizeof(uint32_t) + msg_len;
150 if (recv_buffers_[client_idx].size() < total_size) {
151 recv_buffers_[client_idx].resize(total_size);
152 }
153
154 // Store length prefix in buffer
155 std::memcpy(recv_buffers_[client_idx].data(), &msg_len, sizeof(uint32_t));
156
157 // Read message data - must loop until all bytes received (MSG_WAITALL unreliable on macOS)
158 total_read = 0;
159 while (total_read < msg_len) {
160 ssize_t n =
161 ::recv(fd, recv_buffers_[client_idx].data() + sizeof(uint32_t) + total_read, msg_len - total_read, 0);
162 if (n < 0) {
163 if (errno == EINTR) {
164 continue; // Interrupted, retry
165 }
166 disconnect_client(client_id);
167 return {};
168 }
169 if (n == 0) {
170 // Client disconnected mid-message
171 disconnect_client(client_id);
172 return {};
173 }
174 total_read += static_cast<size_t>(n);
175 }
176
177 return std::span<const uint8_t>(recv_buffers_[client_idx].data() + sizeof(uint32_t), msg_len);
178}
179
180#ifdef __APPLE__
181// ============================================================================
182// macOS Implementation (kqueue, blocking sockets, simple accept)
183// ============================================================================
184
186{
187 if (listen_fd_ >= 0) {
188 return true; // Already listening
189 }
190
191 // Remove any existing socket file
192 ::unlink(socket_path_.c_str());
193
194 // Create socket
195 listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
196 if (listen_fd_ < 0) {
197 return false;
198 }
199
200 // Set non-blocking mode (required for accept-until-EAGAIN pattern)
201 int flags = fcntl(listen_fd_, F_GETFL, 0);
202 if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
204 listen_fd_ = -1;
205 return false;
206 }
207
208 // Bind to path
209 struct sockaddr_un addr;
210 std::memset(&addr, 0, sizeof(addr));
211 addr.sun_family = AF_UNIX;
212 std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1);
213
214 if (bind(listen_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
216 listen_fd_ = -1;
217 return false;
218 }
219
220 // Listen with backlog
221 int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : 10;
222 if (::listen(listen_fd_, backlog) < 0) {
224 listen_fd_ = -1;
225 ::unlink(socket_path_.c_str());
226 return false;
227 }
228
229 // Create kqueue instance
230 fd_ = kqueue();
231 if (fd_ < 0) {
233 listen_fd_ = -1;
234 ::unlink(socket_path_.c_str());
235 return false;
236 }
237
238 // Add listen socket to kqueue
239 struct kevent ev;
240 EV_SET(&ev, listen_fd_, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr);
241 if (kevent(fd_, &ev, 1, nullptr, 0, nullptr) < 0) {
242 ::close(fd_);
243 fd_ = -1;
245 listen_fd_ = -1;
246 ::unlink(socket_path_.c_str());
247 return false;
248 }
249
250 return true;
251}
252
254{
255 if (listen_fd_ < 0) {
256 errno = EINVAL;
257 return -1;
258 }
259
260 // Accept all pending connections (loop until EAGAIN)
261 // Non-blocking socket ensures this returns immediately
262 int last_client_id = -1;
263
264 while (true) {
265 int client_fd = ::accept(listen_fd_, nullptr, nullptr);
266
267 if (client_fd < 0) {
268 // Check if this is expected (no more connections) or a real error
269 if (errno == EAGAIN || errno == EWOULDBLOCK) {
270 // No more pending connections - expected, break
271 break;
272 }
273 // Real error - but if we already accepted some, return success
274 if (last_client_id >= 0) {
275 break;
276 }
277 // No connections accepted and got real error
278 return -1;
279 }
280
281 // Set client socket to BLOCKING mode (inherited non-blocking from listen socket)
282 // This avoids busy-waiting in recv() - we only recv after kqueue signals data ready
283 int flags = fcntl(client_fd, F_GETFL, 0);
284 if (flags >= 0) {
285 fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK);
286 }
287
288 // Find free slot (or allocate new one)
289 int client_id = find_free_slot();
290
291 // Store client fd
292 const auto client_id_unsigned = static_cast<size_t>(client_id);
293 if (client_id_unsigned >= client_fds_.size()) {
294 client_fds_.resize(client_id_unsigned + 1, -1);
295 }
296 client_fds_[static_cast<size_t>(client_id)] = client_fd;
297 fd_to_client_id_[client_fd] = client_id;
298 num_clients_++;
299
300 // Add client to kqueue
301 struct kevent kev;
302 EV_SET(&kev, client_fd, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr);
303 if (kevent(fd_, &kev, 1, nullptr, 0, nullptr) < 0) {
304 disconnect_client(client_id);
305 // Continue trying to accept other pending connections
306 continue;
307 }
308
309 last_client_id = client_id;
310 }
311
312 return last_client_id;
313}
314
315int SocketServer::wait_for_data(uint64_t timeout_ns)
316{
317 if (fd_ < 0) {
318 errno = EINVAL;
319 return -1;
320 }
321
322 struct kevent ev;
323 struct timespec timeout;
324 struct timespec* timeout_ptr = nullptr;
325
326 if (timeout_ns > 0) {
327 timeout.tv_sec = static_cast<time_t>(timeout_ns / 1000000000ULL);
328 timeout.tv_nsec = static_cast<long>(timeout_ns % 1000000000ULL);
329 timeout_ptr = &timeout;
330 } else if (timeout_ns == 0) {
331 timeout.tv_sec = 0;
332 timeout.tv_nsec = 0;
333 timeout_ptr = &timeout;
334 }
335
336 int n = kevent(fd_, nullptr, 0, &ev, 1, timeout_ptr);
337 if (n <= 0) {
338 return -1;
339 }
340
341 int ready_fd = static_cast<int>(ev.ident);
342
343 // Check if it's listen socket (new connection) or client data
344 if (ready_fd == listen_fd_) {
345 errno = EAGAIN; // Signal caller to call accept
346 return -1;
347 }
348
349 // Find which client
350 auto it = fd_to_client_id_.find(ready_fd);
351 if (it == fd_to_client_id_.end()) {
352 errno = ENOENT;
353 return -1;
354 }
355
356 return it->second;
357}
358
359void SocketServer::disconnect_client(int client_id)
360{
361 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size()) {
362 return;
363 }
364
365 int fd = client_fds_[static_cast<size_t>(client_id)];
366 if (fd >= 0) {
367 // For kqueue, we don't need explicit deletion - closing the fd removes it automatically
368 // But we can explicitly remove it for clarity
369 struct kevent ev;
370 EV_SET(&ev, fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr);
371 kevent(fd_, &ev, 1, nullptr, 0, nullptr);
372
373 ::close(fd);
374 fd_to_client_id_.erase(fd);
375 client_fds_[static_cast<size_t>(client_id)] = -1;
376 num_clients_--;
377 }
378}
379
380#else
381
382// ============================================================================
383// Linux Implementation (epoll, non-blocking sockets, accept-until-EAGAIN)
384// ============================================================================
385
387{
388 if (listen_fd_ >= 0) {
389 return true; // Already listening
390 }
391
392 // Remove any existing socket file
393 ::unlink(socket_path_.c_str());
394
395 // Create socket
396 listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
397 if (listen_fd_ < 0) {
398 return false;
399 }
400
401 // Set non-blocking mode (required for accept-until-EAGAIN pattern)
402 int flags = fcntl(listen_fd_, F_GETFL, 0);
403 if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
405 listen_fd_ = -1;
406 return false;
407 }
408
409 // Bind to path
410 struct sockaddr_un addr;
411 std::memset(&addr, 0, sizeof(addr));
412 addr.sun_family = AF_UNIX;
413 std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1);
414
415 if (bind(listen_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0) {
417 listen_fd_ = -1;
418 return false;
419 }
420
421 // Listen with backlog
422 int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : 10;
423 if (::listen(listen_fd_, backlog) < 0) {
425 listen_fd_ = -1;
426 ::unlink(socket_path_.c_str());
427 return false;
428 }
429
430 // Create epoll instance
431 fd_ = epoll_create1(0);
432 if (fd_ < 0) {
434 listen_fd_ = -1;
435 ::unlink(socket_path_.c_str());
436 return false;
437 }
438
439 // Add listen socket to epoll
440 struct epoll_event ev;
441 ev.events = EPOLLIN;
442 ev.data.fd = listen_fd_;
443 if (epoll_ctl(fd_, EPOLL_CTL_ADD, listen_fd_, &ev) < 0) {
444 ::close(fd_);
445 fd_ = -1;
447 listen_fd_ = -1;
448 ::unlink(socket_path_.c_str());
449 return false;
450 }
451
452 return true;
453}
454
456{
457 if (listen_fd_ < 0) {
458 errno = EINVAL;
459 return -1;
460 }
461
462 // Accept all pending connections (loop until EAGAIN)
463 // Non-blocking socket ensures this returns immediately
464 int last_client_id = -1;
465
466 while (true) {
467 int client_fd = ::accept(listen_fd_, nullptr, nullptr);
468
469 if (client_fd < 0) {
470 // Check if this is expected (no more connections) or a real error
471 if (errno == EAGAIN || errno == EWOULDBLOCK) {
472 // No more pending connections - expected, break
473 break;
474 }
475 // Real error - but if we already accepted some, return success
476 if (last_client_id >= 0) {
477 break;
478 }
479 // No connections accepted and got real error
480 return -1;
481 }
482
483 // Set client socket to BLOCKING mode (inherited non-blocking from listen socket)
484 // This avoids busy-waiting in recv() - we only recv after epoll signals data ready
485 int flags = fcntl(client_fd, F_GETFL, 0);
486 if (flags >= 0) {
487 fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK);
488 }
489
490 // Find free slot (or allocate new one)
491 int client_id = find_free_slot();
492
493 // Store client fd
494 const auto client_id_unsigned = static_cast<size_t>(client_id);
495 if (client_id_unsigned >= client_fds_.size()) {
496 client_fds_.resize(client_id_unsigned + 1, -1);
497 }
498 client_fds_[static_cast<size_t>(client_id)] = client_fd;
499 fd_to_client_id_[client_fd] = client_id;
500 num_clients_++;
501
502 // Add client to epoll
503 struct epoll_event client_ev;
504 client_ev.events = EPOLLIN;
505 client_ev.data.fd = client_fd;
506 if (epoll_ctl(fd_, EPOLL_CTL_ADD, client_fd, &client_ev) < 0) {
507 disconnect_client(client_id);
508 // Continue trying to accept other pending connections
509 continue;
510 }
511
512 last_client_id = client_id;
513 }
514
515 return last_client_id;
516}
517
518int SocketServer::wait_for_data(uint64_t timeout_ns)
519{
520 if (fd_ < 0) {
521 errno = EINVAL;
522 return -1;
523 }
524
525 struct epoll_event ev;
526 int timeout_ms = timeout_ns > 0 ? static_cast<int>(timeout_ns / 1000000) : -1;
527 int n = epoll_wait(fd_, &ev, 1, timeout_ms);
528 if (n <= 0) {
529 return -1;
530 }
531
532 // Check if it's listen socket (new connection) or client data
533 if (ev.data.fd == listen_fd_) {
534 errno = EAGAIN; // Signal caller to call accept
535 return -1;
536 }
537
538 // Find which client
539 auto it = fd_to_client_id_.find(ev.data.fd);
540 if (it == fd_to_client_id_.end()) {
541 errno = ENOENT;
542 return -1;
543 }
544
545 return it->second;
546}
547
549{
550 if (client_id < 0 || static_cast<size_t>(client_id) >= client_fds_.size()) {
551 return;
552 }
553
554 int fd = client_fds_[static_cast<size_t>(client_id)];
555 if (fd >= 0) {
556 epoll_ctl(fd_, EPOLL_CTL_DEL, fd, nullptr);
557 ::close(fd);
558 fd_to_client_id_.erase(fd);
559 client_fds_[static_cast<size_t>(client_id)] = -1;
560 num_clients_--;
561 }
562}
563
564#endif
565
566} // namespace bb::ipc
int accept() override
Accept a new client connection (optional for some transports)
SocketServer(std::string socket_path, int initial_max_clients)
std::vector< std::vector< uint8_t > > recv_buffers_
bool listen() override
Start listening for client connections.
std::vector< int > client_fds_
void close() override
Close the server and all client connections.
std::unordered_map< int, int > fd_to_client_id_
void disconnect_client(int client_id)
bool send(int client_id, const void *data, size_t len) override
Send a message to a specific client.
std::span< const uint8_t > receive(int client_id) override
Receive next message from a specific client.
void release(int client_id, size_t message_size) override
Release/consume the previously received message.
int wait_for_data(uint64_t timeout_ns) override
Wait for data from any connected client.
const std::vector< MemoryValue > data
STL namespace.
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
uint8_t len