Skip to content

Commit

Permalink
Add support for udp multicast
Browse files Browse the repository at this point in the history
Also fix processing of return value of send(2) and sendto(2) as macOS 15
on CI seems to return 0 for a multicast packet

Signed-off-by: Paul Guyot <[email protected]>
  • Loading branch information
pguyot committed Jan 23, 2025
1 parent 459ba1b commit 805fb76
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `externalterm_to_term_with_roots` to efficiently preserve roots when allocating memory for external terms.
- Added `erl_epmd` client implementation to epmd using `socket` module
- Added support for socket asynchronous API for `recv`, `recvfrom` and `accept`.
- Added support for UDP multicast with socket API.

### Changed

Expand Down
8 changes: 6 additions & 2 deletions libs/estdlib/src/socket.erl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@
}.
-type in_addr() :: {0..255, 0..255, 0..255, 0..255}.
-type port_number() :: 0..65535.
-type ip_mreq() :: #{multiaddr := in_addr(), interface := in_addr()}.

-type socket_option() ::
{socket, reuseaddr | linger | type}
| {otp, recvbuf}.
| {otp, recvbuf}
| {ip, add_membership}.

-export_type([
socket/0,
Expand All @@ -80,7 +82,8 @@
sockaddr_in/0,
in_addr/0,
port_number/0,
socket_option/0
socket_option/0,
ip_mreq/0
]).

-define(DEFAULT_BACKLOG, 4).
Expand Down Expand Up @@ -647,6 +650,7 @@ getopt(_Socket, _SocketOption) ->
%% <tr><td>`{socket, reuseaddr}'</td><td>`boolean()'</td></tr>
%% <tr><td>`{socket, linger}'</td><td>`#{onoff => boolean(), linger => non_neg_integer()}'</td></tr>
%% <tr><td>`{otp, recvbuf}'</td><td>`non_neg_integer()'</td></tr>
%% <tr><td>`{ip, add_membership}'</td><td>`ip_mreq()'</td></tr>
%% </table>
%%
%% Example:
Expand Down
77 changes: 63 additions & 14 deletions src/libAtomVM/otp_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ static const char *const port_atom = ATOM_STR("\x4", "port");
static const char *const rcvbuf_atom = ATOM_STR("\x6", "rcvbuf");
static const char *const reuseaddr_atom = ATOM_STR("\x9", "reuseaddr");
static const char *const type_atom = ATOM_STR("\x4", "type");
static const char *const add_membership_atom = ATOM_STR("\xE", "add_membership");

#define CLOSED_FD 0

Expand Down Expand Up @@ -221,12 +222,14 @@ enum otp_socket_setopt_level
{
OtpSocketInvalidSetoptLevel = 0,
OtpSocketSetoptLevelSocket,
OtpSocketSetoptLevelOTP
OtpSocketSetoptLevelOTP,
OtpSocketSetoptLevelIP
};

static const AtomStringIntPair otp_socket_setopt_level_table[] = {
{ ATOM_STR("\x6", "socket"), OtpSocketSetoptLevelSocket },
{ ATOM_STR("\x3", "otp"), OtpSocketSetoptLevelOTP },
{ ATOM_STR("\x2", "ip"), OtpSocketSetoptLevelIP },
SELECT_INT_DEFAULT(OtpSocketInvalidSetoptLevel)
};

Expand Down Expand Up @@ -604,7 +607,7 @@ static term nif_socket_open(Context *ctx, int argc, term argv[])
}

term socket_term = term_alloc_tuple(2, &ctx->heap);
uint64_t ref_ticks = globalcontext_get_ref_ticks(ctx->global);
uint64_t ref_ticks = globalcontext_get_ref_ticks(global);
rsrc_obj->socket_ref_ticks = ref_ticks;
term ref = term_from_ref_ticks(ref_ticks, &ctx->heap);
term_put_tuple_element(socket_term, 0, obj);
Expand Down Expand Up @@ -1261,8 +1264,8 @@ static term nif_socket_setopt(Context *ctx, int argc, term argv[])
return OK_ATOM;
#endif
} else if (globalcontext_is_term_equal_to_atom_string(global, opt, linger_atom)) {
term onoff = interop_kv_get_value(value, onoff_atom, ctx->global);
term linger = interop_kv_get_value(value, linger_atom, ctx->global);
term onoff = interop_kv_get_value(value, onoff_atom, global);
term linger = interop_kv_get_value(value, linger_atom, global);
VALIDATE_VALUE(linger, term_is_integer);

#if OTP_SOCKET_BSD
Expand Down Expand Up @@ -1323,6 +1326,52 @@ static term nif_socket_setopt(Context *ctx, int argc, term argv[])
}
}

#if OTP_SOCKET_BSD
case OtpSocketSetoptLevelIP: {
term opt = term_get_tuple_element(level_tuple, 1);
if (globalcontext_is_term_equal_to_atom_string(global, opt, add_membership_atom)) {
// socket:setopt(Socket, {ip, add_membership_atom}, Req :: ip_mreq())

if (UNLIKELY(!term_is_map(value))) {
TRACE("socket:setopt: ip add_membership_atom value must be a map");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

term multiaddr = interop_kv_get_value(value, ATOM_STR("\x9", "multiaddr"), global);
if (UNLIKELY(!term_is_tuple(multiaddr) || term_get_tuple_arity(multiaddr) != 4)) {
TRACE("socket:setopt: ip add_membership_atom multiaddr value must be an IP addr");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

term interface = interop_kv_get_value(value, ATOM_STR("\x9", "interface"), global);
if (UNLIKELY(!term_is_tuple(interface) || term_get_tuple_arity(interface) != 4)) {
TRACE("socket:setopt: ip add_membership_atom interface value must be an IP addr");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

struct ip_mreq option_value;
option_value.imr_multiaddr.s_addr = htonl(inet_addr4_to_uint32(multiaddr));
option_value.imr_interface.s_addr = htonl(inet_addr4_to_uint32(interface));

int res = setsockopt(rsrc_obj->fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &option_value, sizeof(option_value));

SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
if (UNLIKELY(res != 0)) {
return make_errno_tuple(ctx);
} else {
return OK_ATOM;
}
} else {
TRACE("socket:setopt: Unsupported ip option");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_option_atom), ctx);
}
}
#endif

default: {
TRACE("socket:setopt: Unsupported level");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
Expand Down Expand Up @@ -1538,9 +1587,9 @@ static term nif_socket_bind(Context *ctx, int argc, term argv[])
ip_addr_set_loopback(false, &ip_addr);
#endif
} else if (term_is_map(sockaddr)) {
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), ctx->global);
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), global);
port_u16 = term_to_int(port);
term addr = interop_kv_get_value(sockaddr, addr_atom, ctx->global);
term addr = interop_kv_get_value(sockaddr, addr_atom, global);
if (globalcontext_is_term_equal_to_atom_string(global, addr, any_atom)) {
#if OTP_SOCKET_BSD
serveraddr.sin_addr.s_addr = htonl(INADDR_ANY);
Expand Down Expand Up @@ -1764,7 +1813,7 @@ static term nif_socket_accept(Context *ctx, int argc, term argv[])
}

term socket_term = term_alloc_tuple(2, &ctx->heap);
uint64_t ref_ticks = globalcontext_get_ref_ticks(ctx->global);
uint64_t ref_ticks = globalcontext_get_ref_ticks(global);
conn_rsrc_obj->socket_ref_ticks = ref_ticks;
term ref = term_from_ref_ticks(ref_ticks, &ctx->heap);
term_put_tuple_element(socket_term, 0, new_resource);
Expand Down Expand Up @@ -1808,7 +1857,7 @@ static term nif_socket_accept(Context *ctx, int argc, term argv[])
// return EAGAIN
LWIP_END();
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(posix_errno_to_term(EAGAIN, ctx->global), ctx);
return make_error_tuple(posix_errno_to_term(EAGAIN, global), ctx);
}
LWIP_END();
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
Expand Down Expand Up @@ -2285,13 +2334,13 @@ static ssize_t do_socket_send(struct SocketResource *rsrc_obj, const uint8_t *bu
} else {
sent_data = send(rsrc_obj->fd, buf, len, 0);
}
if (sent_data == 0) {
return SocketClosed;
}
if (sent_data < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return SocketWouldBlock;
}
if (errno == EBADF || errno == ECONNRESET) {
return SocketClosed;
}
return SocketOtherError;
}
return sent_data;
Expand Down Expand Up @@ -2430,7 +2479,7 @@ static term nif_socket_send_internal(Context *ctx, int argc, term argv[], bool i
RAISE_ERROR(OUT_OF_MEMORY_ATOM);
}

term rest = term_maybe_create_sub_binary(data, sent_data, rest_len, &ctx->heap, ctx->global);
term rest = term_maybe_create_sub_binary(data, sent_data, rest_len, &ctx->heap, global);
return port_create_tuple2(ctx, OK_ATOM, rest);

} else if (sent_data == 0) {
Expand Down Expand Up @@ -2535,8 +2584,8 @@ static term nif_socket_connect(Context *ctx, int argc, term argv[])

SMP_RWLOCK_RDLOCK(rsrc_obj->socket_lock);
term sockaddr = argv[1];
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), ctx->global);
term addr = interop_kv_get_value(sockaddr, addr_atom, ctx->global);
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), global);
term addr = interop_kv_get_value(sockaddr, addr_atom, global);
if (term_is_invalid_term(addr)) {
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
RAISE_ERROR(BADARG_ATOM);
Expand Down
34 changes: 34 additions & 0 deletions tests/libs/estdlib/test_udp_socket.erl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ test() ->
ok = test_timeout(),
ok = test_nowait(),
ok = test_setopt_getopt(),
ok = test_multicast(),
ok.

-define(PACKET_SIZE, 7).
Expand Down Expand Up @@ -290,3 +291,36 @@ test_setopt_getopt() ->
{error, closed} = socket:getopt(Socket, {socket, type}),
{error, closed} = socket:setopt(Socket, {socket, reuseaddr}, true),
ok.

test_multicast() ->
{ok, SocketRecv} = socket:open(inet, dgram, udp),
SocketRecvAddr = #{
family => inet, addr => {0, 0, 0, 0}, port => 8042
},
ok = socket:setopt(SocketRecv, {socket, reuseaddr}, true),
ok = socket:bind(SocketRecv, SocketRecvAddr),
ok = socket:setopt(SocketRecv, {ip, add_membership}, #{
multiaddr => {224, 0, 0, 42}, interface => {0, 0, 0, 0}
}),

{ok, SocketSender} = socket:open(inet, dgram, udp),
ok = socket:sendto(SocketSender, <<"42">>, #{
family => inet, addr => {224, 0, 0, 42}, port => 8042
}),
{ok, SocketSenderAddr} = socket:sockname(SocketSender),
SocketSenderAddrPort = maps:get(port, SocketSenderAddr),

{ok, {SocketSenderAddrFrom, <<"42">>}} = socket:recvfrom(SocketRecv, 2, 500),
{error, timeout} = socket:recvfrom(SocketRecv, 2, 0),
SocketSenderAddrPort = maps:get(port, SocketSenderAddrFrom),

ok = socket:sendto(SocketRecv, <<"43">>, #{
family => inet, addr => {224, 0, 0, 42}, port => 8042
}),
{ok, {SocketRecvAddrFrom, <<"43">>}} = socket:recvfrom(SocketRecv, 2, 500),
{error, timeout} = socket:recvfrom(SocketRecv, 2, 0),
8042 = maps:get(port, SocketRecvAddrFrom),

ok = socket:close(SocketRecv),
ok = socket:close(SocketSender),
ok.

0 comments on commit 805fb76

Please sign in to comment.