// SPDX-License-Identifier: GPL-2.0
#include <linux/kernel.h>
#include <linux/errno.h>
#include <linux/dma-map-ops.h>
#include <linux/mm.h>
#include <linux/nospec.h>
#include <linux/io_uring.h>
#include <linux/netdevice.h>
#include <linux/rtnetlink.h>
#include <linux/skbuff_ref.h>
#include <linux/anon_inodes.h>
#include <net/page_pool/helpers.h>
#include <net/page_pool/memory_provider.h>
#include <net/netlink.h>
#include <net/netdev_queues.h>
#include <net/netdev_rx_queue.h>
#include <net/tcp.h>
#include <net/rps.h>
#include <trace/events/page_pool.h>
#include <uapi/linux/io_uring.h>
#include "io_uring.h"
#include "kbuf.h"
#include "memmap.h"
#include "zcrx.h"
#include "rsrc.h"
#define IO_ZCRX_AREA_SUPPORTED_FLAGS (IORING_ZCRX_AREA_DMABUF)
#define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
static inline struct io_zcrx_ifq *io_pp_to_ifq(struct page_pool *pp)
{
return pp->mp_priv;
}
static inline struct io_zcrx_area *io_zcrx_iov_to_area(const struct net_iov *niov)
{
struct net_iov_area *owner = net_iov_owner(niov);
return container_of(owner, struct io_zcrx_area, nia);
}
static inline struct page *io_zcrx_iov_page(const struct net_iov *niov)
{
struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
unsigned niov_pages_shift;
lockdep_assert(!area->mem.is_dmabuf);
niov_pages_shift = area->ifq->niov_shift - PAGE_SHIFT;
return area->mem.pages[net_iov_idx(niov) << niov_pages_shift];
}
static int io_area_max_shift(struct io_zcrx_mem *mem)
{
struct sg_table *sgt = mem->sgt;
struct scatterlist *sg;
unsigned shift = -1U;
unsigned i;
for_each_sgtable_dma_sg(sgt, sg, i)
shift = min(shift, __ffs(sg_dma_len(sg)));
return shift;
}
static int io_populate_area_dma(struct io_zcrx_ifq *ifq,
struct io_zcrx_area *area)
{
unsigned niov_size = 1U << ifq->niov_shift;
struct sg_table *sgt = area->mem.sgt;
struct scatterlist *sg;
unsigned i, niov_idx = 0;
for_each_sgtable_dma_sg(sgt, sg, i) {
dma_addr_t dma = sg_dma_address(sg);
unsigned long sg_len = sg_dma_len(sg);
if (WARN_ON_ONCE(sg_len % niov_size))
return -EINVAL;
while (sg_len && niov_idx < area->nia.num_niovs) {
struct net_iov *niov = &area->nia.niovs[niov_idx];
if (net_mp_niov_set_dma_addr(niov, dma))
return -EFAULT;
sg_len -= niov_size;
dma += niov_size;
niov_idx++;
}
}
if (WARN_ON_ONCE(niov_idx != area->nia.num_niovs))
return -EFAULT;
return 0;
}
static void io_release_dmabuf(struct io_zcrx_mem *mem)
{
if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
return;
if (mem->sgt)
dma_buf_unmap_attachment_unlocked(mem->attach, mem->sgt,
DMA_FROM_DEVICE);
if (mem->attach)
dma_buf_detach(mem->dmabuf, mem->attach);
if (mem->dmabuf)
dma_buf_put(mem->dmabuf);
mem->sgt = NULL;
mem->attach = NULL;
mem->dmabuf = NULL;
}
static int io_import_dmabuf(struct io_zcrx_ifq *ifq,
struct io_zcrx_mem *mem,
struct io_uring_zcrx_area_reg *area_reg)
{
unsigned long off = (unsigned long)area_reg->addr;
unsigned long len = (unsigned long)area_reg->len;
unsigned long total_size = 0;
struct scatterlist *sg;
int dmabuf_fd = area_reg->dmabuf_fd;
int i, ret;
if (!ifq->dev)
return -EINVAL;
if (off)
return -EINVAL;
if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
return -EINVAL;
mem->is_dmabuf = true;
mem->dmabuf = dma_buf_get(dmabuf_fd);
if (IS_ERR(mem->dmabuf)) {
ret = PTR_ERR(mem->dmabuf);
mem->dmabuf = NULL;
goto err;
}
mem->attach = dma_buf_attach(mem->dmabuf, ifq->dev);
if (IS_ERR(mem->attach)) {
ret = PTR_ERR(mem->attach);
mem->attach = NULL;
goto err;
}
mem->sgt = dma_buf_map_attachment_unlocked(mem->attach, DMA_FROM_DEVICE);
if (IS_ERR(mem->sgt)) {
ret = PTR_ERR(mem->sgt);
mem->sgt = NULL;
goto err;
}
for_each_sgtable_dma_sg(mem->sgt, sg, i)
total_size += sg_dma_len(sg);
if (total_size != len) {
ret = -EINVAL;
goto err;
}
mem->size = len;
return 0;
err:
io_release_dmabuf(mem);
return ret;
}
static unsigned long io_count_account_pages(struct page **pages, unsigned nr_pages)
{
struct folio *last_folio = NULL;
unsigned long res = 0;
int i;
for (i = 0; i < nr_pages; i++) {
struct folio *folio = page_folio(pages[i]);
if (folio == last_folio)
continue;
last_folio = folio;
res += folio_nr_pages(folio);
}
return res;
}
static int io_import_umem(struct io_zcrx_ifq *ifq,
struct io_zcrx_mem
|