diff options
36 files changed, 7755 insertions, 0 deletions
diff --git a/MAINTAINERS b/MAINTAINERS index bd5847e802de..ce7fd2e7aa1a 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -17850,6 +17850,14 @@ L: linux-gpio@vger.kernel.org S: Maintained F: drivers/gpio/gpio-ws16c48.c +WIREGUARD SECURE NETWORK TUNNEL +M: Jason A. Donenfeld <Jason@zx2c4.com> +S: Maintained +F: drivers/net/wireguard/ +F: tools/testing/selftests/wireguard/ +L: wireguard@lists.zx2c4.com +L: netdev@vger.kernel.org + WISTRON LAPTOP BUTTON DRIVER M: Miloslav Trmac <mitr@volny.cz> S: Maintained diff --git a/drivers/net/Kconfig b/drivers/net/Kconfig index d02f12a5254e..ffe8d4e2b206 100644 --- a/drivers/net/Kconfig +++ b/drivers/net/Kconfig @@ -71,6 +71,47 @@ config DUMMY To compile this driver as a module, choose M here: the module will be called dummy. +config WIREGUARD + tristate "WireGuard secure network tunnel" + depends on NET && INET + depends on IPV6 || !IPV6 + select NET_UDP_TUNNEL + select DST_CACHE + select CRYPTO + select CRYPTO_LIB_CURVE25519 + select CRYPTO_LIB_CHACHA20POLY1305 + select CRYPTO_LIB_BLAKE2S + select CRYPTO_CHACHA20_X86_64 if X86 && 64BIT + select CRYPTO_POLY1305_X86_64 if X86 && 64BIT + select CRYPTO_BLAKE2S_X86 if X86 && 64BIT + select CRYPTO_CURVE25519_X86 if X86 && 64BIT + select CRYPTO_CHACHA20_NEON if (ARM || ARM64) && KERNEL_MODE_NEON + select CRYPTO_POLY1305_NEON if ARM64 && KERNEL_MODE_NEON + select CRYPTO_POLY1305_ARM if ARM + select CRYPTO_CURVE25519_NEON if ARM && KERNEL_MODE_NEON + select CRYPTO_CHACHA_MIPS if CPU_MIPS32_R2 + select CRYPTO_POLY1305_MIPS if CPU_MIPS32 || (CPU_MIPS64 && 64BIT) + help + WireGuard is a secure, fast, and easy to use replacement for IPSec + that uses modern cryptography and clever networking tricks. It's + designed to be fairly general purpose and abstract enough to fit most + use cases, while at the same time remaining extremely simple to + configure. See www.wireguard.com for more info. + + It's safe to say Y or M here, as the driver is very lightweight and + is only in use when an administrator chooses to add an interface. + +config WIREGUARD_DEBUG + bool "Debugging checks and verbose messages" + depends on WIREGUARD + help + This will write log messages for handshake and other events + that occur for a WireGuard interface. It will also perform some + extra validation checks and unit tests at various points. This is + only useful for debugging. + + Say N here unless you know what you're doing. + config EQUALIZER tristate "EQL (serial line load balancing) support" ---help--- diff --git a/drivers/net/Makefile b/drivers/net/Makefile index 0d3ba056cda3..953b7c12f0b0 100644 --- a/drivers/net/Makefile +++ b/drivers/net/Makefile @@ -10,6 +10,7 @@ obj-$(CONFIG_BONDING) += bonding/ obj-$(CONFIG_IPVLAN) += ipvlan/ obj-$(CONFIG_IPVTAP) += ipvlan/ obj-$(CONFIG_DUMMY) += dummy.o +obj-$(CONFIG_WIREGUARD) += wireguard/ obj-$(CONFIG_EQUALIZER) += eql.o obj-$(CONFIG_IFB) += ifb.o obj-$(CONFIG_MACSEC) += macsec.o diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile new file mode 100644 index 000000000000..fc52b2cb500b --- /dev/null +++ b/drivers/net/wireguard/Makefile @@ -0,0 +1,18 @@ +ccflags-y := -O3 +ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt' +ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG +wireguard-y := main.o +wireguard-y += noise.o +wireguard-y += device.o +wireguard-y += peer.o +wireguard-y += timers.o +wireguard-y += queueing.o +wireguard-y += send.o +wireguard-y += receive.o +wireguard-y += socket.o +wireguard-y += peerlookup.o +wireguard-y += allowedips.o +wireguard-y += ratelimiter.o +wireguard-y += cookie.o +wireguard-y += netlink.o +obj-$(CONFIG_WIREGUARD) := wireguard.o diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c new file mode 100644 index 000000000000..72667d5399c3 --- /dev/null +++ b/drivers/net/wireguard/allowedips.c @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +#include "allowedips.h" +#include "peer.h" + +static void swap_endian(u8 *dst, const u8 *src, u8 bits) +{ + if (bits == 32) { + *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); + } else if (bits == 128) { + ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); + ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); + } +} + +static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, + u8 cidr, u8 bits) +{ + node->cidr = cidr; + node->bit_at_a = cidr / 8U; +#ifdef __LITTLE_ENDIAN + node->bit_at_a ^= (bits / 8U - 1U) % 8U; +#endif + node->bit_at_b = 7U - (cidr % 8U); + node->bitlen = bits; + memcpy(node->bits, src, bits / 8U); +} +#define CHOOSE_NODE(parent, key) \ + parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] + +static void node_free_rcu(struct rcu_head *rcu) +{ + kfree(container_of(rcu, struct allowedips_node, rcu)); +} + +static void push_rcu(struct allowedips_node **stack, + struct allowedips_node __rcu *p, unsigned int *len) +{ + if (rcu_access_pointer(p)) { + WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); + stack[(*len)++] = rcu_dereference_raw(p); + } +} + +static void root_free_rcu(struct rcu_head *rcu) +{ + struct allowedips_node *node, *stack[128] = { + container_of(rcu, struct allowedips_node, rcu) }; + unsigned int len = 1; + + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); + kfree(node); + } +} + +static void root_remove_peer_lists(struct allowedips_node *root) +{ + struct allowedips_node *node, *stack[128] = { root }; + unsigned int len = 1; + + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); + if (rcu_access_pointer(node->peer)) + list_del(&node->peer_list); + } +} + +static void walk_remove_by_peer(struct allowedips_node __rcu **top, + struct wg_peer *peer, struct mutex *lock) +{ +#define REF(p) rcu_access_pointer(p) +#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) +#define PUSH(p) ({ \ + WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \ + stack[len++] = p; \ + }) + + struct allowedips_node __rcu **stack[128], **nptr; + struct allowedips_node *node, *prev; + unsigned int len; + + if (unlikely(!peer || !REF(*top))) + return; + + for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { + nptr = stack[len - 1]; + node = DEREF(nptr); + if (!node) { + --len; + continue; + } + if (!prev || REF(prev->bit[0]) == node || + REF(prev->bit[1]) == node) { + if (REF(node->bit[0])) + PUSH(&node->bit[0]); + else if (REF(node->bit[1])) + PUSH(&node->bit[1]); + } else if (REF(node->bit[0]) == prev) { + if (REF(node->bit[1])) + PUSH(&node->bit[1]); + } else { + if (rcu_dereference_protected(node->peer, + lockdep_is_held(lock)) == peer) { + RCU_INIT_POINTER(node->peer, NULL); + list_del_init(&node->peer_list); + if (!node->bit[0] || !node->bit[1]) { + rcu_assign_pointer(*nptr, DEREF( + &node->bit[!REF(node->bit[0])])); + call_rcu(&node->rcu, node_free_rcu); + node = DEREF(nptr); + } + } + --len; + } + } + +#undef REF +#undef DEREF +#undef PUSH +} + +static unsigned int fls128(u64 a, u64 b) +{ + return a ? fls64(a) + 64U : fls64(b); +} + +static u8 common_bits(const struct allowedips_node *node, const u8 *key, + u8 bits) +{ + if (bits == 32) + return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); + else if (bits == 128) + return 128U - fls128( + *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], + *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); + return 0; +} + +static bool prefix_matches(const struct allowedips_node *node, const u8 *key, + u8 bits) +{ + /* This could be much faster if it actually just compared the common + * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and + * the rest, but it turns out that common_bits is already super fast on + * modern processors, even taking into account the unfortunate bswap. + * So, we just inline it like this instead. + */ + return common_bits(node, key, bits) >= node->cidr; +} + +static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, + const u8 *key) +{ + struct allowedips_node *node = trie, *found = NULL; + + while (node && prefix_matches(node, key, bits)) { + if (rcu_access_pointer(node->peer)) + found = node; + if (node->cidr == bits) + break; + node = rcu_dereference_bh(CHOOSE_NODE(node, key)); + } + return found; +} + +/* Returns a strong reference to a peer */ +static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, + const void *be_ip) +{ + /* Aligned so it can be passed to fls/fls64 */ + u8 ip[16] __aligned(__alignof(u64)); + struct allowedips_node *node; + struct wg_peer *peer = NULL; + + swap_endian(ip, be_ip, bits); + + rcu_read_lock_bh(); +retry: + node = find_node(rcu_dereference_bh(root), bits, ip); + if (node) { + peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); + if (!peer) + goto retry; + } + rcu_read_unlock_bh(); + return peer; +} + +static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, + u8 cidr, u8 bits, struct allowedips_node **rnode, + struct mutex *lock) +{ + struct allowedips_node *node = rcu_dereference_protected(trie, + lockdep_is_held(lock)); + struct allowedips_node *parent = NULL; + bool exact = false; + + while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { + parent = node; + if (parent->cidr == cidr) { + exact = true; + break; + } + node = rcu_dereference_protected(CHOOSE_NODE(parent, key), + lockdep_is_held(lock)); + } + *rnode = parent; + return exact; +} + +static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + struct allowedips_node *node, *parent, *down, *newnode; + + if (unlikely(cidr > bits || !peer)) + return -EINVAL; + + if (!rcu_access_pointer(*trie)) { + node = kzalloc(sizeof(*node), GFP_KERNEL); + if (unlikely(!node)) + return -ENOMEM; + RCU_INIT_POINTER(node->peer, peer); + list_add_tail(&node->peer_list, &peer->allowedips_list); + copy_and_assign_cidr(node, key, cidr, bits); + rcu_assign_pointer(*trie, node); + return 0; + } + if (node_placement(*trie, key, cidr, bits, &node, lock)) { + rcu_assign_pointer(node->peer, peer); + list_move_tail(&node->peer_list, &peer->allowedips_list); + return 0; + } + + newnode = kzalloc(sizeof(*newnode), GFP_KERNEL); + if (unlikely(!newnode)) + return -ENOMEM; + RCU_INIT_POINTER(newnode->peer, peer); + list_add_tail(&newnode->peer_list, &peer->allowedips_list); + copy_and_assign_cidr(newnode, key, cidr, bits); + + if (!node) { + down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); + } else { + down = rcu_dereference_protected(CHOOSE_NODE(node, key), + lockdep_is_held(lock)); + if (!down) { + rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); + return 0; + } + } + cidr = min(cidr, common_bits(down, key, bits)); + parent = node; + + if (newnode->cidr == cidr) { + rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); + if (!parent) + rcu_assign_pointer(*trie, newnode); + else + rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), + newnode); + } else { + node = kzalloc(sizeof(*node), GFP_KERNEL); + if (unlikely(!node)) { + kfree(newnode); + return -ENOMEM; + } + INIT_LIST_HEAD(&node->peer_list); + copy_and_assign_cidr(node, newnode->bits, cidr, bits); + + rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); + rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); + if (!parent) + rcu_assign_pointer(*trie, node); + else + rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), + node); + } + return 0; +} + +void wg_allowedips_init(struct allowedips *table) +{ + table->root4 = table->root6 = NULL; + table->seq = 1; +} + +void wg_allowedips_free(struct allowedips *table, struct mutex *lock) +{ + struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; + + ++table->seq; + RCU_INIT_POINTER(table->root4, NULL); + RCU_INIT_POINTER(table->root6, NULL); + if (rcu_access_pointer(old4)) { + struct allowedips_node *node = rcu_dereference_protected(old4, + lockdep_is_held(lock)); + + root_remove_peer_lists(node); + call_rcu(&node->rcu, root_free_rcu); + } + if (rcu_access_pointer(old6)) { + struct allowedips_node *node = rcu_dereference_protected(old6, + lockdep_is_held(lock)); + + root_remove_peer_lists(node); + call_rcu(&node->rcu, root_free_rcu); + } +} + +int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls */ + u8 key[4] __aligned(__alignof(u32)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 32); + return add(&table->root4, 32, key, cidr, peer, lock); +} + +int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls64 */ + u8 key[16] __aligned(__alignof(u64)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 128); + return add(&table->root6, 128, key, cidr, peer, lock); +} + +void wg_allowedips_remove_by_peer(struct allowedips *table, + struct wg_peer *peer, struct mutex *lock) +{ + ++table->seq; + walk_remove_by_peer(&table->root4, peer, lock); + walk_remove_by_peer(&table->root6, peer, lock); +} + +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) +{ + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + swap_endian(ip, node->bits, node->bitlen); + memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); + if (node->cidr) + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + + *cidr = node->cidr; + return node->bitlen == 32 ? AF_INET : AF_INET6; +} + +/* Returns a strong reference to a peer */ +struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb) +{ + if (skb->protocol == htons(ETH_P_IP)) + return lookup(table->root4, 32, &ip_hdr(skb)->daddr); + else if (skb->protocol == htons(ETH_P_IPV6)) + return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); + return NULL; +} + +/* Returns a strong reference to a peer */ +struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb) +{ + if (skb->protocol == htons(ETH_P_IP)) + return lookup(table->root4, 32, &ip_hdr(skb)->saddr); + else if (skb->protocol == htons(ETH_P_IPV6)) + return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); + return NULL; +} + +#include "selftest/allowedips.c" diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h new file mode 100644 index 000000000000..e5c83cafcef4 --- /dev/null +++ b/drivers/net/wireguard/allowedips.h @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +#ifndef _WG_ALLOWEDIPS_H +#define _WG_ALLOWEDIPS_H + +#include <linux/mutex.h> +#include <linux/ip.h> +#include <linux/ipv6.h> + +struct wg_peer; + +struct allowedips_node { + struct wg_peer __rcu *peer; + struct allowedips_node __rcu *bit[2]; + /* While it may seem scandalous that we waste space for v4, + * we're alloc'ing to the nearest power of 2 anyway, so this + * doesn't actually make a difference. + */ + u8 bits[16] __aligned(__alignof(u64)); + u8 cidr, bit_at_a, bit_at_b, bitlen; + + /* Keep rarely used list at bottom to be beyond cache line. */ + union { + struct list_head peer_list; + struct rcu_head rcu; + }; +}; + +struct allowedips { + struct allowedips_node __rcu *root4; + struct allowedips_node __rcu *root6; + u64 seq; +}; + +void wg_allowedips_init(struct allowedips *table); +void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); +int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); +void wg_allowedips_remove_by_peer(struct allowedips *table, + struct wg_peer *peer, struct mutex *lock); +/* The ip input pointer should be __aligned(__alignof(u64))) */ +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr); + +/* These return a strong reference to a peer: */ +struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb); +struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb); + +#ifdef DEBUG +bool wg_allowedips_selftest(void); +#endif + +#endif /* _WG_ALLOWEDIPS_H */ diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c new file mode 100644 index 000000000000..4956f0499c19 --- /dev/null +++ b/drivers/net/wireguard/cookie.c @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +#include "cookie.h" +#include "peer.h" +#include "device.h" +#include "messages.h" +#include "ratelimiter.h" +#include "timers.h" + +#include <crypto/blake2s.h> +#include <crypto/chacha20poly1305.h> + +#include <net/ipv6.h> +#include <crypto/algapi.h> + +void wg_cookie_checker_init(struct cookie_checker *checker, + struct wg_device *wg) +{ + init_rwsem(&checker->secret_lock); + checker->secret_birthdate = ktime_get_coarse_boottime_ns(); + get_random_bytes(checker->secret, NOISE_HASH_LEN); + checker->device = wg; +} + +enum { COOKIE_KEY_LABEL_LEN = 8 }; +static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; +static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; + +static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 pubkey[NOISE_PUBLIC_KEY_LEN], + const u8 label[COOKIE_KEY_LABEL_LEN]) +{ + struct blake2s_state blake; + + blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN); + blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN); + blake2s_update(&blake, pubkey, NOISE_PUBLIC_KEY_LEN); + blake2s_final(&blake, key); +} + +/* Must hold peer->handshake.static_identity->lock */ +void wg_cookie_checker_precompute_device_keys(struct cookie_checker *checker) +{ + if (likely(checker->device->static_identity.has_identity)) { + precompute_key(checker->cookie_encryption_key, + checker->device->static_identity.static_public, + cookie_key_label); + precompute_key(checker->message_mac1_key, + checker->device->static_identity.static_public, + mac1_key_label); + } else { + memset(checker->cookie_encryption_key, 0, + NOISE_SYMMETRIC_KEY_LEN); + memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN); + } +} + +void wg_cookie_checker_precompute_peer_keys(struct wg_peer *peer) +{ + precompute_key(peer->latest_cookie.cookie_decryption_key, + peer->handshake.remote_static, cookie_key_label); + precompute_key(peer->latest_cookie.message_mac1_key, + peer->handshake.remote_static, mac1_key_label); +} + +void wg_cookie_init(struct cookie *cookie) +{ + memset(cookie, 0, sizeof(*cookie)); + init_rwsem(&cookie->lock); +} + +static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, + const u8 key[NOISE_SYMMETRIC_KEY_LEN]) +{ + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac1); + blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN); +} + +static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, + const u8 cookie[COOKIE_LEN]) +{ + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac2); + blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN); +} + +static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, + struct cookie_checker *checker) +{ + struct blake2s_state state; + + if (wg_birthdate_has_expired(checker->secret_birthdate, + COOKIE_SECRET_MAX_AGE)) { + down_write(&checker->secret_lock); + checker->secret_birthdate = ktime_get_coarse_boottime_ns(); + get_random_bytes(checker->secret, NOISE_HASH_LEN); + up_write(&checker->secret_lock); + } + + down_read(&checker->secret_lock); + + blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN); + if (skb->protocol == htons(ETH_P_IP)) + blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, + sizeof(struct in_addr)); + else if (skb->protocol == htons(ETH_P_IPV6)) + blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, + sizeof(struct in6_addr)); + blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16)); + blake2s_final(&state, cookie); + + up_read(&checker->secret_lock); +} + +enum cookie_mac_state wg_cookie_validate_packet(struct cookie_checker *checker, + struct sk_buff *skb, + bool check_cookie) +{ + struct message_macs *macs = (struct message_macs *) + (skb->data + skb->len - sizeof(*macs)); + enum cookie_mac_state ret; + u8 computed_mac[COOKIE_LEN]; + u8 cookie[COOKIE_LEN]; + + ret = INVALID_MAC; + compute_mac1(computed_mac, skb->data, skb->len, + checker->message_mac1_key); + if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN)) + goto out; + + ret = VALID_MAC_BUT_NO_COOKIE; + + if (!check_cookie) + goto out; + + make_cookie(cookie, skb, checker); + + compute_mac2(computed_mac, skb->data, skb->len, cookie); + if (crypto_memneq(computed_mac, macs->mac2, COOKIE_LEN)) + goto out; + + ret = VALID_MAC_WITH_COOKIE_BUT_RATELIMITED; + if (!wg_ratelimiter_allow(skb, dev_net(checker->device->dev))) + goto out; + + ret = VALID_MAC_WITH_COOKIE; + +out: + return ret; +} + +void wg_cookie_add_mac_to_packet(void *message, size_t len, + struct wg_peer *peer) +{ + struct message_macs *macs = (struct message_macs *) + ((u8 *)message + len - sizeof(*macs)); + + down_write(&peer->latest_cookie.lock); + compute_mac1(macs->mac1, message, len, + peer->latest_cookie.message_mac1_key); + memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN); + peer->latest_cookie.have_sent_mac1 = true; + up_write(&peer->latest_cookie.lock); + + down_read(&peer->latest_cookie.lock); + if (peer->latest_cookie.is_valid && + !wg_birthdate_has_expired(peer->latest_cookie.birthdate, + COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY)) + compute_mac2(macs->mac2, message, len, + peer->latest_cookie.cookie); + else + memset(macs->mac2, 0, COOKIE_LEN); + up_read(&peer->latest_cookie.lock); +} + +void wg_cookie_message_create(struct message_handshake_cookie *dst, + struct sk_buff *skb, __le32 index, + struct cookie_checker *checker) +{ + struct message_macs *macs = (struct message_macs *) + ((u8 *)skb->data + skb->len - sizeof(*macs)); + u8 cookie[COOKIE_LEN]; + + dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE); + dst->receiver_index = index; + get_random_bytes_wait(dst->nonce, COOKIE_NONCE_LEN); + + make_cookie(cookie, skb, checker); + xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, + macs->mac1, COOKIE_LEN, dst->nonce, + checker->cookie_encryption_key); +} + +void wg_cookie_message_consume(struct message_handshake_cookie *src, + struct wg_device *wg) +{ + struct wg_peer *peer = NULL; + u8 cookie[COOKIE_LEN]; + bool ret; + + if (unlikely(!wg_index_hashtable_lookup(wg->index_hashtable, + INDEX_HASHTABLE_HANDSHAKE | + INDEX_HASHTABLE_KEYPAIR, + src->receiver_index, &peer))) + return; + + down_read(&peer->latest_cookie.lock); + if (unlikely(!peer->latest_cookie.have_sent_mac1)) { + up_read(&peer->latest_cookie.lock); + goto out; + } + ret = xchacha20poly1305_decrypt( + cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), + peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, + peer->latest_cookie.cookie_decryption_key); + up_read(&peer->latest_cookie.lock); + + if (ret) { + down_write(&peer->latest_cookie.lock); + memcpy(peer->latest_cookie.cookie, cookie, COOKIE_LEN); + peer->latest_cookie.birthdate = ktime_get_coarse_boottime_ns(); + peer->la |