diff options
Diffstat (limited to 'fs/smb/client/compress.c')
| -rw-r--r-- | fs/smb/client/compress.c | 272 |
1 files changed, 239 insertions, 33 deletions
diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index de55ddd122a5..7fcb8c796aef 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -25,7 +25,162 @@ #include "compress/lz77.h" #include "compress.h" -int smb_compress(void *buf, const void *data, size_t *len) +static void pattern_scan(const u8 *src, size_t src_len, + struct smb2_compression_pattern_v1 *fwd, + struct smb2_compression_pattern_v1 *bwd) +{ + const u8 *srcp = src, *end = src + src_len - 1; + u32 reps, plen = sizeof(*fwd); + + memset(fwd, 0, plen); + memset(bwd, 0, plen); + + /* Forward pattern scan. */ + while (++srcp <= end && *srcp == *src) { } + + reps = srcp - src; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + fwd->Pattern = src[0]; + fwd->Repetitions = cpu_to_le32(reps); + + src += reps; + } + + if (reps == src_len) + return; + + /* Backward pattern scan. */ + srcp = end; + while (--srcp >= src && *srcp == *end) { } + + reps = end - srcp; + if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { + bwd->Pattern = *end; + bwd->Repetitions = cpu_to_le32(reps); + + return; + } + + return; +} + +static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained) +{ + struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 }; + struct smb2_compression_payload_hdr *phdr; + const u8 *srcp = src; + u8 *dstp = dst; + size_t plen = 0; + int ret = 0; + + if (!chained) + goto unchained; + + pattern_scan(srcp, src_len, &fwd, &bwd); + + if (le32_to_cpu(fwd.Repetitions) > 0) { + plen = sizeof(fwd); + if (plen > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, &fwd, plen); + dstp += plen; + *dst_len -= plen; + + srcp += le32_to_cpu(fwd.Repetitions); + src_len -= le32_to_cpu(fwd.Repetitions); + if (src_len == 0) + goto out; + } + + if (le32_to_cpu(bwd.Repetitions) > 0) { + src_len -= le32_to_cpu(bwd.Repetitions); + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + + if (src_len == 0) + goto out_bwd_pattern; + } + + /* Leftover uncompressed data with size not worth compressing. */ + if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) { + if (src_len > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(src_len); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, srcp, src_len); + dstp += src_len; + *dst_len -= src_len; + + goto out_bwd_pattern; + } +unchained: + plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize); + if (plen < *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen); + if (ret) { + *dst_len = 0; + return ret; + } + + if (chained) { + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->OriginalPayloadSize = cpu_to_le32(src_len); + plen += sizeof(phdr->OriginalPayloadSize); + phdr->Length = cpu_to_le32(plen); + plen += SMB_COMPRESS_PAYLOAD_HDR_LEN; + } + + dstp += plen; + *dst_len -= plen; +out_bwd_pattern: + if (bwd.Repetitions >= 64) { + plen = sizeof(bwd); + if (plen > *dst_len) { + *dst_len = 0; + return -EMSGSIZE; + } + + phdr = (struct smb2_compression_payload_hdr *)dstp; + phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; + phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + phdr->Length = cpu_to_le32(plen); + + dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + memcpy(dstp, &bwd, plen); + dstp += plen; + *dst_len -= plen; + } +out: + *dst_len = dstp - (u8 *)dst; + + return 0; +} + +int smb_compress(void *buf, const void *data, size_t *len, bool chained) { struct smb2_compression_hdr *hdr; size_t buf_len, data_len; @@ -35,15 +190,21 @@ int smb_compress(void *buf, const void *data, size_t *len) data_len = *len; *len = 0; + if (data_len < SMB_COMPRESS_MIN_LEN) + return -ENODATA; + hdr = buf; hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len); hdr->Offset = cpu_to_le32(buf_len); - hdr->Flags = SMB2_COMPRESSION_FLAG_NONE; + hdr->Flags = chained; hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; + if (chained) { + hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; + hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len); + } - /* XXX: add other algs here as they're implemented */ - ret = lz77_compress(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len); + ret = compress_data(data, data_len, buf + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained); if (!ret) *len = SMB_COMPRESS_HDR_LEN + buf_len + data_len; @@ -53,47 +214,92 @@ int smb_compress(void *buf, const void *data, size_t *len) int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len) { const struct smb2_compression_hdr *hdr; - size_t buf_len, data_len; + size_t slen, dlen; + const void *srcp = src; + void *dstp = dst; + bool chained; int ret; - hdr = src; - if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) + hdr = srcp; + chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED); + slen = le32_to_cpu(hdr->Offset); + + if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) return -EIO; - buf_len = le32_to_cpu(hdr->Offset); - data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); - - /* - * Copy uncompressed data from the beginning of the payload. - * The remainder is all compressed data. - */ - src += SMB_COMPRESS_HDR_LEN; - memcpy(dst, src, buf_len); - src += buf_len; - src_len -= SMB_COMPRESS_HDR_LEN + buf_len; - *dst_len -= buf_len; - - ret = lz77_decompress(src, src_len, dst + buf_len, dst_len); - if (ret) - return ret; + /* Copy the uncompressed SMB2 READ header. */ + srcp += SMB_COMPRESS_HDR_LEN; + src_len -= SMB_COMPRESS_HDR_LEN; + memcpy(dstp, srcp, slen); + srcp += slen; + src_len -= slen; + dstp += slen; + + if (!chained) { + slen = src_len; + goto unchained; + } + + while (src_len > 0) { + const struct smb2_compression_payload_hdr *phdr = srcp; + __le16 alg = phdr->CompressionAlgorithm; + + srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; + src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; + slen = le32_to_cpu(phdr->Length); + dlen = slen; +unchained: + if (!smb_compress_alg_valid(alg, chained)) + return -EIO; + + /* XXX: add other algs here as they're implemented */ + if (alg == SMB3_COMPRESS_LZ77) { + if (chained) { + /* sizeof(OriginalPayloadSize) */ + srcp += 4; + slen -= 4; + dlen = le32_to_cpu(phdr->OriginalPayloadSize); + } + + ret = lz77_decompress(srcp, slen, dstp, &dlen); + if (ret) + return ret; - if (*dst_len != data_len) { - cifs_dbg(VFS, "decompressed size mismatch: got %zu, expected %zu\n", - *dst_len, data_len); + if (chained) + src_len -= 4; + } else if (alg == SMB3_COMPRESS_PATTERN) { + struct smb2_compression_pattern_v1 *pattern; + + pattern = (struct smb2_compression_pattern_v1 *)srcp; + dlen = le32_to_cpu(pattern->Repetitions); + + if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen); + return -ECONNRESET; + } + + memset(dstp, pattern->Pattern, dlen); + } else if (alg == SMB3_COMPRESS_NONE) { + memcpy(dstp, srcp, dlen); + } + + srcp += slen; + src_len -= slen; + dstp += dlen; + } + + *dst_len = dstp - dst; + if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { + cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len, + le32_to_cpu(hdr->OriginalCompressedSegmentSize)); return -ECONNRESET; } if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) { - cifs_dbg(VFS, "decompressed buffer is not an SMB2 message: ProtocolId 0x%x\n", + cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n", *(__le32 *)dst); return -ECONNRESET; } - /* - * @dst_len contains only the decompressed data size, add back - * the previously copied uncompressed size - */ - *dst_len += buf_len; - return 0; } |
