/*
* Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
* Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
*
* This software is available to you under a choice of one of two
* licenses. You may choose to be licensed under the terms of the GNU
* General Public License (GPL) Version 2, available from the file
* COPYING in the main directory of this source tree, or the
* OpenIB.org BSD license below:
*
* Redistribution and use in source and binary forms, with or
* without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above
* copyright notice, this list of conditions and the following
* disclaimer.
*
* - Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials
* provided with the distribution.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
* BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
* ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include <linux/module.h>
#include <net/tcp.h>
#include <net/inet_common.h>
#include <linux/highmem.h>
#include <linux/netdevice.h>
#include <linux/sched/signal.h>
#include <linux/inetdevice.h>
#include <linux/inet_diag.h>
#include <net/snmp.h>
#include <net/tls.h>
#include <net/tls_toe.h>
#include "tls.h"
MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL");
MODULE_ALIAS_TCP_ULP("tls");
enum {
TLSV4,
TLSV6,
TLS_NUM_PROTS,
};
#define CHECK_CIPHER_DESC(cipher,ci) \
static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \
static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \
static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \
static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \
static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \
static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \
static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \
static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE);
#define __CIPHER_DESC(ci) \
.iv_offset = offsetof(struct ci, iv), \
.key_offset = offsetof(struct ci, key), \
.salt_offset = offsetof(struct ci, salt), \
.rec_seq_offset = offsetof(struct ci, rec_seq), \
.crypto_info = sizeof(struct ci)
#define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \
.nonce = cipher ## _IV_SIZE, \
.iv = cipher ## _IV_SIZE, \
.key = cipher ## _KEY_SIZE, \
.salt = cipher ## _SALT_SIZE, \
.tag = cipher ## _TAG_SIZE, \
.rec_seq = cipher ## _REC_SEQ_SIZE, \
.cipher_name = algname, \
.offloadable = _offloadable, \
__CIPHER_DESC(ci), \
}
#define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \
.nonce = 0, \
.iv = cipher ## _IV_SIZE, \
.key = cipher ## _KEY_SIZE, \
.salt = cipher ## _SALT_SIZE, \
.tag = cipher ## _TAG_SIZE, \
.rec_seq = cipher ## _REC_SEQ_SIZE, \
.cipher_name = algname, \
.offloadable = _offloadable, \
__CIPHER_DESC(ci), \
}
const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = {
CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true),
CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true),
CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false),
CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false),
CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false),
CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false),
CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false),
CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false),
};
CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128);
CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256);
CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128);
CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305);
CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm);
CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm);
CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128);
CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256);
static const struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot;
static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
const struct proto *base);
void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
WRITE_ONCE(sk->sk_prot,
&tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
WRITE_ONCE(sk->sk_socket->ops,
&