// SPDX-License-Identifier: GPL-2.0-only /* * Copyright (C) 2024, SUSE LLC * * Authors: Enzo Matsumiya * * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only). * See compress/ for implementation details of each algorithm. * * References: * MS-SMB2 "3.1.4.4 Compressing the Message" * MS-SMB2 "3.1.5.3 Decompressing the Chained Message" * MS-XCA - for details of the supported algorithms */ #include #include #include #include "cifsglob.h" #include "../common/smb2pdu.h" #include "cifsproto.h" #include "smb2proto.h" #include "cifs_debug.h" #include "compress/lz77.h" #include "compress.h" static void pattern_scan(const u8 *src, size_t src_len, struct smb2_compression_pattern_v1 *fwd, struct smb2_compression_pattern_v1 *bwd) { const u8 *srcp = src, *end = src + src_len - 1; u32 reps, plen = sizeof(*fwd); memset(fwd, 0, plen); memset(bwd, 0, plen); /* Forward pattern scan. */ while (++srcp <= end && *srcp == *src) { } reps = srcp - src; if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { fwd->Pattern = src[0]; fwd->Repetitions = cpu_to_le32(reps); src += reps; } if (reps == src_len) return; /* Backward pattern scan. */ srcp = end; while (--srcp >= src && *srcp == *end) { } reps = end - srcp; if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { bwd->Pattern = *end; bwd->Repetitions = cpu_to_le32(reps); return; } return; } static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained) { struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 }; struct smb2_compression_payload_hdr *phdr; const u8 *srcp = src; u8 *dstp = dst; size_t plen = 0; int ret = 0; if (!chained) goto unchained; pattern_scan(srcp, src_len, &fwd, &bwd); if (le32_to_cpu(fwd.Repetitions) > 0) { plen = sizeof(fwd); if (plen > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(plen); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, &fwd, plen); dstp += plen; *dst_len -= plen; srcp += le32_to_cpu(fwd.Repetitions); src_len -= le32_to_cpu(fwd.Repetitions); if (src_len == 0) goto out; } if (le32_to_cpu(bwd.Repetitions) > 0) { src_len -= le32_to_cpu(bwd.Repetitions); *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; if (src_len == 0) goto out_bwd_pattern; } /* Leftover uncompressed data with size not worth compressing. */ if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) { if (src_len > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(src_len); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, srcp, src_len); dstp += src_len; *dst_len -= src_len; goto out_bwd_pattern; } unchained: plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize); if (plen < *dst_len) { *dst_len = 0; return -EMSGSIZE; } ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen); if (ret) { *dst_len = 0; return ret; } if (chained) { phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->OriginalPayloadSize = cpu_to_le32(src_len); plen += sizeof(phdr->OriginalPayloadSize); phdr->Length = cpu_to_le32(plen); plen += SMB_COMPRESS_PAYLOAD_HDR_LEN; } dstp += plen; *dst_len -= plen; out_bwd_pattern: if (bwd.Repetitions >= 64) { plen = sizeof(bwd); if (plen > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(plen); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, &bwd, plen); dstp += plen; *dst_len -= plen; } out: *dst_len = dstp - (u8 *)dst; return 0; } int smb_compress(void *buf, const void *data, size_t *len, bool chained) { struct smb2_compression_hdr *hdr; size_t buf_len, data_len; int ret; buf_len = sizeof(struct smb2_write_req); data_len = *len; *len = 0; if (data_len < SMB_COMPRESS_MIN_LEN) return -ENODATA; hdr = buf; hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len); hdr->Offset = cpu_to_le32(buf_len); hdr->Flags = chained; hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; if (chained) { hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len); } ret = compress_data(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained); if (!ret) *len = SMB_COMPRESS_HDR_LEN + buf_len + data_len; return ret; } int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len) { const struct smb2_compression_hdr *hdr; size_t slen, dlen; const void *srcp = src; void *dstp = dst; bool chained; int ret; hdr = srcp; chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED); slen = le32_to_cpu(hdr->Offset); if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) return -EIO; /* Copy the uncompressed SMB2 READ header. */ srcp += SMB_COMPRESS_HDR_LEN; src_len -= SMB_COMPRESS_HDR_LEN; memcpy(dstp, srcp, slen); srcp += slen; src_len -= slen; dstp += slen; if (!chained) { slen = src_len; goto unchained; } while (src_len > 0) { const struct smb2_compression_payload_hdr *phdr = srcp; __le16 alg = phdr->CompressionAlgorithm; srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; slen = le32_to_cpu(phdr->Length); dlen = slen; unchained: if (!smb_compress_alg_valid(alg, chained)) return -EIO; /* XXX: add other algs here as they're implemented */ if (alg == SMB3_COMPRESS_LZ77) { if (chained) { /* sizeof(OriginalPayloadSize) */ srcp += 4; slen -= 4; dlen = le32_to_cpu(phdr->OriginalPayloadSize); } ret = lz77_decompress(srcp, slen, dstp, &dlen); if (ret) return ret; if (chained) src_len -= 4; } else if (alg == SMB3_COMPRESS_PATTERN) { struct smb2_compression_pattern_v1 *pattern; pattern = (struct smb2_compression_pattern_v1 *)srcp; dlen = le32_to_cpu(pattern->Repetitions); if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen); return -ECONNRESET; } memset(dstp, pattern->Pattern, dlen); } else if (alg == SMB3_COMPRESS_NONE) { memcpy(dstp, srcp, dlen); } srcp += slen; src_len -= slen; dstp += dlen; } *dst_len = dstp - dst; if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len, le32_to_cpu(hdr->OriginalCompressedSegmentSize)); return -ECONNRESET; } if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) { cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n", *(__le32 *)dst); return -ECONNRESET; } return 0; }