summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--MAINTAINERS8
-rw-r--r--drivers/net/Kconfig41
-rw-r--r--drivers/net/Makefile1
-rw-r--r--drivers/net/wireguard/Makefile18
-rw-r--r--drivers/net/wireguard/allowedips.c381
-rw-r--r--drivers/net/wireguard/allowedips.h59
-rw-r--r--drivers/net/wireguard/cookie.c236
-rw-r--r--drivers/net/wireguard/cookie.h59
-rw-r--r--drivers/net/wireguard/device.c458
-rw-r--r--drivers/net/wireguard/device.h73
-rw-r--r--drivers/net/wireguard/main.c64
-rw-r--r--drivers/net/wireguard/messages.h128
-rw-r--r--drivers/net/wireguard/netlink.c642
-rw-r--r--drivers/net/wireguard/netlink.h12
-rw-r--r--drivers/net/wireguard/noise.c828
-rw-r--r--drivers/net/wireguard/noise.h137
-rw-r--r--drivers/net/wireguard/peer.c240
-rw-r--r--drivers/net/wireguard/peer.h83
-rw-r--r--drivers/net/wireguard/peerlookup.c221
-rw-r--r--drivers/net/wireguard/peerlookup.h64
-rw-r--r--drivers/net/wireguard/queueing.c53
-rw-r--r--drivers/net/wireguard/queueing.h197
-rw-r--r--drivers/net/wireguard/ratelimiter.c223
-rw-r--r--drivers/net/wireguard/ratelimiter.h19
-rw-r--r--drivers/net/wireguard/receive.c595
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c683
-rw-r--r--drivers/net/wireguard/selftest/counter.c104
-rw-r--r--drivers/net/wireguard/selftest/ratelimiter.c226
-rw-r--r--drivers/net/wireguard/send.c413
-rw-r--r--drivers/net/wireguard/socket.c437
-rw-r--r--drivers/net/wireguard/socket.h44
-rw-r--r--drivers/net/wireguard/timers.c243
-rw-r--r--drivers/net/wireguard/timers.h31
-rw-r--r--drivers/net/wireguard/version.h1
-rw-r--r--include/uapi/linux/wireguard.h196
-rwxr-xr-xtools/testing/selftests/wireguard/netns.sh537
36 files changed, 7755 insertions, 0 deletions
diff --git a/MAINTAINERS b/MAINTAINERS
index bd5847e802de..ce7fd2e7aa1a 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -17850,6 +17850,14 @@ L: linux-gpio@vger.kernel.org
S: Maintained
F: drivers/gpio/gpio-ws16c48.c
+WIREGUARD SECURE NETWORK TUNNEL
+M: Jason A. Donenfeld <Jason@zx2c4.com>
+S: Maintained
+F: drivers/net/wireguard/
+F: tools/testing/selftests/wireguard/
+L: wireguard@lists.zx2c4.com
+L: netdev@vger.kernel.org
+
WISTRON LAPTOP BUTTON DRIVER
M: Miloslav Trmac <mitr@volny.cz>
S: Maintained
diff --git a/drivers/net/Kconfig b/drivers/net/Kconfig
index d02f12a5254e..ffe8d4e2b206 100644
--- a/drivers/net/Kconfig
+++ b/drivers/net/Kconfig
@@ -71,6 +71,47 @@ config DUMMY
To compile this driver as a module, choose M here: the module
will be called dummy.
+config WIREGUARD
+ tristate "WireGuard secure network tunnel"
+ depends on NET && INET
+ depends on IPV6 || !IPV6
+ select NET_UDP_TUNNEL
+ select DST_CACHE
+ select CRYPTO
+ select CRYPTO_LIB_CURVE25519
+ select CRYPTO_LIB_CHACHA20POLY1305
+ select CRYPTO_LIB_BLAKE2S
+ select CRYPTO_CHACHA20_X86_64 if X86 && 64BIT
+ select CRYPTO_POLY1305_X86_64 if X86 && 64BIT
+ select CRYPTO_BLAKE2S_X86 if X86 && 64BIT
+ select CRYPTO_CURVE25519_X86 if X86 && 64BIT
+ select CRYPTO_CHACHA20_NEON if (ARM || ARM64) && KERNEL_MODE_NEON
+ select CRYPTO_POLY1305_NEON if ARM64 && KERNEL_MODE_NEON
+ select CRYPTO_POLY1305_ARM if ARM
+ select CRYPTO_CURVE25519_NEON if ARM && KERNEL_MODE_NEON
+ select CRYPTO_CHACHA_MIPS if CPU_MIPS32_R2
+ select CRYPTO_POLY1305_MIPS if CPU_MIPS32 || (CPU_MIPS64 && 64BIT)
+ help
+ WireGuard is a secure, fast, and easy to use replacement for IPSec
+ that uses modern cryptography and clever networking tricks. It's
+ designed to be fairly general purpose and abstract enough to fit most
+ use cases, while at the same time remaining extremely simple to
+ configure. See www.wireguard.com for more info.
+
+ It's safe to say Y or M here, as the driver is very lightweight and
+ is only in use when an administrator chooses to add an interface.
+
+config WIREGUARD_DEBUG
+ bool "Debugging checks and verbose messages"
+ depends on WIREGUARD
+ help
+ This will write log messages for handshake and other events
+ that occur for a WireGuard interface. It will also perform some
+ extra validation checks and unit tests at various points. This is
+ only useful for debugging.
+
+ Say N here unless you know what you're doing.
+
config EQUALIZER
tristate "EQL (serial line load balancing) support"
---help---
diff --git a/drivers/net/Makefile b/drivers/net/Makefile
index 0d3ba056cda3..953b7c12f0b0 100644
--- a/drivers/net/Makefile
+++ b/drivers/net/Makefile
@@ -10,6 +10,7 @@ obj-$(CONFIG_BONDING) += bonding/
obj-$(CONFIG_IPVLAN) += ipvlan/
obj-$(CONFIG_IPVTAP) += ipvlan/
obj-$(CONFIG_DUMMY) += dummy.o
+obj-$(CONFIG_WIREGUARD) += wireguard/
obj-$(CONFIG_EQUALIZER) += eql.o
obj-$(CONFIG_IFB) += ifb.o
obj-$(CONFIG_MACSEC) += macsec.o
diff --git a/drivers/net/wireguard/Makefile b/drivers/net/wireguard/Makefile
new file mode 100644
index 000000000000..fc52b2cb500b
--- /dev/null
+++ b/drivers/net/wireguard/Makefile
@@ -0,0 +1,18 @@
+ccflags-y := -O3
+ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
+ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
+wireguard-y := main.o
+wireguard-y += noise.o
+wireguard-y += device.o
+wireguard-y += peer.o
+wireguard-y += timers.o
+wireguard-y += queueing.o
+wireguard-y += send.o
+wireguard-y += receive.o
+wireguard-y += socket.o
+wireguard-y += peerlookup.o
+wireguard-y += allowedips.o
+wireguard-y += ratelimiter.o
+wireguard-y += cookie.o
+wireguard-y += netlink.o
+obj-$(CONFIG_WIREGUARD) := wireguard.o
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
new file mode 100644
index 000000000000..72667d5399c3
--- /dev/null
+++ b/drivers/net/wireguard/allowedips.c
@@ -0,0 +1,381 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#include "allowedips.h"
+#include "peer.h"
+
+static void swap_endian(u8 *dst, const u8 *src, u8 bits)
+{
+ if (bits == 32) {
+ *(u32 *)dst = be32_to_cpu(*(const __be32 *)src);
+ } else if (bits == 128) {
+ ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]);
+ ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]);
+ }
+}
+
+static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
+ u8 cidr, u8 bits)
+{
+ node->cidr = cidr;
+ node->bit_at_a = cidr / 8U;
+#ifdef __LITTLE_ENDIAN
+ node->bit_at_a ^= (bits / 8U - 1U) % 8U;
+#endif
+ node->bit_at_b = 7U - (cidr % 8U);
+ node->bitlen = bits;
+ memcpy(node->bits, src, bits / 8U);
+}
+#define CHOOSE_NODE(parent, key) \
+ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+
+static void node_free_rcu(struct rcu_head *rcu)
+{
+ kfree(container_of(rcu, struct allowedips_node, rcu));
+}
+
+static void push_rcu(struct allowedips_node **stack,
+ struct allowedips_node __rcu *p, unsigned int *len)
+{
+ if (rcu_access_pointer(p)) {
+ WARN_ON(IS_ENABLED(DEBUG) && *len >= 128);
+ stack[(*len)++] = rcu_dereference_raw(p);
+ }
+}
+
+static void root_free_rcu(struct rcu_head *rcu)
+{
+ struct allowedips_node *node, *stack[128] = {
+ container_of(rcu, struct allowedips_node, rcu) };
+ unsigned int len = 1;
+
+ while (len > 0 && (node = stack[--len])) {
+ push_rcu(stack, node->bit[0], &len);
+ push_rcu(stack, node->bit[1], &len);
+ kfree(node);
+ }
+}
+
+static void root_remove_peer_lists(struct allowedips_node *root)
+{
+ struct allowedips_node *node, *stack[128] = { root };
+ unsigned int len = 1;
+
+ while (len > 0 && (node = stack[--len])) {
+ push_rcu(stack, node->bit[0], &len);
+ push_rcu(stack, node->bit[1], &len);
+ if (rcu_access_pointer(node->peer))
+ list_del(&node->peer_list);
+ }
+}
+
+static void walk_remove_by_peer(struct allowedips_node __rcu **top,
+ struct wg_peer *peer, struct mutex *lock)
+{
+#define REF(p) rcu_access_pointer(p)
+#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
+#define PUSH(p) ({ \
+ WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
+ stack[len++] = p; \
+ })
+
+ struct allowedips_node __rcu **stack[128], **nptr;
+ struct allowedips_node *node, *prev;
+ unsigned int len;
+
+ if (unlikely(!peer || !REF(*top)))
+ return;
+
+ for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
+ nptr = stack[len - 1];
+ node = DEREF(nptr);
+ if (!node) {
+ --len;
+ continue;
+ }
+ if (!prev || REF(prev->bit[0]) == node ||
+ REF(prev->bit[1]) == node) {
+ if (REF(node->bit[0]))
+ PUSH(&node->bit[0]);
+ else if (REF(node->bit[1]))
+ PUSH(&node->bit[1]);
+ } else if (REF(node->bit[0]) == prev) {
+ if (REF(node->bit[1]))
+ PUSH(&node->bit[1]);
+ } else {
+ if (rcu_dereference_protected(node->peer,
+ lockdep_is_held(lock)) == peer) {
+ RCU_INIT_POINTER(node->peer, NULL);
+ list_del_init(&node->peer_list);
+ if (!node->bit[0] || !node->bit[1]) {
+ rcu_assign_pointer(*nptr, DEREF(
+ &node->bit[!REF(node->bit[0])]));
+ call_rcu(&node->rcu, node_free_rcu);
+ node = DEREF(nptr);
+ }
+ }
+ --len;
+ }
+ }
+
+#undef REF
+#undef DEREF
+#undef PUSH
+}
+
+static unsigned int fls128(u64 a, u64 b)
+{
+ return a ? fls64(a) + 64U : fls64(b);
+}
+
+static u8 common_bits(const struct allowedips_node *node, const u8 *key,
+ u8 bits)
+{
+ if (bits == 32)
+ return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key);
+ else if (bits == 128)
+ return 128U - fls128(
+ *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0],
+ *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
+ return 0;
+}
+
+static bool prefix_matches(const struct allowedips_node *node, const u8 *key,
+ u8 bits)
+{
+ /* This could be much faster if it actually just compared the common
+ * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and
+ * the rest, but it turns out that common_bits is already super fast on
+ * modern processors, even taking into account the unfortunate bswap.
+ * So, we just inline it like this instead.
+ */
+ return common_bits(node, key, bits) >= node->cidr;
+}
+
+static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
+ const u8 *key)
+{
+ struct allowedips_node *node = trie, *found = NULL;
+
+ while (node && prefix_matches(node, key, bits)) {
+ if (rcu_access_pointer(node->peer))
+ found = node;
+ if (node->cidr == bits)
+ break;
+ node = rcu_dereference_bh(CHOOSE_NODE(node, key));
+ }
+ return found;
+}
+
+/* Returns a strong reference to a peer */
+static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits,
+ const void *be_ip)
+{
+ /* Aligned so it can be passed to fls/fls64 */
+ u8 ip[16] __aligned(__alignof(u64));
+ struct allowedips_node *node;
+ struct wg_peer *peer = NULL;
+
+ swap_endian(ip, be_ip, bits);
+
+ rcu_read_lock_bh();
+retry:
+ node = find_node(rcu_dereference_bh(root), bits, ip);
+ if (node) {
+ peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer));
+ if (!peer)
+ goto retry;
+ }
+ rcu_read_unlock_bh();
+ return peer;
+}
+
+static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
+ u8 cidr, u8 bits, struct allowedips_node **rnode,
+ struct mutex *lock)
+{
+ struct allowedips_node *node = rcu_dereference_protected(trie,
+ lockdep_is_held(lock));
+ struct allowedips_node *parent = NULL;
+ bool exact = false;
+
+ while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
+ parent = node;
+ if (parent->cidr == cidr) {
+ exact = true;
+ break;
+ }
+ node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
+ lockdep_is_held(lock));
+ }
+ *rnode = parent;
+ return exact;
+}
+
+static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ struct allowedips_node *node, *parent, *down, *newnode;
+
+ if (unlikely(cidr > bits || !peer))
+ return -EINVAL;
+
+ if (!rcu_access_pointer(*trie)) {
+ node = kzalloc(sizeof(*node), GFP_KERNEL);
+ if (unlikely(!node))
+ return -ENOMEM;
+ RCU_INIT_POINTER(node->peer, peer);
+ list_add_tail(&node->peer_list, &peer->allowedips_list);
+ copy_and_assign_cidr(node, key, cidr, bits);
+ rcu_assign_pointer(*trie, node);
+ return 0;
+ }
+ if (node_placement(*trie, key, cidr, bits, &node, lock)) {
+ rcu_assign_pointer(node->peer, peer);
+ list_move_tail(&node->peer_list, &peer->allowedips_list);
+ return 0;
+ }
+
+ newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
+ if (unlikely(!newnode))
+ return -ENOMEM;
+ RCU_INIT_POINTER(newnode->peer, peer);
+ list_add_tail(&newnode->peer_list, &peer->allowedips_list);
+ copy_and_assign_cidr(newnode, key, cidr, bits);
+
+ if (!node) {
+ down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
+ } else {
+ down = rcu_dereference_protected(CHOOSE_NODE(node, key),
+ lockdep_is_held(lock));
+ if (!down) {
+ rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
+ return 0;
+ }
+ }
+ cidr = min(cidr, common_bits(down, key, bits));
+ parent = node;
+
+ if (newnode->cidr == cidr) {
+ rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
+ if (!parent)
+ rcu_assign_pointer(*trie, newnode);
+ else
+ rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
+ newnode);
+ } else {
+ node = kzalloc(sizeof(*node), GFP_KERNEL);
+ if (unlikely(!node)) {
+ kfree(newnode);
+ return -ENOMEM;
+ }
+ INIT_LIST_HEAD(&node->peer_list);
+ copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+
+ rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
+ rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
+ if (!parent)
+ rcu_assign_pointer(*trie, node);
+ else
+ rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
+ node);
+ }
+ return 0;
+}
+
+void wg_allowedips_init(struct allowedips *table)
+{
+ table->root4 = table->root6 = NULL;
+ table->seq = 1;
+}
+
+void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
+{
+ struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
+
+ ++table->seq;
+ RCU_INIT_POINTER(table->root4, NULL);
+ RCU_INIT_POINTER(table->root6, NULL);
+ if (rcu_access_pointer(old4)) {
+ struct allowedips_node *node = rcu_dereference_protected(old4,
+ lockdep_is_held(lock));
+
+ root_remove_peer_lists(node);
+ call_rcu(&node->rcu, root_free_rcu);
+ }
+ if (rcu_access_pointer(old6)) {
+ struct allowedips_node *node = rcu_dereference_protected(old6,
+ lockdep_is_held(lock));
+
+ root_remove_peer_lists(node);
+ call_rcu(&node->rcu, root_free_rcu);
+ }
+}
+
+int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls */
+ u8 key[4] __aligned(__alignof(u32));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 32);
+ return add(&table->root4, 32, key, cidr, peer, lock);
+}
+
+int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock)
+{
+ /* Aligned so it can be passed to fls64 */
+ u8 key[16] __aligned(__alignof(u64));
+
+ ++table->seq;
+ swap_endian(key, (const u8 *)ip, 128);
+ return add(&table->root6, 128, key, cidr, peer, lock);
+}
+
+void wg_allowedips_remove_by_peer(struct allowedips *table,
+ struct wg_peer *peer, struct mutex *lock)
+{
+ ++table->seq;
+ walk_remove_by_peer(&table->root4, peer, lock);
+ walk_remove_by_peer(&table->root6, peer, lock);
+}
+
+int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
+{
+ const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
+ swap_endian(ip, node->bits, node->bitlen);
+ memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes);
+ if (node->cidr)
+ ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
+
+ *cidr = node->cidr;
+ return node->bitlen == 32 ? AF_INET : AF_INET6;
+}
+
+/* Returns a strong reference to a peer */
+struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table,
+ struct sk_buff *skb)
+{
+ if (skb->protocol == htons(ETH_P_IP))
+ return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
+ else if (skb->protocol == htons(ETH_P_IPV6))
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr);
+ return NULL;
+}
+
+/* Returns a strong reference to a peer */
+struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
+ struct sk_buff *skb)
+{
+ if (skb->protocol == htons(ETH_P_IP))
+ return lookup(table->root4, 32, &ip_hdr(skb)->saddr);
+ else if (skb->protocol == htons(ETH_P_IPV6))
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr);
+ return NULL;
+}
+
+#include "selftest/allowedips.c"
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
new file mode 100644
index 000000000000..e5c83cafcef4
--- /dev/null
+++ b/drivers/net/wireguard/allowedips.h
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#ifndef _WG_ALLOWEDIPS_H
+#define _WG_ALLOWEDIPS_H
+
+#include <linux/mutex.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+
+struct wg_peer;
+
+struct allowedips_node {
+ struct wg_peer __rcu *peer;
+ struct allowedips_node __rcu *bit[2];
+ /* While it may seem scandalous that we waste space for v4,
+ * we're alloc'ing to the nearest power of 2 anyway, so this
+ * doesn't actually make a difference.
+ */
+ u8 bits[16] __aligned(__alignof(u64));
+ u8 cidr, bit_at_a, bit_at_b, bitlen;
+
+ /* Keep rarely used list at bottom to be beyond cache line. */
+ union {
+ struct list_head peer_list;
+ struct rcu_head rcu;
+ };
+};
+
+struct allowedips {
+ struct allowedips_node __rcu *root4;
+ struct allowedips_node __rcu *root6;
+ u64 seq;
+};
+
+void wg_allowedips_init(struct allowedips *table);
+void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
+int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
+int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wg_peer *peer, struct mutex *lock);
+void wg_allowedips_remove_by_peer(struct allowedips *table,
+ struct wg_peer *peer, struct mutex *lock);
+/* The ip input pointer should be __aligned(__alignof(u64))) */
+int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr);
+
+/* These return a strong reference to a peer: */
+struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table,
+ struct sk_buff *skb);
+struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
+ struct sk_buff *skb);
+
+#ifdef DEBUG
+bool wg_allowedips_selftest(void);
+#endif
+
+#endif /* _WG_ALLOWEDIPS_H */
diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c
new file mode 100644
index 000000000000..4956f0499c19
--- /dev/null
+++ b/drivers/net/wireguard/cookie.c
@@ -0,0 +1,236 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#include "cookie.h"
+#include "peer.h"
+#include "device.h"
+#include "messages.h"
+#include "ratelimiter.h"
+#include "timers.h"
+
+#include <crypto/blake2s.h>
+#include <crypto/chacha20poly1305.h>
+
+#include <net/ipv6.h>
+#include <crypto/algapi.h>
+
+void wg_cookie_checker_init(struct cookie_checker *checker,
+ struct wg_device *wg)
+{
+ init_rwsem(&checker->secret_lock);
+ checker->secret_birthdate = ktime_get_coarse_boottime_ns();
+ get_random_bytes(checker->secret, NOISE_HASH_LEN);
+ checker->device = wg;
+}
+
+enum { COOKIE_KEY_LABEL_LEN = 8 };
+static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----";
+static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--";
+
+static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ const u8 pubkey[NOISE_PUBLIC_KEY_LEN],
+ const u8 label[COOKIE_KEY_LABEL_LEN])
+{
+ struct blake2s_state blake;
+
+ blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN);
+ blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN);
+ blake2s_update(&blake, pubkey, NOISE_PUBLIC_KEY_LEN);
+ blake2s_final(&blake, key);
+}
+
+/* Must hold peer->handshake.static_identity->lock */
+void wg_cookie_checker_precompute_device_keys(struct cookie_checker *checker)
+{
+ if (likely(checker->device->static_identity.has_identity)) {
+ precompute_key(checker->cookie_encryption_key,
+ checker->device->static_identity.static_public,
+ cookie_key_label);
+ precompute_key(checker->message_mac1_key,
+ checker->device->static_identity.static_public,
+ mac1_key_label);
+ } else {
+ memset(checker->cookie_encryption_key, 0,
+ NOISE_SYMMETRIC_KEY_LEN);
+ memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN);
+ }
+}
+
+void wg_cookie_checker_precompute_peer_keys(struct wg_peer *peer)
+{
+ precompute_key(peer->latest_cookie.cookie_decryption_key,
+ peer->handshake.remote_static, cookie_key_label);
+ precompute_key(peer->latest_cookie.message_mac1_key,
+ peer->handshake.remote_static, mac1_key_label);
+}
+
+void wg_cookie_init(struct cookie *cookie)
+{
+ memset(cookie, 0, sizeof(*cookie));
+ init_rwsem(&cookie->lock);
+}
+
+static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len,
+ const u8 key[NOISE_SYMMETRIC_KEY_LEN])
+{
+ len = len - sizeof(struct message_macs) +
+ offsetof(struct message_macs, mac1);
+ blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN);
+}
+
+static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
+ const u8 cookie[COOKIE_LEN])
+{
+ len = len - sizeof(struct message_macs) +
+ offsetof(struct message_macs, mac2);
+ blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN);
+}
+
+static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
+ struct cookie_checker *checker)
+{
+ struct blake2s_state state;
+
+ if (wg_birthdate_has_expired(checker->secret_birthdate,
+ COOKIE_SECRET_MAX_AGE)) {
+ down_write(&checker->secret_lock);
+ checker->secret_birthdate = ktime_get_coarse_boottime_ns();
+ get_random_bytes(checker->secret, NOISE_HASH_LEN);
+ up_write(&checker->secret_lock);
+ }
+
+ down_read(&checker->secret_lock);
+
+ blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
+ if (skb->protocol == htons(ETH_P_IP))
+ blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr,
+ sizeof(struct in_addr));
+ else if (skb->protocol == htons(ETH_P_IPV6))
+ blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr,
+ sizeof(struct in6_addr));
+ blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
+ blake2s_final(&state, cookie);
+
+ up_read(&checker->secret_lock);
+}
+
+enum cookie_mac_state wg_cookie_validate_packet(struct cookie_checker *checker,
+ struct sk_buff *skb,
+ bool check_cookie)
+{
+ struct message_macs *macs = (struct message_macs *)
+ (skb->data + skb->len - sizeof(*macs));
+ enum cookie_mac_state ret;
+ u8 computed_mac[COOKIE_LEN];
+ u8 cookie[COOKIE_LEN];
+
+ ret = INVALID_MAC;
+ compute_mac1(computed_mac, skb->data, skb->len,
+ checker->message_mac1_key);
+ if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN))
+ goto out;
+
+ ret = VALID_MAC_BUT_NO_COOKIE;
+
+ if (!check_cookie)
+ goto out;
+
+ make_cookie(cookie, skb, checker);
+
+ compute_mac2(computed_mac, skb->data, skb->len, cookie);
+ if (crypto_memneq(computed_mac, macs->mac2, COOKIE_LEN))
+ goto out;
+
+ ret = VALID_MAC_WITH_COOKIE_BUT_RATELIMITED;
+ if (!wg_ratelimiter_allow(skb, dev_net(checker->device->dev)))
+ goto out;
+
+ ret = VALID_MAC_WITH_COOKIE;
+
+out:
+ return ret;
+}
+
+void wg_cookie_add_mac_to_packet(void *message, size_t len,
+ struct wg_peer *peer)
+{
+ struct message_macs *macs = (struct message_macs *)
+ ((u8 *)message + len - sizeof(*macs));
+
+ down_write(&peer->latest_cookie.lock);
+ compute_mac1(macs->mac1, message, len,
+ peer->latest_cookie.message_mac1_key);
+ memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN);
+ peer->latest_cookie.have_sent_mac1 = true;
+ up_write(&peer->latest_cookie.lock);
+
+ down_read(&peer->latest_cookie.lock);
+ if (peer->latest_cookie.is_valid &&
+ !wg_birthdate_has_expired(peer->latest_cookie.birthdate,
+ COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY))
+ compute_mac2(macs->mac2, message, len,
+ peer->latest_cookie.cookie);
+ else
+ memset(macs->mac2, 0, COOKIE_LEN);
+ up_read(&peer->latest_cookie.lock);
+}
+
+void wg_cookie_message_create(struct message_handshake_cookie *dst,
+ struct sk_buff *skb, __le32 index,
+ struct cookie_checker *checker)
+{
+ struct message_macs *macs = (struct message_macs *)
+ ((u8 *)skb->data + skb->len - sizeof(*macs));
+ u8 cookie[COOKIE_LEN];
+
+ dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE);
+ dst->receiver_index = index;
+ get_random_bytes_wait(dst->nonce, COOKIE_NONCE_LEN);
+
+ make_cookie(cookie, skb, checker);
+ xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN,
+ macs->mac1, COOKIE_LEN, dst->nonce,
+ checker->cookie_encryption_key);
+}
+
+void wg_cookie_message_consume(struct message_handshake_cookie *src,
+ struct wg_device *wg)
+{
+ struct wg_peer *peer = NULL;
+ u8 cookie[COOKIE_LEN];
+ bool ret;
+
+ if (unlikely(!wg_index_hashtable_lookup(wg->index_hashtable,
+ INDEX_HASHTABLE_HANDSHAKE |
+ INDEX_HASHTABLE_KEYPAIR,
+ src->receiver_index, &peer)))
+ return;
+
+ down_read(&peer->latest_cookie.lock);
+ if (unlikely(!peer->latest_cookie.have_sent_mac1)) {
+ up_read(&peer->latest_cookie.lock);
+ goto out;
+ }
+ ret = xchacha20poly1305_decrypt(
+ cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie),
+ peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce,
+ peer->latest_cookie.cookie_decryption_key);
+ up_read(&peer->latest_cookie.lock);
+
+ if (ret) {
+ down_write(&peer->latest_cookie.lock);
+ memcpy(peer->latest_cookie.cookie, cookie, COOKIE_LEN);
+ peer->latest_cookie.birthdate = ktime_get_coarse_boottime_ns();
+ peer->la