// SPDX-License-Identifier: GPL-2.0
#include <errno.h>
#include <error.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <limits.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <linux/rtnetlink.h>
#include <linux/genetlink.h>
#include "linux/mptcp.h"
#ifndef MPTCP_PM_NAME
#define MPTCP_PM_NAME "mptcp_pm"
#endif
#ifndef MPTCP_PM_EVENTS
#define MPTCP_PM_EVENTS "mptcp_pm_events"
#endif
#ifndef IPPROTO_MPTCP
#define IPPROTO_MPTCP 262
#endif
static void syntax(char *argv[])
{
fprintf(stderr, "%s add|ann|rem|csf|dsf|get|set|del|flush|dump|events|listen|accept [<args>]\n", argv[0]);
fprintf(stderr, "\tadd [flags signal|subflow|backup|fullmesh] [id <nr>] [dev <name>] <ip>\n");
fprintf(stderr, "\tann <local-ip> id <local-id> token <token> [port <local-port>] [dev <name>]\n");
fprintf(stderr, "\trem id <local-id> token <token>\n");
fprintf(stderr, "\tcsf lip <local-ip> lid <local-id> rip <remote-ip> rport <remote-port> token <token>\n");
fprintf(stderr, "\tdsf lip <local-ip> lport <local-port> rip <remote-ip> rport <remote-port> token <token>\n");
fprintf(stderr, "\tdel <id> [<ip>]\n");
fprintf(stderr, "\tget <id>\n");
fprintf(stderr, "\tset [<ip>] [id <nr>] flags [no]backup|[no]fullmesh [port <nr>] [token <token>] [rip <ip>] [rport <port>]\n");
fprintf(stderr, "\tflush\n");
fprintf(stderr, "\tdump\n");
fprintf(stderr, "\tlimits [<rcv addr max> <subflow max>]\n");
fprintf(stderr, "\tevents\n");
fprintf(stderr, "\tlisten <local-ip> <local-port>\n");
exit(0);
}
static int init_genl_req(char *data, int family, int cmd, int version)
{
struct nlmsghdr *nh = (void *)data;
struct genlmsghdr *gh;
int off = 0;
nh->nlmsg_type = family;
nh->nlmsg_flags = NLM_F_REQUEST;
nh->nlmsg_len = NLMSG_LENGTH(GENL_HDRLEN);
off += NLMSG_ALIGN(sizeof(*nh));
gh = (void *)(data + off);
gh->cmd = cmd;
gh->version = version;
off += NLMSG_ALIGN(sizeof(*gh));
return off;
}
static void nl_error(struct nlmsghdr *nh)
{
struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nh);
int len = nh->nlmsg_len - sizeof(*nh);
uint32_t off;
if (len < sizeof(struct nlmsgerr))
error(1, 0, "netlink error message truncated %d min %ld", len,
sizeof(struct nlmsgerr));
if (!err->error) {
/* check messages from kernel */
struct rtattr *attrs = (struct rtattr *)NLMSG_DATA(nh);
while (RTA_OK(attrs, len)) {
if (attrs->rta_type == NLMSGERR_ATTR_MSG)
fprintf(stderr, "netlink ext ack msg: %s\n",
(char *)RTA_DATA(attrs));
if (attrs->rta_type == NLMSGERR_ATTR_OFFS) {
memcpy(&off, RTA_DATA(attrs), 4);
fprintf(stderr, "netlink err off %d\n",
(int)off);
}
attrs = RTA_NEXT(attrs, len);
}
} else {
fprintf(stderr, "netlink error %d", err->error);
}
}
static int capture_events(int fd, int event_group)
{
u_int8_t buffer[NLMSG_ALIGN(sizeof(struct nlmsghdr)) +
NLMSG_ALIGN(sizeof(struct genlmsghdr)) + 1024];
struct genlmsghdr *ghdr;
struct rtattr *attrs;
struct nlmsghdr *nh;
int ret = 0;
int res_len;
int msg_len;
fd_set rfds;
if (setsockopt(fd, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
&event_group, sizeof(event_group)) < 0)
error(1, errno, "could not join the " MPTCP_PM_EVENTS " mcast group");
do {
FD_ZERO(&rfds);
FD_SET(fd, &rfds);
res_len = NLMSG_ALIGN(sizeof(struct nlmsghdr)) +
NLMSG_ALIGN(sizeof(struct genlmsghdr)) + 1024;
ret = select(FD_SETSIZE, &rfds, NULL, NULL, NULL);
if (ret < 0)
error(1, ret, "error in select() on NL socket");
res_len = recv(fd, buffer, res_len, 0);
if (res_len < 0)
error(1, res_len, "error on recv() from NL socket");
nh = (struct nlmsghdr *)buffer;
for (; NLMSG_OK(nh, res_len); nh = NLMSG_NEXT(nh, res_len)) {
if (nh->nlmsg_type == NLMSG_ERROR)
error(1, NLMSG_ERROR, "received invalid NL message");
ghdr = (struct genlmsghdr *)NLMSG_DATA(nh);
if (ghdr->cmd == 0)
continue;
fprintf(stderr, "type:%d", ghdr->cmd);
msg_len = nh->nlmsg_len - NLMSG_LENGTH(GENL_HDRLEN);
attrs = (struct rtattr *) ((char *) ghdr + GENL_HDRLEN);
while (RTA_OK(attrs, msg_len)) {
if (attrs->rta_type == MPTCP_ATTR_TOKEN)
fprintf(stderr, ",token:%u", *(__u32 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_FAMILY)
fprintf(stderr, ",family:%u", *(__u16 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_LOC_ID)
fprintf(stderr, ",loc_id:%u", *(__u8 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_REM_ID)
fprintf(stderr, ",rem_id:%u", *(__u8 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_SADDR4) {
u_int32_t saddr4 = ntohl(*(__u32 *)RTA_DATA(attrs));
fprintf(stderr, ",saddr4:%u.%u.%u.%u", saddr4 >> 24,
(saddr4 >> 16) & 0xFF, (saddr4 >> 8) & 0xFF,
(saddr4 & 0xFF));
} else if (attrs->rta_type == MPTCP_ATTR_SADDR6) {
char buf[INET6_ADDRSTRLEN];
if (inet_ntop(AF_INET6, RTA_DATA(attrs), buf,
sizeof(buf)) != NULL)
fprintf(stderr, ",saddr6:%s", buf);
} else if (attrs->rta_type == MPTCP_ATTR_DADDR4) {
u_int32_t daddr4 = ntohl(*(__u32 *)RTA_DATA(attrs));
fprintf(stderr, ",daddr4:%u.%u.%u.%u", daddr4 >> 24,
(daddr4 >> 16) & 0xFF, (daddr4 >> 8) & 0xFF,
(daddr4 & 0xFF));
} else if (attrs->rta_type == MPTCP_ATTR_DADDR6) {
char buf[INET6_ADDRSTRLEN];
if (inet_ntop(AF_INET6, RTA_DATA(attrs), buf,
sizeof(buf)) != NULL)
fprintf(stderr, ",daddr6:%s", buf);
} else if (attrs->rta_type == MPTCP_ATTR_SPORT)
fprintf(stderr, ",sport:%u",
ntohs(*(__u16 *)RTA_DATA(attrs)));
else if (attrs->rta_type == MPTCP_ATTR_DPORT)
fprintf(stderr, ",dport:%u",
ntohs(*(__u16 *)RTA_DATA(attrs)));
else if (attrs->rta_type == MPTCP_ATTR_BACKUP)
fprintf(stderr, ",backup:%u", *(__u8 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_ERROR)
fprintf(stderr, ",error:%u", *(__u8 *)RTA_DATA(attrs));
else if (attrs->rta_type == MPTCP_ATTR_SERVER_SIDE)
fprintf(stderr, ",server_side:%u", *(__u8 *)RTA_DATA(attrs));
attrs = RTA_NEXT(attrs, msg_len);
}
}
fprintf(stderr, "\n");
} while (1);
return 0;
}
/* do a netlink command and, if max > 0, fetch the reply */
static int do_nl_req(int fd, struct nlmsghdr *nh, int len, int max)
{
struct sockaddr_nl nladdr = { .nl_family = AF_NETLINK };
socklen_t addr_len;
void *data = nh;
int rem, ret;
int err = 0;
nh->nlmsg_len = len;
ret = sendto(fd, data, len, 0, (void *)&nladdr, sizeof(nladdr));
if (ret != len)
error(1, errno, "send netlink: %uB != %uB\n", ret, len);
if (max == 0)
return 0;
addr_len = sizeof(nladdr);
rem = ret = recvfrom(fd, data, max, 0, (void *)&nladdr, &addr_len);
if (ret < 0)
error(1, errno, "recv netlink: %uB\n", ret);
/* Beware: the NLMSG_NEXT macro updates the 'rem' argument */
for (; NLMSG_OK(nh, rem); nh = NLMSG_NEXT(nh, rem)) {
if (nh->nlmsg_type == NLMSG_ERROR) {
nl_error(nh);
err = 1;
}
}
if (err)
error(1, 0, "bailing out due to netlink error[s]");
return ret;
}
static int genl_parse_getfamily(struct nlmsghdr *nlh, int *pm_family,
int *events_mcast_grp)
{
struct genlmsghdr *ghdr = NLMSG_DATA(nlh);
int len = nlh->nlmsg_len;
struct rtattr *attrs;
struct rtattr *grps;
struct rtattr *grp;
int got_events_grp;
int got_family;
int grps_len;
int grp_len;
if (nlh->nlmsg_type != GENL_ID_CTRL)
error(1, errno, "Not a controller
|