diff options
| author | Enzo Matsumiya <ematsumiya@suse.de> | 2024-05-13 07:58:24 -0600 |
|---|---|---|
| committer | Enzo Matsumiya <ematsumiya@suse.de> | 2024-05-13 08:06:44 -0600 |
| commit | cc8a11c7c11d92fd432ebd1be267ce116ef8f5f8 (patch) | |
| tree | d5052ac0f25ec5a664bced76b432ef848b403ddc | |
| parent | d1b24f6424993f2c3672b7ea602ff27e8325324a (diff) | |
| download | linux-cc8a11c7c11d92fd432ebd1be267ce116ef8f5f8.tar.gz linux-cc8a11c7c11d92fd432ebd1be267ce116ef8f5f8.tar.bz2 linux-cc8a11c7c11d92fd432ebd1be267ce116ef8f5f8.zip | |
smb: client: implement chained compression support
Introduce 'Pattern V1' algorithm (MS-SMB2) for pattern (repeated
single byte) scanning/matching.
Usage of this algorithm requires chained compression support to be
negotiated with the server, which is also done by this commit.
Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
| -rw-r--r-- | fs/smb/client/cifsglob.h | 1 | ||||
| -rw-r--r-- | fs/smb/client/compress.c | 267 | ||||
| -rw-r--r-- | fs/smb/client/compress.h | 12 | ||||
| -rw-r--r-- | fs/smb/client/compress/lz77.h | 18 | ||||
| -rw-r--r-- | fs/smb/client/smb2pdu.c | 38 | ||||
| -rw-r--r-- | fs/smb/client/transport.c | 2 |
6 files changed, 292 insertions, 46 deletions
diff --git a/fs/smb/client/cifsglob.h b/fs/smb/client/cifsglob.h index c366fae1f669..94607d750a56 100644 --- a/fs/smb/client/cifsglob.h +++ b/fs/smb/client/cifsglob.h @@ -772,6 +772,7 @@ struct TCP_Server_Info { struct { bool requested; /* "compress" mount option set*/ bool enabled; /* actually negotiated with server */ + bool chained; /* chained support negotiated with server */ __le16 alg; /* preferred alg negotiated with server */ } compression; __u16 signing_algorithm; diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index 6ddd21a5e19d..af14508c2b3b 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -23,7 +23,162 @@ #include "compress/lz77.h" #include "compress.h" -int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) +static void pattern_scan(const u8 *src, size_t src_len, + struct smb2_compression_pattern_v1 *fwd, + struct smb2_compression_pattern_v1 *bwd) +{ + const u8 *srcp = src, *end = src + src_len - 1; + u32 reps, plen = sizeof(*fwd); + + memset(fwd, 0, plen); + memset(bwd, 0, plen); + + /* Forward pattern scan. */ + while (++srcp <= end && *srcp == *src) { } + + reps = srcp - src; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + fwd->Pattern = src[0]; + fwd->Repetitions = cpu_to_le32(reps); + + src += reps; + } + + if (reps == src_len) + return; + + /* Backward pattern scan. */ + srcp = end; + while (--srcp >= src && *srcp == *end) { } + + reps = end - srcp; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + bwd->Pattern = *end; + bwd->Repetitions = cpu_to_le32(reps); + + return; + } + + return; +} + +static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained) +{ + struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 }; + struct smb2_compression_payload_hdr *phdr; + const u8 *srcp = src; + u8 *dstp = dst; + size_t plen = 0; + int ret = 0; + + if (!chained) + goto unchained; + + pattern_scan(srcp, src_len, &fwd, &bwd); + + if (le32_to_cpu(fwd.Repetitions) > 0) { + plen = sizeof(fwd); + if (plen > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, &fwd, plen); + dstp += plen; + *dst_len -= plen; + + srcp += le32_to_cpu(fwd.Repetitions); + src_len -= le32_to_cpu(fwd.Repetitions); + if (src_len == 0) + goto out; + } + + if (le32_to_cpu(bwd.Repetitions) > 0) { + src_len -= le32_to_cpu(bwd.Repetitions); + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + + if (src_len == 0) + goto out_bwd_pattern; + } + + /* Leftover uncompressed data with size not worth compressing. */ + if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) { + if (src_len > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(src_len); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, srcp, src_len); + dstp += src_len; + *dst_len -= src_len; + + goto out_bwd_pattern; + } +unchained: + plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize); + if (plen < *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen); + if (ret) { + *dst_len = 0; + return ret; + } + + if (chained) { + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->OriginalPayloadSize = cpu_to_le32(src_len); + plen += sizeof(phdr->OriginalPayloadSize); + phdr->Length = cpu_to_le32(plen); + plen += SMB_COMPRESS_PAYLOAD_HDR_LEN; + } + + dstp += plen; + *dst_len -= plen; +out_bwd_pattern: + if (bwd.Repetitions >= 64) { + plen = sizeof(bwd); + if (plen > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, &bwd, plen); + dstp += plen; + *dst_len -= plen; + } +out: + *dst_len = dstp - (u8 *)dst; + + return 0; +} + +int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq, bool chained) { struct smb2_compression_hdr *hdr; size_t buf_len, data_len; @@ -58,8 +213,12 @@ int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len); hdr->Offset = cpu_to_le32(buf_len); - hdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + hdr->Flags = chained; hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + if (chained) { + hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len); + } /* * Copy SMB2 header uncompressed to @dst. @@ -67,8 +226,7 @@ int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) */ memcpy(dst + SMB_COMPRESS_HDR_LEN, src_rq->rq_iov->iov_base, buf_len); - /* XXX: add other algs here as they're implemented */ - ret = lz77_compress(src, data_len, dst + SMB_COMPRESS_HDR_LEN + buf_len, &data_len); + ret = compress_data(src, data_len, dst + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained); err_free: kvfree(src); if (!ret) { @@ -84,47 +242,92 @@ err_free: int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len) { const struct smb2_compression_hdr *hdr; - size_t buf_len, data_len; + size_t slen, dlen; + const void *srcp = src; + void *dstp = dst; + bool chained; int ret; - hdr = src; - if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) + hdr = srcp; + chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED); + slen = le32_to_cpu(hdr->Offset); + + if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) return -EIO; - buf_len = le32_to_cpu(hdr->Offset); - data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); + /* Copy the uncompressed SMB2 READ header. */ + srcp += SMB_COMPRESS_HDR_LEN; + src_len -= SMB_COMPRESS_HDR_LEN; + memcpy(dstp, srcp, slen); + srcp += slen; + src_len -= slen; + dstp += slen; - /* - * Copy uncompressed data from the beginning of the payload. - * The remainder is all compressed data. - */ - src += SMB_COMPRESS_HDR_LEN; - memcpy(dst, src, buf_len); - src += buf_len; - src_len -= SMB_COMPRESS_HDR_LEN + buf_len; - *dst_len -= buf_len; - - ret = lz77_decompress(src, src_len, dst + buf_len, dst_len); - if (ret) - return ret; + if (!chained) { + slen = src_len; + goto unchained; + } + + while (src_len > 0) { + const struct smb2_compression_payload_hdr *phdr = srcp; + __le16 alg = phdr->CompressionAlgorithm; + + srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + slen = le32_to_cpu(phdr->Length); + dlen = slen; +unchained: + if (!smb_compress_alg_valid(alg, chained)) + return -EIO; + + /* XXX: add other algs here as they're implemented */ + if (alg == SMB3_COMPRESS_LZ77) { + if (chained) { + /* sizeof(OriginalPayloadSize) */ + srcp += 4; + slen -= 4; + dlen = le32_to_cpu(phdr->OriginalPayloadSize); + } + + ret = lz77_decompress(srcp, slen, dstp, &dlen); + if (ret) + return ret; + + if (chained) + src_len -= 4; + } else if (alg == SMB3_COMPRESS_PATTERN) { + struct smb2_compression_pattern_v1 *pattern; - if (*dst_len != data_len) { - cifs_dbg(VFS, "decompressed size mismatch: got %zu, expected %zu\n", - *dst_len, data_len); + pattern = (struct smb2_compression_pattern_v1 *)srcp; + dlen = le32_to_cpu(pattern->Repetitions); + + if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen); + return -ECONNRESET; + } + + memset(dstp, pattern->Pattern, dlen); + } else if (alg == SMB3_COMPRESS_NONE) { + memcpy(dstp, srcp, dlen); + } + + srcp += slen; + src_len -= slen; + dstp += dlen; + } + + *dst_len = dstp - dst; + if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len, + le32_to_cpu(hdr->OriginalCompressedSegmentSize)); return -ECONNRESET; } if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) { - cifs_dbg(VFS, "decompressed buffer is not an SMB2 message: ProtocolId 0x%x\n", + cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n", *(__le32 *)dst); return -ECONNRESET; } - /* - * @dst_len contains only the decompressed data size, add back - * the previously copied uncompressed size - */ - *dst_len += buf_len; - return 0; } diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h index 8369a1de0435..3f247befd87a 100644 --- a/fs/smb/client/compress.h +++ b/fs/smb/client/compress.h @@ -25,6 +25,9 @@ /* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_PAYLOAD_HDR_LEN 8 #define SMB_COMPRESS_MIN_LEN PAGE_SIZE +/* follows Windows implementation (as per MS-SMB2) */ +#define SMB_COMPRESS_PAYLOAD_MIN_LEN 1024 +#define SMB_COMPRESS_PATTERN_MIN_LEN 64 #define SMB_DECOMPRESS_MAX_LEN(_srv) \ (256 + SMB_COMPRESS_HDR_LEN + \ max_t(size_t, (_srv)->maxBuf, max_t(size_t, (_srv)->max_read, (_srv)->max_write))) @@ -39,7 +42,7 @@ struct smb_compress_ctx { }; #ifdef CONFIG_CIFS_COMPRESSION -int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq); +int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq, bool chained); int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len); /** @@ -112,9 +115,12 @@ static __always_inline bool should_compress(const struct cifs_tcon *tcon, const static __always_inline size_t decompressed_size(const void *buf) { const struct smb2_compression_hdr *hdr = buf; + size_t size = le32_to_cpu(hdr->OriginalCompressedSegmentSize); - return le32_to_cpu(hdr->Offset) + - le32_to_cpu(hdr->OriginalCompressedSegmentSize); + if (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED) + size += le32_to_cpu(hdr->Offset); + + return size; } #else /* CONFIG_CIFS_COMPRESSION */ #define smb_compress(arg1, arg2) (-EOPNOTSUPP) diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h index e26733a57954..d87701ce491f 100644 --- a/fs/smb/client/compress/lz77.h +++ b/fs/smb/client/compress/lz77.h @@ -13,6 +13,7 @@ #include <asm/ptrace.h> #include <linux/kernel.h> #include <linux/string.h> +#include <linux/const.h> #ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS #include <asm-generic/unaligned.h> #endif @@ -66,6 +67,23 @@ static __always_inline u32 lz77_log2(unsigned int x) return x ? ((u32)(31 - __builtin_clz(x))) : 0; } +/* + * Computes minimum memory required for LZ77 compression based on input length + * (@src_len). + * + * This minimum accounts for the worst case scenario, where no matches are found + * in the input buffer. + * + * For every literal (byte) written to the output buffer, a u32 flag is written + * as well. + * + * Avoid a couple extra instructions by rounding up adding an extra sizeof(u32). + */ +static __always_inline size_t lz77_compress_min_mem(size_t src_len) +{ + return src_len + ((src_len >> 3) + sizeof(u32)); +} + #ifdef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS static __always_inline u8 lz77_read8(const void *ptr) { diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c index 23801b028530..afd823c9cbca 100644 --- a/fs/smb/client/smb2pdu.c +++ b/fs/smb/client/smb2pdu.c @@ -589,12 +589,17 @@ build_compression_ctxt(struct smb2_compression_capabilities_context *pneg_ctxt) pneg_ctxt->DataLength = cpu_to_le16(sizeof(struct smb2_compression_capabilities_context) - sizeof(struct smb2_neg_context)); - pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(1); + + /* This enables the usage of the Pattern_V1 algorithm */ + pneg_ctxt->Flags = SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED; + pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(2); + /* - * Send the only algorithm we support (XXX: add others as they're + * Send the algorithms we support (XXX: add others as they're * implemented). */ pneg_ctxt->CompressionAlgorithms[0] = SMB3_COMPRESS_LZ77; + pneg_ctxt->CompressionAlgorithms[1] = SMB3_COMPRESS_PATTERN; } static unsigned int @@ -782,9 +787,13 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, struct smb2_compression_capabilities_context *ctxt) { unsigned int len = le16_to_cpu(ctxt->DataLength); - __le16 alg; + __le16 count, i; + int chained; server->compression.enabled = false; + server->compression.chained = false; + count = ctxt->CompressionAlgorithmCount; + chained = !!(ctxt->Flags & SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED); /* * Caller checked that DataLength remains within SMB boundary. We still @@ -796,18 +805,27 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, return; } - if (le16_to_cpu(ctxt->CompressionAlgorithmCount) != 1) { - pr_warn_once("invalid SMB3 compress algorithm count\n"); + if (unlikely(count != 2)) { + pr_warn_once("invalid SMB3 compress algorithm count '%u'\n", count); return; } - alg = ctxt->CompressionAlgorithms[0]; - if (!smb_compress_alg_valid(alg, false)) { - pr_warn_once("invalid compression algorithm '%u'\n", alg); - return; + for (i = 0; i < count; i++) { + __le16 alg = ctxt->CompressionAlgorithms[i]; + + if (!smb_compress_alg_valid(alg, false)) { + pr_warn_once("invalid compression algorithm '%u'\n", alg); + return; + } + + if (alg == SMB3_COMPRESS_PATTERN) { + if (chained) + server->compression.chained = true; + } else { + server->compression.alg = alg; + } } - server->compression.alg = alg; server->compression.enabled = true; } diff --git a/fs/smb/client/transport.c b/fs/smb/client/transport.c index cc8e091d3500..2da9a72d1d2e 100644 --- a/fs/smb/client/transport.c +++ b/fs/smb/client/transport.c @@ -427,7 +427,7 @@ static int compress_send_rqst(struct TCP_Server_Info *server, struct smb_rqst *r compr.rq_iov = &iov; compr.rq_nvec = 1; - ret = smb_compress(rqst, &compr); + ret = smb_compress(rqst, &compr, server->compression.chained); if (!ret) { to_send = &compr; } else { |
