2023-01-16 23:17:52

by David Howells

[permalink] [raw]
Subject: [PATCH v6 16/34] af_alg: [RFC] Use netfs_extract_iter_to_sg() to create scatterlists

Use netfs_extract_iter_to_sg() to decant the destination iterator into a
scatterlist in af_alg_get_rsgl(). af_alg_make_sg() can then be removed.

Note that if this fits, netfs_extract_iter_to_sg() should move to core
code.

Signed-off-by: David Howells <[email protected]>
cc: Herbert Xu <[email protected]>
cc: [email protected]
---

crypto/af_alg.c | 63 +++++++++++++----------------------------------
crypto/algif_hash.c | 21 +++++++++++-----
include/crypto/if_alg.h | 7 +----
3 files changed, 35 insertions(+), 56 deletions(-)

diff --git a/crypto/af_alg.c b/crypto/af_alg.c
index c99e09fce71f..c5fbe39366ff 100644
--- a/crypto/af_alg.c
+++ b/crypto/af_alg.c
@@ -22,6 +22,7 @@
#include <linux/sched/signal.h>
#include <linux/security.h>
#include <linux/string.h>
+#include <linux/netfs.h>
#include <keys/user-type.h>
#include <keys/trusted-type.h>
#include <keys/encrypted-type.h>
@@ -531,55 +532,22 @@ static const struct net_proto_family alg_family = {
.owner = THIS_MODULE,
};

-int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len,
- unsigned int gup_flags)
-{
- struct page **pages = sgl->pages;
- size_t off;
- ssize_t n;
- int npages, i;
-
- n = iov_iter_extract_pages(iter, &pages, len, ALG_MAX_PAGES,
- gup_flags, &off);
- if (n < 0)
- return n;
-
- sgl->cleanup_mode = iov_iter_extract_mode(iter, gup_flags);
-
- npages = DIV_ROUND_UP(off + n, PAGE_SIZE);
- if (WARN_ON(npages == 0))
- return -EINVAL;
- /* Add one extra for linking */
- sg_init_table(sgl->sg, npages + 1);
-
- for (i = 0, len = n; i < npages; i++) {
- int plen = min_t(int, len, PAGE_SIZE - off);
-
- sg_set_page(sgl->sg + i, sgl->pages[i], plen, off);
-
- off = 0;
- len -= plen;
- }
- sg_mark_end(sgl->sg + npages - 1);
- sgl->npages = npages;
-
- return n;
-}
-EXPORT_SYMBOL_GPL(af_alg_make_sg);
-
static void af_alg_link_sg(struct af_alg_sgl *sgl_prev,
struct af_alg_sgl *sgl_new)
{
- sg_unmark_end(sgl_prev->sg + sgl_prev->npages - 1);
- sg_chain(sgl_prev->sg, sgl_prev->npages + 1, sgl_new->sg);
+ sg_unmark_end(sgl_prev->sgt.sgl + sgl_prev->sgt.nents - 1);
+ sg_chain(sgl_prev->sgt.sgl, sgl_prev->sgt.nents + 1, sgl_new->sgt.sgl);
}

void af_alg_free_sg(struct af_alg_sgl *sgl)
{
int i;

- for (i = 0; i < sgl->npages; i++)
- page_put_unpin(sgl->pages[i], sgl->cleanup_mode);
+ if (!(sgl->cleanup_mode & (FOLL_PIN | FOLL_GET)))
+ return;
+
+ for (i = 0; i < sgl->sgt.nents; i++)
+ page_put_unpin(sg_page(&sgl->sgt.sgl[i]), sgl->cleanup_mode);
}
EXPORT_SYMBOL_GPL(af_alg_free_sg);

@@ -1293,8 +1261,8 @@ int af_alg_get_rsgl(struct sock *sk, struct msghdr *msg, int flags,

while (maxsize > len && msg_data_left(msg)) {
struct af_alg_rsgl *rsgl;
+ ssize_t err;
size_t seglen;
- int err;

/* limit the amount of readable buffers */
if (!af_alg_readable(sk))
@@ -1311,17 +1279,22 @@ int af_alg_get_rsgl(struct sock *sk, struct msghdr *msg, int flags,
return -ENOMEM;
}

- rsgl->sgl.npages = 0;
+ rsgl->sgl.sgt.sgl = rsgl->sgl.sgl;
+ rsgl->sgl.sgt.nents = 0;
+ rsgl->sgl.sgt.orig_nents = 0;
list_add_tail(&rsgl->list, &areq->rsgl_list);

- /* make one iovec available as scatterlist */
- err = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, seglen,
- FOLL_DEST_BUF);
+ err = netfs_extract_iter_to_sg(&msg->msg_iter, seglen,
+ &rsgl->sgl.sgt, ALG_MAX_PAGES,
+ FOLL_DEST_BUF);
if (err < 0) {
rsgl->sg_num_bytes = 0;
return err;
}

+ rsgl->sgl.cleanup_mode = iov_iter_extract_mode(&msg->msg_iter,
+ FOLL_DEST_BUF);
+
/* chain the new scatterlist with previous one */
if (areq->last_rsgl)
af_alg_link_sg(&areq->last_rsgl->sgl, &rsgl->sgl);
diff --git a/crypto/algif_hash.c b/crypto/algif_hash.c
index fe3d2258145f..5aef6818a9ff 100644
--- a/crypto/algif_hash.c
+++ b/crypto/algif_hash.c
@@ -14,6 +14,7 @@
#include <linux/mm.h>
#include <linux/module.h>
#include <linux/net.h>
+#include <linux/netfs.h>
#include <net/sock.h>

struct hash_ctx {
@@ -91,14 +92,22 @@ static int hash_sendmsg(struct socket *sock, struct msghdr *msg,
if (len > limit)
len = limit;

- len = af_alg_make_sg(&ctx->sgl, &msg->msg_iter, len,
- FOLL_SOURCE_BUF);
+ ctx->sgl.sgt.sgl = ctx->sgl.sgl;
+ ctx->sgl.sgt.nents = 0;
+ ctx->sgl.sgt.orig_nents = 0;
+
+ len = netfs_extract_iter_to_sg(&msg->msg_iter, len,
+ &ctx->sgl.sgt, ALG_MAX_PAGES,
+ FOLL_SOURCE_BUF);
if (len < 0) {
err = copied ? 0 : len;
goto unlock;
}

- ahash_request_set_crypt(&ctx->req, ctx->sgl.sg, NULL, len);
+ ctx->sgl.cleanup_mode = iov_iter_extract_mode(&msg->msg_iter,
+ FOLL_SOURCE_BUF);
+
+ ahash_request_set_crypt(&ctx->req, ctx->sgl.sgt.sgl, NULL, len);

err = crypto_wait_req(crypto_ahash_update(&ctx->req),
&ctx->wait);
@@ -142,8 +151,8 @@ static ssize_t hash_sendpage(struct socket *sock, struct page *page,
flags |= MSG_MORE;

lock_sock(sk);
- sg_init_table(ctx->sgl.sg, 1);
- sg_set_page(ctx->sgl.sg, page, size, offset);
+ sg_init_table(ctx->sgl.sgl, 1);
+ sg_set_page(ctx->sgl.sgl, page, size, offset);

if (!(flags & MSG_MORE)) {
err = hash_alloc_result(sk, ctx);
@@ -152,7 +161,7 @@ static ssize_t hash_sendpage(struct socket *sock, struct page *page,
} else if (!ctx->more)
hash_free_result(sk, ctx);

- ahash_request_set_crypt(&ctx->req, ctx->sgl.sg, ctx->result, size);
+ ahash_request_set_crypt(&ctx->req, ctx->sgl.sgl, ctx->result, size);

if (!(flags & MSG_MORE)) {
if (ctx->more)
diff --git a/include/crypto/if_alg.h b/include/crypto/if_alg.h
index 95b3b7517d3f..424a2071705d 100644
--- a/include/crypto/if_alg.h
+++ b/include/crypto/if_alg.h
@@ -58,9 +58,8 @@ struct af_alg_type {
};

struct af_alg_sgl {
- struct scatterlist sg[ALG_MAX_PAGES + 1];
- struct page *pages[ALG_MAX_PAGES];
- unsigned int npages;
+ struct sg_table sgt;
+ struct scatterlist sgl[ALG_MAX_PAGES + 1];
unsigned int cleanup_mode;
};

@@ -166,8 +165,6 @@ int af_alg_release(struct socket *sock);
void af_alg_release_parent(struct sock *sk);
int af_alg_accept(struct sock *sk, struct socket *newsock, bool kern);

-int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len,
- unsigned int gup_flags);
void af_alg_free_sg(struct af_alg_sgl *sgl);

static inline struct alg_sock *alg_sk(struct sock *sk)