diff options
| author | Enzo Matsumiya <ematsumiya@suse.de> | 2025-12-01 15:27:36 -0300 |
|---|---|---|
| committer | Enzo Matsumiya <ematsumiya@suse.de> | 2026-01-23 17:47:19 -0300 |
| commit | b1027bc4f590db7a7a42e226674cb36987700c62 (patch) | |
| tree | b05ac36d93d7827a76ad1f20e4079a545f1e7ab9 | |
| parent | 173840904122599acf1c5fe52f8f9e42f5fc3729 (diff) | |
| download | linux-b1027bc4f590db7a7a42e226674cb36987700c62.tar.gz linux-b1027bc4f590db7a7a42e226674cb36987700c62.tar.bz2 linux-b1027bc4f590db7a7a42e226674cb36987700c62.zip | |
smb: client: add chained compression and pattern scanning
Introduce chained de/compression support and implementation of
"Pattern V1" algorithm.
Pattern V1 algorithm scans the beginning and end of a buffer, looking
for repeating bytes.
In order to use it, MS-SMB2 requires chained compression to be
enabled/negotiated with the server.
This commit implements both, and negotiates chained compression by
default (falls back to unchained compression if server doesn't support
it).
Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
| -rw-r--r-- | fs/smb/client/cifs_debug.c | 21 | ||||
| -rw-r--r-- | fs/smb/client/cifsglob.h | 6 | ||||
| -rw-r--r-- | fs/smb/client/compress.c | 416 | ||||
| -rw-r--r-- | fs/smb/client/smb2pdu.c | 51 |
4 files changed, 417 insertions, 77 deletions
diff --git a/fs/smb/client/cifs_debug.c b/fs/smb/client/cifs_debug.c index 1fb71d2d31b5..1292f83ec23f 100644 --- a/fs/smb/client/cifs_debug.c +++ b/fs/smb/client/cifs_debug.c @@ -574,14 +574,25 @@ skip_rdma: } seq_puts(m, "\nCompression: "); - if (!IS_ENABLED(CONFIG_CIFS_COMPRESSION)) + if (!IS_ENABLED(CONFIG_CIFS_COMPRESSION)) { seq_puts(m, "no built-in support"); - else if (!server->compression.requested) + } else if (!server->compression.requested) { seq_puts(m, "disabled on mount"); - else if (server->compression.enabled) - seq_printf(m, "enabled (%s)", compression_alg_str(server->compression.alg)); - else + } else if (server->compression.enabled) { + seq_printf(m, "enabled, chained: %s, algs: ", + str_yes_no(server->compression.chained)); + + for (j = 0; j < SMB3_COMPRESS_MAX_ALGS; j++) { + __le16 alg = server->compression.algs[j]; + + if (alg == 0) + continue; + + seq_printf(m, "%s ", compression_alg_str(alg)); + } + } else { seq_puts(m, "disabled (not supported by this server)"); + } /* Show negotiated encryption cipher, even if not required */ seq_puts(m, "\nEncryption: "); diff --git a/fs/smb/client/cifsglob.h b/fs/smb/client/cifsglob.h index 203e2aaa3c25..cfd0532ef152 100644 --- a/fs/smb/client/cifsglob.h +++ b/fs/smb/client/cifsglob.h @@ -813,9 +813,11 @@ struct TCP_Server_Info { unsigned int rdma_readwrite_threshold; unsigned int retrans; struct { - bool requested; /* "compress" mount option set*/ + bool requested; /* "compress" mount option set */ bool enabled; /* actually negotiated with server */ - __le16 alg; /* preferred alg negotiated with server */ + bool chained; /* chained compression negotiated with server */ +#define SMB3_COMPRESS_MAX_ALGS 2 + __le16 algs[SMB3_COMPRESS_MAX_ALGS]; /* algs negotiated with server (pref order) */ } compression; __u16 signing_algorithm; __le16 cipher_type; diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index 147e3ceb13ca..ba51888a3278 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -17,11 +17,15 @@ #include <linux/uio.h> #include <linux/sort.h> #include <linux/iov_iter.h> +#include <linux/count_zeros.h> +#include <linux/unaligned.h> #include "cifsglob.h" #include "cifs_debug.h" #include "compress/lz77.h" #include "compress.h" +#define decompress_err(fmt, ...) cifs_dbg(VFS, "Decompression failure: " fmt, __VA_ARGS__) + /* * The heuristic_*() functions below try to determine data compressibility. * @@ -262,6 +266,225 @@ out: return ret; } +/* + * {fwd,bwd}_repeating_bytes() counts the number of repeating bytes from start and end of @src, + * respectively. + * + * They use the same 8-byte read strategy as lz77_match_len() (compress/lz77.c), but here we're not + * reading from a window buffer, but rather just an u64 filled with the first/last byte. + * + * Code is duplicated in order to avoid branches e.g. for advancing/rewinding ptrs. + * + * Also, bwd scanning has small, but impactful, details that makes the code unfit for a generic + * macro/function. + */ +static inline u32 fwd_repeating_bytes(const u8 *src, const u8 *end) +{ + const u8 *start = src; + u64 repeat; + + memset(&repeat, *src, sizeof(u64)); + + do { + const u64 diff = __get_unaligned_cpu64(src) ^ repeat; + + if (!diff) { + src += sizeof(u64); + + continue; + } + + /* This computes the number of common bytes in @diff. */ + src += count_trailing_zeros(diff) >> 3; + + return (src - start); + } while (likely(src + sizeof(u64) < end)); + + while (src < end && get_unaligned(src) == *start) + src++; + + return (src - start); +} + +static inline u32 bwd_repeating_bytes(const u8 *src, const u8 *end) +{ + const u8 *cur = --end; + u64 repeat; + + memset(&repeat, *cur, sizeof(u64)); + cur -= sizeof(u64); + + do { + /* we're scanning from the end, so we need to reverse the bytes for this to work */ + const u64 diff = swab64(__get_unaligned_cpu64(cur)) ^ repeat; + u64 common; + + if (!diff) { + cur -= sizeof(u64); + + continue; + } + + cur += sizeof(u64); + common = (count_trailing_zeros(diff) >> 3) + 1; + if (common > 1) + cur -= common; + + return (end - cur); + } while (likely(cur - sizeof(u64) >= src)); + + while (cur >= src && get_unaligned(cur) == (repeat & 0xff)) + cur--; + + return (end - cur); +} + +/* + * Pattern V1 compression algorithm for scanning repeating characters in a buffer. + * + * Refs: MS-SMB2 "3.1.4.4.1 Algorithm for Scanning Data Patterns V1" + */ +static void pattern_scan(const u8 *src, u32 slen, struct smb2_compression_pattern_v1 *fwd, + struct smb2_compression_pattern_v1 *bwd) +{ + u32 reps = fwd_repeating_bytes(src, src + slen); + + if (reps >= 64) { + fwd->Pattern = *src; + fwd->Repetitions = cpu_to_le32(reps); + + /* don't scan from end if <64 bytes left */ + if (slen - reps < 64) + return; + } + + reps = bwd_repeating_bytes(src + reps, src + slen); + if (reps >= 64) { + bwd->Pattern = *(src + slen - 1); + bwd->Repetitions = cpu_to_le32(reps); + } +} + +/* + * Add a payload header to @dst and set its fields. + * Return advanced @dst. + */ +static inline void *append_payload_hdr(void *dst, __le16 alg, u32 len) +{ + struct smb2_compression_payload_hdr *phdr = dst; + + phdr->CompressionAlgorithm = alg; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(len); + + return dst + SMB_COMPRESS_PAYLOAD_HDR_LEN; +} + +static int compress_data(const void *src, u32 slen, void *dst, u32 *dlen, bool chained) +{ + struct smb2_compression_pattern_v1 fwd = {}, bwd = {}; + const u32 orig_slen = slen; + const void *srcp = src; + void *dstp = dst; + u32 reps; + + /* Pattern_V1 scan + compression. */ + if (chained) { + pattern_scan(srcp, slen, &fwd, &bwd); + + /* add fwd here, bwd is added at the end */ + reps = le32_to_cpu(fwd.Repetitions); + if (reps) { + dstp = append_payload_hdr(dstp, SMB3_COMPRESS_PATTERN, sizeof(fwd)); + memcpy(dstp, &fwd, sizeof(fwd)); + + srcp += reps; + slen -= reps; + dstp += sizeof(fwd); + *dlen -= sizeof(fwd); + } + + reps = le32_to_cpu(bwd.Repetitions); + if (reps) { + slen -= reps; + *dlen -= sizeof(bwd); + } + } + + /* + * LZ77 compression. + * If chained, these are remaining bytes from Pattern_V1 scan. + * If unchained, this is the whole payload buffer. + */ + if (slen > 1024) { + struct smb2_compression_payload_hdr *phdr = NULL; + u32 len; + int ret; + + if (chained) { + /* length is set after compression here */ + phdr = dstp; + dstp = append_payload_hdr(dstp, SMB3_COMPRESS_LZ77, 0); + phdr->OriginalPayloadSize = cpu_to_le32(slen); + + dstp += sizeof(phdr->OriginalPayloadSize); + *dlen -= sizeof(phdr->OriginalPayloadSize); + } + + len = *dlen; + ret = lz77_compress(srcp, slen, dstp, &len); + if (ret) { + if (ret != -EMSGSIZE || + (le32_to_cpu(fwd.Repetitions) < orig_slen / 2 && + le32_to_cpu(bwd.Repetitions) < orig_slen / 2)) + return ret; + + /* + * If LZ77 compression failed, but we still got a ~50% compression from + * Pattern_V1, copy the remaining bytes as-is. + * Rewind @dstp for that. + */ + if (phdr) { + memset(phdr, 0, sizeof(*phdr)); + dstp = phdr; + } + + goto leftovers; + } + + if (chained) + /* + * payload length for LZ* compression includes the size of + * OriginalPayloadSize field. + * + * Ref: MS-SMB2 3.1.4.4 "Compressing the Message" + */ + phdr->Length = cpu_to_le32(len + sizeof(phdr->OriginalPayloadSize)); + + dstp += len; + *dlen -= len; + } else if (slen > 0) { + /* leftovers not worth compressing, add a payload with 'NONE' algorithm */ +leftovers: + dstp = append_payload_hdr(dstp, SMB3_COMPRESS_NONE, slen); + memcpy(dstp, srcp, slen); + dstp += slen; + } + + reps = le32_to_cpu(bwd.Repetitions); + if (reps >= 64) { + dstp = append_payload_hdr(dstp, SMB3_COMPRESS_PATTERN, sizeof(bwd)); + memcpy(dstp, &bwd, sizeof(bwd)); + + /* slen and dlen were subtracted earlier */ + dstp += sizeof(bwd); + } + + *dlen = dstp - dst; + + return 0; +} + bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) { const struct smb2_hdr *shdr = rq->rq_iov->iov_base; @@ -296,14 +519,12 @@ bool should_decompress(struct TCP_Server_Info *server, const void *buf) len = decompressed_size(buf); if (len < SMB_COMPRESS_HDR_LEN + server->vals->read_rsp_size) { - cifs_dbg(VFS, "decompression failure: compressed message too small (%u)\n", len); - + decompress_err("compressed message too small (%u)\n", len); return false; } if (len > SMB_DECOMPRESS_MAX_LEN(server)) { - cifs_dbg(VFS, "decompression failure: uncompressed message too big (%u)\n", len); - + decompress_err("uncompressed message too big (%u)\n", len); return false; } @@ -313,8 +534,8 @@ bool should_decompress(struct TCP_Server_Info *server, const void *buf) int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn) { struct iov_iter iter; - u32 slen, dlen; void *src, *dst = NULL; + u32 slen, dlen; int ret; if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base) @@ -339,13 +560,18 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s } dlen = lz77_calc_dlen(slen); + + /* Pattern_V1 payloads (fwd and bwd) are included in @dst */ + if (server->compression.chained) + dlen += (sizeof(struct smb2_compression_pattern_v1) * 2); + dst = kvzalloc(dlen, GFP_KERNEL); if (!dst) { ret = -ENOMEM; goto err_free; } - ret = lz77_compress(src, slen, dst, &dlen); + ret = compress_data(src, slen, dst, &dlen, server->compression.chained); if (!ret) { struct smb2_compression_hdr hdr = { 0 }; struct smb_rqst comp_rq = { .rq_nvec = 3, }; @@ -353,10 +579,23 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen); - hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77; - hdr.Flags = SMB2_COMPRESSION_FLAG_NONE; hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len); + /* + * For chained compression: + * - flags must be set to chained + * - algorithm must be 'NONE', as each payload hdr contains the algorithm used + * - original size must include SMB2 header size + */ + if (server->compression.chained) { + hdr.Flags = SMB2_COMPRESSION_FLAG_CHAINED; + hdr.CompressionAlgorithm = SMB3_COMPRESS_NONE; + hdr.OriginalCompressedSegmentSize += hdr.Offset; + } else { + hdr.Flags = SMB2_COMPRESSION_FLAG_NONE; + hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77; + } + iov[0].iov_base = &hdr; iov[0].iov_len = sizeof(hdr); iov[1] = rq->rq_iov[0]; @@ -378,57 +617,128 @@ err_free: int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen) { - const struct smb2_compression_hdr *hdr = src; - struct smb2_hdr *shdr; - u32 buf_len, data_len; + const struct smb2_compression_hdr *hdr; + u32 expected, unchained_offset = 0, orig_dlen = *dlen; + const void *srcp, *end; + bool chained; + void *dstp; int ret; - if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) - return -EIO; - - buf_len = le32_to_cpu(hdr->Offset); - data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); - - if (unlikely(*dlen != buf_len + data_len)) - return -EINVAL; - - /* - * Copy uncompressed data from the beginning of the payload. - * The remainder is all compressed data. - */ - src += SMB_COMPRESS_HDR_LEN; - slen -= SMB_COMPRESS_HDR_LEN; - memcpy(dst, src, buf_len); - - src += buf_len; - slen -= buf_len; - *dlen = data_len; - - ret = lz77_decompress(src, slen, dst + buf_len, dlen); - if (ret) - return ret; - - shdr = dst; - - if (*dlen != data_len) { - cifs_dbg(VFS, "decompression failure: size mismatch (got %u, expected %u, mid=%llu)\n", - *dlen, data_len, le64_to_cpu(shdr->MessageId)); + hdr = src; + expected = le32_to_cpu(hdr->OriginalCompressedSegmentSize); + chained = le32_to_cpu(hdr->Flags & SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED); + if (unlikely(!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)) { + ret = -EIO; + goto out; + } - return -EINVAL; + /* always start from the first payload header */ + srcp = src + sizeof(hdr->ProtocolId) + sizeof(hdr->OriginalCompressedSegmentSize); + end = src + slen; + dstp = dst; + + while (srcp + SMB_COMPRESS_PAYLOAD_HDR_LEN <= end) { + const struct smb2_compression_payload_hdr *phdr; + __le16 alg; + u32 len; + + phdr = srcp; + alg = phdr->CompressionAlgorithm; + if (unlikely(!smb_compress_alg_valid(alg, true))) { + ret = -EIO; + goto out; + } + + len = le32_to_cpu(phdr->Length); + srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + + if (alg == SMB3_COMPRESS_NONE) { + memcpy(dstp, srcp, len); + srcp += len; + dstp += len; + } else if (alg == SMB3_COMPRESS_PATTERN) { + const struct smb2_compression_pattern_v1 *pattern = srcp; + + len = le32_to_cpu(pattern->Repetitions); + if (unlikely(len == 0 || len > expected)) { + decompress_err("corrupt compressed data (Pattern_V1 repetitions == %u)\n", + len); + ret = -ECONNRESET; + goto out; + } + + memset(dstp, pattern->Pattern, len); + srcp += sizeof(*pattern); + dstp += len; + } else { + u32 orig_plen, plen; + + if (chained) { + orig_plen = le32_to_cpu(phdr->OriginalPayloadSize); + srcp += sizeof(plen); + len -= sizeof(plen); + } else { + /* + * When unchained, the SMB2 header is uncompressed, so copy it as-is + * before decompressing the data. + * We then expect the decompressed payload to have the full expected + * decompressed size. + */ + memcpy(dstp, srcp, len); + srcp += len; + dstp += len; + unchained_offset = len; + orig_plen = expected; + len = slen - SMB_COMPRESS_HDR_LEN - len; + } + + /* + * Set plen to whatever we have available in @dst. + * + * Note we must accept @len (source size) to be greater than @plen, because + * some servers/configurations might simply send a compressed payload that + * decompresses to a smaller size, so as long as we can fit it in @dst, we + * keep going. + */ + plen = orig_dlen - (dstp - dst); + if (plen < orig_plen) { + decompress_err("would overflow (available %u, needed %u)\n", plen, + orig_plen); + ret = -EIO; + goto out; + } + + ret = lz77_decompress(srcp, len, dstp, &plen); + if (ret) + goto out; + + if (unlikely(plen != orig_plen)) { + decompress_err("LZ77 decompress fail (got %u bytes, expected %u)\n", + plen, orig_plen); + ret = -EIO; + goto out; + } + + srcp += len; + dstp += plen; + } } - if (shdr->ProtocolId != SMB2_PROTO_NUMBER) { - cifs_dbg(VFS, "decompression failure: buffer is not an SMB2 message: got ProtocolId 0x%x (mid=%llu)\n", - le32_to_cpu(shdr->ProtocolId), le64_to_cpu(shdr->MessageId)); + ret = 0; + *dlen = dstp - dst; - return -EINVAL; + if (unlikely(*dlen - unchained_offset != expected)) { + decompress_err("uncompressed size mismatch (got %u, expected %u)\n", + *dlen - unchained_offset, expected); + ret = -ECONNRESET; + goto out; } - /* - * @dlen contains only the decompressed data size, add back the previously copied - * hdr->Offset size. - */ - *dlen += buf_len; - - return 0; + if (unlikely(((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER)) { + decompress_err("buffer is not an SMB2 message (got ProtocolId 0x%x)\n", + *(__le32 *)dst); + ret = -ECONNRESET; + } +out: + return ret; } diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c index 845907904b56..81f4c02f045b 100644 --- a/fs/smb/client/smb2pdu.c +++ b/fs/smb/client/smb2pdu.c @@ -600,13 +600,18 @@ static void build_compression_ctxt(struct smb2_compression_capabilities_context *pneg_ctxt) { pneg_ctxt->ContextType = SMB2_COMPRESSION_CAPABILITIES; - pneg_ctxt->DataLength = - cpu_to_le16(sizeof(struct smb2_compression_capabilities_context) - - sizeof(struct smb2_neg_context)); - pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(3); - pneg_ctxt->CompressionAlgorithms[0] = SMB3_COMPRESS_LZ77; - pneg_ctxt->CompressionAlgorithms[1] = SMB3_COMPRESS_LZ77_HUFF; - pneg_ctxt->CompressionAlgorithms[2] = SMB3_COMPRESS_LZNT1; + pneg_ctxt->DataLength = cpu_to_le16(sizeof(*pneg_ctxt) - sizeof(struct smb2_neg_context)); + pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(2); + pneg_ctxt->Flags = SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED; + pneg_ctxt->CompressionAlgorithms[0] = SMB3_COMPRESS_PATTERN; + pneg_ctxt->CompressionAlgorithms[1] = SMB3_COMPRESS_LZ77; + + /* + * TODO (implement in order of priority): + * - SMB3_COMPRESS_LZ4 + * - SMB3_COMPRESS_LZ77_HUFF + * - SMB3_COMPRESS_LZNT1 + */ } static unsigned int @@ -794,9 +799,12 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, struct smb2_compression_capabilities_context *ctxt) { unsigned int len = le16_to_cpu(ctxt->DataLength); - __le16 alg; + int i, count; + /* reset in case of changes in between reconnects */ server->compression.enabled = false; + server->compression.chained = false; + memset(server->compression.algs, 0, SMB3_COMPRESS_MAX_ALGS * sizeof(__le16)); /* * Caller checked that DataLength remains within SMB boundary. We still @@ -808,21 +816,30 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, return; } - if (le16_to_cpu(ctxt->CompressionAlgorithmCount) != 1) { - pr_warn_once("invalid SMB3 compress algorithm count\n"); + count = le16_to_cpu(ctxt->CompressionAlgorithmCount); + if (count == 0 || count > 4) { + pr_warn_once("invalid SMB3 compress algorithm count (%d)\n", count); return; } - alg = ctxt->CompressionAlgorithms[0]; + for (i = 0; i < count; i++) { + __le16 alg = ctxt->CompressionAlgorithms[i]; - /* 'NONE' (0) compressor type is never negotiated */ - if (alg == 0 || le16_to_cpu(alg) > 3) { - pr_warn_once("invalid compression algorithm '%u'\n", alg); - return; + /* only supported algorithms for now, update this as new ones are implemented */ + if (alg != SMB3_COMPRESS_PATTERN && alg != SMB3_COMPRESS_LZ77) { + pr_warn_once("invalid compression algorithm '%u'\n", le16_to_cpu(alg)); + continue; + } + + server->compression.algs[i] = alg; } - server->compression.alg = alg; - server->compression.enabled = true; + /* we need at least 1 supported algorithm */ + if (server->compression.algs[0] != 0) { + server->compression.enabled = true; + if (le32_to_cpu(ctxt->Flags & SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED)) + server->compression.chained = true; + } } static int decode_encrypt_ctx(struct TCP_Server_Info *server, |
