summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEnzo Matsumiya <ematsumiya@suse.de>2024-05-10 19:45:36 -0300
committerEnzo Matsumiya <ematsumiya@suse.de>2024-05-12 17:51:52 -0600
commit9cfb5aaa34094979de64ca490e27b76d113bd81e (patch)
treec1b52ec1428a3a0790003f4f03b46ed7c431cee2
parent40afed03b810c78e4ca2d08bdfc645088b782709 (diff)
downloadlinux-9cfb5aaa34094979de64ca490e27b76d113bd81e.tar.gz
linux-9cfb5aaa34094979de64ca490e27b76d113bd81e.tar.bz2
linux-9cfb5aaa34094979de64ca490e27b76d113bd81e.zip
smb: client: implement chained compression support
Introduce 'Pattern V1' algorithm (MS-SMB2) for pattern (repeated single byte) scanning/matching. Usage of this algorithm requires chained compression support to be negotiated with the server, which is also done by this commit. Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
-rw-r--r--fs/smb/client/cifsglob.h1
-rw-r--r--fs/smb/client/compress.c272
-rw-r--r--fs/smb/client/compress.h14
-rw-r--r--fs/smb/client/compress/lz77.h18
-rw-r--r--fs/smb/client/smb2pdu.c38
-rw-r--r--fs/smb/client/transport.c2
6 files changed, 297 insertions, 48 deletions
diff --git a/fs/smb/client/cifsglob.h b/fs/smb/client/cifsglob.h
index c366fae1f669..94607d750a56 100644
--- a/fs/smb/client/cifsglob.h
+++ b/fs/smb/client/cifsglob.h
@@ -772,6 +772,7 @@ struct TCP_Server_Info {
struct {
bool requested; /* "compress" mount option set*/
bool enabled; /* actually negotiated with server */
+ bool chained; /* chained support negotiated with server */
__le16 alg; /* preferred alg negotiated with server */
} compression;
__u16 signing_algorithm;
diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c
index de55ddd122a5..7fcb8c796aef 100644
--- a/fs/smb/client/compress.c
+++ b/fs/smb/client/compress.c
@@ -25,7 +25,162 @@
#include "compress/lz77.h"
#include "compress.h"
-int smb_compress(void *buf, const void *data, size_t *len)
+static void pattern_scan(const u8 *src, size_t src_len,
+ struct smb2_compression_pattern_v1 *fwd,
+ struct smb2_compression_pattern_v1 *bwd)
+{
+ const u8 *srcp = src, *end = src + src_len - 1;
+ u32 reps, plen = sizeof(*fwd);
+
+ memset(fwd, 0, plen);
+ memset(bwd, 0, plen);
+
+ /* Forward pattern scan. */
+ while (++srcp <= end && *srcp == *src) { }
+
+ reps = srcp - src;
+ if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) {
+ fwd->Pattern = src[0];
+ fwd->Repetitions = cpu_to_le32(reps);
+
+ src += reps;
+ }
+
+ if (reps == src_len)
+ return;
+
+ /* Backward pattern scan. */
+ srcp = end;
+ while (--srcp >= src && *srcp == *end) { }
+
+ reps = end - srcp;
+ if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) {
+ bwd->Pattern = *end;
+ bwd->Repetitions = cpu_to_le32(reps);
+
+ return;
+ }
+
+ return;
+}
+
+static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained)
+{
+ struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 };
+ struct smb2_compression_payload_hdr *phdr;
+ const u8 *srcp = src;
+ u8 *dstp = dst;
+ size_t plen = 0;
+ int ret = 0;
+
+ if (!chained)
+ goto unchained;
+
+ pattern_scan(srcp, src_len, &fwd, &bwd);
+
+ if (le32_to_cpu(fwd.Repetitions) > 0) {
+ plen = sizeof(fwd);
+ if (plen > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(plen);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, &fwd, plen);
+ dstp += plen;
+ *dst_len -= plen;
+
+ srcp += le32_to_cpu(fwd.Repetitions);
+ src_len -= le32_to_cpu(fwd.Repetitions);
+ if (src_len == 0)
+ goto out;
+ }
+
+ if (le32_to_cpu(bwd.Repetitions) > 0) {
+ src_len -= le32_to_cpu(bwd.Repetitions);
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+
+ if (src_len == 0)
+ goto out_bwd_pattern;
+ }
+
+ /* Leftover uncompressed data with size not worth compressing. */
+ if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) {
+ if (src_len > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(src_len);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, srcp, src_len);
+ dstp += src_len;
+ *dst_len -= src_len;
+
+ goto out_bwd_pattern;
+ }
+unchained:
+ plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize);
+ if (plen < *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen);
+ if (ret) {
+ *dst_len = 0;
+ return ret;
+ }
+
+ if (chained) {
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->OriginalPayloadSize = cpu_to_le32(src_len);
+ plen += sizeof(phdr->OriginalPayloadSize);
+ phdr->Length = cpu_to_le32(plen);
+ plen += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ }
+
+ dstp += plen;
+ *dst_len -= plen;
+out_bwd_pattern:
+ if (bwd.Repetitions >= 64) {
+ plen = sizeof(bwd);
+ if (plen > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(plen);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, &bwd, plen);
+ dstp += plen;
+ *dst_len -= plen;
+ }
+out:
+ *dst_len = dstp - (u8 *)dst;
+
+ return 0;
+}
+
+int smb_compress(void *buf, const void *data, size_t *len, bool chained)
{
struct smb2_compression_hdr *hdr;
size_t buf_len, data_len;
@@ -35,15 +190,21 @@ int smb_compress(void *buf, const void *data, size_t *len)
data_len = *len;
*len = 0;
+ if (data_len < SMB_COMPRESS_MIN_LEN)
+ return -ENODATA;
+
hdr = buf;
hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len);
hdr->Offset = cpu_to_le32(buf_len);
- hdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ hdr->Flags = chained;
hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77;
+ if (chained) {
+ hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE;
+ hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len);
+ }
- /* XXX: add other algs here as they're implemented */
- ret = lz77_compress(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len);
+ ret = compress_data(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained);
if (!ret)
*len = SMB_COMPRESS_HDR_LEN + buf_len + data_len;
@@ -53,47 +214,92 @@ int smb_compress(void *buf, const void *data, size_t *len)
int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len)
{
const struct smb2_compression_hdr *hdr;
- size_t buf_len, data_len;
+ size_t slen, dlen;
+ const void *srcp = src;
+ void *dstp = dst;
+ bool chained;
int ret;
- hdr = src;
- if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)
+ hdr = srcp;
+ chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED);
+ slen = le32_to_cpu(hdr->Offset);
+
+ if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)
return -EIO;
- buf_len = le32_to_cpu(hdr->Offset);
- data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize);
-
- /*
- * Copy uncompressed data from the beginning of the payload.
- * The remainder is all compressed data.
- */
- src += SMB_COMPRESS_HDR_LEN;
- memcpy(dst, src, buf_len);
- src += buf_len;
- src_len -= SMB_COMPRESS_HDR_LEN + buf_len;
- *dst_len -= buf_len;
-
- ret = lz77_decompress(src, src_len, dst + buf_len, dst_len);
- if (ret)
- return ret;
+ /* Copy the uncompressed SMB2 READ header. */
+ srcp += SMB_COMPRESS_HDR_LEN;
+ src_len -= SMB_COMPRESS_HDR_LEN;
+ memcpy(dstp, srcp, slen);
+ srcp += slen;
+ src_len -= slen;
+ dstp += slen;
+
+ if (!chained) {
+ slen = src_len;
+ goto unchained;
+ }
+
+ while (src_len > 0) {
+ const struct smb2_compression_payload_hdr *phdr = srcp;
+ __le16 alg = phdr->CompressionAlgorithm;
+
+ srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ slen = le32_to_cpu(phdr->Length);
+ dlen = slen;
+unchained:
+ if (!smb_compress_alg_valid(alg, chained))
+ return -EIO;
+
+ /* XXX: add other algs here as they're implemented */
+ if (alg == SMB3_COMPRESS_LZ77) {
+ if (chained) {
+ /* sizeof(OriginalPayloadSize) */
+ srcp += 4;
+ slen -= 4;
+ dlen = le32_to_cpu(phdr->OriginalPayloadSize);
+ }
+
+ ret = lz77_decompress(srcp, slen, dstp, &dlen);
+ if (ret)
+ return ret;
- if (*dst_len != data_len) {
- cifs_dbg(VFS, "decompressed size mismatch: got %zu, expected %zu\n",
- *dst_len, data_len);
+ if (chained)
+ src_len -= 4;
+ } else if (alg == SMB3_COMPRESS_PATTERN) {
+ struct smb2_compression_pattern_v1 *pattern;
+
+ pattern = (struct smb2_compression_pattern_v1 *)srcp;
+ dlen = le32_to_cpu(pattern->Repetitions);
+
+ if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) {
+ cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen);
+ return -ECONNRESET;
+ }
+
+ memset(dstp, pattern->Pattern, dlen);
+ } else if (alg == SMB3_COMPRESS_NONE) {
+ memcpy(dstp, srcp, dlen);
+ }
+
+ srcp += slen;
+ src_len -= slen;
+ dstp += dlen;
+ }
+
+ *dst_len = dstp - dst;
+ if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) {
+ cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len,
+ le32_to_cpu(hdr->OriginalCompressedSegmentSize));
return -ECONNRESET;
}
if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) {
- cifs_dbg(VFS, "decompressed buffer is not an SMB2 message: ProtocolId 0x%x\n",
+ cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n",
*(__le32 *)dst);
return -ECONNRESET;
}
- /*
- * @dst_len contains only the decompressed data size, add back
- * the previously copied uncompressed size
- */
- *dst_len += buf_len;
-
return 0;
}
diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h
index ffc712fae45a..c5691c1db066 100644
--- a/fs/smb/client/compress.h
+++ b/fs/smb/client/compress.h
@@ -25,6 +25,9 @@
/* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */
#define SMB_COMPRESS_PAYLOAD_HDR_LEN 8
#define SMB_COMPRESS_MIN_LEN PAGE_SIZE
+/* follows Windows implementation (as per MS-SMB2) */
+#define SMB_COMPRESS_PAYLOAD_MIN_LEN 1024
+#define SMB_COMPRESS_PATTERN_MIN_LEN 64
#define SMB_DECOMPRESS_MAX_LEN(_srv) \
(256 + SMB_COMPRESS_HDR_LEN + \
max_t(size_t, (_srv)->maxBuf, max_t(size_t, (_srv)->max_read, (_srv)->max_write)))
@@ -40,7 +43,7 @@ struct smb_compress_ctx {
};
#ifdef CONFIG_CIFS_COMPRESSION
-int smb_compress(void *buf, const void *data, size_t *len);
+int smb_compress(void *buf, const void *data, size_t *len, bool chained);
int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len);
/**
@@ -113,12 +116,15 @@ static __always_inline bool should_compress(const struct cifs_tcon *tcon, const
static __always_inline size_t decompressed_size(const void *buf)
{
const struct smb2_compression_hdr *hdr = buf;
+ size_t size = le32_to_cpu(hdr->OriginalCompressedSegmentSize);
- return le32_to_cpu(hdr->Offset) +
- le32_to_cpu(hdr->OriginalCompressedSegmentSize);
+ if (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED)
+ size += le32_to_cpu(hdr->Offset);
+
+ return size;
}
#else /* CONFIG_CIFS_COMPRESSION */
-#define smb_compress(arg1, arg2, arg3) (-EOPNOTSUPP)
+#define smb_compress(arg1, arg2, arg3, arg4) (-EOPNOTSUPP)
#define smb_decompress(arg1, arg2, arg3, arg4) (-EOPNOTSUPP)
#define smb_compress_alg_valid(arg1, arg2) (-EOPNOTSUPP)
#define should_compress(arg1, arg2) (false)
diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h
index e26733a57954..d87701ce491f 100644
--- a/fs/smb/client/compress/lz77.h
+++ b/fs/smb/client/compress/lz77.h
@@ -13,6 +13,7 @@
#include <asm/ptrace.h>
#include <linux/kernel.h>
#include <linux/string.h>
+#include <linux/const.h>
#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
#include <asm-generic/unaligned.h>
#endif
@@ -66,6 +67,23 @@ static __always_inline u32 lz77_log2(unsigned int x)
return x ? ((u32)(31 - __builtin_clz(x))) : 0;
}
+/*
+ * Computes minimum memory required for LZ77 compression based on input length
+ * (@src_len).
+ *
+ * This minimum accounts for the worst case scenario, where no matches are found
+ * in the input buffer.
+ *
+ * For every literal (byte) written to the output buffer, a u32 flag is written
+ * as well.
+ *
+ * Avoid a couple extra instructions by rounding up adding an extra sizeof(u32).
+ */
+static __always_inline size_t lz77_compress_min_mem(size_t src_len)
+{
+ return src_len + ((src_len >> 3) + sizeof(u32));
+}
+
#ifdef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS
static __always_inline u8 lz77_read8(const void *ptr)
{
diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c
index 23801b028530..afd823c9cbca 100644
--- a/fs/smb/client/smb2pdu.c
+++ b/fs/smb/client/smb2pdu.c
@@ -589,12 +589,17 @@ build_compression_ctxt(struct smb2_compression_capabilities_context *pneg_ctxt)
pneg_ctxt->DataLength =
cpu_to_le16(sizeof(struct smb2_compression_capabilities_context)
- sizeof(struct smb2_neg_context));
- pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(1);
+
+ /* This enables the usage of the Pattern_V1 algorithm */
+ pneg_ctxt->Flags = SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED;
+ pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(2);
+
/*
- * Send the only algorithm we support (XXX: add others as they're
+ * Send the algorithms we support (XXX: add others as they're
* implemented).
*/
pneg_ctxt->CompressionAlgorithms[0] = SMB3_COMPRESS_LZ77;
+ pneg_ctxt->CompressionAlgorithms[1] = SMB3_COMPRESS_PATTERN;
}
static unsigned int
@@ -782,9 +787,13 @@ static void decode_compress_ctx(struct TCP_Server_Info *server,
struct smb2_compression_capabilities_context *ctxt)
{
unsigned int len = le16_to_cpu(ctxt->DataLength);
- __le16 alg;
+ __le16 count, i;
+ int chained;
server->compression.enabled = false;
+ server->compression.chained = false;
+ count = ctxt->CompressionAlgorithmCount;
+ chained = !!(ctxt->Flags & SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED);
/*
* Caller checked that DataLength remains within SMB boundary. We still
@@ -796,18 +805,27 @@ static void decode_compress_ctx(struct TCP_Server_Info *server,
return;
}
- if (le16_to_cpu(ctxt->CompressionAlgorithmCount) != 1) {
- pr_warn_once("invalid SMB3 compress algorithm count\n");
+ if (unlikely(count != 2)) {
+ pr_warn_once("invalid SMB3 compress algorithm count '%u'\n", count);
return;
}
- alg = ctxt->CompressionAlgorithms[0];
- if (!smb_compress_alg_valid(alg, false)) {
- pr_warn_once("invalid compression algorithm '%u'\n", alg);
- return;
+ for (i = 0; i < count; i++) {
+ __le16 alg = ctxt->CompressionAlgorithms[i];
+
+ if (!smb_compress_alg_valid(alg, false)) {
+ pr_warn_once("invalid compression algorithm '%u'\n", alg);
+ return;
+ }
+
+ if (alg == SMB3_COMPRESS_PATTERN) {
+ if (chained)
+ server->compression.chained = true;
+ } else {
+ server->compression.alg = alg;
+ }
}
- server->compression.alg = alg;
server->compression.enabled = true;
}
diff --git a/fs/smb/client/transport.c b/fs/smb/client/transport.c
index 92cae9079880..0503d5ae2aae 100644
--- a/fs/smb/client/transport.c
+++ b/fs/smb/client/transport.c
@@ -451,7 +451,7 @@ static void compress_thread(struct work_struct *work)
size_t len = ctx->len;
int ret;
- ret = smb_compress(ctx->buf, ctx->data, &len);
+ ret = smb_compress(ctx->buf, ctx->data, &len, server->compression.chained);
if (!ret) {
iov[0].iov_base = ctx->buf;
iov[0].iov_len = len;