// SPDX-License-Identifier: GPL-2.0-only /* * Copyright (C) 2024, SUSE LLC * * Authors: Enzo Matsumiya * * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only). * See compress/ for implementation details of each algorithm. * * References: * MS-SMB2 "3.1.4.4 Compressing the Message" * MS-SMB2 "3.1.5.3 Decompressing the Chained Message" * MS-XCA - for details of the supported algorithms */ #include #include #include #include "../common/smb2pdu.h" #include "cifsglob.h" #include "cifs_debug.h" #include "compress/lz77.h" #include "compress.h" #define SAMPLING_READ_SIZE (16) #define SAMPLING_INTERVAL (256) #define BUCKET_SIZE (256) /* * The size of the sample is based on a statistical sampling rule of thumb. * The common way is to perform sampling tests as long as the number of * elements in each cell is at least 5. * * Instead of 5, we choose 32 to obtain more accurate results. * If the data contain the maximum number of symbols, which is 256, we obtain a * sample size bound by 8192. * * For a sample of at most 8KB of data per data range: 16 consecutive bytes * from up to 512 locations. */ #define MAX_SAMPLE_SIZE (8192 * SAMPLING_READ_SIZE / SAMPLING_INTERVAL) // ^ == LZ77 window size struct bucket_item { size_t count; }; struct heuristic_ctx { /* Partial copy of input data */ const u8 *sample; size_t sample_size; /* Buckets store counters for each byte value * * For statistical analysis of the input data we consider bytes that form a * Galois Field of 256 objects. Each object has an attribute count, ie. how * many times the object appeared in the sample. */ struct bucket_item bucket[BUCKET_SIZE]; struct bucket_item aux_bucket[BUCKET_SIZE]; struct list_head list; }; /* * Shannon Entropy calculation. * * Pure byte distribution analysis fails to determine compressibility of data. * Try calculating entropy to estimate the average minimum number of bits * needed to encode the sampled data. * * For convenience, return the percentage of needed bits, instead of amount of * bits directly. * * @ENTROPY_LEVEL_OK - below that threshold, sample has low byte entropy * and can be compressible with high probability * * @ENTROPY_LEVEL_HIGH - data are not compressible with high probability * * Use of ilog2() decreases precision, we lower the LVL to 5 to compensate. */ #define ENTROPY_LEVEL_OK 65 #define ENTROPY_LEVEL_HIGH 80 /* * For increasead precision in shannon_entropy calculation, * let's do pow(n, M) to save more digits after comma: * * - maximum int bit length is 64 * - ilog2(MAX_SAMPLE_SIZE) -> 13 * - 13 * 4 = 52 < 64 -> M = 4 * * So use pow(n, 4). */ static inline u32 ilog2_w(u64 n) { return ilog2(n * n * n * n); } static u32 shannon_entropy(struct heuristic_ctx *ctx) { const size_t max = 8 * ilog2_w(2); size_t i, p, p_base, sz_base, sum = 0; sz_base = ilog2_w(ctx->sample_size); for (i = 0; i < 256 && ctx->bucket[i].count > 0; i++) { p = ctx->bucket[i].count; p_base = ilog2_w(p); sum += p * (sz_base - p_base); } sum /= ctx->sample_size; return sum * 100 / max; } #define RADIX_BASE 4U #define COUNTERS_SIZE (1U << RADIX_BASE) static __always_inline u8 get4bits(u64 num, int shift) { /* Reverse order */ return ((COUNTERS_SIZE - 1) - ((num >> shift) % COUNTERS_SIZE)); } /* * Use 4 bits as radix base * Use 16 u32 counters for calculating new position in buf array * * @array - array that will be sorted * @aux - buffer array to store sorting results * must be equal in size to @array * @num - array size */ static void radix_sort(struct bucket_item *array, struct bucket_item *aux, int num) { size_t buf_num, max_num, addr, new_addr, counters[COUNTERS_SIZE]; int bitlen, shift, i; /* * Try avoid useless loop iterations for small numbers stored in big * counters. Example: 48 33 4 ... in 64bit array */ max_num = array[0].count; for (i = 1; i < num; i++) { buf_num = array[i].count; if (buf_num > max_num) max_num = buf_num; } buf_num = ilog2(max_num); bitlen = ALIGN(buf_num, RADIX_BASE * 2); shift = 0; while (shift < bitlen) { memset(counters, 0, sizeof(counters)); for (i = 0; i < num; i++) { buf_num = array[i].count; addr = get4bits(buf_num, shift); counters[addr]++; } for (i = 1; i < COUNTERS_SIZE; i++) counters[i] += counters[i - 1]; for (i = num - 1; i >= 0; i--) { buf_num = array[i].count; addr = get4bits(buf_num, shift); counters[addr]--; new_addr = counters[addr]; aux[new_addr] = array[i]; } shift += RADIX_BASE; /* * Normal radix expects to move data from a temporary array, to * the main one. But that requires some CPU time. Avoid that * by doing another sort iteration to original array instead of * memcpy() */ memset(counters, 0, sizeof(counters)); for (i = 0; i < num; i ++) { buf_num = aux[i].count; addr = get4bits(buf_num, shift); counters[addr]++; } for (i = 1; i < COUNTERS_SIZE; i++) counters[i] += counters[i - 1]; for (i = num - 1; i >= 0; i--) { buf_num = aux[i].count; addr = get4bits(buf_num, shift); counters[addr]--; new_addr = counters[addr]; array[new_addr] = aux[i]; } shift += RADIX_BASE; } } /* * Count how many bytes cover 90% of the sample. * * There are several types of structured binary data that use nearly all byte * values. The distribution can be uniform and counts in all buckets will be * nearly the same (eg. encrypted data). Unlikely to be compressible. * * Other possibility is normal (Gaussian) distribution, where the data could * be potentially compressible, but we have to take a few more steps to decide * how much. * * @BYTE_COVERAGE_LOW - main part of byte values repeated frequently, * compression algo can easy fix that * @BYTE_COVERAGE_HIGH - data have uniform distribution and with high * probability is not compressible */ #define BYTE_COVERAGE_LOW 64 #define BYTE_COVERAGE_HIGH 200 static int byte_coverage(struct heuristic_ctx *ctx) { const size_t threshold = ctx->sample_size * 90 / 100; struct bucket_item *bkt = &ctx->bucket[0]; size_t sum = 0; int i; /* Sort in reverse order */ radix_sort(ctx->bucket, ctx->aux_bucket, BUCKET_SIZE); for (i = 0; i < BYTE_COVERAGE_LOW; i++) sum += bkt[i].count; if (sum > threshold) return i; for (; i < BYTE_COVERAGE_HIGH && bkt[i].count > 0; i++) { sum += bkt[i].count; if (sum > threshold) break; } return i; } /* * Count ASCII bytes in buckets. * * This heuristic can detect textual data (configs, xml, json, html, etc). * Because in most text-like data byte set is restricted to limited number of * possible characters, and that restriction in most cases makes data easy to * compress. * * @ASCII_COUNT_THRESHOLD - consider all data within this byte set size: * less - compressible * more - need additional analysis */ #define ASCII_COUNT_THRESHOLD 64 static __always_inline u32 ascii_count(const struct heuristic_ctx *ctx) { size_t count = 0; int i; for (i = 0; i < ASCII_COUNT_THRESHOLD; i++) if (ctx->bucket[i].count > 0) count++; /* * Continue collecting count of byte values in buckets. If the byte * set size is bigger then the threshold, it's pointless to continue, * the detection technique would fail for this type of data. */ for (; i < 256; i++) { if (ctx->bucket[i].count > 0) { count++; if (count > ASCII_COUNT_THRESHOLD) break; } } return count; } static __always_inline struct heuristic_ctx *heuristic_init(const u8 *buf, size_t len) { struct heuristic_ctx *ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); int i = 0, s = 0; if (!ctx) return ERR_PTR(-ENOMEM); ctx->sample = kzalloc(MAX_SAMPLE_SIZE, GFP_KERNEL); if (!ctx->sample) { kfree(ctx); return ERR_PTR(-ENOMEM); } if (len > MAX_SAMPLE_SIZE) len = MAX_SAMPLE_SIZE; while (i < len - SAMPLING_READ_SIZE) { memcpy((void *)&ctx->sample[s], &buf[i], SAMPLING_READ_SIZE); i += SAMPLING_INTERVAL; s += SAMPLING_INTERVAL; } ctx->sample_size = s; INIT_LIST_HEAD(&ctx->list); return ctx; } static __always_inline bool sample_repeated_patterns(struct heuristic_ctx *ctx) { const size_t half = ctx->sample_size / 2; return (memcmp(&ctx->sample[0], &ctx->sample[half], half) == 0); } static int is_compressible(const void *buf, size_t len) { struct heuristic_ctx *ctx; int i, ret = 0; u8 byte; ctx = heuristic_init(buf, len); if (!ctx) return -ENOMEM; /* * Parse from low-hanging fruits (compressible) to "need more analysis" (uncompressible). */ ret = 1; if (sample_repeated_patterns(ctx)) goto out; for (i = 0; i < ctx->sample_size; i++) { byte = ctx->sample[i]; ctx->bucket[byte].count++; } if (ascii_count(ctx) < ASCII_COUNT_THRESHOLD) goto out; i = byte_coverage(ctx); if (i <= BYTE_COVERAGE_LOW) goto out; if (i >= BYTE_COVERAGE_HIGH) { ret = 0; goto out; } i = shannon_entropy(ctx); if (i <= ENTROPY_LEVEL_OK) goto out; /* * For the levels below ENTROPY_LVL_HIGH, additional analysis would be * needed to give green light to compression. * * For now just assume that compression at that level is not worth the * resources because: * * 1. it is possible to defrag the data later * * 2. the data would turn out to be hardly compressible, eg. 150 byte * values, every bucket has counter at level ~54. The heuristic would * be confused. This can happen when data have some internal repeated * patterns like "abbacbbc...". This can be detected by analyzing * pairs of bytes, which is too costly. */ if (i < ENTROPY_LEVEL_HIGH) ret = 1; else ret = 0; out: kvfree(ctx->sample); kfree(ctx); return ret; } static void pattern_scan(const u8 *src, size_t src_len, struct smb2_compression_pattern_v1 *fwd, struct smb2_compression_pattern_v1 *bwd) { const u8 *srcp = src, *end = src + src_len - 1; u32 reps, plen = sizeof(*fwd); memset(fwd, 0, plen); memset(bwd, 0, plen); /* Forward pattern scan. */ while (++srcp <= end && *srcp == *src) { } reps = srcp - src; if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { fwd->Pattern = src[0]; fwd->Repetitions = cpu_to_le32(reps); src += reps; } if (reps == src_len) return; /* Backward pattern scan. */ srcp = end; while (--srcp >= src && *srcp == *end) { } reps = end - srcp; if (reps >= SMB_COMPRESS_PATTERN_MIN_LEN) { bwd->Pattern = *end; bwd->Repetitions = cpu_to_le32(reps); return; } return; } static int compress_data(const void *src, size_t src_len, void *dst, size_t *dst_len, bool chained) { struct smb2_compression_pattern_v1 fwd = { 0 }, bwd = { 0 }; struct smb2_compression_payload_hdr *phdr; const u8 *srcp = src; u8 *dstp = dst; size_t plen = 0; int ret = 0; if (!chained) goto unchained; pattern_scan(srcp, src_len, &fwd, &bwd); if (le32_to_cpu(fwd.Repetitions) > 0) { plen = sizeof(fwd); if (plen > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(plen); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, &fwd, plen); dstp += plen; *dst_len -= plen; srcp += le32_to_cpu(fwd.Repetitions); src_len -= le32_to_cpu(fwd.Repetitions); if (src_len == 0) goto out; } if (le32_to_cpu(bwd.Repetitions) > 0) { src_len -= le32_to_cpu(bwd.Repetitions); *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; if (src_len == 0) goto out_bwd_pattern; } /* Leftover uncompressed data with size not worth compressing. */ if (src_len <= SMB_COMPRESS_PAYLOAD_MIN_LEN) { if (src_len > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(src_len); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, srcp, src_len); dstp += src_len; *dst_len -= src_len; goto out_bwd_pattern; } unchained: plen = *dst_len - SMB_COMPRESS_PAYLOAD_HDR_LEN - sizeof(phdr->OriginalPayloadSize); if (plen < *dst_len) { *dst_len = 0; return -EMSGSIZE; } ret = lz77_compress(srcp, src_len, dstp + SMB_COMPRESS_PAYLOAD_HDR_LEN + sizeof(phdr->OriginalPayloadSize), &plen); if (ret) { *dst_len = 0; return ret; } if (chained) { phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->OriginalPayloadSize = cpu_to_le32(src_len); plen += sizeof(phdr->OriginalPayloadSize); phdr->Length = cpu_to_le32(plen); plen += SMB_COMPRESS_PAYLOAD_HDR_LEN; } dstp += plen; *dst_len -= plen; out_bwd_pattern: if (bwd.Repetitions >= 64) { plen = sizeof(bwd); if (plen > *dst_len) { *dst_len = 0; return -EMSGSIZE; } phdr = (struct smb2_compression_payload_hdr *)dstp; phdr->CompressionAlgorithm = SMB3_COMPRESS_PATTERN; phdr->Flags = SMB2_COMPRESSION_FLAG_NONE; phdr->Length = cpu_to_le32(plen); dstp += SMB_COMPRESS_PAYLOAD_HDR_LEN; *dst_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; memcpy(dstp, &bwd, plen); dstp += plen; *dst_len -= plen; } out: *dst_len = dstp - (u8 *)dst; return 0; } int smb_compress(struct smb_rqst *src_rq, struct smb_rqst *dst_rq, bool chained) { struct smb2_compression_hdr *hdr; size_t buf_len, data_len; struct iov_iter tmp = src_rq->rq_iter; void *src, *dst; int ret = - ENOMEM; buf_len = src_rq->rq_iov->iov_len; if (WARN(buf_len != sizeof(struct smb2_write_req), "%s: unexpected buf len %zu\n", __func__, buf_len)) return -EIO; data_len = iov_iter_count(&src_rq->rq_iter); if (data_len < SMB_COMPRESS_MIN_LEN) return -ENODATA; src = kvzalloc(data_len, GFP_KERNEL); if (!src) goto err_free; if (!copy_from_iter_full(src, data_len, &tmp)) { ret = -EIO; goto err_free; } pr_err("%s: is compressible %d\n", __func__, is_compressible(src, data_len)); dst_rq->rq_iov->iov_base = kvzalloc(SMB_COMPRESS_HDR_LEN + buf_len + data_len, GFP_KERNEL); if (!dst_rq->rq_iov->iov_base) goto err_free; dst = dst_rq->rq_iov->iov_base; hdr = dst; hdr->ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr->OriginalCompressedSegmentSize = cpu_to_le32(data_len); hdr->Offset = cpu_to_le32(buf_len); hdr->Flags = chained; hdr->CompressionAlgorithm = SMB3_COMPRESS_LZ77; if (chained) { hdr->CompressionAlgorithm = SMB3_COMPRESS_NONE; hdr->OriginalCompressedSegmentSize += cpu_to_le32(buf_len); } /* * Copy SMB2 header uncompressed to @dst. * Compression header is setup in smb_compress(). */ memcpy(dst + SMB_COMPRESS_HDR_LEN, src_rq->rq_iov->iov_base, buf_len); ret = compress_data(src, data_len, dst + SMB_COMPRESS_HDR_LEN + buf_len, &data_len, chained); pr_err("%s: compress ret %d\n", __func__, ret); err_free: kvfree(src); if (!ret) { dst_rq->rq_iov->iov_len = SMB_COMPRESS_HDR_LEN + buf_len + data_len; } else { kvfree(dst_rq->rq_iov->iov_base); dst_rq->rq_iov->iov_base = NULL; } return ret; } int smb_decompress(const void *src, size_t src_len, void *dst, size_t *dst_len) { const struct smb2_compression_hdr *hdr; size_t slen, dlen; const void *srcp = src; void *dstp = dst; bool chained; int ret; hdr = srcp; chained = (hdr->Flags == SMB2_COMPRESSION_FLAG_CHAINED); slen = le32_to_cpu(hdr->Offset); if (!chained && hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) return -EIO; /* Copy the uncompressed SMB2 READ header. */ srcp += SMB_COMPRESS_HDR_LEN; src_len -= SMB_COMPRESS_HDR_LEN; memcpy(dstp, srcp, slen); srcp += slen; src_len -= slen; dstp += slen; if (!chained) { slen = src_len; goto unchained; } while (src_len > 0) { const struct smb2_compression_payload_hdr *phdr = srcp; __le16 alg = phdr->CompressionAlgorithm; srcp += SMB_COMPRESS_PAYLOAD_HDR_LEN; src_len -= SMB_COMPRESS_PAYLOAD_HDR_LEN; slen = le32_to_cpu(phdr->Length); dlen = slen; unchained: if (!smb_compress_alg_valid(alg, chained)) return -EIO; /* XXX: add other algs here as they're implemented */ if (alg == SMB3_COMPRESS_LZ77) { if (chained) { /* sizeof(OriginalPayloadSize) */ srcp += 4; slen -= 4; dlen = le32_to_cpu(phdr->OriginalPayloadSize); } ret = lz77_decompress(srcp, slen, dstp, &dlen); if (ret) return ret; if (chained) src_len -= 4; } else if (alg == SMB3_COMPRESS_PATTERN) { struct smb2_compression_pattern_v1 *pattern; pattern = (struct smb2_compression_pattern_v1 *)srcp; dlen = le32_to_cpu(pattern->Repetitions); if (dlen == 0 || dlen > le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { cifs_dbg(VFS, "corrupt compressed data (%zu pattern repetitions)\n", dlen); return -ECONNRESET; } memset(dstp, pattern->Pattern, dlen); } else if (alg == SMB3_COMPRESS_NONE) { memcpy(dstp, srcp, dlen); } srcp += slen; src_len -= slen; dstp += dlen; } *dst_len = dstp - dst; if (*dst_len != le32_to_cpu(hdr->OriginalCompressedSegmentSize)) { cifs_dbg(VFS, "decompressed size mismatch (got %zu, expected %u)\n", *dst_len, le32_to_cpu(hdr->OriginalCompressedSegmentSize)); return -ECONNRESET; } if (((struct smb2_hdr *)dst)->ProtocolId != SMB2_PROTO_NUMBER) { cifs_dbg(VFS, "decompressed buffer is not an SMB2 message (ProtocolId 0x%x)\n", *(__le32 *)dst); return -ECONNRESET; } return 0; }