summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEnzo Matsumiya <ematsumiya@suse.de>2025-04-14 14:46:06 +0200
committerEnzo Matsumiya <ematsumiya@suse.de>2025-04-14 14:46:06 +0200
commit9acffd588ff49f452b826c31d7a89d10158c8bb7 (patch)
treedbbefc81034c0977728814033d033dea70488a8e
parentb77d90938623aff54af3fc1cb95157427a98da66 (diff)
downloadlinux-9acffd588ff49f452b826c31d7a89d10158c8bb7.tar.gz
linux-9acffd588ff49f452b826c31d7a89d10158c8bb7.tar.bz2
linux-9acffd588ff49f452b826c31d7a89d10158c8bb7.zip
smb: client: sambaXP 2025 modssambaXP-2025
iakerb + compression patches Signed-off-by: Enzo Matsumiya <ematsumiya@suse.de>
-rw-r--r--fs/smb/client/asn1.c9
-rw-r--r--fs/smb/client/cifs_spnego.c18
-rw-r--r--fs/smb/client/compress.c208
-rw-r--r--fs/smb/client/compress.h44
-rw-r--r--fs/smb/client/compress/lz77.c298
-rw-r--r--fs/smb/client/compress/lz77.h1
-rw-r--r--fs/smb/client/fs_context.c4
-rw-r--r--fs/smb/client/fs_context.h1
-rw-r--r--fs/smb/client/smb2ops.c141
-rw-r--r--fs/smb/client/smb2pdu.c9
10 files changed, 616 insertions, 117 deletions
diff --git a/fs/smb/client/asn1.c b/fs/smb/client/asn1.c
index 214a44509e7b..87910e2e0803 100644
--- a/fs/smb/client/asn1.c
+++ b/fs/smb/client/asn1.c
@@ -48,9 +48,12 @@ int cifs_neg_token_init_mech_type(void *context, size_t hdrlen,
server->sec_mskerberos = true;
else if (oid == OID_krb5u2u)
server->sec_kerberosu2u = true;
- else if (oid == OID_krb5)
- server->sec_kerberos = true;
- else if (oid == OID_ntlmssp)
+ else if (oid == OID_krb5) {
+ if (!server->sec_iakerb)
+ server->sec_kerberos = true;
+ else
+ server->sec_kerberos = false;
+ } else if (oid == OID_ntlmssp)
server->sec_ntlmssp = true;
else if (oid == OID_IAKerb)
server->sec_iakerb = true;
diff --git a/fs/smb/client/cifs_spnego.c b/fs/smb/client/cifs_spnego.c
index bc1c1e9b288a..6a8d085ec991 100644
--- a/fs/smb/client/cifs_spnego.c
+++ b/fs/smb/client/cifs_spnego.c
@@ -160,7 +160,23 @@ cifs_get_spnego_key(struct cifs_ses *sesInfo,
if (sesInfo->user_name) {
dp = description + strlen(description);
- sprintf(dp, ";user=%s", sesInfo->user_name);
+ if (server->sec_iakerb) {
+ /*
+ * TODO: add option to set ccache name or specify realm desired
+ */
+ if (!sesInfo->password) {
+ cifs_dbg(VFS, "IAKerb requested, but no password provided\n");
+ spnego_key = ERR_PTR(-EINVAL);
+ goto out;
+ }
+ sprintf(dp, ";user=%s;pw=%s", sesInfo->user_name, sesInfo->password);
+ } else {
+ sprintf(dp, ";user=%s", sesInfo->user_name);
+ }
+ } else if (server->sec_iakerb) {
+ cifs_dbg(VFS, "IAKerb requested, but no username provided\n");
+ spnego_key = ERR_PTR(-EINVAL);
+ goto out;
}
dp = description + strlen(description);
diff --git a/fs/smb/client/compress.c b/fs/smb/client/compress.c
index 766b4de13da7..a6324bb15b8e 100644
--- a/fs/smb/client/compress.c
+++ b/fs/smb/client/compress.c
@@ -16,17 +16,18 @@
#include <linux/kernel.h>
#include <linux/uio.h>
#include <linux/sort.h>
+#include <linux/iov_iter.h>
#include "cifsglob.h"
-#include "../common/smb2pdu.h"
-#include "cifsproto.h"
-#include "smb2proto.h"
#include "compress/lz77.h"
#include "compress.h"
+#define SAMPLE_CHUNK_LEN SZ_2K
+#define SAMPLE_MAX_LEN SZ_2M
+
/*
- * The heuristic_*() functions below try to determine data compressibility.
+ * The functions below try to determine data compressibility.
*
* Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
* unused parts.
@@ -34,7 +35,8 @@
* Read that file for better and more detailed explanation of the calculations.
*
* The algorithms are ran in a collected sample of the input (uncompressed) data.
- * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
+ * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 2M.
+ * Those are adjusted according to R/W bufsizes negotiated with the server.
*
* Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
* to "need more analysis" (likely uncompressible).
@@ -154,61 +156,39 @@ static int cmp_bkt(const void *_a, const void *_b)
return 1;
}
-/*
- * TODO:
- * Support other iter types, if required.
- * Only ITER_XARRAY is supported for now.
- */
-static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample)
+static size_t collect_step(void *base, size_t progress, size_t len, void *priv_iov, void *priv_len)
{
- struct folio *folios[16], *folio;
- unsigned int nr, i, j, npages;
- loff_t start = iter->xarray_start + iter->iov_offset;
- pgoff_t last, index = start / PAGE_SIZE;
- size_t len, off, foff;
- void *p;
- int s = 0;
-
- last = (start + max - 1) / PAGE_SIZE;
- do {
- nr = xa_extract(iter->xarray, (void **)folios, index, last, ARRAY_SIZE(folios),
- XA_PRESENT);
- if (nr == 0)
- return -EIO;
-
- for (i = 0; i < nr; i++) {
- folio = folios[i];
- npages = folio_nr_pages(folio);
- foff = start - folio_pos(folio);
- off = foff % PAGE_SIZE;
-
- for (j = foff / PAGE_SIZE; j < npages; j++) {
- size_t len2;
-
- len = min_t(size_t, max, PAGE_SIZE - off);
- len2 = min_t(size_t, len, SZ_2K);
-
- p = kmap_local_page(folio_page(folio, j));
- memcpy(&sample[s], p, len2);
- kunmap_local(p);
-
- s += len2;
-
- if (len2 < SZ_2K || s >= max - SZ_2K)
- return s;
-
- max -= len;
- if (max <= 0)
- return s;
-
- start += len;
- off = 0;
- index++;
- }
- }
- } while (nr == ARRAY_SIZE(folios));
-
- return s;
+ size_t plen, *cur_len = priv_len;
+ struct kvec *iov = priv_iov;
+
+ if (progress >= iov->iov_len)
+ return len;
+
+ plen = min_t(size_t, len, SAMPLE_CHUNK_LEN);
+ memcpy(iov->iov_base + *cur_len, base, plen);
+
+ *cur_len += plen;
+
+ if (len < SAMPLE_CHUNK_LEN)
+ return len;
+
+ return 0;
+}
+
+static int collect_sample(const struct iov_iter *iter, u8 *sample, size_t max_sample_len)
+{
+ size_t ret, len = iov_iter_count(iter);
+ struct iov_iter it = *iter;
+ struct kvec iov = {
+ .iov_base = sample,
+ .iov_len = max_sample_len,
+ };
+
+ ret = iterate_and_advance_kernel(&it, len, &iov, &max_sample_len, collect_step);
+ if (ret > len || ret > max_sample_len)
+ return -EIO;
+
+ return 0;
}
/**
@@ -220,22 +200,18 @@ static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample)
* Tests shows that this function is quite reliable in predicting data compressibility,
* matching close to 1:1 with the behaviour of LZ77 compression success and failures.
*/
-static bool is_compressible(const struct iov_iter *data)
+static __maybe_unused bool is_compressible(const struct iov_iter *data)
{
- const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M;
+ const size_t bkt_size = 256;
struct bucket *bkt = NULL;
size_t len;
u8 *sample;
bool ret = false;
int i;
- /* Preventive double check -- already checked in should_compress(). */
len = iov_iter_count(data);
- if (unlikely(len < read_size))
- return ret;
-
- if (len - read_size > max)
- len = max;
+ if (len > SAMPLE_MAX_LEN)
+ len = SAMPLE_MAX_LEN;
sample = kvzalloc(len, GFP_KERNEL);
if (!sample) {
@@ -245,16 +221,15 @@ static bool is_compressible(const struct iov_iter *data)
}
/* Sample 2K bytes per page of the uncompressed data. */
- i = collect_sample(data, len, sample);
- if (i <= 0) {
- WARN_ON_ONCE(1);
+ i = collect_sample(data, sample, len);
+ if (i < 0) {
+ WARN_ONCE(1, "data len=%zu, max sample len=%zu\n", iov_iter_count(data), len);
goto out;
}
- len = i;
ret = true;
-
+ len = i;
if (has_repeated_data(sample, len))
goto out;
@@ -292,12 +267,15 @@ out:
bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
{
+ struct TCP_Server_Info *server;
const struct smb2_hdr *shdr = rq->rq_iov->iov_base;
if (unlikely(!tcon || !tcon->ses || !tcon->ses->server))
return false;
- if (!tcon->ses->server->compression.enabled)
+ server = tcon->ses->server;
+
+ if (!server->compression.enabled)
return false;
if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
@@ -315,12 +293,36 @@ bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
return (shdr->Command == SMB2_READ);
}
+bool should_decompress(struct TCP_Server_Info *server, const void *buf)
+{
+ u32 len;
+
+ if (!is_compress_hdr(buf))
+ return false;
+
+ len = decompressed_size(buf);
+ if (len < SMB_COMPRESS_HDR_LEN + server->vals->read_rsp_size) {
+ pr_warn("CIFS: Compressed message too small (%u bytes)\n", len);
+
+ return false;
+ }
+
+ if (len > SMB_DECOMPRESS_MAX_LEN(server)) {
+ pr_warn("CIFS: Uncompressed message too big (%u bytes)\n", len);
+
+ return false;
+ }
+
+ return true;
+}
+
int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn)
{
struct iov_iter iter;
u32 slen, dlen;
void *src, *dst = NULL;
int ret;
+ //unsigned long start;
if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base)
return -EINVAL;
@@ -354,12 +356,16 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s
goto err_free;
}
+ //start = jiffies;
ret = lz77_compress(src, slen, dst, &dlen);
+ //pr_err("%s: compress runtime=%ums, ret=%d, dlen=%u\n", __func__, jiffies_to_msecs(jiffies - start), ret, dlen);
if (!ret) {
struct smb2_compression_hdr hdr = { 0 };
struct smb_rqst comp_rq = { .rq_nvec = 3, };
struct kvec iov[3];
+ //pr_err("%s: compress runtime=%ums, dlen=%u\n", __func__, jiffies_to_msecs(jiffies - start), dlen);
+
hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen);
hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77;
@@ -384,3 +390,59 @@ err_free:
return ret;
}
+
+int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen)
+{
+ const struct smb2_compression_hdr *hdr = src;
+ struct smb2_hdr *shdr;
+ u32 buf_len, data_len;
+ int ret;
+
+ if (hdr->CompressionAlgorithm != SMB3_COMPRESS_LZ77)
+ return -EIO;
+
+ buf_len = le32_to_cpu(hdr->Offset);
+ data_len = le32_to_cpu(hdr->OriginalCompressedSegmentSize);
+ if (unlikely(*dlen != buf_len + data_len))
+ return -EINVAL;
+
+ /*
+ * Copy uncompressed data from the beginning of the payload.
+ * The remainder is all compressed data.
+ */
+ src += SMB_COMPRESS_HDR_LEN;
+ slen -= SMB_COMPRESS_HDR_LEN;
+ memcpy(dst, src, buf_len);
+
+ src += buf_len;
+ slen -= buf_len;
+ *dlen = data_len;
+
+ ret = lz77_decompress(src, slen, dst + buf_len, dlen);
+ if (ret)
+ return ret;
+
+ shdr = dst;
+
+ if (*dlen != data_len) {
+ pr_warn("CIFS: Decompressed size mismatch: got %u, expected %u (mid=%llu)\n",
+ *dlen, data_len, le64_to_cpu(shdr->MessageId));
+
+ return -EINVAL;
+ }
+
+ if (shdr->ProtocolId != SMB2_PROTO_NUMBER) {
+ pr_warn("CIFS: Decompressed buffer is not an SMB2 message: ProtocolId 0x%x (mid=%llu)\n",
+ le32_to_cpu(shdr->ProtocolId), le64_to_cpu(shdr->MessageId));
+
+ return -EINVAL;
+ }
+
+ /*
+ * @dlen contains only the decompressed data size, add back the previously copied
+ * hdr->Offset size.
+ */
+ *dlen += buf_len;
+
+ return 0;
+}
diff --git a/fs/smb/client/compress.h b/fs/smb/client/compress.h
index f3ed1d3e52fb..717076bcd343 100644
--- a/fs/smb/client/compress.h
+++ b/fs/smb/client/compress.h
@@ -17,19 +17,48 @@
#include <linux/uio.h>
#include <linux/kernel.h>
+#include <linux/minmax.h>
+
#include "../common/smb2pdu.h"
-#include "cifsglob.h"
/* sizeof(smb2_compression_hdr) - sizeof(OriginalPayloadSize) */
#define SMB_COMPRESS_HDR_LEN 16
/* sizeof(smb2_compression_payload_hdr) - sizeof(OriginalPayloadSize) */
#define SMB_COMPRESS_PAYLOAD_HDR_LEN 8
-#define SMB_COMPRESS_MIN_LEN PAGE_SIZE
+#define SMB_COMPRESS_MIN_LEN SZ_64K
+#define SMB_DECOMPRESS_MAX_LEN(_srv) \
+ (256 + SMB_COMPRESS_HDR_LEN + max3((_srv)->maxBuf, (_srv)->max_read, (_srv)->max_write))
+
+static __always_inline u32 decompressed_size(const void *buf)
+{
+ const struct smb2_compression_hdr *hdr = buf;
+
+ if (!buf || !IS_ENABLED(CONFIG_CIFS_COMPRESSION))
+ return 0;
+
+ return le32_to_cpu(hdr->Offset) +
+ le32_to_cpu(hdr->OriginalCompressedSegmentSize);
+}
+
+static __always_inline bool is_compress_hdr(const void *buf)
+{
+ const struct smb2_compression_hdr *hdr = buf;
+
+ if (!buf)
+ return false;
+
+ return (hdr->ProtocolId == SMB2_COMPRESSION_TRANSFORM_ID);
+}
#ifdef CONFIG_CIFS_COMPRESSION
+struct TCP_Server_Info;
+struct cifs_tcon;
+struct smb_rqst;
+
typedef int (*compress_send_fn)(struct TCP_Server_Info *, int, struct smb_rqst *);
int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn);
+int smb_decompress(const void *src, u32 slen, void *dst, u32 *dlen);
/**
* should_compress() - Determines if a request (write) or the response to a
@@ -47,6 +76,7 @@ int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_s
* Return false otherwise.
*/
bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq);
+bool should_decompress(struct TCP_Server_Info *server, const void *buf);
/**
* smb_compress_alg_valid() - Validate a compression algorithm.
@@ -77,11 +107,21 @@ static inline int smb_compress(void *unused1, void *unused2, void *unused3)
return -EOPNOTSUPP;
}
+static inline int smb_decompress(const void *unused1, u32 unused2, void *unused3, u32 unused4)
+{
+ return -EOPNOTSUPP;
+}
+
static inline bool should_compress(void *unused1, void *unused2)
{
return false;
}
+static inline bool should_decompress(void *unused1, void *unused2)
+{
+ return false;
+}
+
static inline int smb_compress_alg_valid(__le16 unused1, bool unused2)
{
return -EOPNOTSUPP;
diff --git a/fs/smb/client/compress/lz77.c b/fs/smb/client/compress/lz77.c
index 96e8a8057a77..9d35691261fd 100644
--- a/fs/smb/client/compress/lz77.c
+++ b/fs/smb/client/compress/lz77.c
@@ -16,7 +16,7 @@
/*
* Compression parameters.
*/
-#define LZ77_MATCH_MIN_LEN 4
+#define LZ77_MATCH_MIN_LEN 3
#define LZ77_MATCH_MIN_DIST 1
#define LZ77_MATCH_MAX_DIST SZ_1K
#define LZ77_HASH_LOG 15
@@ -28,6 +28,16 @@ static __always_inline u8 lz77_read8(const u8 *ptr)
return get_unaligned(ptr);
}
+static __always_inline u16 lz77_read16(const u16 *ptr)
+{
+ return get_unaligned(ptr);
+}
+
+static __always_inline u32 lz77_read32(const u32 *ptr)
+{
+ return get_unaligned(ptr);
+}
+
static __always_inline u64 lz77_read64(const u64 *ptr)
{
return get_unaligned(ptr);
@@ -51,11 +61,13 @@ static __always_inline void lz77_write32(u32 *ptr, u32 v)
static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end)
{
const void *start = cur;
- u64 diff;
-
- /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */
+ /*
+ * Safe for a do/while (i.e. no bounds check the first iteration) because otherwise we
+ * wouldn't reach here from the main loop.
+ */
do {
- diff = lz77_read64(cur) ^ lz77_read64(wnd);
+ const u64 diff = lz77_read64(cur) ^ lz77_read64(wnd);
+
if (!diff) {
cur += LZ77_STEP_SIZE;
wnd += LZ77_STEP_SIZE;
@@ -69,8 +81,10 @@ static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, cons
return (cur - start);
} while (likely(cur + LZ77_STEP_SIZE < end));
- while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++))
- ;
+ while (cur < end && lz77_read8(cur) == lz77_read8(wnd)) {
+ cur++;
+ wnd++;
+ }
return (cur - start);
}
@@ -131,29 +145,82 @@ static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u
noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
{
- const void *srcp, *end;
- void *dstp, *nib, *flag_pos;
+ const void *srcp, *end, *anchor;
+ void *dstp, *nib, *flag_pos, *dend;
u32 flag_count = 0;
long flag = 0;
u64 *htable;
srcp = src;
+ anchor = srcp;
end = src + slen;
dstp = dst;
nib = NULL;
flag_pos = dstp;
dstp += 4;
+ dend = dst + (*dlen - (*dlen >> 3));
htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
if (!htable)
return -ENOMEM;
+#if 0
+ d = 4;
+
+ /* warm up loop */
+ do {
+ u32 dist, len = 0;
+ const void *wnd;
+ u64 hash;
+
+ hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
+ wnd = src + htable[hash];
+ htable[hash] = srcp - src;
+ dist = srcp - wnd;
+
+ if (dist && dist < LZ77_MATCH_MAX_DIST)
+ len = lz77_match_len(wnd, srcp, end);
+
+ if (len < LZ77_MATCH_MIN_LEN) {
+ srcp++;
+ d++;
+
+ continue;
+ }
+
+ d += 4;
+ srcp += len;
+ } while (likely(srcp + LZ77_STEP_SIZE < end));
+
+ if (srcp < end)
+ d += (end - srcp);
+
+ if (d > (slen - (slen >> 2))) {
+ *dlen = slen;
+ goto out;
+ }
+
+ srcp = src;
+ anchor = srcp;
+#endif
+
/* Main loop. */
do {
u32 dist, len = 0;
const void *wnd;
u64 hash;
+#if 0
+ /*
+ * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
+ * going further.
+ */
+ if (unlikely(dstp - dst >= slen - (slen >> 3))) {
+ *dlen = slen;
+ goto out;
+ }
+#endif
+
hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
wnd = src + htable[hash];
htable[hash] = srcp - src;
@@ -162,12 +229,41 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
if (dist && dist < LZ77_MATCH_MAX_DIST)
len = lz77_match_len(wnd, srcp, end);
+#if 1
+ if (len < LZ77_MATCH_MIN_LEN) {
+ srcp++;
+ continue;
+ }
+
+ if (dstp + (srcp - anchor) >= dend)
+ goto out;
+
+ while (anchor < srcp) {
+ u32 c = umin(srcp - anchor, 32 - flag_count);
+
+ memcpy(dstp, anchor, c);
+ anchor += c;
+ dstp += c;
+
+ flag <<= c;
+ flag_count += c;
+ if (flag_count == 32) {
+ lz77_write32(flag_pos, flag);
+ flag_count = 0;
+ flag_pos = dstp;
+ dstp += 4;
+ }
+ }
+#else
if (len < LZ77_MATCH_MIN_LEN) {
lz77_write8(dstp, lz77_read8(srcp));
dstp++;
srcp++;
+ if (dstp >= dend)
+ goto out;
+
flag <<= 1;
flag_count++;
if (flag_count == 32) {
@@ -179,18 +275,13 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
continue;
}
-
- /*
- * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
- * going further.
- */
- if (unlikely(dstp - dst >= slen - (slen >> 3))) {
- *dlen = slen;
+#endif
+ if (dstp + len >= dend)
goto out;
- }
dstp = lz77_write_match(dstp, &nib, dist, len);
srcp += len;
+ anchor = srcp;
flag = (flag << 1) | 1;
flag_count++;
@@ -202,9 +293,32 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
}
} while (likely(srcp + LZ77_STEP_SIZE < end));
+ while (anchor < srcp) {
+ u32 c = umin(srcp - anchor, 32 - flag_count);
+
+ if (dstp + c >= dend)
+ goto out;
+
+ memcpy(dstp, anchor, c);
+ anchor += c;
+ dstp += c;
+
+ flag <<= c;
+ flag_count += c;
+ if (flag_count == 32) {
+ lz77_write32(flag_pos, flag);
+ flag_count = 0;
+ flag_pos = dstp;
+ dstp += 4;
+ }
+ }
+
while (srcp < end) {
u32 c = umin(end - srcp, 32 - flag_count);
+ if (dstp + c >= dend)
+ goto out;
+
memcpy(dstp, srcp, c);
dstp += c;
@@ -221,7 +335,7 @@ noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
}
flag <<= (32 - flag_count);
- flag |= (1 << (32 - flag_count)) - 1;
+ flag |= (1UL << (32 - flag_count)) - 1;
lz77_write32(flag_pos, flag);
*dlen = dstp - dst;
@@ -233,3 +347,151 @@ out:
return -EMSGSIZE;
}
+
+static __always_inline const void *lz77_read_match_len(const void *src, const void **nib, u32 *dist,
+ u32 *len)
+{
+ u32 mlen;
+ u16 sym;
+
+ sym = lz77_read16(src);
+ src += 2;
+
+ *dist = (sym >> 3) + 1;
+ mlen = sym & 7;
+
+ if (mlen == 7) {
+ if (!*nib) {
+ *nib = src;
+ mlen = lz77_read8(src) % 16;
+ src++;
+ } else {
+ mlen = lz77_read8(*nib) >> 4;
+ *nib = NULL;
+ }
+
+ if (mlen == 15) {
+ mlen = lz77_read8(src);
+ src++;
+
+ if (mlen == 255) {
+ mlen = lz77_read16(src);
+ src += 2;
+
+ if (mlen == 0) {
+ mlen = lz77_read32(src);
+ src += 4;
+ }
+
+ if (mlen < 15 + 7) {
+ pr_warn("CIFS: unexpected match length %u\n", mlen);
+
+ return NULL;
+ }
+ mlen -= (15 + 7);
+ }
+ mlen += 15;
+ }
+ mlen += 7;
+ }
+ mlen += 3;
+
+ *len = mlen;
+
+ return src;
+}
+
+int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen)
+{
+ const void *srcp = src, *end = src + slen, *nib = NULL;
+ void *dstp = dst, *dst_end;
+ u32 c, flag_count = 0;
+ long flag;
+
+ if (!dlen || *dlen < slen)
+ return -EINVAL;
+
+ dst_end = dst + *dlen;
+ *dlen = 0;
+
+ while (likely(srcp < end)) {
+ u32 dist, len;
+
+ if (flag_count == 0) {
+ flag = lz77_read32(srcp);
+ srcp += 4;
+ flag_count = 32;
+ }
+
+ /* Decode literals. */
+ c = flag ? __builtin_clz(flag) : 32;
+ c = umin(c, flag_count);
+
+ if (unlikely(dstp + c > dst_end))
+ return -EFAULT;
+
+ flag_count -= c;
+ flag <<= c;
+
+ memcpy(dstp, srcp, c);
+
+ dstp += c;
+ srcp += c;
+
+ if (flag_count == 0)
+ continue;
+
+ /*
+ * This means we've parsed the whole input buffer @src and filled
+ * @dst within its memory bounds.
+ *
+ * However, it's up to callers to determine if the decompressed
+ * buffer and size are according to what they expected to get.
+ */
+ if (unlikely(srcp + 1 >= end))
+ break;
+
+ flag_count--;
+ flag <<= 1;
+
+ /* Decode match. */
+ srcp = lz77_read_match_len(srcp, &nib, &dist, &len);
+ if (unlikely(!srcp))
+ return -EIO;
+
+ /*
+ * Even though we use a 1K window size and 4-byte min match len, the server can use
+ * parameters of their choosing, so check against the standard defaults.
+ */
+ if (unlikely(dist > SZ_8K || len < 3))
+ return -EIO;
+
+ /* Bogus compression, match distance too far. */
+ if (unlikely(dstp - dst < dist))
+ return -EFAULT;
+
+ /* Bogus compression, match length too long. */
+ if (unlikely(dstp + len > dst_end))
+ return -EMSGSIZE;
+
+ /*
+ * If non-overlapping memory, we can use memcpy().
+ * Otherwise, we have to do it byte by byte.
+ */
+ if (len < dist) {
+ memcpy(dstp, dstp - dist, len);
+ dstp += len;
+ } else {
+ int i = 0;
+
+ while (i++ < len) {
+ lz77_write8(dstp, lz77_read8(dstp - dist));
+ dstp++;
+ }
+ }
+ }
+
+ *dlen = dstp - dst;
+
+ return 0;
+}
diff --git a/fs/smb/client/compress/lz77.h b/fs/smb/client/compress/lz77.h
index cdcb191b48a2..04bf2e022b39 100644
--- a/fs/smb/client/compress/lz77.h
+++ b/fs/smb/client/compress/lz77.h
@@ -12,4 +12,5 @@
#include <linux/kernel.h>
int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen);
+int lz77_decompress(const void *src, u32 slen, void *dst, u32 *dlen);
#endif /* _SMB_COMPRESS_LZ77_H */
diff --git a/fs/smb/client/fs_context.c b/fs/smb/client/fs_context.c
index 063e62189bea..ef20ea5319f0 100644
--- a/fs/smb/client/fs_context.c
+++ b/fs/smb/client/fs_context.c
@@ -62,6 +62,7 @@ static const match_table_t cifs_secflavor_tokens = {
{ Opt_sec_ntlmv2, "nontlm" },
{ Opt_sec_ntlmv2, "ntlmv2" },
{ Opt_sec_ntlmv2i, "ntlmv2i" },
+ { Opt_sec_iakerb, "iakerb" },
{ Opt_sec_none, "none" },
{ Opt_sec_err, NULL }
@@ -248,6 +249,9 @@ cifs_parse_security_flavors(struct fs_context *fc, char *value, struct smb3_fs_c
case Opt_sec_ntlmv2:
ctx->sectype = NTLMv2;
break;
+ case Opt_sec_iakerb:
+ ctx->sectype = IAKerb;
+ break;
case Opt_sec_none:
ctx->nullauth = 1;
kfree(ctx->username);
diff --git a/fs/smb/client/fs_context.h b/fs/smb/client/fs_context.h
index d1d29249bcdb..26df0fd36c7e 100644
--- a/fs/smb/client/fs_context.h
+++ b/fs/smb/client/fs_context.h
@@ -65,6 +65,7 @@ enum cifs_sec_param {
Opt_sec_krb5,
Opt_sec_krb5i,
Opt_sec_krb5p,
+ Opt_sec_iakerb,
Opt_sec_ntlmsspi,
Opt_sec_ntlmssp,
Opt_sec_ntlmv2,
diff --git a/fs/smb/client/smb2ops.c b/fs/smb/client/smb2ops.c
index 41d8cd20b25f..b35b9a7747a8 100644
--- a/fs/smb/client/smb2ops.c
+++ b/fs/smb/client/smb2ops.c
@@ -30,6 +30,7 @@
#include "fs_context.h"
#include "cached_dir.h"
#include "reparse.h"
+#include "compress.h"
/* Change credits for different ops and return the total number of credits */
static int
@@ -4511,12 +4512,17 @@ err_free:
return rc;
}
+static __always_inline bool is_transform_hdr(const void *buf)
+{
+ const struct smb2_transform_hdr *trhdr = buf;
+
+ return (trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM);
+}
+
static int
smb3_is_transform_hdr(void *buf)
{
- struct smb2_transform_hdr *trhdr = buf;
-
- return trhdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM;
+ return is_transform_hdr(buf) || is_compress_hdr(buf);
}
static int
@@ -4725,8 +4731,10 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid,
} else if (buf_len >= data_offset + data_len) {
/* read response payload is in buf */
+ struct iov_iter it = rdata->subreq.io_iter;
+
WARN_ONCE(buffer, "read data can be either in buf or in buffer");
- length = copy_to_iter(buf + data_offset, data_len, &rdata->subreq.io_iter);
+ length = copy_to_iter(buf + data_offset, data_len, &it);
if (length < 0)
return length;
rdata->got_bytes = data_len;
@@ -5023,6 +5031,87 @@ one_more:
return ret;
}
+static int receive_compressed(struct TCP_Server_Info *server)
+{
+ struct mid_q_entry *mid;
+ void *src, *dst = NULL;
+ u32 slen, dlen;
+ int ret;
+
+ slen = server->pdu_size;
+ src = kvzalloc(slen, GFP_KERNEL);
+ if (!src)
+ return -ENOMEM;
+
+ ret = server->total_read;
+ memcpy(src, server->smallbuf, ret);
+
+ ret = cifs_read_from_socket(server, src + ret, slen - ret);
+ if (ret < 0)
+ goto err_free;
+
+ server->total_read += ret;
+
+ dlen = decompressed_size(src);
+ dst = kvzalloc(dlen, GFP_KERNEL);
+ if (!dst) {
+ ret = -ENOMEM;
+
+ goto err_free;
+ }
+
+ ret = smb_decompress(src, slen, dst, &dlen);
+ if (ret) {
+ spin_lock(&server->srv_lock);
+ server->tcpStatus = CifsNeedReconnect;
+ spin_unlock(&server->srv_lock);
+
+ goto err_free;
+ }
+
+ mid = smb2_find_dequeue_mid(server, dst);
+ if (!mid) {
+ ret = -EIO;
+
+ goto err_free;
+ }
+
+ ret = handle_read_data(server, mid, dst, dlen, NULL, 0, true);
+ if (!ret) {
+#ifdef CONFIG_CIFS_STATS2
+ mid->when_received = jiffies;
+#endif
+ if (server->ops->is_network_name_deleted)
+ server->ops->is_network_name_deleted(dst, server);
+
+ mid->callback(mid);
+ } else {
+ spin_lock(&server->srv_lock);
+ if (server->tcpStatus == CifsNeedReconnect) {
+ spin_lock(&server->mid_lock);
+ mid->mid_state = MID_RETRY_NEEDED;
+ spin_unlock(&server->mid_lock);
+ spin_unlock(&server->srv_lock);
+
+ mid->callback(mid);
+ } else {
+ spin_lock(&server->mid_lock);
+ mid->mid_state = MID_REQUEST_SUBMITTED;
+ mid->mid_flags &= ~(MID_DELETED);
+ list_add_tail(&mid->qhead, &server->pending_mid_q);
+ spin_unlock(&server->mid_lock);
+ spin_unlock(&server->srv_lock);
+ }
+ }
+
+ release_mid(mid);
+err_free:
+ kvfree(src);
+ kvfree(dst);
+
+ return ret;
+}
+
static int
smb3_receive_transform(struct TCP_Server_Info *server,
struct mid_q_entry **mids, char **bufs, int *num_mids)
@@ -5032,26 +5121,36 @@ smb3_receive_transform(struct TCP_Server_Info *server,
struct smb2_transform_hdr *tr_hdr = (struct smb2_transform_hdr *)buf;
unsigned int orig_len = le32_to_cpu(tr_hdr->OriginalMessageSize);
- if (pdu_length < sizeof(struct smb2_transform_hdr) +
- sizeof(struct smb2_hdr)) {
- cifs_server_dbg(VFS, "Transform message is too small (%u)\n",
- pdu_length);
- cifs_reconnect(server, true);
- return -ECONNABORTED;
- }
+ if (is_transform_hdr(buf)) {
+ if (pdu_length < sizeof(struct smb2_transform_hdr) +
+ sizeof(struct smb2_hdr)) {
+ cifs_server_dbg(VFS, "Transform message is too small (%u)\n", pdu_length);
+ cifs_reconnect(server, true);
- if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) {
- cifs_server_dbg(VFS, "Transform message is broken\n");
- cifs_reconnect(server, true);
- return -ECONNABORTED;
- }
+ return -ECONNABORTED;
+ }
- /* TODO: add support for compounds containing READ. */
- if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server)) {
- return receive_encrypted_read(server, &mids[0], num_mids);
+ if (pdu_length < orig_len + sizeof(struct smb2_transform_hdr)) {
+ cifs_server_dbg(VFS, "Transform message is broken\n");
+ cifs_reconnect(server, true);
+
+ return -ECONNABORTED;
+ }
+
+ /* TODO: add support for compounds containing READ. */
+ if (pdu_length > CIFSMaxBufSize + MAX_HEADER_SIZE(server))
+ return receive_encrypted_read(server, &mids[0], num_mids);
+
+ return receive_encrypted_standard(server, mids, bufs, num_mids);
}
- return receive_encrypted_standard(server, mids, bufs, num_mids);
+ if (should_decompress(server, buf))
+ return receive_compressed(server);
+
+ cifs_server_dbg(VFS, "Invalid ProtocolId 0x%x\n", *(__le32 *)buf);
+ cifs_reconnect(server, true);
+
+ return -ECONNABORTED;
}
int
@@ -5073,6 +5172,8 @@ static int smb2_next_header(stru