summaryrefslogtreecommitdiff
path: root/fs/smb/client/compress.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/smb/client/compress.c')
-rw-r--r--fs/smb/client/compress.c272
1 files changed, 239 insertions, 33 deletions
diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c
index de55ddd122a5..7fcb8c796aef 100644
--- a/fs/smb/client/compress.c
+++ b/fs/smb/client/compress.c
@@ -25,7 +25,162 @@
#include "compress/lz77.h"
#include "compress.h"
-int smb_compress(void *buf, const void *data, size_t *len)
+static void pattern_scan(const u8 *src, size_t src_len,
+ struct smb2_compression_pattern_v1 *fwd,
+ struct smb2_compression_pattern_v1 *bwd)
+{
+ const u8 *srcp = src, *end = src + src_len - 1;
+ u32 reps, plen = sizeof(*fwd);
+
+ memset(fwd, 0, plen);
+ memset(bwd, 0, plen);
+
+ /* Forward pattern scan. */
+ while (++srcp <= end && *srcp == *src) { }
+
+ reps = srcp - src;
+ if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) {
+ fwd->Pattern = src[0];
+ fwd->Repetitions = cpu_to_le32(reps);
+
+ src += reps;
+ }
+
+ if (reps == src_len)
+ return;
+
+ /* Backward pattern scan. */
+ srcp = end;
+ while (--srcp >= src && *srcp == *end) { }
+
+ reps = end - srcp;
+ if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) {
+ bwd->Pattern = *end;
+ bwd->Repetitions = cpu_to_le32(reps);
+
+ return;
+ }
+
+ return;
+}
+
+static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained)
+{
+ struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 };
+ struct smb2_compression_payload_hdr *phdr;
+ const u8 *srcp = src;
+ u8 *dstp = dst;
+ size_t plen = 0;
+ int ret = 0;
+
+ if (!chained)
+ goto unchained;
+
+ pattern_scan(srcp, src_len, &fwd, &bwd);
+
+ if (le32_to_cpu(fwd.Repetitions) > 0) {
+ plen = sizeof(fwd);
+ if (plen > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(plen);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, &fwd, plen);
+ dstp += plen;
+ *dst_len -= plen;
+
+ srcp += le32_to_cpu(fwd.Repetitions);
+ src_len -= le32_to_cpu(fwd.Repetitions);
+ if (src_len == 0)
+ goto out;
+ }
+
+ if (le32_to_cpu(bwd.Repetitions) > 0) {
+ src_len -= le32_to_cpu(bwd.Repetitions);
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+
+ if (src_len == 0)
+ goto out_bwd_pattern;
+ }
+
+ /* Leftover uncompressed data with size not worth compressing. */
+ if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) {
+ if (src_len > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(src_len);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, srcp, src_len);
+ dstp += src_len;
+ *dst_len -= src_len;
+
+ goto out_bwd_pattern;
+ }
+unchained:
+ plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize);
+ if (plen < *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen);
+ if (ret) {
+ *dst_len = 0;
+ return ret;
+ }
+
+ if (chained) {
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->OriginalPayloadSize = cpu_to_le32(src_len);
+ plen += sizeof(phdr->OriginalPayloadSize);
+ phdr->Length = cpu_to_le32(plen);
+ plen += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ }
+
+ dstp += plen;
+ *dst_len -= plen;
+out_bwd_pattern:
+ if (bwd.Repetitions >= 64) {
+ plen = sizeof(bwd);
+ if (plen > *dst_len) {
+ *dst_len = 0;
+ return -EMSGSIZE;
+ }
+
+ phdr = (struct smb2_compression_payload_hdr *)dstp;
+ phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN;
+ phdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ phdr->Length = cpu_to_le32(plen);
+
+ dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ memcpy(dstp, &bwd, plen);
+ dstp += plen;
+ *dst_len -= plen;
+ }
+out:
+ *dst_len = dstp - (u8 *)dst;
+
+ return 0;
+}
+
+int smb_compress(void *buf, const void *data, size_t *len, bool chained)
{
struct smb2_compression_hdr *hdr;
size_t buf_len, data_len;
@@ -35,15 +190,21 @@ int smb_compress(void *buf, const void *data, size_t *len)
data_len = *len;
*len = 0;
+ if (data_len < SMB_COMPRESS_MIN_LEN)
+ return -ENODATA;
+
hdr = buf;
hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len);
hdr->Offset = cpu_to_le32(buf_len);
- hdr->Flags = SMB2_COMPRESSION_FLAG_NONE;
+ hdr->Flags = chained;
hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77;
+ if (chained) {
+ hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE;
+ hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len);
+ }
- /* XXX: add other algs here as they're implemented */
- ret = lz77_compress(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len);
+ ret = compress_data(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained);
if (!ret)
*len = SMB_COMPRESS_HDR_LEN + buf_len + data_len;
@@ -53,47 +214,92 @@ int smb_compress(void *buf, const void *data, size_t *len)
int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len)
{
const struct smb2_compression_hdr *hdr;
- size_t buf_len, data_len;
+ size_t slen, dlen;
+ const void *srcp = src;
+ void *dstp = dst;
+ bool chained;
int ret;
- hdr = src;
- if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)
+ hdr = srcp;
+ chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED);
+ slen = le32_to_cpu(hdr->Offset);
+
+ if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)
return -EIO;
- buf_len = le32_to_cpu(hdr->Offset);
- data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize);
-
- /*
- * Copy uncompressed data from the beginning of the payload.
- * The remainder is all compressed data.
- */
- src += SMB_COMPRESS_HDR_LEN;
- memcpy(dst, src, buf_len);
- src += buf_len;
- src_len -= SMB_COMPRESS_HDR_LEN + buf_len;
- *dst_len -= buf_len;
-
- ret = lz77_decompress(src, src_len, dst + buf_len, dst_len);
- if (ret)
- return ret;
+ /* Copy the uncompressed SMB2 READ header. */
+ srcp += SMB_COMPRESS_HDR_LEN;
+ src_len -= SMB_COMPRESS_HDR_LEN;
+ memcpy(dstp, srcp, slen);
+ srcp += slen;
+ src_len -= slen;
+ dstp += slen;
+
+ if (!chained) {
+ slen = src_len;
+ goto unchained;
+ }
+
+ while (src_len > 0) {
+ const struct smb2_compression_payload_hdr *phdr = srcp;
+ __le16 alg = phdr->CompressionAlgorithm;
+
+ srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN;
+ slen = le32_to_cpu(phdr->Length);
+ dlen = slen;
+unchained:
+ if (!smb_compress_alg_valid(alg, chained))
+ return -EIO;
+
+ /* XXX: add other algs here as they're implemented */
+ if (alg == SMB3_COMPRESS_LZ77) {
+ if (chained) {
+ /* sizeof(OriginalPayloadSize) */
+ srcp += 4;
+ slen -= 4;
+ dlen = le32_to_cpu(phdr->OriginalPayloadSize);
+ }
+
+ ret = lz77_decompress(srcp, slen, dstp, &dlen);
+ if (ret)
+ return ret;
- if (*dst_len != data_len) {
- cifs_dbg(VFS, "decompressed size mismatch: got %zu, expected %zu\n",
- *dst_len, data_len);
+ if (chained)
+ src_len -= 4;
+ } else if (alg == SMB3_COMPRESS_PATTERN) {
+ struct smb2_compression_pattern_v1 *pattern;
+
+ pattern = (struct smb2_compression_pattern_v1 *)srcp;
+ dlen = le32_to_cpu(pattern->Repetitions);
+
+ if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) {
+ cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen);
+ return -ECONNRESET;
+ }
+
+ memset(dstp, pattern->Pattern, dlen);
+ } else if (alg == SMB3_COMPRESS_NONE) {
+ memcpy(dstp, srcp, dlen);
+ }
+
+ srcp += slen;
+ src_len -= slen;
+ dstp += dlen;
+ }
+
+ *dst_len = dstp - dst;
+ if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) {
+ cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len,
+ le32_to_cpu(hdr->OriginalCompressedSegmentSize));
return -ECONNRESET;
}
if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) {
- cifs_dbg(VFS, "decompressed buffer is not an SMB2 message: ProtocolId 0x%x\n",
+ cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n",
*(__le32 *)dst);
return -ECONNRESET;
}
- /*
- * @dst_len contains only the decompressed data size, add back
- * the previously copied uncompressed size
- */
- *dst_len += buf_len;
-
return 0;
}