From 2e8afe16eb831c227693c37a7da59215e5686134 Mon Sep 17 00:00:00 2001
From: Enzo Matsumiya <ematsumiya@suse.de>
Date: Wed, 14 Aug 2024 13:40:58 -0300
Subject: smb: client: lz77 fast works

Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
---
 fs/smb/client/compress/lz77.c | 121 ++++++++++++++++++++----------------------
 1 file changed, 59 insertions(+), 62 deletions(-)

diff --git a/fs/smb/client/compress/lz77.c b/fs/smb/client/compress/lz77.c
index 2cbfcf3ed9bf..7d672200d55f 100644
--- a/fs/smb/client/compress/lz77.c
+++ b/fs/smb/client/compress/lz77.c
@@ -40,12 +40,6 @@
 #define lz77_read(nbits, _ptr)		get_unaligned((const u ## nbits *)(_ptr))
 #define lz77_write(nbits, _ptr, _v)	put_unaligned_le ## nbits((_v), (typeof((_v)) *)(_ptr))
 
-struct lz77_flag {
-	u8 *pos;
-	u32 count;
-	long val;
-};
-
 static __always_inline u16 lz77_read16(const void *ptr)
 {
 	return lz77_read(16, ptr);
@@ -94,19 +88,9 @@ static __always_inline u32 lz77_count_common_bytes(const u64 diff)
 static __always_inline u32 lz77_match_len(const u8 *match, const u8 *cur, const u8 *end)
 {
 	const u8 *start = cur;
-	u32 step = sizeof(u32);
+	u32 step = sizeof(u64);
 	u64 diff;
 
-	if (cur > end - step)
-		return 0;
-
-	if (lz77_read32(cur) ^ lz77_read32(match))
-		return 0;
-
-	cur += step;
-	match += step;
-	step = sizeof(u64);
-
 	while (likely(cur < end - (step - 1))) {
 		diff = lz77_read64(cur) ^ lz77_read64(match);
 		if (!diff) {
@@ -214,39 +198,15 @@ static __always_inline void lz77_copy(u8 *dst, const u8 *src, size_t count)
 		memcpy(dst, src, count);
 }
 
-static u8 *lz77_write_literals(u8 *dst, const u8 *src, u32 count, struct lz77_flag *flags)
-{
-	const u8 *end = src + count;
-
-	while (likely(src < end)) {
-		u32 c = umin(count, 32 - flags->count);
-
-		lz77_copy(dst, src, c);
-
-		dst += c;
-		src += c;
-		count -= c;
-
-		flags->val <<= c;
-		flags->count += c;
-		if (flags->count == 32) {
-			lz77_write32(flags->pos, flags->val);
-			flags->count = 0;
-			flags->pos = dst;
-			dst += 4;
-		}
-	}
-
-	return dst;
-}
-
 noinline int lz77_compress(const u8 *src, u32 slen, u8 *dst, u32 *dlen)
 {
 	const u8 *srcp, *end, *anchor;
-	struct lz77_flag flags = { 0 };
-	u8 *dstp, *nib;
+	u8 *dstp, *nib, *flag_pos;
+	u32 flag_count = 0;
 	u64 *htable;
+	long flag = 0;
 	int ret;
+	unsigned long s, e;
 
 	srcp = src;
 	anchor = srcp;
@@ -254,15 +214,15 @@ noinline int lz77_compress(const u8 *src, u32 slen, u8 *dst, u32 *dlen)
 
 	dstp = dst;
 	nib = NULL;
-
-	/* Output buffer start with a 4 byte flags. */
-	flags.pos = dstp;
+	flag_pos = dstp;
 	dstp += 4;
 
 	htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
 	if (!htable)
 		return -ENOMEM;
 
+	s = jiffies;
+
 	/* Main loop. */
 	while (likely(srcp < end)) {
 		u32 offset, dist, len;
@@ -271,7 +231,6 @@ noinline int lz77_compress(const u8 *src, u32 slen, u8 *dst, u32 *dlen)
 		while (likely(srcp + LZ77_MATCH_MIN_LEN < end)) {
 			offset = srcp - src;
 			hash = lz77_hash_bytes(srcp);
-
 			dist = offset - htable[hash];
 			if (dist >= LZ77_MATCH_MIN_DIST && dist < LZ77_MATCH_MAX_DIST)
 				len = lz77_match_len(src + htable[hash], srcp, end);
@@ -286,7 +245,23 @@ noinline int lz77_compress(const u8 *src, u32 slen, u8 *dst, u32 *dlen)
 			srcp++;
 		}
 
-		dstp = lz77_write_literals(dstp, anchor, srcp - anchor, &flags);
+		while (likely(anchor < srcp)) {
+			u32 c = umin(srcp - anchor, 32 - flag_count);
+
+			lz77_copy(dstp, anchor, c);
+
+			dstp += c;
+			anchor += c;
+
+			flag <<= c;
+			flag_count += c;
+			if (flag_count == 32) {
+				lz77_write32(flag_pos, flag);
+				flag_count = 0;
+				flag_pos = dstp;
+				dstp += 4;
+			}
+		}
 
 		if (unlikely(srcp + LZ77_MATCH_MIN_LEN >= end))
 			goto leftovers;
@@ -295,30 +270,52 @@ noinline int lz77_compress(const u8 *src, u32 slen, u8 *dst, u32 *dlen)
 		srcp += len;
 		anchor = srcp;
 
-		flags.val = (flags.val << 1) | 1;
-		flags.count++;
-		if (flags.count == 32) {
-			lz77_write32(flags.pos, flags.val);
-			flags.count = 0;
-			flags.pos = dstp;
+		flag = (flag << 1) | 1;
+		flag_count++;
+		if (flag_count == 32) {
+			lz77_write32(flag_pos, flag);
+			flag_count = 0;
+			flag_pos = dstp;
 			dstp += 4;
 		}
 	}
 leftovers:
-	if (srcp < end)
-		dstp = lz77_write_literals(dstp, srcp, end - srcp, &flags);
+	if (srcp < end) {
+		while (likely(srcp < end)) {
+			u32 c = umin(end - srcp, 32 - flag_count);
+
+			lz77_copy(dstp, srcp, c);
+
+			dstp += c;
+			srcp += c;
+
+			flag <<= c;
+			flag_count += c;
+			if (flag_count == 32) {
+				lz77_write32(flag_pos, flag);
+				flag_count = 0;
+				flag_pos = dstp;
+				dstp += 4;
+			}
+		}
+	}
 
-	flags.val <<= (32 - flags.count);
-	flags.val |= (1 << (32 - flags.count)) - 1;
-	lz77_write32(flags.pos, flags.val);
+	flag <<= (32 - flag_count);
+	flag |= (1 << (32 - flag_count)) - 1;
+	lz77_write32(flag_pos, flag);
 
 	*dlen = dstp - (u8 *)dst;
+
+	e = jiffies;
+	pr_err("%s: (fast) took %ums to compress (slen=%u, dlen=%u)\n", __func__, jiffies_to_msecs(e - s), slen, *dlen);
+
 	if (*dlen < slen)
 		ret = 0;
 	else
 		ret = -EMSGSIZE;
-//out:
+
 	kvfree(htable);
 
+
 	return ret;
 }
-- 
cgit v1.2.3