2024-02-13 06:01:45

by Eric Biggers

[permalink] [raw]
Subject: [PATCH riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS

From: Eric Biggers <[email protected]>

Add an implementation of cts(cbc(aes)) accelerated using the Zvkned
RISC-V vector crypto extension. This is mainly useful for fscrypt,
where cts(cbc(aes)) is the "default" filenames encryption algorithm. In
that use case, typically most messages are short and are block-aligned.
The CBC-CTS variant implemented is CS3; this is the variant Linux uses.

To perform well on short messages, the new implementation processes the
full message in one call to the assembly function if the data is
contiguous. Otherwise it falls back to CBC operations followed by CTS
at the end. For decryption, to further improve performance on short
messages, especially block-aligned messages, the CBC-CTS assembly
function parallelizes the AES decryption of all full blocks. This
improves on the arm64 implementation of cts(cbc(aes)), which always
splits the CBC part(s) from the CTS part, doing the AES decryptions for
the last two blocks serially and usually loading the round keys twice.

Tested in QEMU with CONFIG_CRYPTO_MANAGER_EXTRA_TESTS=y.

Signed-off-by: Eric Biggers <[email protected]>
---
arch/riscv/crypto/Kconfig | 4 +-
arch/riscv/crypto/aes-riscv64-glue.c | 93 ++++++++++++++-
arch/riscv/crypto/aes-riscv64-zvkned.S | 153 +++++++++++++++++++++++++
3 files changed, 245 insertions(+), 5 deletions(-)

diff --git a/arch/riscv/crypto/Kconfig b/arch/riscv/crypto/Kconfig
index 2ad44e1d464a..ad58dad9a580 100644
--- a/arch/riscv/crypto/Kconfig
+++ b/arch/riscv/crypto/Kconfig
@@ -1,23 +1,23 @@
# SPDX-License-Identifier: GPL-2.0

menu "Accelerated Cryptographic Algorithms for CPU (riscv)"

config CRYPTO_AES_RISCV64
- tristate "Ciphers: AES, modes: ECB, CBC, CTR, XTS"
+ tristate "Ciphers: AES, modes: ECB, CBC, CTS, CTR, XTS"
depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
select CRYPTO_ALGAPI
select CRYPTO_LIB_AES
select CRYPTO_SKCIPHER
help
Block cipher: AES cipher algorithms
- Length-preserving ciphers: AES with ECB, CBC, CTR, XTS
+ Length-preserving ciphers: AES with ECB, CBC, CTS, CTR, XTS

Architecture: riscv64 using:
- Zvkned vector crypto extension
- Zvbb vector extension (XTS)
- Zvkb vector crypto extension (CTR)
- Zvkg vector crypto extension (XTS)

config CRYPTO_CHACHA_RISCV64
tristate "Ciphers: ChaCha"
depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
diff --git a/arch/riscv/crypto/aes-riscv64-glue.c b/arch/riscv/crypto/aes-riscv64-glue.c
index 37bc6ef0be40..f814ee048555 100644
--- a/arch/riscv/crypto/aes-riscv64-glue.c
+++ b/arch/riscv/crypto/aes-riscv64-glue.c
@@ -1,20 +1,22 @@
// SPDX-License-Identifier: GPL-2.0-only
/*
* AES using the RISC-V vector crypto extensions. Includes the bare block
- * cipher and the ECB, CBC, CTR, and XTS modes.
+ * cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes.
*
* Copyright (C) 2023 VRULL GmbH
* Author: Heiko Stuebner <[email protected]>
*
* Copyright (C) 2023 SiFive, Inc.
* Author: Jerry Shih <[email protected]>
+ *
+ * Copyright 2024 Google LLC
*/

#include <asm/simd.h>
#include <asm/vector.h>
#include <crypto/aes.h>
#include <crypto/internal/cipher.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <crypto/xts.h>
@@ -33,20 +35,24 @@ asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len);

asmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len,
u8 iv[AES_BLOCK_SIZE]);
asmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len,
u8 iv[AES_BLOCK_SIZE]);

+asmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
+ const u8 *in, u8 *out, size_t len,
+ const u8 iv[AES_BLOCK_SIZE], bool enc);
+
asmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len,
u8 iv[AES_BLOCK_SIZE]);

asmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg(
const struct crypto_aes_ctx *key,
const u8 *in, u8 *out, size_t len,
u8 tweak[AES_BLOCK_SIZE]);

asmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg(
@@ -157,21 +163,21 @@ static int riscv64_aes_ecb_encrypt(struct skcipher_request *req)
return riscv64_aes_ecb_crypt(req, true);
}

static int riscv64_aes_ecb_decrypt(struct skcipher_request *req)
{
return riscv64_aes_ecb_crypt(req, false);
}

/* AES-CBC */

-static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
+static int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
struct skcipher_walk walk;
unsigned int nbytes;
int err;

err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes) != 0) {
kernel_vector_begin();
@@ -195,20 +201,84 @@ static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
static int riscv64_aes_cbc_encrypt(struct skcipher_request *req)
{
return riscv64_aes_cbc_crypt(req, true);
}

static int riscv64_aes_cbc_decrypt(struct skcipher_request *req)
{
return riscv64_aes_cbc_crypt(req, false);
}

+/* AES-CBC-CTS */
+
+static int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc)
+{
+ struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+ const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+ struct scatterlist sg_src[2], sg_dst[2];
+ struct skcipher_request subreq;
+ struct scatterlist *src, *dst;
+ struct skcipher_walk walk;
+ unsigned int cbc_len;
+ int err;
+
+ if (req->cryptlen < AES_BLOCK_SIZE)
+ return -EINVAL;
+
+ err = skcipher_walk_virt(&walk, req, false);
+ if (err)
+ return err;
+ /*
+ * If the full message is available in one step, decrypt it in one call
+ * to the CBC-CTS assembly function. This reduces overhead, especially
+ * on short messages. Otherwise, fall back to doing CBC up to the last
+ * two blocks, then invoke CTS just for the ciphertext stealing.
+ */
+ if (unlikely(walk.nbytes != req->cryptlen)) {
+ cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1,
+ AES_BLOCK_SIZE);
+ skcipher_walk_abort(&walk);
+ skcipher_request_set_tfm(&subreq, tfm);
+ skcipher_request_set_callback(&subreq,
+ skcipher_request_flags(req),
+ NULL, NULL);
+ skcipher_request_set_crypt(&subreq, req->src, req->dst,
+ cbc_len, req->iv);
+ err = riscv64_aes_cbc_crypt(&subreq, enc);
+ if (err)
+ return err;
+ dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len);
+ if (req->dst != req->src)
+ dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len);
+ skcipher_request_set_crypt(&subreq, src, dst,
+ req->cryptlen - cbc_len, req->iv);
+ err = skcipher_walk_virt(&walk, &subreq, false);
+ if (err)
+ return err;
+ }
+ kernel_vector_begin();
+ aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr,
+ walk.nbytes, req->iv, enc);
+ kernel_vector_end();
+ return skcipher_walk_done(&walk, 0);
+}
+
+static int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req)
+{
+ return riscv64_aes_cbc_cts_crypt(req, true);
+}
+
+static int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req)
+{
+ return riscv64_aes_cbc_cts_crypt(req, false);
+}
+
/* AES-CTR */

static int riscv64_aes_ctr_crypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
unsigned int nbytes, p1_nbytes;
struct skcipher_walk walk;
u32 ctr32, nblocks;
int err;
@@ -427,20 +497,36 @@ static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.base = {
.cra_blocksize = AES_BLOCK_SIZE,
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
.cra_priority = 300,
.cra_name = "cbc(aes)",
.cra_driver_name = "cbc-aes-riscv64-zvkned",
.cra_module = THIS_MODULE,
},
+ }, {
+ .setkey = riscv64_aes_setkey_skcipher,
+ .encrypt = riscv64_aes_cbc_cts_encrypt,
+ .decrypt = riscv64_aes_cbc_cts_decrypt,
+ .min_keysize = AES_MIN_KEY_SIZE,
+ .max_keysize = AES_MAX_KEY_SIZE,
+ .ivsize = AES_BLOCK_SIZE,
+ .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
+ .base = {
+ .cra_blocksize = AES_BLOCK_SIZE,
+ .cra_ctxsize = sizeof(struct crypto_aes_ctx),
+ .cra_priority = 300,
+ .cra_name = "cts(cbc(aes))",
+ .cra_driver_name = "cts-cbc-aes-riscv64-zvkned",
+ .cra_module = THIS_MODULE,
+ },
}
};

static struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = {
.setkey = riscv64_aes_setkey_skcipher,
.encrypt = riscv64_aes_ctr_crypt,
.decrypt = riscv64_aes_ctr_crypt,
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
@@ -533,18 +619,19 @@ static void __exit riscv64_aes_mod_exit(void)
if (riscv_isa_extension_available(NULL, ZVKB))
crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
}

module_init(riscv64_aes_mod_init);
module_exit(riscv64_aes_mod_exit);

-MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS (RISC-V accelerated)");
+MODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)");
MODULE_AUTHOR("Jerry Shih <[email protected]>");
MODULE_LICENSE("GPL");
MODULE_ALIAS_CRYPTO("aes");
MODULE_ALIAS_CRYPTO("ecb(aes)");
MODULE_ALIAS_CRYPTO("cbc(aes)");
+MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
MODULE_ALIAS_CRYPTO("ctr(aes)");
MODULE_ALIAS_CRYPTO("xts(aes)");
diff --git a/arch/riscv/crypto/aes-riscv64-zvkned.S b/arch/riscv/crypto/aes-riscv64-zvkned.S
index 43541aad6386..23d063f94ce6 100644
--- a/arch/riscv/crypto/aes-riscv64-zvkned.S
+++ b/arch/riscv/crypto/aes-riscv64-zvkned.S
@@ -177,10 +177,163 @@ SYM_FUNC_END(aes_cbc_encrypt_zvkned)

// Same prototype and calling convention as the encryption function
SYM_FUNC_START(aes_cbc_decrypt_zvkned)
aes_begin KEYP, 128f, 192f
aes_cbc_decrypt 256
128:
aes_cbc_decrypt 128
192:
aes_cbc_decrypt 192
SYM_FUNC_END(aes_cbc_decrypt_zvkned)
+
+.macro aes_cbc_cts_encrypt keylen
+
+ // CBC-encrypt all blocks except the last. But don't store the
+ // second-to-last block to the output buffer yet, since it will be
+ // handled specially in the ciphertext stealing step. Exception: if the
+ // message is single-block, still encrypt the last (and only) block.
+ li t0, 16
+ j 2f
+1:
+ vse32.v v16, (OUTP) // Store ciphertext block
+ addi OUTP, OUTP, 16
+2:
+ vle32.v v17, (INP) // Load plaintext block
+ vxor.vv v16, v16, v17 // XOR with IV or prev ciphertext block
+ aes_encrypt v16, \keylen // Encrypt
+ addi INP, INP, 16
+ addi LEN, LEN, -16
+ bgt LEN, t0, 1b // Repeat if more than one block remains
+
+ // Special case: if the message is a single block, just do CBC.
+ beqz LEN, .Lcts_encrypt_done\@
+
+ // Encrypt the last two blocks using ciphertext stealing as follows:
+ // C[n-1] = Encrypt(Encrypt(P[n-1] ^ C[n-2]) ^ P[n])
+ // C[n] = Encrypt(P[n-1] ^ C[n-2])[0..LEN]
+ //
+ // C[i] denotes the i'th ciphertext block, and likewise P[i] the i'th
+ // plaintext block. Block n, the last block, may be partial; its length
+ // is 1 <= LEN <= 16. If there are only 2 blocks, C[n-2] means the IV.
+ //
+ // v16 already contains Encrypt(P[n-1] ^ C[n-2]).
+ // INP points to P[n]. OUTP points to where C[n-1] should go.
+ // To support in-place encryption, load P[n] before storing C[n].
+ addi t0, OUTP, 16 // Get pointer to where C[n] should go
+ vsetvli zero, LEN, e8, m1, tu, ma
+ vle8.v v17, (INP) // Load P[n]
+ vse8.v v16, (t0) // Store C[n]
+ vxor.vv v16, v16, v17 // v16 = Encrypt(P[n-1] ^ C[n-2]) ^ P[n]
+ vsetivli zero, 4, e32, m1, ta, ma
+ aes_encrypt v16, \keylen
+.Lcts_encrypt_done\@:
+ vse32.v v16, (OUTP) // Store C[n-1] (or C[n] in single-block case)
+ ret
+.endm
+
+#define LEN32 t4 // Length of remaining full blocks in 32-bit words
+#define LEN_MOD16 t5 // Length of message in bytes mod 16
+
+.macro aes_cbc_cts_decrypt keylen
+ andi LEN32, LEN, ~15
+ srli LEN32, LEN32, 2
+ andi LEN_MOD16, LEN, 15
+
+ // Save C[n-2] in v28 so that it's available later during the ciphertext
+ // stealing step. If there are fewer than three blocks, C[n-2] means
+ // the IV, otherwise it means the third-to-last ciphertext block.
+ vmv.v.v v28, v16 // IV
+ add t0, LEN, -33
+ bltz t0, .Lcts_decrypt_loop\@
+ andi t0, t0, ~15
+ add t0, t0, INP
+ vle32.v v28, (t0)
+
+ // CBC-decrypt all full blocks. For the last full block, or the last 2
+ // full blocks if the message is block-aligned, this doesn't write the
+ // correct output blocks (unless the message is only a single block),
+ // because it XORs the wrong values with the raw AES plaintexts. But we
+ // fix this after this loop without redoing the AES decryptions. This
+ // approach allows more of the AES decryptions to be parallelized.
+.Lcts_decrypt_loop\@:
+ vsetvli t0, LEN32, e32, m4, ta, ma
+ addi t1, t0, -4
+ vle32.v v20, (INP) // Load next set of ciphertext blocks
+ vmv.v.v v24, v16 // Get IV or last ciphertext block of prev set
+ vslideup.vi v24, v20, 4 // Setup prev ciphertext blocks
+ vslidedown.vx v16, v20, t1 // Save last ciphertext block of this set
+ aes_decrypt v20, \keylen // Decrypt this set of blocks
+ vxor.vv v24, v24, v20 // XOR prev ciphertext blocks with decrypted blocks
+ vse32.v v24, (OUTP) // Store this set of plaintext blocks
+ sub LEN32, LEN32, t0
+ slli t0, t0, 2 // Words to bytes
+ add INP, INP, t0
+ add OUTP, OUTP, t0
+ bnez LEN32, .Lcts_decrypt_loop\@
+
+ vsetivli zero, 4, e32, m4, ta, ma
+ vslidedown.vx v20, v20, t1 // Extract raw plaintext of last full block
+ addi t0, OUTP, -16 // Get pointer to last full plaintext block
+ bnez LEN_MOD16, .Lcts_decrypt_non_block_aligned\@
+
+ // Special case: if the message is a single block, just do CBC.
+ li t1, 16
+ beq LEN, t1, .Lcts_decrypt_done\@
+
+ // Block-aligned message. Just fix up the last 2 blocks. We need:
+ //
+ // P[n-1] = Decrypt(C[n]) ^ C[n-2]
+ // P[n] = Decrypt(C[n-1]) ^ C[n]
+ //
+ // We have C[n] in v16, Decrypt(C[n]) in v20, and C[n-2] in v28.
+ // Together with Decrypt(C[n-1]) ^ C[n-2] from the output buffer, this
+ // is everything needed to fix the output without re-decrypting blocks.
+ addi t1, OUTP, -32 // Get pointer to where P[n-1] should go
+ vxor.vv v20, v20, v28 // Decrypt(C[n]) ^ C[n-2] == P[n-1]
+ vle32.v v24, (t1) // Decrypt(C[n-1]) ^ C[n-2]
+ vse32.v v20, (t1) // Store P[n-1]
+ vxor.vv v20, v24, v16 // Decrypt(C[n-1]) ^ C[n-2] ^ C[n] == P[n] ^ C[n-2]
+ j .Lcts_decrypt_finish\@
+
+.Lcts_decrypt_non_block_aligned\@:
+ // Decrypt the last two blocks using ciphertext stealing as follows:
+ //
+ // P[n-1] = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16]) ^ C[n-2]
+ // P[n] = (Decrypt(C[n-1]) ^ C[n])[0..LEN_MOD16]
+ //
+ // We already have Decrypt(C[n-1]) in v20 and C[n-2] in v28.
+ vmv.v.v v16, v20 // v16 = Decrypt(C[n-1])
+ vsetvli zero, LEN_MOD16, e8, m1, tu, ma
+ vle8.v v20, (INP) // v20 = C[n] || Decrypt(C[n-1])[LEN_MOD16..16]
+ vxor.vv v16, v16, v20 // v16 = Decrypt(C[n-1]) ^ C[n]
+ vse8.v v16, (OUTP) // Store P[n]
+ vsetivli zero, 4, e32, m1, ta, ma
+ aes_decrypt v20, \keylen // v20 = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16])
+.Lcts_decrypt_finish\@:
+ vxor.vv v20, v20, v28 // XOR with C[n-2]
+ vse32.v v20, (t0) // Store last full plaintext block
+.Lcts_decrypt_done\@:
+ ret
+.endm
+
+.macro aes_cbc_cts_crypt keylen
+ vle32.v v16, (IVP) // Load IV
+ beqz a5, .Lcts_decrypt\@
+ aes_cbc_cts_encrypt \keylen
+.Lcts_decrypt\@:
+ aes_cbc_cts_decrypt \keylen
+.endm
+
+// void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
+// const u8 *in, u8 *out, size_t len,
+// const u8 iv[16], bool enc);
+//
+// Encrypts or decrypts a message with the CS3 variant of AES-CBC-CTS.
+// This is the variant that unconditionally swaps the last two blocks.
+SYM_FUNC_START(aes_cbc_cts_crypt_zvkned)
+ aes_begin KEYP, 128f, 192f
+ aes_cbc_cts_crypt 256
+128:
+ aes_cbc_cts_crypt 128
+192:
+ aes_cbc_cts_crypt 192
+SYM_FUNC_END(aes_cbc_cts_crypt_zvkned)

base-commit: cb4ede926134a65bc3bf90ed58dace8451d7e759
prerequisite-patch-id: 2a69e1270be0fa567cc43269826171d6e46d65de
--
2.43.0



2024-02-14 16:38:22

by Ard Biesheuvel

[permalink] [raw]
Subject: Re: [PATCH riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS

On Tue, 13 Feb 2024 at 06:57, Eric Biggers <[email protected]> wrote:
>
> From: Eric Biggers <[email protected]>
>
> Add an implementation of cts(cbc(aes)) accelerated using the Zvkned
> RISC-V vector crypto extension. This is mainly useful for fscrypt,
> where cts(cbc(aes)) is the "default" filenames encryption algorithm. In
> that use case, typically most messages are short and are block-aligned.

Does this mean the storage space for filenames is rounded up to AES block size?

> The CBC-CTS variant implemented is CS3; this is the variant Linux uses.
>
> To perform well on short messages, the new implementation processes the
> full message in one call to the assembly function if the data is
> contiguous. Otherwise it falls back to CBC operations followed by CTS
> at the end. For decryption, to further improve performance on short
> messages, especially block-aligned messages, the CBC-CTS assembly
> function parallelizes the AES decryption of all full blocks.

Nice!

> This
> improves on the arm64 implementation of cts(cbc(aes)), which always
> splits the CBC part(s) from the CTS part, doing the AES decryptions for
> the last two blocks serially and usually loading the round keys twice.
>

So is the overhead of this sub-optimal approach mostly in the
redundant loading of the round keys? Or are there other significant
benefits?

If there are, I suppose we might port this improvement to x86 too, but
otherwise, I guess it'll only make sense for arm64.


> Tested in QEMU with CONFIG_CRYPTO_MANAGER_EXTRA_TESTS=y.
>
> Signed-off-by: Eric Biggers <[email protected]>
> ---
> arch/riscv/crypto/Kconfig | 4 +-
> arch/riscv/crypto/aes-riscv64-glue.c | 93 ++++++++++++++-
> arch/riscv/crypto/aes-riscv64-zvkned.S | 153 +++++++++++++++++++++++++
> 3 files changed, 245 insertions(+), 5 deletions(-)
>
> diff --git a/arch/riscv/crypto/Kconfig b/arch/riscv/crypto/Kconfig
> index 2ad44e1d464a..ad58dad9a580 100644
> --- a/arch/riscv/crypto/Kconfig
> +++ b/arch/riscv/crypto/Kconfig
> @@ -1,23 +1,23 @@
> # SPDX-License-Identifier: GPL-2.0
>
> menu "Accelerated Cryptographic Algorithms for CPU (riscv)"
>
> config CRYPTO_AES_RISCV64
> - tristate "Ciphers: AES, modes: ECB, CBC, CTR, XTS"
> + tristate "Ciphers: AES, modes: ECB, CBC, CTS, CTR, XTS"
> depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
> select CRYPTO_ALGAPI
> select CRYPTO_LIB_AES
> select CRYPTO_SKCIPHER
> help
> Block cipher: AES cipher algorithms
> - Length-preserving ciphers: AES with ECB, CBC, CTR, XTS
> + Length-preserving ciphers: AES with ECB, CBC, CTS, CTR, XTS
>
> Architecture: riscv64 using:
> - Zvkned vector crypto extension
> - Zvbb vector extension (XTS)
> - Zvkb vector crypto extension (CTR)
> - Zvkg vector crypto extension (XTS)
>
> config CRYPTO_CHACHA_RISCV64
> tristate "Ciphers: ChaCha"
> depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
> diff --git a/arch/riscv/crypto/aes-riscv64-glue.c b/arch/riscv/crypto/aes-riscv64-glue.c
> index 37bc6ef0be40..f814ee048555 100644
> --- a/arch/riscv/crypto/aes-riscv64-glue.c
> +++ b/arch/riscv/crypto/aes-riscv64-glue.c
> @@ -1,20 +1,22 @@
> // SPDX-License-Identifier: GPL-2.0-only
> /*
> * AES using the RISC-V vector crypto extensions. Includes the bare block
> - * cipher and the ECB, CBC, CTR, and XTS modes.
> + * cipher and the ECB, CBC, CBC-CTS, CTR, and XTS modes.
> *
> * Copyright (C) 2023 VRULL GmbH
> * Author: Heiko Stuebner <[email protected]>
> *
> * Copyright (C) 2023 SiFive, Inc.
> * Author: Jerry Shih <[email protected]>
> + *
> + * Copyright 2024 Google LLC
> */
>
> #include <asm/simd.h>
> #include <asm/vector.h>
> #include <crypto/aes.h>
> #include <crypto/internal/cipher.h>
> #include <crypto/internal/simd.h>
> #include <crypto/internal/skcipher.h>
> #include <crypto/scatterwalk.h>
> #include <crypto/xts.h>
> @@ -33,20 +35,24 @@ asmlinkage void aes_ecb_encrypt_zvkned(const struct crypto_aes_ctx *key,
> asmlinkage void aes_ecb_decrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len);
>
> asmlinkage void aes_cbc_encrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
> asmlinkage void aes_cbc_decrypt_zvkned(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
>
> +asmlinkage void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
> + const u8 *in, u8 *out, size_t len,
> + const u8 iv[AES_BLOCK_SIZE], bool enc);
> +
> asmlinkage void aes_ctr32_crypt_zvkned_zvkb(const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 iv[AES_BLOCK_SIZE]);
>
> asmlinkage void aes_xts_encrypt_zvkned_zvbb_zvkg(
> const struct crypto_aes_ctx *key,
> const u8 *in, u8 *out, size_t len,
> u8 tweak[AES_BLOCK_SIZE]);
>
> asmlinkage void aes_xts_decrypt_zvkned_zvbb_zvkg(
> @@ -157,21 +163,21 @@ static int riscv64_aes_ecb_encrypt(struct skcipher_request *req)
> return riscv64_aes_ecb_crypt(req, true);
> }
>
> static int riscv64_aes_ecb_decrypt(struct skcipher_request *req)
> {
> return riscv64_aes_ecb_crypt(req, false);
> }
>
> /* AES-CBC */
>
> -static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> +static int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> {
> struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> struct skcipher_walk walk;
> unsigned int nbytes;
> int err;
>
> err = skcipher_walk_virt(&walk, req, false);
> while ((nbytes = walk.nbytes) != 0) {
> kernel_vector_begin();
> @@ -195,20 +201,84 @@ static inline int riscv64_aes_cbc_crypt(struct skcipher_request *req, bool enc)
> static int riscv64_aes_cbc_encrypt(struct skcipher_request *req)
> {
> return riscv64_aes_cbc_crypt(req, true);
> }
>
> static int riscv64_aes_cbc_decrypt(struct skcipher_request *req)
> {
> return riscv64_aes_cbc_crypt(req, false);
> }
>
> +/* AES-CBC-CTS */
> +
> +static int riscv64_aes_cbc_cts_crypt(struct skcipher_request *req, bool enc)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct scatterlist sg_src[2], sg_dst[2];
> + struct skcipher_request subreq;
> + struct scatterlist *src, *dst;
> + struct skcipher_walk walk;
> + unsigned int cbc_len;
> + int err;
> +
> + if (req->cryptlen < AES_BLOCK_SIZE)
> + return -EINVAL;
> +
> + err = skcipher_walk_virt(&walk, req, false);
> + if (err)
> + return err;
> + /*
> + * If the full message is available in one step, decrypt it in one call
> + * to the CBC-CTS assembly function. This reduces overhead, especially
> + * on short messages. Otherwise, fall back to doing CBC up to the last
> + * two blocks, then invoke CTS just for the ciphertext stealing.
> + */
> + if (unlikely(walk.nbytes != req->cryptlen)) {
> + cbc_len = round_down(req->cryptlen - AES_BLOCK_SIZE - 1,
> + AES_BLOCK_SIZE);
> + skcipher_walk_abort(&walk);
> + skcipher_request_set_tfm(&subreq, tfm);
> + skcipher_request_set_callback(&subreq,
> + skcipher_request_flags(req),
> + NULL, NULL);
> + skcipher_request_set_crypt(&subreq, req->src, req->dst,
> + cbc_len, req->iv);
> + err = riscv64_aes_cbc_crypt(&subreq, enc);
> + if (err)
> + return err;
> + dst = src = scatterwalk_ffwd(sg_src, req->src, cbc_len);
> + if (req->dst != req->src)
> + dst = scatterwalk_ffwd(sg_dst, req->dst, cbc_len);
> + skcipher_request_set_crypt(&subreq, src, dst,
> + req->cryptlen - cbc_len, req->iv);
> + err = skcipher_walk_virt(&walk, &subreq, false);
> + if (err)
> + return err;
> + }
> + kernel_vector_begin();
> + aes_cbc_cts_crypt_zvkned(ctx, walk.src.virt.addr, walk.dst.virt.addr,
> + walk.nbytes, req->iv, enc);
> + kernel_vector_end();
> + return skcipher_walk_done(&walk, 0);
> +}
> +
> +static int riscv64_aes_cbc_cts_encrypt(struct skcipher_request *req)
> +{
> + return riscv64_aes_cbc_cts_crypt(req, true);
> +}
> +
> +static int riscv64_aes_cbc_cts_decrypt(struct skcipher_request *req)
> +{
> + return riscv64_aes_cbc_cts_crypt(req, false);
> +}
> +
> /* AES-CTR */
>
> static int riscv64_aes_ctr_crypt(struct skcipher_request *req)
> {
> struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> unsigned int nbytes, p1_nbytes;
> struct skcipher_walk walk;
> u32 ctr32, nblocks;
> int err;
> @@ -427,20 +497,36 @@ static struct skcipher_alg riscv64_zvkned_aes_skcipher_algs[] = {
> .max_keysize = AES_MAX_KEY_SIZE,
> .ivsize = AES_BLOCK_SIZE,
> .base = {
> .cra_blocksize = AES_BLOCK_SIZE,
> .cra_ctxsize = sizeof(struct crypto_aes_ctx),
> .cra_priority = 300,
> .cra_name = "cbc(aes)",
> .cra_driver_name = "cbc-aes-riscv64-zvkned",
> .cra_module = THIS_MODULE,
> },
> + }, {
> + .setkey = riscv64_aes_setkey_skcipher,
> + .encrypt = riscv64_aes_cbc_cts_encrypt,
> + .decrypt = riscv64_aes_cbc_cts_decrypt,
> + .min_keysize = AES_MIN_KEY_SIZE,
> + .max_keysize = AES_MAX_KEY_SIZE,
> + .ivsize = AES_BLOCK_SIZE,
> + .walksize = 4 * AES_BLOCK_SIZE, /* matches LMUL=4 */
> + .base = {
> + .cra_blocksize = AES_BLOCK_SIZE,
> + .cra_ctxsize = sizeof(struct crypto_aes_ctx),
> + .cra_priority = 300,
> + .cra_name = "cts(cbc(aes))",
> + .cra_driver_name = "cts-cbc-aes-riscv64-zvkned",
> + .cra_module = THIS_MODULE,
> + },
> }
> };
>
> static struct skcipher_alg riscv64_zvkned_zvkb_aes_skcipher_alg = {
> .setkey = riscv64_aes_setkey_skcipher,
> .encrypt = riscv64_aes_ctr_crypt,
> .decrypt = riscv64_aes_ctr_crypt,
> .min_keysize = AES_MIN_KEY_SIZE,
> .max_keysize = AES_MAX_KEY_SIZE,
> .ivsize = AES_BLOCK_SIZE,
> @@ -533,18 +619,19 @@ static void __exit riscv64_aes_mod_exit(void)
> if (riscv_isa_extension_available(NULL, ZVKB))
> crypto_unregister_skcipher(&riscv64_zvkned_zvkb_aes_skcipher_alg);
> crypto_unregister_skciphers(riscv64_zvkned_aes_skcipher_algs,
> ARRAY_SIZE(riscv64_zvkned_aes_skcipher_algs));
> crypto_unregister_alg(&riscv64_zvkned_aes_cipher_alg);
> }
>
> module_init(riscv64_aes_mod_init);
> module_exit(riscv64_aes_mod_exit);
>
> -MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS (RISC-V accelerated)");
> +MODULE_DESCRIPTION("AES-ECB/CBC/CTS/CTR/XTS (RISC-V accelerated)");
> MODULE_AUTHOR("Jerry Shih <[email protected]>");
> MODULE_LICENSE("GPL");
> MODULE_ALIAS_CRYPTO("aes");
> MODULE_ALIAS_CRYPTO("ecb(aes)");
> MODULE_ALIAS_CRYPTO("cbc(aes)");
> +MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
> MODULE_ALIAS_CRYPTO("ctr(aes)");
> MODULE_ALIAS_CRYPTO("xts(aes)");
> diff --git a/arch/riscv/crypto/aes-riscv64-zvkned.S b/arch/riscv/crypto/aes-riscv64-zvkned.S
> index 43541aad6386..23d063f94ce6 100644
> --- a/arch/riscv/crypto/aes-riscv64-zvkned.S
> +++ b/arch/riscv/crypto/aes-riscv64-zvkned.S
> @@ -177,10 +177,163 @@ SYM_FUNC_END(aes_cbc_encrypt_zvkned)
>
> // Same prototype and calling convention as the encryption function
> SYM_FUNC_START(aes_cbc_decrypt_zvkned)
> aes_begin KEYP, 128f, 192f
> aes_cbc_decrypt 256
> 128:
> aes_cbc_decrypt 128
> 192:
> aes_cbc_decrypt 192
> SYM_FUNC_END(aes_cbc_decrypt_zvkned)
> +
> +.macro aes_cbc_cts_encrypt keylen
> +
> + // CBC-encrypt all blocks except the last. But don't store the
> + // second-to-last block to the output buffer yet, since it will be
> + // handled specially in the ciphertext stealing step. Exception: if the
> + // message is single-block, still encrypt the last (and only) block.
> + li t0, 16
> + j 2f
> +1:
> + vse32.v v16, (OUTP) // Store ciphertext block
> + addi OUTP, OUTP, 16
> +2:
> + vle32.v v17, (INP) // Load plaintext block
> + vxor.vv v16, v16, v17 // XOR with IV or prev ciphertext block
> + aes_encrypt v16, \keylen // Encrypt
> + addi INP, INP, 16
> + addi LEN, LEN, -16
> + bgt LEN, t0, 1b // Repeat if more than one block remains
> +
> + // Special case: if the message is a single block, just do CBC.
> + beqz LEN, .Lcts_encrypt_done\@
> +
> + // Encrypt the last two blocks using ciphertext stealing as follows:
> + // C[n-1] = Encrypt(Encrypt(P[n-1] ^ C[n-2]) ^ P[n])
> + // C[n] = Encrypt(P[n-1] ^ C[n-2])[0..LEN]
> + //
> + // C[i] denotes the i'th ciphertext block, and likewise P[i] the i'th
> + // plaintext block. Block n, the last block, may be partial; its length
> + // is 1 <= LEN <= 16. If there are only 2 blocks, C[n-2] means the IV.
> + //
> + // v16 already contains Encrypt(P[n-1] ^ C[n-2]).
> + // INP points to P[n]. OUTP points to where C[n-1] should go.
> + // To support in-place encryption, load P[n] before storing C[n].
> + addi t0, OUTP, 16 // Get pointer to where C[n] should go
> + vsetvli zero, LEN, e8, m1, tu, ma
> + vle8.v v17, (INP) // Load P[n]
> + vse8.v v16, (t0) // Store C[n]
> + vxor.vv v16, v16, v17 // v16 = Encrypt(P[n-1] ^ C[n-2]) ^ P[n]
> + vsetivli zero, 4, e32, m1, ta, ma
> + aes_encrypt v16, \keylen
> +.Lcts_encrypt_done\@:
> + vse32.v v16, (OUTP) // Store C[n-1] (or C[n] in single-block case)
> + ret
> +.endm
> +
> +#define LEN32 t4 // Length of remaining full blocks in 32-bit words
> +#define LEN_MOD16 t5 // Length of message in bytes mod 16
> +
> +.macro aes_cbc_cts_decrypt keylen
> + andi LEN32, LEN, ~15
> + srli LEN32, LEN32, 2
> + andi LEN_MOD16, LEN, 15
> +
> + // Save C[n-2] in v28 so that it's available later during the ciphertext
> + // stealing step. If there are fewer than three blocks, C[n-2] means
> + // the IV, otherwise it means the third-to-last ciphertext block.
> + vmv.v.v v28, v16 // IV
> + add t0, LEN, -33
> + bltz t0, .Lcts_decrypt_loop\@
> + andi t0, t0, ~15
> + add t0, t0, INP
> + vle32.v v28, (t0)
> +
> + // CBC-decrypt all full blocks. For the last full block, or the last 2
> + // full blocks if the message is block-aligned, this doesn't write the
> + // correct output blocks (unless the message is only a single block),
> + // because it XORs the wrong values with the raw AES plaintexts. But we
> + // fix this after this loop without redoing the AES decryptions. This
> + // approach allows more of the AES decryptions to be parallelized.
> +.Lcts_decrypt_loop\@:
> + vsetvli t0, LEN32, e32, m4, ta, ma
> + addi t1, t0, -4
> + vle32.v v20, (INP) // Load next set of ciphertext blocks
> + vmv.v.v v24, v16 // Get IV or last ciphertext block of prev set
> + vslideup.vi v24, v20, 4 // Setup prev ciphertext blocks
> + vslidedown.vx v16, v20, t1 // Save last ciphertext block of this set
> + aes_decrypt v20, \keylen // Decrypt this set of blocks
> + vxor.vv v24, v24, v20 // XOR prev ciphertext blocks with decrypted blocks
> + vse32.v v24, (OUTP) // Store this set of plaintext blocks
> + sub LEN32, LEN32, t0
> + slli t0, t0, 2 // Words to bytes
> + add INP, INP, t0
> + add OUTP, OUTP, t0
> + bnez LEN32, .Lcts_decrypt_loop\@
> +
> + vsetivli zero, 4, e32, m4, ta, ma
> + vslidedown.vx v20, v20, t1 // Extract raw plaintext of last full block
> + addi t0, OUTP, -16 // Get pointer to last full plaintext block
> + bnez LEN_MOD16, .Lcts_decrypt_non_block_aligned\@
> +
> + // Special case: if the message is a single block, just do CBC.
> + li t1, 16
> + beq LEN, t1, .Lcts_decrypt_done\@
> +
> + // Block-aligned message. Just fix up the last 2 blocks. We need:
> + //
> + // P[n-1] = Decrypt(C[n]) ^ C[n-2]
> + // P[n] = Decrypt(C[n-1]) ^ C[n]
> + //
> + // We have C[n] in v16, Decrypt(C[n]) in v20, and C[n-2] in v28.
> + // Together with Decrypt(C[n-1]) ^ C[n-2] from the output buffer, this
> + // is everything needed to fix the output without re-decrypting blocks.
> + addi t1, OUTP, -32 // Get pointer to where P[n-1] should go
> + vxor.vv v20, v20, v28 // Decrypt(C[n]) ^ C[n-2] == P[n-1]
> + vle32.v v24, (t1) // Decrypt(C[n-1]) ^ C[n-2]
> + vse32.v v20, (t1) // Store P[n-1]
> + vxor.vv v20, v24, v16 // Decrypt(C[n-1]) ^ C[n-2] ^ C[n] == P[n] ^ C[n-2]
> + j .Lcts_decrypt_finish\@
> +
> +.Lcts_decrypt_non_block_aligned\@:
> + // Decrypt the last two blocks using ciphertext stealing as follows:
> + //
> + // P[n-1] = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16]) ^ C[n-2]
> + // P[n] = (Decrypt(C[n-1]) ^ C[n])[0..LEN_MOD16]
> + //
> + // We already have Decrypt(C[n-1]) in v20 and C[n-2] in v28.
> + vmv.v.v v16, v20 // v16 = Decrypt(C[n-1])
> + vsetvli zero, LEN_MOD16, e8, m1, tu, ma
> + vle8.v v20, (INP) // v20 = C[n] || Decrypt(C[n-1])[LEN_MOD16..16]
> + vxor.vv v16, v16, v20 // v16 = Decrypt(C[n-1]) ^ C[n]
> + vse8.v v16, (OUTP) // Store P[n]
> + vsetivli zero, 4, e32, m1, ta, ma
> + aes_decrypt v20, \keylen // v20 = Decrypt(C[n] || Decrypt(C[n-1])[LEN_MOD16..16])
> +.Lcts_decrypt_finish\@:
> + vxor.vv v20, v20, v28 // XOR with C[n-2]
> + vse32.v v20, (t0) // Store last full plaintext block
> +.Lcts_decrypt_done\@:
> + ret
> +.endm
> +
> +.macro aes_cbc_cts_crypt keylen
> + vle32.v v16, (IVP) // Load IV
> + beqz a5, .Lcts_decrypt\@
> + aes_cbc_cts_encrypt \keylen
> +.Lcts_decrypt\@:
> + aes_cbc_cts_decrypt \keylen
> +.endm
> +
> +// void aes_cbc_cts_crypt_zvkned(const struct crypto_aes_ctx *key,
> +// const u8 *in, u8 *out, size_t len,
> +// const u8 iv[16], bool enc);
> +//
> +// Encrypts or decrypts a message with the CS3 variant of AES-CBC-CTS.
> +// This is the variant that unconditionally swaps the last two blocks.
> +SYM_FUNC_START(aes_cbc_cts_crypt_zvkned)
> + aes_begin KEYP, 128f, 192f
> + aes_cbc_cts_crypt 256
> +128:
> + aes_cbc_cts_crypt 128
> +192:
> + aes_cbc_cts_crypt 192
> +SYM_FUNC_END(aes_cbc_cts_crypt_zvkned)
>
> base-commit: cb4ede926134a65bc3bf90ed58dace8451d7e759
> prerequisite-patch-id: 2a69e1270be0fa567cc43269826171d6e46d65de
> --
> 2.43.0
>

2024-02-14 22:42:35

by Eric Biggers

[permalink] [raw]
Subject: Re: [PATCH riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS

On Wed, Feb 14, 2024 at 05:34:03PM +0100, Ard Biesheuvel wrote:
> On Tue, 13 Feb 2024 at 06:57, Eric Biggers <[email protected]> wrote:
> >
> > From: Eric Biggers <[email protected]>
> >
> > Add an implementation of cts(cbc(aes)) accelerated using the Zvkned
> > RISC-V vector crypto extension. This is mainly useful for fscrypt,
> > where cts(cbc(aes)) is the "default" filenames encryption algorithm. In
> > that use case, typically most messages are short and are block-aligned.
>
> Does this mean the storage space for filenames is rounded up to AES block size?

Yes, in most cases. fscrypt allows the filenames padding to be configured to be
4, 8, 16, or 32 bytes. If it's 16 or 32, which is recommended, then the sizes
of encrypted filenames are multiples of the AES block size, except for filenames
longer than 240 bytes which get rounded up to 255 bytes.

>
> > The CBC-CTS variant implemented is CS3; this is the variant Linux uses.
> >
> > To perform well on short messages, the new implementation processes the
> > full message in one call to the assembly function if the data is
> > contiguous. Otherwise it falls back to CBC operations followed by CTS
> > at the end. For decryption, to further improve performance on short
> > messages, especially block-aligned messages, the CBC-CTS assembly
> > function parallelizes the AES decryption of all full blocks.
>
> Nice!
>
> > This
> > improves on the arm64 implementation of cts(cbc(aes)), which always
> > splits the CBC part(s) from the CTS part, doing the AES decryptions for
> > the last two blocks serially and usually loading the round keys twice.
> >
>
> So is the overhead of this sub-optimal approach mostly in the
> redundant loading of the round keys? Or are there other significant
> benefits?
>
> If there are, I suppose we might port this improvement to x86 too, but
> otherwise, I guess it'll only make sense for arm64.

I expect that the serialization of the last two AES decryptions makes the
biggest difference, followed by the other sources of overhead (loading round
keys, skcipher_walk, kernel_neon_begin). It needs to be measured, though.

I'd like to try the same optimization for arm64 and x86. It's not fun going
back to SIMD after working with the RISC-V Vector Extension, though!

- Eric

Subject: Re: [PATCH riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS

Hello:

This patch was applied to riscv/linux.git (for-next)
by Palmer Dabbelt <[email protected]>:

On Mon, 12 Feb 2024 21:54:42 -0800 you wrote:
> From: Eric Biggers <[email protected]>
>
> Add an implementation of cts(cbc(aes)) accelerated using the Zvkned
> RISC-V vector crypto extension. This is mainly useful for fscrypt,
> where cts(cbc(aes)) is the "default" filenames encryption algorithm. In
> that use case, typically most messages are short and are block-aligned.
> The CBC-CTS variant implemented is CS3; this is the variant Linux uses.
>
> [...]

Here is the summary with links:
- [riscv/for-next] crypto: riscv - add vector crypto accelerated AES-CBC-CTS
https://git.kernel.org/riscv/c/c70dfa4a2723

You are awesome, thank you!
--
Deet-doot-dot, I am a bot.
https://korg.docs.kernel.org/patchwork/pwbot.html