// SPDX-License-Identifier: GPL-2.0-only
#include <linux/module.h>
#include <linux/errno.h>
#include <linux/socket.h>
#include <linux/skbuff.h>
#include <linux/ip.h>
#include <linux/icmp.h>
#include <linux/udp.h>
#include <linux/types.h>
#include <linux/kernel.h>
#include <net/genetlink.h>
#include <net/gro.h>
#include <net/gue.h>
#include <net/fou.h>
#include <net/ip.h>
#include <net/protocol.h>
#include <net/udp.h>
#include <net/udp_tunnel.h>
#include <uapi/linux/fou.h>
#include <uapi/linux/genetlink.h>
#include "fou_nl.h"
struct fou {
struct socket *sock;
u8 protocol;
u8 flags;
__be16 port;
u8 family;
u16 type;
struct list_head list;
struct rcu_head rcu;
};
#define FOU_F_REMCSUM_NOPARTIAL BIT(0)
struct fou_cfg {
u16 type;
u8 protocol;
u8 flags;
struct udp_port_cfg udp_config;
};
static unsigned int fou_net_id;
struct fou_net {
struct list_head fou_list;
struct mutex fou_lock;
};
static inline struct fou *fou_from_sock(struct sock *sk)
{
return rcu_dereference_sk_user_data(sk);
}
static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len)
{
/* Remove 'len' bytes from the packet (UDP header and
* FOU header if present).
*/
if (fou->family == AF_INET)
ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
else
ipv6_hdr(skb)->payload_len =
htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
__skb_pull(skb, len);
skb_postpull_rcsum(skb, udp_hdr(skb), len);
skb_reset_transport_header(skb);
return iptunnel_pull_offloads(skb);
}
static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
{
struct fou *fou = fou_from_sock(sk);
if (!fou)
return 1;
if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
goto drop;
return -fou->protocol;
drop:
kfree_skb(skb);
return 0;
}
static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
void *data, size_t hdrlen, u8 ipproto,
bool nopartial)
{
__be16 *pd = data;
size_t start = ntohs(pd[0]);
size_t offset = ntohs(pd[1]);
size_t plen = sizeof(struct udphdr) + hdrlen +
max_t(size_t, offset + sizeof(u16), start);
if (skb->remcsum_offload)
return guehdr;
if (!pskb_may_pull(skb, plen))
return NULL;
guehdr = (struct guehdr *)&udp_hdr(skb)[1];
skb_remcsum_process(skb, (void *)guehdr + hdrlen,
start, offset, nopartial);
return guehdr;