// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
#include <linux/skmsg.h>
#include <linux/skbuff.h>
#include <linux/scatterlist.h>
#include <net/sock.h>
#include <net/tcp.h>
#include <net/tls.h>
static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
{
if (msg->sg.end > msg->sg.start &&
elem_first_coalesce < msg->sg.end)
return true;
if (msg->sg.end < msg->sg.start &&
(elem_first_coalesce > msg->sg.start ||
elem_first_coalesce < msg->sg.end))
return true;
return false;
}
int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
int elem_first_coalesce)
{
struct page_frag *pfrag = sk_page_frag(sk);
int ret = 0;
len -= msg->sg.size;
while (len > 0) {
struct scatterlist *sge;
u32 orig_offset;
int use, i;
if (!sk_page_frag_refill(sk, pfrag))
return -ENOMEM;
orig_offset = pfrag->offset;
use = min_t(int, len, pfrag->size - orig_offset);
if (!sk_wmem_schedule(sk, use))
return -ENOMEM;
i = msg->sg.end;
sk_msg_iter_var_prev(i);
sge = &msg->sg.data[i];
if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
sg_page(sge) == pfrag->page &&
sge->offset + sge->length == orig_offset) {
sge->length += use;
} else {
if (sk_msg_full(msg)) {
ret = -ENOSPC;
break;
}
sge = &msg->sg.data[msg->sg.end];
sg_unmark_end(sge);
sg_set_page(sge, pfrag->page, use, orig_offset);
get_page(pfrag->page);
sk_msg_iter_next(msg, end);
}
sk_mem_charge(sk, use);
msg->sg.size += use;
pfrag->offset += use;
len -= use;
}
return ret;
}
EXPORT_SYMBOL_GPL(sk_msg_alloc);
int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
u32 off, u32 len)
{
int i = src->sg.start;
struct scatterlist *sge = sk_msg_elem(src, i);
struct scatterlist *sgd = NULL;
u32 sge_len, sge_off;
while (off) {
if (sge->length > off)
break;
off -= sge->length;
sk_msg_iter_var_next(i);
if (i == src->sg.end && off)
return -ENOSPC;
sge = sk_msg_elem(src, i);
}
while (len) {
sge_len = sge->length - off;
if (sge_len > len)
sge_len = len;
if (dst->sg.end)
sgd = sk_msg_elem(dst, dst->sg.end - 1);
if (sgd &&
(sg_page(sge) == sg_page(sgd)) &&
(sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
sgd->length += sge_len;
dst->sg.size += sge_len;
} else if (!sk_msg_full(dst)) {
sge_off = sge->offset + off;
sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
} else {
return -ENOSPC;
}
off = 0;
len -= sge_len;
sk_mem_charge(sk, sge_len);
sk_msg_iter_var_next(i);
if (i == src->sg.end && len)
return -ENOSPC;
sge = sk_msg_elem(src, i);
}
return 0;
}
EXPORT_SYMBOL_GPL(sk_msg_clone);
void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
{
int i = msg->sg.start;
do {
struct scatterlist *sge = sk_msg_elem(msg, i);
if (bytes < sge->length) {
sge->length -= bytes;
sge->offset += bytes;
sk_mem_uncharge(sk, bytes);
break;
}
sk_mem_uncharge(sk, sge->length);
bytes -= sge->length;
sge->length = 0;
sge->offset = 0;
sk_msg_iter_var_next(i);
} while (bytes && i != msg->sg.end);
msg->sg.start = i;
}
EXPORT_SYMBOL_GPL(sk_msg_return_zero);
void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
{
int i = msg->sg.start;
do {
struct scatterlist *sge = &msg->sg.data[i];
int uncharge = (bytes < sge->length) ? bytes : sge->length;
sk_mem_uncharge(sk, uncharge);
bytes -= uncharge;
sk_msg_iter_var_next(i);
} while (i != msg->sg.end);
}
EXPORT_SYMBOL_GPL(sk_msg_return);
static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
bool charge)
{
struct scatterlist *sge = sk_msg_elem(msg, i);
u32 len = sge->length;
if (charge)
sk_mem_uncharge(sk, len);
if (!msg->skb)
put_page(sg_page(sge));