diff options
-rw-r--r-- | fs/smb/client/asn1.c | 9 | ||||
-rw-r--r-- | fs/smb/client/cifs_spnego.c | 18 | ||||
-rw-r--r-- | fs/smb/client/compress.c | 208 | ||||
-rw-r--r-- | fs/smb/client/compress.h | 44 | ||||
-rw-r--r-- | fs/smb/client/compress/lz77.c | 298 | ||||
-rw-r--r-- | fs/smb/client/compress/lz77.h | 1 | ||||
-rw-r--r-- | fs/smb/client/fs_context.c | 4 | ||||
-rw-r--r-- | fs/smb/client/fs_context.h | 1 | ||||
-rw-r--r-- | fs/smb/client/smb2ops.c | 141 | ||||
-rw-r--r-- | fs/smb/client/smb2pdu.c | 9 |
10 files changed, 616 insertions, 117 deletions
diff --git a/fs/smb/client/asn1.c b/fs/smb/client/asn1.c index 214a44509e7b..87910e2e0803 100644 --- a/fs/smb/client/asn1.c +++ b/fs/smb/client/asn1.c @@ -48,9 +48,12 @@ int cifs_neg_token_init_mech_type(void *context, size_t hdrlen, server->sec_mskerberos = true; else if (oid == OID_krb5u2u) server->sec_kerberosu2u = true; - else if (oid == OID_krb5) - server->sec_kerberos = true; - else if (oid == OID_ntlmssp) + else if (oid == OID_krb5) { + if (!server->sec_iakerb) + server->sec_kerberos = true; + else + server->sec_kerberos = false; + } else if (oid == OID_ntlmssp) server->sec_ntlmssp = true; else if (oid == OID_IAKerb) server->sec_iakerb = true; diff --git a/fs/smb/client/cifs_spnego.c b/fs/smb/client/cifs_spnego.c index bc1c1e9b288a..6a8d085ec991 100644 --- a/fs/smb/client/cifs_spnego.c +++ b/fs/smb/client/cifs_spnego.c @@ -160,7 +160,23 @@ cifs_get_spnego_key(struct cifs_ses *sesInfo, if (sesInfo->user_name) { dp = description + strlen(description); - sprintf(dp, ";user=%s", sesInfo->user_name); + if (server->sec_iakerb) { + /* + * TODO: add option to set ccache name or specify realm desired + */ + if (!sesInfo->password) { + cifs_dbg(VFS, "IAKerb requested, but no password provided\n"); + spnego_key = ERR_PTR(-EINVAL); + goto out; + } + sprintf(dp, ";user=%s;pw=%s", sesInfo->user_name, sesInfo->password); + } else { + sprintf(dp, ";user=%s", sesInfo->user_name); + } + } else if (server->sec_iakerb) { + cifs_dbg(VFS, "IAKerb requested, but no username provided\n"); + spnego_key = ERR_PTR(-EINVAL); + goto out; } dp = description + strlen(description); diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c index 766b4de13da7..a6324bb15b8e 100644 --- a/fs/smb/client/compress.c +++ b/fs/smb/client/compress.c @@ -16,17 +16,18 @@ #include <linux/kernel.h> #include <linux/uio.h> #include <linux/sort.h> +#include <linux/iov_iter.h> #include "cifsglob.h" -#include "../common/smb2pdu.h" -#include "cifsproto.h" -#include "smb2proto.h" #include "compress/lz77.h" #include "compress.h" +#define SAMPLE_CHUNK_LEN SZ_2K +#define SAMPLE_MAX_LEN SZ_2M + /* - * The heuristic_*() functions below try to determine data compressibility. + * The functions below try to determine data compressibility. * * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing * unused parts. @@ -34,7 +35,8 @@ * Read that file for better and more detailed explanation of the calculations. * * The algorithms are ran in a collected sample of the input (uncompressed) data. - * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M. + * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 2M. + * Those are adjusted according to R/W bufsizes negotiated with the server. * * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible) * to "need more analysis" (likely uncompressible). @@ -154,61 +156,39 @@ static int cmp_bkt(const void *_a, const void *_b) return 1; } -/* - * TODO: - * Support other iter types, if required. - * Only ITER_XARRAY is supported for now. - */ -static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample) +static size_t collect_step(void *base, size_t progress, size_t len, void *priv_iov, void *priv_len) { - struct folio *folios[16], *folio; - unsigned int nr, i, j, npages; - loff_t start = iter->xarray_start + iter->iov_offset; - pgoff_t last, index = start / PAGE_SIZE; - size_t len, off, foff; - void *p; - int s = 0; - - last = (start + max - 1) / PAGE_SIZE; - do { - nr = xa_extract(iter->xarray, (void **)folios, index, last, ARRAY_SIZE(folios), - XA_PRESENT); - if (nr == 0) - return -EIO; - - for (i = 0; i < nr; i++) { - folio = folios[i]; - npages = folio_nr_pages(folio); - foff = start - folio_pos(folio); - off = foff % PAGE_SIZE; - - for (j = foff / PAGE_SIZE; j < npages; j++) { - size_t len2; - - len = min_t(size_t, max, PAGE_SIZE - off); - len2 = min_t(size_t, len, SZ_2K); - - p = kmap_local_page(folio_page(folio, j)); - memcpy(&sample[s], p, len2); - kunmap_local(p); - - s += len2; - - if (len2 < SZ_2K || s >= max - SZ_2K) - return s; - - max -= len; - if (max <= 0) - return s; - - start += len; - off = 0; - index++; - } - } - } while (nr == ARRAY_SIZE(folios)); - - return s; + size_t plen, *cur_len = priv_len; + struct kvec *iov = priv_iov; + + if (progress >= iov->iov_len) + return len; + + plen = min_t(size_t, len, SAMPLE_CHUNK_LEN); + memcpy(iov->iov_base + *cur_len, base, plen); + + *cur_len += plen; + + if (len < SAMPLE_CHUNK_LEN) + return len; + + return 0; +} + +static int collect_sample(const struct iov_iter *iter, u8 *sample, size_t max_sample_len) +{ + size_t ret, len = iov_iter_count(iter); + struct iov_iter it = *iter; + struct kvec iov = { + .iov_base = sample, + .iov_len = max_sample_len, + }; + + ret = iterate_and_advance_kernel(&it, len, &iov, &max_sample_len, collect_step); + if (ret > len || ret > max_sample_len) + return -EIO; + + return 0; } /** @@ -220,22 +200,18 @@ static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample) * Tests shows that this function is quite reliable in predicting data compressibility, * matching close to 1:1 with the behaviour of LZ77 compression success and failures. */ -static bool is_compressible(const struct iov_iter *data) +static __maybe_unused bool is_compressible(const struct iov_iter *data) { - const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M; + const size_t bkt_size = 256; struct bucket *bkt = NULL; size_t len; u8 *sample; bool ret = false; int i; - /* Preventive double check -- already checked in should_compress(). */ len = iov_iter_count(data); - if (unlikely(len < read_size)) - return ret; - - if (len - read_size > max) - len = max; + if (len > SAMPLE_MAX_LEN) + len = SAMPLE_MAX_LEN; sample = kvzalloc(len, GFP_KERNEL); if (!sample) { @@ -245,16 +221,15 @@ static bool is_compressible(const struct iov_iter *data) } /* Sample 2K bytes per page of the uncompressed data. */ - i = collect_sample(data, len, sample); - if (i <= 0) { - WARN_ON_ONCE(1); + i = collect_sample(data, sample, len); + if (i < 0) { + WARN_ONCE(1, "data len=%zu, max sample len=%zu\n", iov_iter_count(data), len); goto out; } - len = i; ret = true; - + len = i; if (has_repeated_data(sample, len)) goto out; @@ -292,12 +267,15 @@ out: bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) { + struct TCP_Server_Info *server; const struct smb2_hdr *shdr = rq->rq_iov->iov_base; if (unlikely(!tcon || !tcon->ses || !tcon->ses->server)) return false; - if (!tcon->ses->server->compression.enabled) + server = tcon->ses->server; + + if (!server->compression.enabled) return false; if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA)) @@ -315,12 +293,36 @@ bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) return (shdr->Command == SMB2_READ); } +bool should_decompress(struct TCP_Server_Info *server, const void *buf) +{ + u32 len; + + if (!is_compress_hdr(buf)) + return false; + + len = decompressed_size(buf); + if (len < SMB_COMPRESS_HDR_LEN + server->vals->read_rsp_size) { + pr_warn("CIFS: Compressed message too small (%u bytes)\n", len); + + return false; + } + + if (len > SMB_DECOMPRESS_MAX_LEN(server)) { + pr_warn("CIFS: Uncompressed message too big (%u bytes)\n", len); + + return false; + } + + return true; +} + int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn) { struct iov_iter iter; u32 slen, dlen; void *src, *dst = NULL; int ret; + //unsigned long start; if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base) return -EINVAL; @@ -354,12 +356,16 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s goto err_free; } + //start = jiffies; ret = lz77_compress(src, slen, dst, &dlen); + //pr_err("%s: compress runtime=%ums, ret=%d, dlen=%u\n", __func__, jiffies_to_msecs(jiffies - start), ret, dlen); if (!ret) { struct smb2_compression_hdr hdr = { 0 }; struct smb_rqst comp_rq = { .rq_nvec = 3, }; struct kvec iov[3]; + //pr_err("%s: compress runtime=%ums, dlen=%u\n", __func__, jiffies_to_msecs(jiffies - start), dlen); + hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen); hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77; @@ -384,3 +390,59 @@ err_free: return ret; } + +int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen) +{ + const struct smb2_compression_hdr *hdr = src; + struct smb2_hdr *shdr; + u32 buf_len, data_len; + int ret; + + if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77) + return -EIO; + + buf_len = le32_to_cpu(hdr->Offset); + data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize); + if (unlikely(*dlen != buf_len + data_len)) + return -EINVAL; + + /* + * Copy uncompressed data from the beginning of the payload. + * The remainder is all compressed data. + */ + src += SMB_COMPRESS_HDR_LEN; + slen -= SMB_COMPRESS_HDR_LEN; + memcpy(dst, src, buf_len); + + src += buf_len; + slen -= buf_len; + *dlen = data_len; + + ret = lz77_decompress(src, slen, dst + buf_len, dlen); + if (ret) + return ret; + + shdr = dst; + + if (*dlen != data_len) { + pr_warn("CIFS: Decompressed size mismatch: got %u, expected %u (mid=%llu)\n", + *dlen, data_len, le64_to_cpu(shdr->MessageId)); + + return -EINVAL; + } + + if (shdr->ProtocolId != SMB2_PROTO_NUMBER) { + pr_warn("CIFS: Decompressed buffer is not an SMB2 message: ProtocolId 0x%x (mid=%llu)\n", + le32_to_cpu(shdr->ProtocolId), le64_to_cpu(shdr->MessageId)); + + return -EINVAL; + } + + /* + * @dlen contains only the decompressed data size, add back the previously copied + * hdr->Offset size. + */ + *dlen += buf_len; + + return 0; +} diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h index f3ed1d3e52fb..717076bcd343 100644 --- a/fs/smb/client/compress.h +++ b/fs/smb/client/compress.h @@ -17,19 +17,48 @@ #include <linux/uio.h> #include <linux/kernel.h> +#include <linux/minmax.h> + #include "../common/smb2pdu.h" -#include "cifsglob.h" /* sizeof(smb2_compression_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_HDR_LEN 16 /* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */ #define SMB_COMPRESS_PAYLOAD_HDR_LEN 8 -#define SMB_COMPRESS_MIN_LEN PAGE_SIZE +#define SMB_COMPRESS_MIN_LEN SZ_64K +#define SMB_DECOMPRESS_MAX_LEN(_srv) \ + (256 + SMB_COMPRESS_HDR_LEN + max3((_srv)->maxBuf, (_srv)->max_read, (_srv)->max_write)) + +static __always_inline u32 decompressed_size(const void *buf) +{ + const struct smb2_compression_hdr *hdr = buf; + + if (!buf || !IS_ENABLED(CONFIG_CIFS_COMPRESSION)) + return 0; + + return le32_to_cpu(hdr->Offset) + + le32_to_cpu(hdr->OriginalCompressedSegmentSize); +} + +static __always_inline bool is_compress_hdr(const void *buf) +{ + const struct smb2_compression_hdr *hdr = buf; + + if (!buf) + return false; + + return (hdr->ProtocolId == SMB2_COMPRESSION_TRANSFORM_ID); +} #ifdef CONFIG_CIFS_COMPRESSION +struct TCP_Server_Info; +struct cifs_tcon; +struct smb_rqst; + typedef int (*compress_send_fn)(struct TCP_Server_Info *, int, struct smb_rqst *); int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn); +int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen); /** * should_compress() - Determines if a request (write) or the response to a @@ -47,6 +76,7 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s * Return false otherwise. */ bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq); +bool should_decompress(struct TCP_Server_Info *server, const void *buf); /** * smb_compress_alg_valid() - Validate a compression algorithm. @@ -77,11 +107,21 @@ static inline int smb_compress(void *unused1, void *unused2, void *unused3) return -EOPNOTSUPP; } +static inline int smb_decompress(const void *unused1, u32 unused2, void *unused3, u32 unused4) +{ + return -EOPNOTSUPP; +} + static inline bool should_compress(void *unused1, void *unused2) { return false; } +static inline bool should_decompress(void *unused1, void *unused2) +{ + return false; +} + static inline int smb_compress_alg_valid(__le16 unused1, bool unused2) { return -EOPNOTSUPP; diff --git a/fs/smb/client/compress/lz77.c b/fs/smb/client/compress/lz77.c index 96e8a8057a77..9d35691261fd 100644 --- a/fs/smb/client/compress/lz77.c +++ b/fs/smb/client/compress/lz77.c @@ -16,7 +16,7 @@ /* * Compression parameters. */ -#define LZ77_MATCH_MIN_LEN 4 +#define LZ77_MATCH_MIN_LEN 3 #define LZ77_MATCH_MIN_DIST 1 #define LZ77_MATCH_MAX_DIST SZ_1K #define LZ77_HASH_LOG 15 @@ -28,6 +28,16 @@ static __always_inline u8 lz77_read8(const u8 *ptr) return get_unaligned(ptr); } +static __always_inline u16 lz77_read16(const u16 *ptr) +{ + return get_unaligned(ptr); +} + +static __always_inline u32 lz77_read32(const u32 *ptr) +{ + return get_unaligned(ptr); +} + static __always_inline u64 lz77_read64(const u64 *ptr) { return get_unaligned(ptr); @@ -51,11 +61,13 @@ static __always_inline void lz77_write32(u32 *ptr, u32 v) static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end) { const void *start = cur; - u64 diff; - - /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */ + /* + * Safe for a do/while (i.e. no bounds check the first iteration) because otherwise we + * wouldn't reach here from the main loop. + */ do { - diff = lz77_read64(cur) ^ lz77_read64(wnd); + const u64 diff = lz77_read64(cur) ^ lz77_read64(wnd); + if (!diff) { cur += LZ77_STEP_SIZE; wnd += LZ77_STEP_SIZE; @@ -69,8 +81,10 @@ static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, cons return (cur - start); } while (likely(cur + LZ77_STEP_SIZE < end)); - while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++)) - ; + while (cur < end && lz77_read8(cur) == lz77_read8(wnd)) { + cur++; + wnd++; + } return (cur - start); } @@ -131,29 +145,82 @@ static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) { - const void *srcp, *end; - void *dstp, *nib, *flag_pos; + const void *srcp, *end, *anchor; + void *dstp, *nib, *flag_pos, *dend; u32 flag_count = 0; long flag = 0; u64 *htable; srcp = src; + anchor = srcp; end = src + slen; dstp = dst; nib = NULL; flag_pos = dstp; dstp += 4; + dend = dst + (*dlen - (*dlen >> 3)); htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL); if (!htable) return -ENOMEM; +#if 0 + d = 4; + + /* warm up loop */ + do { + u32 dist, len = 0; + const void *wnd; + u64 hash; + + hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG); + wnd = src + htable[hash]; + htable[hash] = srcp - src; + dist = srcp - wnd; + + if (dist && dist < LZ77_MATCH_MAX_DIST) + len = lz77_match_len(wnd, srcp, end); + + if (len < LZ77_MATCH_MIN_LEN) { + srcp++; + d++; + + continue; + } + + d += 4; + srcp += len; + } while (likely(srcp + LZ77_STEP_SIZE < end)); + + if (srcp < end) + d += (end - srcp); + + if (d > (slen - (slen >> 2))) { + *dlen = slen; + goto out; + } + + srcp = src; + anchor = srcp; +#endif + /* Main loop. */ do { u32 dist, len = 0; const void *wnd; u64 hash; +#if 0 + /* + * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth + * going further. + */ + if (unlikely(dstp - dst >= slen - (slen >> 3))) { + *dlen = slen; + goto out; + } +#endif + hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG); wnd = src + htable[hash]; htable[hash] = srcp - src; @@ -162,12 +229,41 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) if (dist && dist < LZ77_MATCH_MAX_DIST) len = lz77_match_len(wnd, srcp, end); +#if 1 + if (len < LZ77_MATCH_MIN_LEN) { + srcp++; + continue; + } + + if (dstp + (srcp - anchor) >= dend) + goto out; + + while (anchor < srcp) { + u32 c = umin(srcp - anchor, 32 - flag_count); + + memcpy(dstp, anchor, c); + anchor += c; + dstp += c; + + flag <<= c; + flag_count += c; + if (flag_count == 32) { + lz77_write32(flag_pos, flag); + flag_count = 0; + flag_pos = dstp; + dstp += 4; + } + } +#else if (len < LZ77_MATCH_MIN_LEN) { lz77_write8(dstp, lz77_read8(srcp)); dstp++; srcp++; + if (dstp >= dend) + goto out; + flag <<= 1; flag_count++; if (flag_count == 32) { @@ -179,18 +275,13 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) continue; } - - /* - * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth - * going further. - */ - if (unlikely(dstp - dst >= slen - (slen >> 3))) { - *dlen = slen; +#endif + if (dstp + len >= dend) goto out; - } dstp = lz77_write_match(dstp, &nib, dist, len); srcp += len; + anchor = srcp; flag = (flag << 1) | 1; flag_count++; @@ -202,9 +293,32 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) } } while (likely(srcp + LZ77_STEP_SIZE < end)); + while (anchor < srcp) { + u32 c = umin(srcp - anchor, 32 - flag_count); + + if (dstp + c >= dend) + goto out; + + memcpy(dstp, anchor, c); + anchor += c; + dstp += c; + + flag <<= c; + flag_count += c; + if (flag_count == 32) { + lz77_write32(flag_pos, flag); + flag_count = 0; + flag_pos = dstp; + dstp += 4; + } + } + while (srcp < end) { u32 c = umin(end - srcp, 32 - flag_count); + if (dstp + c >= dend) + goto out; + memcpy(dstp, srcp, c); dstp += c; @@ -221,7 +335,7 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) } flag <<= (32 - flag_count); - flag |= (1 << (32 - flag_count)) - 1; + flag |= (1UL << (32 - flag_count)) - 1; lz77_write32(flag_pos, flag); *dlen = dstp - dst; @@ -233,3 +347,151 @@ out: return -EMSGSIZE; } + +static __always_inline const void *lz77_read_match_len(const void *src, const void **nib, u32 *dist, + u32 *len) +{ + u32 mlen; + u16 sym; + + sym = lz77_read16(src); + src += 2; + + *dist = (sym >> 3) + 1; + mlen = sym & 7; + + if (mlen == 7) { + if (!*nib) { + *nib = src; + mlen = lz77_read8(src) % 16; + src++; + } else { + mlen = lz77_read8(*nib) >> 4; + *nib = NULL; + } + + if (mlen == 15) { + mlen = lz77_read8(src); + src++; + + if (mlen == 255) { + mlen = lz77_read16(src); + src += 2; + + if (mlen == 0) { + mlen = lz77_read32(src); + src += 4; + } + + if (mlen < 15 + 7) { + pr_warn("CIFS: unexpected match length %u\n", mlen); + + return NULL; + } + mlen -= (15 + 7); + } + mlen += 15; + } + mlen += 7; + } + mlen += 3; + + *len = mlen; + + return src; +} + +int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen) +{ + const void *srcp = src, *end = src + slen, *nib = NULL; + void *dstp = dst, *dst_end; + u32 c, flag_count = 0; + long flag; + + if (!dlen || *dlen < slen) + return -EINVAL; + + dst_end = dst + *dlen; + *dlen = 0; + + while (likely(srcp < end)) { + u32 dist, len; + + if (flag_count == 0) { + flag = lz77_read32(srcp); + srcp += 4; + flag_count = 32; + } + + /* Decode literals. */ + c = flag ? __builtin_clz(flag) : 32; + c = umin(c, flag_count); + + if (unlikely(dstp + c > dst_end)) + return -EFAULT; + + flag_count -= c; + flag <<= c; + + memcpy(dstp, srcp, c); + + dstp += c; + srcp += c; + + if (flag_count == 0) + continue; + + /* + * This means we've parsed the whole input buffer @src and filled + * @dst within its memory bounds. + * + * However, it's up to callers to determine if the decompressed + * buffer and size are according to what they expected to get. + */ + if (unlikely(srcp + 1 >= end)) + break; + + flag_count--; + flag <<= 1; + + /* Decode match. */ + srcp = lz77_read_match_len(srcp, &nib, &dist, &len); + if (unlikely(!srcp)) + return -EIO; + + /* + * Even though we use a 1K window size and 4-byte min match len, the server can use + * parameters of their choosing, so check against the standard defaults. + */ + if (unlikely(dist > SZ_8K || len < 3)) + return -EIO; + + /* Bogus compression, match distance too far. */ + if (unlikely(dstp - dst < dist)) + return -EFAULT; + + /* Bogus compression, match length too long. */ + if (unlikely(dstp + len > dst_end)) + return -EMSGSIZE; + + /* + * If non-overlapping memory, we can use memcpy(). + * Otherwise, we have to do it byte by byte. + */ + if (len < dist) { + memcpy(dstp, dstp - dist, len); + dstp += len; + } else { + int i = 0; + + while (i++ < len) { + lz77_write8(dstp, lz77_read8(dstp - dist)); + dstp++; + } + } + } + + *dlen = dstp - dst; + + return 0; +} diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h index cdcb191b48a2..04bf2e022b39 100644 --- a/fs/smb/client/compress/lz77.h +++ b/fs/smb/client/compress/lz77.h @@ -12,4 +12,5 @@ #include <linux/kernel.h> int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen); +int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen); #endif /* _SMB_COMPRESS_LZ77_H */ diff --git a/fs/smb/client/fs_context.c b/fs/smb/client/fs_context.c index 063e62189bea..ef20ea5319f0 100644 --- a/fs/smb/client/fs_context.c +++ b/fs/smb/client/fs_context.c @@ -62,6 +62,7 @@ static const match_table_t cifs_secflavor_tokens = { { Opt_sec_ntlmv2, "nontlm" }, { Opt_sec_ntlmv2, "ntlmv2" }, { Opt_sec_ntlmv2i, "ntlmv2i" }, + { Opt_sec_iakerb, "iakerb" }, { Opt_sec_none, "none" }, { Opt_sec_err, NULL } @@ -248,6 +249,9 @@ cifs_parse_security_flavors(struct fs_context *fc, char *value, struct smb3_fs_c case Opt_sec_ntlmv2: ctx->sectype = NTLMv2; break; + case Opt_sec_iakerb: + ctx->sectype = IAKerb; + break; case Opt_sec_none: ctx->nullauth = 1; kfree(ctx->username); diff --git a/fs/smb/client/fs_context.h b/fs/smb/client/fs_context.h index d1d29249bcdb..26df0fd36c7e 100644 --- a/fs/smb/client/fs_context.h +++ b/fs/smb/client/fs_context.h @@ -65,6 +65,7 @@ enum cifs_sec_param { Opt_sec_krb5, Opt_sec_krb5i, Opt_sec_krb5p, + Opt_sec_iakerb, Opt_sec_ntlmsspi, Opt_sec_ntlmssp, Opt_sec_ntlmv2, diff --git a/fs/smb/client/smb2ops.c b/fs/smb/client/smb2ops.c index 41d8cd20b25f..b35b9a7747a8 100644 --- a/fs/smb/client/smb2ops.c +++ b/fs/smb/client/smb2ops.c @@ -30,6 +30,7 @@ #include "fs_context.h" #include "cached_dir.h" #include "reparse.h" +#include "compress.h" /* Change credits for different ops and return the total number of credits */ static int @@ -4511,12 +4512,17 @@ err_free: return rc; } +static __always_inline bool is_transform_hdr(const void *buf) +{ + const struct smb2_transform_hdr *trhdr = buf; + + return (trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM); +} + static int smb3_is_transform_hdr(void *buf) { - struct smb2_transform_hdr *trhdr = buf; - - return trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM; + return is_transform_hdr(buf) || is_compress_hdr(buf); } static int @@ -4725,8 +4731,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, } else if (buf_len >= data_offset + data_len) { /* read response payload is in buf */ + struct iov_iter it = rdata->subreq.io_iter; + WARN_ONCE(buffer, "read data can be either in buf or in buffer"); - length = copy_to_iter(buf + data_offset, data_len, &rdata->subreq.io_iter); + length = copy_to_iter(buf + data_offset, data_len, &it); if (length < 0) return length; rdata->got_bytes = data_len; @@ -5023,6 +5031,87 @@ one_more: return ret; } +static int receive_compressed(struct TCP_Server_Info *server) +{ + struct mid_q_entry *mid; + void *src, *dst = NULL; + u32 slen, dlen; + int ret; + + slen = server->pdu_size; + src = kvzalloc(slen, GFP_KERNEL); + if (!src) + return -ENOMEM; + + ret = server->total_read; + memcpy(src, server->smallbuf, ret); + + ret = cifs_read_from_socket(server, src + ret, slen - ret); + if (ret < 0) + goto err_free; + + server->total_read += ret; + + dlen = decompressed_size(src); + dst = kvzalloc(dlen, GFP_KERNEL); + if (!dst) { + ret = -ENOMEM; + + goto err_free; + } + + ret = smb_decompress(src, slen, dst, &dlen); + if (ret) { + spin_lock(&server->srv_lock); + server->tcpStatus = CifsNeedReconnect; + spin_unlock(&server->srv_lock); + + goto err_free; + } + + mid = smb2_find_dequeue_mid(server, dst); + if (!mid) { + ret = -EIO; + + goto err_free; + } + + ret = handle_read_data(server, mid, dst, dlen, NULL, 0, true); + if (!ret) { +#ifdef CONFIG_CIFS_STATS2 + mid->when_received = jiffies; +#endif + if (server->ops->is_network_name_deleted) + server->ops->is_network_name_deleted(dst, server); + + mid->callback(mid); + } else { + spin_lock(&server->srv_lock); + if (server->tcpStatus == CifsNeedReconnect) { + spin_lock(&server->mid_lock); + mid->mid_state = MID_RETRY_NEEDED; + spin_unlock(&server->mid_lock); + spin_unlock(&server->srv_lock); + + mid->callback(mid); + } else { + spin_lock(&server->mid_lock); + mid->mid_state = MID_REQUEST_SUBMITTED; + mid->mid_flags &= ~(MID_DELETED); + list_add_tail(&mid->qhead, &server->pending_mid_q); + spin_unlock(&server->mid_lock); + spin_unlock(&server->srv_lock); + } + } + + release_mid(mid); +err_free: + kvfree(src); + kvfree(dst); + + return ret; +} + static int smb3_receive_transform(struct TCP_Server_Info *server, struct mid_q_entry **mids, char **bufs, int *num_mids) @@ -5032,26 +5121,36 @@ smb3_receive_transform(struct TCP_Server_Info *server, struct smb2_transform_hdr *tr_hdr = (struct smb2_transform_hdr *)buf; unsigned int orig_len = le32_to_cpu(tr_hdr->OriginalMessageSize); - if (pdu_length < sizeof(struct smb2_transform_hdr) + - sizeof(struct smb2_hdr)) { - cifs_server_dbg(VFS, "Transform message is too small (%u)\n", - pdu_length); - cifs_reconnect(server, true); - return -ECONNABORTED; - } + if (is_transform_hdr(buf)) { + if (pdu_length < sizeof(struct smb2_transform_hdr) + + sizeof(struct smb2_hdr)) { + cifs_server_dbg(VFS, "Transform message is too small (%u)\n", pdu_length); + cifs_reconnect(server, true); - if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) { - cifs_server_dbg(VFS, "Transform message is broken\n"); - cifs_reconnect(server, true); - return -ECONNABORTED; - } + return -ECONNABORTED; + } - /* TODO: add support for compounds containing READ. */ - if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) { - return receive_encrypted_read(server, &mids[0], num_mids); + if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) { + cifs_server_dbg(VFS, "Transform message is broken\n"); + cifs_reconnect(server, true); + + return -ECONNABORTED; + } + + /* TODO: add support for compounds containing READ. */ + if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) + return receive_encrypted_read(server, &mids[0], num_mids); + + return receive_encrypted_standard(server, mids, bufs, num_mids); } - return receive_encrypted_standard(server, mids, bufs, num_mids); + if (should_decompress(server, buf)) + return receive_compressed(server); + + cifs_server_dbg(VFS, "Invalid ProtocolId 0x%x\n", *(__le32 *)buf); + cifs_reconnect(server, true); + + return -ECONNABORTED; } int @@ -5073,6 +5172,8 @@ static int smb2_next_header(struct TCP_Server_Info *server, char *buf, *noff = le32_to_cpu(t_hdr->OriginalMessageSize); if (unlikely(check_add_overflow(*noff, sizeof(*t_hdr), noff))) return -EINVAL; + } else if (hdr->ProtocolId == SMB2_COMPRESSION_TRANSFORM_ID) { + *noff = 0; } else { *noff = le32_to_cpu(hdr->NextCommand); } diff --git a/fs/smb/client/smb2pdu.c b/fs/smb/client/smb2pdu.c index 81e05db8e4d5..be39613f2857 100644 --- a/fs/smb/client/smb2pdu.c +++ b/fs/smb/client/smb2pdu.c @@ -1408,7 +1408,9 @@ out_free_inbuf: enum securityEnum smb2_select_sectype(struct TCP_Server_Info *server, enum securityEnum requested) { + pr_err("%s: server sec=%d, iakerb=%d, requested=%d\n", __func__, server->sec_mode, server->sec_iakerb, requested); switch (requested) { + case IAKerb: case Kerberos: case RawNTLMSSP: return requested; @@ -1883,6 +1885,7 @@ SMB2_select_sec(struct SMB2_sess_data *sess_data) } switch (type) { + case IAKerb: case Kerberos: sess_data->func = SMB2_auth_kerberos; break; @@ -4694,6 +4697,12 @@ smb2_async_readv(struct cifs_io_subrequest *rdata) flags |= CIFS_HAS_CREDITS; } + if (should_compress(io_parms.tcon, &rqst)) { + struct smb2_read_req *req = (struct smb2_read_req *)buf; + + req->Flags |= SMB2_READFLAG_REQUEST_COMPRESSED; + } + rc = cifs_call_async(server, &rqst, cifs_readv_receive, smb2_readv_callback, smb3_handle_read_data, rdata, flags, |