// SPDX-License-Identifier: GPL-2.0
#include <linux/kernel.h>
#include <linux/errno.h>
#include <linux/file.h>
#include <linux/slab.h>
#include <linux/net.h>
#include <linux/compat.h>
#include <net/compat.h>
#include <linux/io_uring.h>
#include <uapi/linux/io_uring.h>
#include "io_uring.h"
#include "kbuf.h"
#include "alloc_cache.h"
#include "net.h"
#if defined(CONFIG_NET)
struct io_shutdown {
struct file *file;
int how;
};
struct io_accept {
struct file *file;
struct sockaddr __user *addr;
int __user *addr_len;
int flags;
u32 file_slot;
unsigned long nofile;
};
struct io_socket {
struct file *file;
int domain;
int type;
int protocol;
int flags;
u32 file_slot;
unsigned long nofile;
};
struct io_connect {
struct file *file;
struct sockaddr __user *addr;
int addr_len;
};
struct io_sr_msg {
struct file *file;
union {
struct compat_msghdr __user *umsg_compat;
struct user_msghdr __user *umsg;
void __user *buf;
};
int msg_flags;
size_t len;
size_t done_io;
unsigned int flags;
};
#define IO_APOLL_MULTI_POLLED (REQ_F_APOLL_MULTISHOT | REQ_F_POLLED)
int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
{
struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
sqe->buf_index || sqe->splice_fd_in))
return -EINVAL;
shutdown->how = READ_ONCE(sqe->len);
return 0;
}
int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
{
struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
struct socket *sock;
int ret;
if (issue_flags & IO_URING_F_NONBLOCK)
return -EAGAIN;
sock = sock_from_file(req->file);
if (unlikely(!sock))
return -ENOTSOCK;
ret = __sys_shutdown_sock(sock, shutdown->how);
io_req_set_res(req, ret, 0);
return IOU_OK;
}
static bool io_net_retry(struct socket *sock, int flags)
{
if (!(flags & MSG_WAITALL))
return false;
return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
}
static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
{
struct io_async_msghdr *hdr = req->async_data;
if (!hdr || issue_flags & IO_URING_F_UNLOCKED)
return;
/* Let normal cleanup path reap it if we fail adding to the cache */
if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
req->async_data = NULL;
req->flags &= ~REQ_F_ASYNC_DATA;
}
}
static struct io_async_msghdr *io_recvmsg_alloc_async(struct io_kiocb *req,
unsigned int issue_flags)
{
struct io_ring_ctx *ctx = req->ctx;
struct io_cache_entry *entry;
if (!(issue_flags & IO_URING_F_UNLOCKED) &&
(entry = io_alloc_cache_get(&ctx->netmsg_cache)) != NULL) {
struct io_async_msghdr *hdr;
hdr = container_of(entry, struct io_async_msghdr, cache);
req->flags |= REQ_F_ASYNC_DATA;
req->async_data = hdr;
return hdr;
}
if (!io_alloc_async_data(req))
return req->async_data;
return NULL;
}
static int io_setup_async_msg(struct io_kiocb *req,
struct io_async_msghdr *kmsg,
unsigned int issue_flags)
{
struct io_async_msghdr *async_msg = req->async_data;
if (async_msg)
return -EAGAIN;
async_msg = io_recvmsg_alloc_async(req, issue_flags);
if (!async_msg) {
kfree(kmsg->free_iov);
return -ENOMEM;
}
req->flags |= REQ_F_NEED_CLEANUP;
memcpy(async_msg, kmsg, sizeof(*kmsg));
async_msg->msg.msg_name = &async_msg->addr;
/* if were using fast_iov, set it to the new one */
if (!async_msg->free_iov)
async_msg->msg.msg_iter.iov = async_msg->fast_iov;
return -EAGAIN;
}
static int io_sendmsg_copy_hdr(struct io_kiocb *req,
struct io_async_msghdr *iomsg)
{
struct io_sr_msg *sr = io_kiocb_to_cmd(req);
iomsg->msg.msg_name = &iomsg->addr;
iomsg->free_iov = iomsg->fast_iov;
return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
&iomsg->free_iov);
}
int io_sendmsg_prep_async(struct io_kiocb *req)