From 56877938281173b5a29748731af7dd56abd34a68 Mon Sep 17 00:00:00 2001 From: Enzo Matsumiya Date: Fri, 19 Apr 2024 11:12:03 -0300 Subject: smb: client: implement chained de/compression support Introduce 'Pattern V1' algorithm (MS-SMB2) for pattern (repeated single byte) scanning/matching. Usage of this algorithm requires chained compression support to be negotiated with the server, which is also done by this commit. Signed-off-by: Enzo Matsumiya --- fs/smb/client/cifsglob.h | 1 + fs/smb/client/compress.c | 277 +++++++++++++++++++++++++++++++++++++----- fs/smb/client/compress.h | 18 ++- fs/smb/client/compress/lz77.h | 18 +++ fs/smb/client/smb2ops.c | 6 +- fs/smb/client/smb2pdu.c | 38 ++++-- fs/smb/client/transport.c | 4 +- 7 files changed, 310 insertions(+), 52 deletions(-) diff --git a/fs/smb/client/cifsglob.h b/fs/smb/client/cifsglob.h index de80a302a379..b1cd534b72df 100644 --- a/fs/smb/client/cifsglob.h +++ b/fs/smb/client/cifsglob.h @@ -772,6 +772,7 @@ struct TCP_Server_Info { struct { bool requested; /* "compress" mount option set*/ bool enabled; /* actually negotiated with server */ + bool chained; /* chained support negotiated with server */ __le16 alg; /* preferred alg negotiated with server */ } compression; __u16 signing_algorithm; diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index 9ab7cb1dabf9..5628ba678a2d 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -25,7 +25,153 @@ #include "compress/lz77.h" #include "compress.h" -int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) +static void pattern_scan(const u8 *src, size_t src_len, + struct smb2_compression_pattern_v1 *fwd, + struct smb2_compression_pattern_v1 *bwd) +{ + const u8 *srcp = src, *end = src + src_len - 1; + u32 reps, plen = sizeof(*fwd); + + memset(fwd, 0, plen); + memset(bwd, 0, plen); + + /* Forward pattern scan. */ + while (srcp < end && *srcp++ == *src) { } + + reps = srcp - src; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + fwd->Pattern = src[0]; + fwd->Repetitions = cpu_to_le32(reps); + if (reps == src_len) + return; + } + +#if 0 + /* Backward pattern scan. */ + srcp = end; + while (srcp >= src && *srcp-- == *end) { } + + reps = end - srcp; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + bwd->Pattern = *end; + bwd->Repetitions = cpu_to_le32(reps); + + return; + } +#endif + + return; +} + +static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained) +{ + struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 }; + struct smb2_compression_payload_hdr *phdr; + const u8 *srcp = src; + u8 *dstp = dst; + size_t reps, plen = 0; + int ret = 0; + + if (!chained) + goto unchained; + + pattern_scan(srcp, src_len, &fwd, &bwd); + + reps = le32_to_cpu(fwd.Repetitions); + if (reps > 0) { + plen = sizeof(fwd); + if (plen > *dst_len) { + pr_err("%s: 1 plen %zu, dst %zu\n", __func__, plen, *dst_len); + return -EMSGSIZE; + } + + srcp += reps; + src_len -= reps; + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + + memcpy(dstp, &fwd, plen); + dstp += plen; + *dst_len -= plen; + + if (src_len == 0) + goto out; + } + + /* Leftover uncompressed data with size not worth compressing. */ + if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) { + if (src_len > *dst_len) { + pr_err("%s: 2 plen %zu, dst %zu\n", __func__, plen, *dst_len); + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(src_len); + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + + memcpy(dstp, srcp, src_len); + dstp += src_len; + *dst_len -= src_len; + + goto out_bwd_pattern; + } +unchained: + plen = *dst_len; + if (plen < *dst_len) { + pr_err("%s: 3 plen %zu, dst %zu, lz77 %zu\n", __func__, plen, *dst_len, lz77_compress_min_mem(src_len)); + *dst_len = 0; + return -EMSGSIZE; + } + + ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen); + if (ret) { + *dst_len = 0; + return ret; + } + + if (chained) { + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->OriginalPayloadSize = cpu_to_le32(src_len); + plen += sizeof(phdr->OriginalPayloadSize); + phdr->Length = cpu_to_le32(plen); + } + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN + plen; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN + plen; +out_bwd_pattern: + if (bwd.Repetitions >= 64) { +#if 0 + plen = sizeof(bwd); + if (plen > *dst_len) { + pr_err("%s: 4 plen %zu, dst %zu\n", __func__, plen, *dst_len); + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + + memcpy(dstp, &bwd, plen); + dstp += plen; +#endif + } +out: + *dst_len = dstp - (u8 *)dst; + + return 0; +} + +int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq, bool chained) { struct smb2_compression_hdr *hdr; size_t buf_len, data_len, dst_len; @@ -33,6 +179,10 @@ int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) void *src, *dst; int ret; + dst_rq->rq_iov->iov_base = NULL; + dst_rq->rq_iov->iov_len = 0; + dst_rq->rq_nvec = 0; + buf_len = src_rq->rq_iov->iov_len; data_len = iov_iter_count(&iter); if (data_len < SMB_COMPRESS_MIN_LEN) @@ -47,7 +197,7 @@ int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) goto err_free; } - dst_len = SMB_COMPRESS_HDR_LEN + buf_len + data_len; + dst_len = SMB_COMPRESS_HDR_LEN + buf_len + lz77_compress_min_mem(data_len); dst = kvzalloc(dst_len, GFP_KERNEL); if (!dst) { ret = -ENOMEM; @@ -58,13 +208,16 @@ int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq) hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len); hdr->Offset = cpu_to_le32(buf_len); - hdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + hdr->Flags = chained; hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + if (chained) { + hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len); + } memcpy(dst + SMB_COMPRESS_HDR_LEN, src_rq->rq_iov->iov_base, buf_len); - /* XXX: add other algs here as they're implemented */ - ret = lz77_compress(src, data_len, dst + SMB_COMPRESS_HDR_LEN + buf_len, &data_len); + ret = compress_data(src, data_len, dst + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained); err_free: if (!ret) { dst_rq->rq_nvec = 1; @@ -78,50 +231,110 @@ err_free: return ret; } -int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len) +int smb_decompress(const void *src, size_t usrc_len, void *dst, size_t *dst_len) { const struct smb2_compression_hdr *hdr; - size_t buf_len, data_len; + const void *srcp = src; + void *dstp = dst; + ssize_t src_len = usrc_len; + bool chained; + size_t plen; int ret; - hdr = src; - if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) + hdr = srcp; + chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED); + plen = le32_to_cpu(hdr->Offset); + + if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) return -EIO; - buf_len = le32_to_cpu(hdr->Offset); - data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); + /* Copy the uncompressed SMB2 READ header. */ + srcp += SMB_COMPRESS_HDR_LEN; + src_len -= SMB_COMPRESS_HDR_LEN; + memcpy(dstp, srcp, plen); + srcp += plen; + src_len -= plen; + dstp += plen; + + if (!chained) { + plen = src_len; + goto unchained; + } /* - * Copy uncompressed data from the beginning of the payload. - * The remainder is all compressed data. + * Move to the first payload header + * (sizeof(hdr->ProtocolId) + sizeof(hdr->OriginalCompressedSegmentSize)) */ - src += SMB_COMPRESS_HDR_LEN; - memcpy(dst, src, buf_len); - src += buf_len; - src_len -= SMB_COMPRESS_HDR_LEN + buf_len; - *dst_len -= buf_len; - - ret = lz77_decompress(src, src_len, dst + buf_len, dst_len); - if (ret) - return ret; + srcp += 8; + src_len -= 8; + + while (src_len > 0) { + const struct smb2_compression_payload_hdr *phdr = srcp; + __le16 alg = phdr->CompressionAlgorithm; + + srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + plen = le32_to_cpu(phdr->Length); +unchained: + if (!smb_compress_alg_valid(alg, chained)) + return -EIO; + + /* XXX: add other algs here as they're implemented */ + if (alg == SMB3_COMPRESS_LZ77) { + size_t orig = plen; + + if (chained) { + /* sizeof(OriginalPayloadSize) */ + srcp += 4; + plen -= 4; + orig = le32_to_cpu(phdr->OriginalPayloadSize); + } + + ret = lz77_decompress(srcp, plen, dstp, &orig); + if (ret) + return ret; + + srcp += plen; + src_len -= plen; + dstp += orig; + } else if (alg == SMB3_COMPRESS_PATTERN) { + struct smb2_compression_pattern_v1 *pattern; + u32 reps; - if (*dst_len != data_len) { - cifs_dbg(VFS, "decompressed size mismatch: got %zu, expected %zu\n", - *dst_len, data_len); + pattern = (struct smb2_compression_pattern_v1 *)srcp; + reps = le32_to_cpu(pattern->Repetitions); + + if (reps == 0 || reps > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "corrupt compressed data (%u pattern repetitions)\n", reps); + return -ECONNRESET; + } + + memset(dstp, pattern->Pattern, reps); + srcp += sizeof(*pattern); + src_len -= sizeof(*pattern); + dstp += reps; +#if 0 + } else if (alg == SMB3_COMPRESS_NONE) { + memcpy(dstp, srcp, plen); + srcp += plen; + src_len -= plen; + dstp += plen; +#endif + } + } + + *dst_len = dstp - dst; + if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len, + le32_to_cpu(hdr->OriginalCompressedSegmentSize)); return -ECONNRESET; } if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) { - cifs_dbg(VFS, "decompressed buffer is not an SMB2 message: ProtocolId 0x%x\n", + cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n", *(__le32 *)dst); return -ECONNRESET; } - /* - * @dst_len contains only the decompressed data size, add back - * the previously copied uncompressed size - */ - *dst_len += buf_len; - return 0; } diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h index 88896a76fe2b..19249bbb9d68 100644 --- a/fs/smb/client/compress.h +++ b/fs/smb/client/compress.h @@ -25,12 +25,15 @@ /* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_PAYLOAD_HDR_LEN 8 #define SMB_COMPRESS_MIN_LEN PAGE_SIZE +/* follows Windows implementation (as per MS-SMB2) */ +#define SMB_COMPRESS_PAYLOAD_MIN_LEN 1024 +#define SMB_COMPRESS_PATTERN_MIN_LEN 64 #define SMB_DECOMPRESS_MAX_LEN(_srv) \ (256 + SMB_COMPRESS_HDR_LEN + \ max_t(size_t, (_srv)->maxBuf, max_t(size_t, (_srv)->max_read, (_srv)->max_write))) #ifdef CONFIG_CIFS_COMPRESSION -int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq); +int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq, bool chained); int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len); /* @@ -121,14 +124,17 @@ static __always_inline bool should_compress(const struct cifs_tcon *tcon, const static __always_inline size_t decompressed_size(const void *buf) { const struct smb2_compression_hdr *hdr = buf; + size_t size = le32_to_cpu(hdr->OriginalCompressedSegmentSize); - return le32_to_cpu(hdr->Offset) + - le32_to_cpu(hdr->OriginalCompressedSegmentSize); + if (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED) + size += le32_to_cpu(hdr->Offset); + + return size; } #else /* CONFIG_CIFS_COMPRESSION */ -#define smb_compress(arg1, arg2) (-EOPNOTSUPP) -#define smb_decompress(arg1, arg2) (-EOPNOTSUPP) -#define smb_compress_fatal_err(arg1) (false) +#define smb_compress(arg1, arg2, arg3) (-EOPNOTSUPP) +#define smb_decompress(arg1, arg2, arg3, arg4) (-EOPNOTSUPP) +#define smb_compress_fatal_err(arg) (false) #define smb_compress_alg_valid(arg1, arg2) (-EOPNOTSUPP) #define should_compress(arg1, arg2, arg3) (false) #define decompress_size(arg1) (0) diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h index ea4281a87f79..98f91b254ac1 100644 --- a/fs/smb/client/compress/lz77.h +++ b/fs/smb/client/compress/lz77.h @@ -13,6 +13,7 @@ #include #include #include +#include #ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS #include #endif @@ -66,6 +67,23 @@ static __always_inline u32 lz77_log2(unsigned int x) return x ? ((u32)(31 - __builtin_clz(x))) : 0; } +/* + * Computes minimum memory required for LZ77 compression based on input length + * (@src_len). + * + * This minimum accounts for the worst case scenario, where no matches are found + * in the input buffer. + * + * For every literal (byte) written to the output buffer, a u32 flag is written + * as well. + * + * Avoid a couple extra instructions by rounding up adding an extra sizeof(u32). + */ +static __always_inline size_t lz77_compress_min_mem(size_t src_len) +{ + return src_len + ((src_len >> 3) + sizeof(u32)); +} + #ifdef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS static __always_inline u8 lz77_read8(const void *ptr) { diff --git a/fs/smb/client/smb2ops.c b/fs/smb/client/smb2ops.c index 34ee8ba24c1d..f61c265d6920 100644 --- a/fs/smb/client/smb2ops.c +++ b/fs/smb/client/smb2ops.c @@ -4944,6 +4944,7 @@ static void decompress_thread(struct work_struct *work) } ret = smb_decompress(ctx->buf, ctx->len, dst, &dst_len); + pr_err("%s: (async) ret %d, dst %zu, expected %zu, + hdr %zu\n", __func__, ret, dst_len, decompressed_size(ctx->buf), decompressed_size(ctx->buf) + SMB_COMPRESS_HDR_LEN); if (ret) goto err_free; @@ -4995,8 +4996,8 @@ static int receive_compressed(struct TCP_Server_Info *server, char **bufs, struc dst_len = decompressed_size(src); /* offload large decompressions to %decompress_rq */ - if (src_len > CIFSMaxBufSize + MAX_HEADER_SIZE(server) || - dst_len > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) { + //if (src_len > CIFSMaxBufSize + MAX_HEADER_SIZE(server) || + if (dst_len > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) { struct decompress_offload_ctx *ctx; ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); @@ -5017,6 +5018,7 @@ static int receive_compressed(struct TCP_Server_Info *server, char **bufs, struc dst = cifs_buf_get(); ret = smb_decompress(src, src_len, dst, &dst_len); + pr_err("%s: ret %d, dst %zu, expected %zu, + hdr %zu\n", __func__, ret, dst_len, decompressed_size(src), decompressed_size(src) + SMB_COMPRESS_HDR_LEN); if (ret) { cifs_buf_release(dst); goto err_free; diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c index 23801b028530..afd823c9cbca 100644 --- a/fs/smb/client/smb2pdu.c +++ b/fs/smb/client/smb2pdu.c @@ -589,12 +589,17 @@ build_compression_ctxt(struct smb2_compression_capabilities_context *pneg_ctxt) pneg_ctxt->DataLength = cpu_to_le16(sizeof(struct smb2_compression_capabilities_context) - sizeof(struct smb2_neg_context)); - pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(1); + + /* This enables the usage of the Pattern_V1 algorithm */ + pneg_ctxt->Flags = SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED; + pneg_ctxt->CompressionAlgorithmCount = cpu_to_le16(2); + /* - * Send the only algorithm we support (XXX: add others as they're + * Send the algorithms we support (XXX: add others as they're * implemented). */ pneg_ctxt->CompressionAlgorithms[0] = SMB3_COMPRESS_LZ77; + pneg_ctxt->CompressionAlgorithms[1] = SMB3_COMPRESS_PATTERN; } static unsigned int @@ -782,9 +787,13 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, struct smb2_compression_capabilities_context *ctxt) { unsigned int len = le16_to_cpu(ctxt->DataLength); - __le16 alg; + __le16 count, i; + int chained; server->compression.enabled = false; + server->compression.chained = false; + count = ctxt->CompressionAlgorithmCount; + chained = !!(ctxt->Flags & SMB2_COMPRESSION_CAPABILITIES_FLAG_CHAINED); /* * Caller checked that DataLength remains within SMB boundary. We still @@ -796,18 +805,27 @@ static void decode_compress_ctx(struct TCP_Server_Info *server, return; } - if (le16_to_cpu(ctxt->CompressionAlgorithmCount) != 1) { - pr_warn_once("invalid SMB3 compress algorithm count\n"); + if (unlikely(count != 2)) { + pr_warn_once("invalid SMB3 compress algorithm count '%u'\n", count); return; } - alg = ctxt->CompressionAlgorithms[0]; - if (!smb_compress_alg_valid(alg, false)) { - pr_warn_once("invalid compression algorithm '%u'\n", alg); - return; + for (i = 0; i < count; i++) { + __le16 alg = ctxt->CompressionAlgorithms[i]; + + if (!smb_compress_alg_valid(alg, false)) { + pr_warn_once("invalid compression algorithm '%u'\n", alg); + return; + } + + if (alg == SMB3_COMPRESS_PATTERN) { + if (chained) + server->compression.chained = true; + } else { + server->compression.alg = alg; + } } - server->compression.alg = alg; server->compression.enabled = true; } diff --git a/fs/smb/client/transport.c b/fs/smb/client/transport.c index 4eb347747dee..7c16ed1a4b26 100644 --- a/fs/smb/client/transport.c +++ b/fs/smb/client/transport.c @@ -428,7 +428,7 @@ static int smb_send_rqst(struct TCP_Server_Info *server, int num_rqst, struct smb_rqst *rqst, int flags) { - struct send_req_vars *vars; + struct send_req_vars *vars = NULL; struct smb_rqst *cur_rqst; bool compress, encrypt; struct kvec *iov; @@ -453,7 +453,7 @@ smb_send_rqst(struct TCP_Server_Info *server, int num_rqst, /* ->iov_base is allocated in smb_compress() */ rqst = compressed; - rc = smb_compress(orig, rqst); + rc = smb_compress(orig, rqst, server->compression.chained); if (rc) { if (smb_compress_fatal_err(rc)) goto err_free; -- cgit v1.2.3