From 173840904122599acf1c5fe52f8f9e42f5fc3729 Mon Sep 17 00:00:00 2001 From: Enzo Matsumiya Date: Mon, 1 Dec 2025 13:22:44 -0300 Subject: smb: client: add support for decompressing READs Implement decompression support for SMB2 READ messages. Signed-off-by: Enzo Matsumiya --- fs/smb/client/compress.c | 86 ++++++++++++++++++-- fs/smb/client/compress.h | 42 +++++++++- fs/smb/client/compress/lz77.c | 185 ++++++++++++++++++++++++++++++++++++++++++ fs/smb/client/compress/lz77.h | 1 + fs/smb/client/smb2ops.c | 141 +++++++++++++++++++++++++++----- fs/smb/client/smb2pdu.c | 6 ++ 6 files changed, 435 insertions(+), 26 deletions(-) diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index 7349f1dc6b1f..147e3ceb13ca 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -17,12 +17,8 @@ #include #include #include - #include "cifsglob.h" -#include "../common/smb2pdu.h" -#include "cifsproto.h" -#include "smb2proto.h" - +#include "cifs_debug.h" #include "compress/lz77.h" #include "compress.h" @@ -291,6 +287,29 @@ bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) return (shdr->Command == SMB2_READ); } +bool should_decompress(struct TCP_Server_Info *server, const void *buf) +{ + u32 len; + + if (!is_compress_hdr(buf)) + return false; + + len = decompressed_size(buf); + if (len < SMB_COMPRESS_HDR_LEN + server->vals->read_rsp_size) { + cifs_dbg(VFS, "decompression failure: compressed message too small (%u)\n", len); + + return false; + } + + if (len > SMB_DECOMPRESS_MAX_LEN(server)) { + cifs_dbg(VFS, "decompression failure: uncompressed message too big (%u)\n", len); + + return false; + } + + return true; +} + int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn) { struct iov_iter iter; @@ -356,3 +375,60 @@ err_free: return ret; } + +int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen) +{ + const struct smb2_compression_hdr *hdr = src; + struct smb2_hdr *shdr; + u32 buf_len, data_len; + int ret; + + if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) + return -EIO; + + buf_len = le32_to_cpu(hdr->Offset); + data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); + + if (unlikely(*dlen != buf_len + data_len)) + return -EINVAL; + + /* + * Copy uncompressed data from the beginning of the payload. + * The remainder is all compressed data. + */ + src += SMB_COMPRESS_HDR_LEN; + slen -= SMB_COMPRESS_HDR_LEN; + memcpy(dst, src, buf_len); + + src += buf_len; + slen -= buf_len; + *dlen = data_len; + + ret = lz77_decompress(src, slen, dst + buf_len, dlen); + if (ret) + return ret; + + shdr = dst; + + if (*dlen != data_len) { + cifs_dbg(VFS, "decompression failure: size mismatch (got %u, expected %u, mid=%llu)\n", + *dlen, data_len, le64_to_cpu(shdr->MessageId)); + + return -EINVAL; + } + + if (shdr->ProtocolId != SMB2_PROTO_NUMBER) { + cifs_dbg(VFS, "decompression failure: buffer is not an SMB2 message: got ProtocolId 0x%x (mid=%llu)\n", + le32_to_cpu(shdr->ProtocolId), le64_to_cpu(shdr->MessageId)); + + return -EINVAL; + } + + /* + * @dlen contains only the decompressed data size, add back the previously copied + * hdr->Offset size. + */ + *dlen += buf_len; + + return 0; +} diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h index f3ed1d3e52fb..28f4e3ffbfb4 100644 --- a/fs/smb/client/compress.h +++ b/fs/smb/client/compress.h @@ -17,19 +17,48 @@ #include #include +#include + #include "../common/smb2pdu.h" -#include "cifsglob.h" /* sizeof(smb2_compression_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_HDR_LEN 16 /* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_PAYLOAD_HDR_LEN 8 #define SMB_COMPRESS_MIN_LEN PAGE_SIZE +#define SMB_DECOMPRESS_MAX_LEN(_srv) \ + (256 + SMB_COMPRESS_HDR_LEN + max3((_srv)->maxBuf, (_srv)->max_read, (_srv)->max_write)) + +static __always_inline u32 decompressed_size(const void *buf) +{ + const struct smb2_compression_hdr *hdr = buf; + + if (!buf) + return 0; + + return le32_to_cpu(hdr->Offset) + + le32_to_cpu(hdr->OriginalCompressedSegmentSize); +} + +static __always_inline bool is_compress_hdr(const void *buf) +{ + const struct smb2_compression_hdr *hdr = buf; + + if (!buf) + return false; + + return (hdr->ProtocolId == SMB2_COMPRESSION_TRANSFORM_ID); +} #ifdef CONFIG_CIFS_COMPRESSION +struct TCP_Server_Info; +struct cifs_tcon; +struct smb_rqst; + typedef int (*compress_send_fn)(struct TCP_Server_Info *, int, struct smb_rqst *); int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn); +int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen); /** * should_compress() - Determines if a request (write) or the response to a @@ -47,6 +76,7 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s * Return false otherwise. */ bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq); +bool should_decompress(struct TCP_Server_Info *server, const void *buf); /** * smb_compress_alg_valid() - Validate a compression algorithm. @@ -77,11 +107,21 @@ static inline int smb_compress(void *unused1, void *unused2, void *unused3) return -EOPNOTSUPP; } +static inline int smb_decompress(const void *unused1, u32 unused2, void *unused3, u32 *unused4) +{ + return -EOPNOTSUPP; +} + static inline bool should_compress(void *unused1, void *unused2) { return false; } +static inline bool should_decompress(void *unused1, void *unused2) +{ + return false; +} + static inline int smb_compress_alg_valid(__le16 unused1, bool unused2) { return -EOPNOTSUPP; diff --git a/fs/smb/client/compress/lz77.c b/fs/smb/client/compress/lz77.c index 1d0cb4a22f5f..4a1d62c07921 100644 --- a/fs/smb/client/compress/lz77.c +++ b/fs/smb/client/compress/lz77.c @@ -28,6 +28,16 @@ static __always_inline u8 lz77_read8(const u8 *ptr) return get_unaligned(ptr); } +static __always_inline u16 lz77_read16(const u16 *ptr) +{ + return get_unaligned(ptr); +} + +static __always_inline u32 lz77_read32(const u32 *ptr) +{ + return get_unaligned(ptr); +} + static __always_inline u64 lz77_read64(const u64 *ptr) { return get_unaligned(ptr); @@ -250,3 +260,178 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) return -EMSGSIZE; } + +static __always_inline const void *lz77_decode_match_len(const void *src, const void **nib, u32 *m) +{ + u32 mlen = *m & 7; + + if (mlen == 7) { + if (!*nib) { + *nib = src; + mlen = lz77_read8(src) % 16; + src++; + } else { + mlen = lz77_read8(*nib) >> 4; + *nib = NULL; + } + + if (mlen == 15) { + mlen = lz77_read8(src); + src++; + + if (mlen == 255) { + mlen = lz77_read16(src); + src += 2; + + if (mlen == 0) { + mlen = lz77_read32(src); + src += 4; + } + + /* Unexpected match len. */ + if (unlikely(mlen < 15 + 7)) { + *m = U32_MAX; + return src; + } + + mlen -= (15 + 7); + } + + mlen += 15; + } + + mlen += 7; + } + + mlen += 3; + *m = mlen; + + return src; +} + +int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen) +{ + const void *srcp, *end, *nib = NULL; + u32 flag, flag_count = 0; + void *dstp, *dst_end; + + if (!dlen || *dlen < slen) + return -EINVAL; + + srcp = src; + end = srcp + slen; + dstp = dst; + dst_end = dstp + *dlen; + *dlen = 0; + + do { + u32 c, dist, len; + + /* + * Read flag. + * + * The flag is a 32-bit bitmap where 0s are literal bytes (to be copied from @srcp + * to @dstp) and 1s are matches (decoded from @srcp). + * + * Compressed payload always starts with a flag. + */ + if (flag_count == 0) { + flag = lz77_read32(srcp); + srcp += 4; + flag_count = 32; + } + + /* + * Decode flag. + * + * Here, we'll have 'c' literals to write. + * + * Notes: + * - @flag == 0 means literals are bound by @flag_count, not necessarily + * exactly 32 + * - here, 0 < @flag_count <= 32 + * - __builtin_clz() yields UB if arg is 0 + * - unlike lz77_compress(), we don't use a 'long' flag here because of CLZ + * in order to avoid extra math when sizeof(long) == 8 + */ + c = flag ? __builtin_clz(flag) : flag_count; + if (c) { + if (unlikely(dstp + c > dst_end)) + return -EIO; + + flag_count -= c; + flag <<= c; + + memcpy(dstp, srcp, c); + srcp += c; + dstp += c; + + if (flag_count == 0) + continue; + + if (unlikely(srcp >= end)) + break; + } + + /* + * Decode matches. + * + * Just as we read/write sequential literals, we do the same for possibly + * sequential matches. + */ + c = umin(__builtin_clz(~flag), flag_count); + while (c--) { + /* Store match symbol in @len */ + len = lz77_read16(srcp); + srcp += 2; + dist = (len >> 3) + 1; + + srcp = lz77_decode_match_len(srcp, &nib, &len); + + /* + * Check bogus match values. + * + * We don't know what compression parameters (e.g. match max dist, min len) + * the server is using, so check against min/max allowed by spec. + * + * Also check if within @dst boundaries. + */ + if (unlikely(!dist || dist >= SZ_8K || dstp - dst < dist)) + return -EIO; + + if (unlikely(len < 3 || len >= U32_MAX || dstp + len > dst_end)) + return -EIO; + + /* + * Dist and len are good. + * If non-overlapping memory, we can use memcpy() (common case). + * Otherwise, we have to do it byte by byte. + */ + if (len < dist) { + memcpy(dstp, dstp - dist, len); + dstp += len; + } else { + while (len--) { + lz77_write8(dstp, lz77_read8(dstp - dist)); + dstp++; + } + } + + if (unlikely(srcp >= end)) + break; + + flag_count--; + flag <<= 1U; + } + } while (srcp < end); + + /* + * We've now fully parsed @src (compressed buffer) without any processing errors. + * + * However, it's up to callers to determine if @dst contents and @dlen are according to + * their expectations (i.e. what SMB2 header indicates). + */ + *dlen = dstp - dst; + + return 0; +} diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h index 3c75b70b51b0..dc88e6a11c52 100644 --- a/fs/smb/client/compress/lz77.h +++ b/fs/smb/client/compress/lz77.h @@ -13,4 +13,5 @@ u32 lz77_calc_dlen(u32 slen); int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen); +int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen); #endif /* _SMB_COMPRESS_LZ77_H */ diff --git a/fs/smb/client/smb2ops.c b/fs/smb/client/smb2ops.c index 1e39f2165e42..f0cb2de4a778 100644 --- a/fs/smb/client/smb2ops.c +++ b/fs/smb/client/smb2ops.c @@ -30,6 +30,7 @@ #include "fs_context.h" #include "cached_dir.h" #include "reparse.h" +#include "compress.h" /* Change credits for different ops and return the total number of credits */ static int @@ -4629,12 +4630,17 @@ err_free: return rc; } +static __always_inline bool is_transform_hdr(const void *buf) +{ + const struct smb2_transform_hdr *trhdr = buf; + + return (trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM); +} + static int smb3_is_transform_hdr(void *buf) { - struct smb2_transform_hdr *trhdr = buf; - - return trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM; + return is_transform_hdr(buf) || is_compress_hdr(buf); } static int @@ -4843,8 +4849,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, } else if (buf_len >= data_offset + data_len) { /* read response payload is in buf */ + struct iov_iter it = rdata->subreq.io_iter; + WARN_ONCE(buffer, "read data can be either in buf or in buffer"); - copied = copy_to_iter(buf + data_offset, data_len, &rdata->subreq.io_iter); + copied = copy_to_iter(buf + data_offset, data_len, &it); if (copied == 0) return -EIO; rdata->got_bytes = copied; @@ -5141,6 +5149,87 @@ one_more: return ret; } +static int receive_compressed(struct TCP_Server_Info *server) +{ + struct mid_q_entry *mid; + void *src, *dst = NULL; + u32 slen, dlen; + int ret; + + slen = server->pdu_size; + src = kvzalloc(slen, GFP_KERNEL); + if (!src) + return -ENOMEM; + + ret = server->total_read; + memcpy(src, server->smallbuf, ret); + + ret = cifs_read_from_socket(server, src + ret, slen - ret); + if (ret < 0) + goto err_free; + + server->total_read += ret; + + dlen = decompressed_size(src); + dst = kvzalloc(dlen, GFP_KERNEL); + if (!dst) { + ret = -ENOMEM; + + goto err_free; + } + + ret = smb_decompress(src, slen, dst, &dlen); + if (ret) { + spin_lock(&server->srv_lock); + server->tcpStatus = CifsNeedReconnect; + spin_unlock(&server->srv_lock); + + goto err_free; + } + + mid = smb2_find_dequeue_mid(server, dst); + if (!mid) { + ret = -EIO; + + goto err_free; + } + + ret = handle_read_data(server, mid, dst, dlen, NULL, 0, true); + if (!ret) { +#ifdef CONFIG_CIFS_STATS2 + mid->when_received = jiffies; +#endif + if (server->ops->is_network_name_deleted) + server->ops->is_network_name_deleted(dst, server); + + mid->callback(mid); + } else { + spin_lock(&server->srv_lock); + if (server->tcpStatus == CifsNeedReconnect) { + spin_lock(&server->mid_queue_lock); + mid->mid_state = MID_RETRY_NEEDED; + spin_unlock(&server->mid_queue_lock); + spin_unlock(&server->srv_lock); + + mid->callback(mid); + } else { + spin_lock(&server->mid_queue_lock); + mid->mid_state = MID_REQUEST_SUBMITTED; + mid->deleted_from_q = false; + list_add_tail(&mid->qhead, &server->pending_mid_q); + spin_unlock(&server->mid_queue_lock); + spin_unlock(&server->srv_lock); + } + } + + release_mid(mid); +err_free: + kvfree(src); + kvfree(dst); + + return ret; +} + static int smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mids, char **bufs, int *num_mids) @@ -5150,26 +5239,36 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct smb2_transform_hdr *tr_hdr = (struct smb2_transform_hdr *)buf; unsigned int orig_len = le32_to_cpu(tr_hdr->OriginalMessageSize); - if (pdu_length < sizeof(struct smb2_transform_hdr) + - sizeof(struct smb2_hdr)) { - cifs_server_dbg(VFS, "Transform message is too small (%u)\n", - pdu_length); - cifs_reconnect(server, true); - return -ECONNABORTED; - } + if (is_transform_hdr(buf)) { + if (pdu_length < sizeof(struct smb2_transform_hdr) + + sizeof(struct smb2_hdr)) { + cifs_server_dbg(VFS, "Transform message is too small (%u)\n", pdu_length); + cifs_reconnect(server, true); - if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) { - cifs_server_dbg(VFS, "Transform message is broken\n"); - cifs_reconnect(server, true); - return -ECONNABORTED; - } + return -ECONNABORTED; + } - /* TODO: add support for compounds containing READ. */ - if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) { - return receive_encrypted_read(server, &mids[0], num_mids); + if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) { + cifs_server_dbg(VFS, "Transform message is broken\n"); + cifs_reconnect(server, true); + + return -ECONNABORTED; + } + + /* TODO: add support for compounds containing READ. */ + if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) + return receive_encrypted_read(server, &mids[0], num_mids); + + return receive_encrypted_standard(server, mids, bufs, num_mids); } - return receive_encrypted_standard(server, mids, bufs, num_mids); + if (should_decompress(server, buf)) + return receive_compressed(server); + + cifs_server_dbg(VFS, "Invalid ProtocolId 0x%x\n", *(__le32 *)buf); + cifs_reconnect(server, true); + + return -ECONNABORTED; } int @@ -5191,6 +5290,8 @@ static int smb2_next_header(struct TCP_Server_Info *server, char *buf, *noff = le32_to_cpu(t_hdr->OriginalMessageSize); if (unlikely(check_add_overflow(*noff, sizeof(*t_hdr), noff))) return -EINVAL; + } else if (hdr->ProtocolId == SMB2_COMPRESSION_TRANSFORM_ID) { + *noff = 0; } else { *noff = le32_to_cpu(hdr->NextCommand); } diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c index 8b4a4573e9c3..845907904b56 100644 --- a/fs/smb/client/smb2pdu.c +++ b/fs/smb/client/smb2pdu.c @@ -4714,6 +4714,12 @@ smb2_async_readv(struct cifs_io_subrequest *rdata) flags |= CIFS_HAS_CREDITS; } + if (should_compress(io_parms.tcon, &rqst)) { + struct smb2_read_req *req = (struct smb2_read_req *)buf; + + req->Flags |= SMB2_READFLAG_REQUEST_COMPRESSED; + } + rc = cifs_call_async(server, &rqst, cifs_readv_receive, smb2_readv_callback, smb3_handle_read_data, rdata, flags, -- cgit v1.2.3