2022-05-20 18:31:46

by Nathan Huckleberry

[permalink] [raw]
Subject: [PATCH v9 6/9] crypto: arm64/aes-xctr: Improve readability of XCTR and CTR modes

Added some clarifying comments, changed the register allocations to make
the code clearer, and added register aliases.

Signed-off-by: Nathan Huckleberry <[email protected]>
Reviewed-by: Eric Biggers <[email protected]>
---
arch/arm64/crypto/aes-glue.c | 16 +++
arch/arm64/crypto/aes-modes.S | 237 ++++++++++++++++++++++++----------
2 files changed, 185 insertions(+), 68 deletions(-)

diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index b6883288234c..162787c7aa86 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -464,6 +464,14 @@ static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
u8 *dst = walk.dst.virt.addr;
u8 buf[AES_BLOCK_SIZE];

+ /*
+ * If given less than 16 bytes, we must copy the partial block
+ * into a temporary buffer of 16 bytes to avoid out of bounds
+ * reads and writes. Furthermore, this code is somewhat unusual
+ * in that it expects the end of the data to be at the end of
+ * the temporary buffer, rather than the start of the data at
+ * the start of the temporary buffer.
+ */
if (unlikely(nbytes < AES_BLOCK_SIZE))
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
@@ -501,6 +509,14 @@ static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
u8 *dst = walk.dst.virt.addr;
u8 buf[AES_BLOCK_SIZE];

+ /*
+ * If given less than 16 bytes, we must copy the partial block
+ * into a temporary buffer of 16 bytes to avoid out of bounds
+ * reads and writes. Furthermore, this code is somewhat unusual
+ * in that it expects the end of the data to be at the end of
+ * the temporary buffer, rather than the start of the data at
+ * the start of the temporary buffer.
+ */
if (unlikely(nbytes < AES_BLOCK_SIZE))
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 9a027200bdba..5e0f0e51557c 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -322,32 +322,60 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
* This macro generates the code for CTR and XCTR mode.
*/
.macro ctr_encrypt xctr
+ // Arguments
+ OUT .req x0
+ IN .req x1
+ KEY .req x2
+ ROUNDS_W .req w3
+ BYTES_W .req w4
+ IV .req x5
+ BYTE_CTR_W .req w6 // XCTR only
+ // Intermediate values
+ CTR_W .req w11 // XCTR only
+ CTR .req x11 // XCTR only
+ IV_PART .req x12
+ BLOCKS .req x13
+ BLOCKS_W .req w13
+
stp x29, x30, [sp, #-16]!
mov x29, sp

- enc_prepare w3, x2, x12
- ld1 {vctr.16b}, [x5]
+ enc_prepare ROUNDS_W, KEY, IV_PART
+ ld1 {vctr.16b}, [IV]

+ /*
+ * Keep 64 bits of the IV in a register. For CTR mode this lets us
+ * easily increment the IV. For XCTR mode this lets us efficiently XOR
+ * the 64-bit counter with the IV.
+ */
.if \xctr
- umov x12, vctr.d[0]
- lsr w11, w6, #4
+ umov IV_PART, vctr.d[0]
+ lsr CTR_W, BYTE_CTR_W, #4
.else
- umov x12, vctr.d[1] /* keep swabbed ctr in reg */
- rev x12, x12
+ umov IV_PART, vctr.d[1]
+ rev IV_PART, IV_PART
.endif

.LctrloopNx\xctr:
- add w7, w4, #15
- sub w4, w4, #MAX_STRIDE << 4
- lsr w7, w7, #4
+ add BLOCKS_W, BYTES_W, #15
+ sub BYTES_W, BYTES_W, #MAX_STRIDE << 4
+ lsr BLOCKS_W, BLOCKS_W, #4
mov w8, #MAX_STRIDE
- cmp w7, w8
- csel w7, w7, w8, lt
+ cmp BLOCKS_W, w8
+ csel BLOCKS_W, BLOCKS_W, w8, lt

+ /*
+ * Set up the counter values in v0-v{MAX_STRIDE-1}.
+ *
+ * If we are encrypting less than MAX_STRIDE blocks, the tail block
+ * handling code expects the last keystream block to be in
+ * v{MAX_STRIDE-1}. For example: if encrypting two blocks with
+ * MAX_STRIDE=5, then v3 and v4 should have the next two counter blocks.
+ */
.if \xctr
- add x11, x11, x7
+ add CTR, CTR, BLOCKS
.else
- adds x12, x12, x7
+ adds IV_PART, IV_PART, BLOCKS
.endif
mov v0.16b, vctr.16b
mov v1.16b, vctr.16b
@@ -355,16 +383,16 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
mov v3.16b, vctr.16b
ST5( mov v4.16b, vctr.16b )
.if \xctr
- sub x6, x11, #MAX_STRIDE - 1
- sub x7, x11, #MAX_STRIDE - 2
- sub x8, x11, #MAX_STRIDE - 3
- sub x9, x11, #MAX_STRIDE - 4
-ST5( sub x10, x11, #MAX_STRIDE - 5 )
- eor x6, x6, x12
- eor x7, x7, x12
- eor x8, x8, x12
- eor x9, x9, x12
-ST5( eor x10, x10, x12 )
+ sub x6, CTR, #MAX_STRIDE - 1
+ sub x7, CTR, #MAX_STRIDE - 2
+ sub x8, CTR, #MAX_STRIDE - 3
+ sub x9, CTR, #MAX_STRIDE - 4
+ST5( sub x10, CTR, #MAX_STRIDE - 5 )
+ eor x6, x6, IV_PART
+ eor x7, x7, IV_PART
+ eor x8, x8, IV_PART
+ eor x9, x9, IV_PART
+ST5( eor x10, x10, IV_PART )
mov v0.d[0], x6
mov v1.d[0], x7
mov v2.d[0], x8
@@ -373,17 +401,32 @@ ST5( mov v4.d[0], x10 )
.else
bcs 0f
.subsection 1
- /* apply carry to outgoing counter */
+ /*
+ * This subsection handles carries.
+ *
+ * Conditional branching here is allowed with respect to time
+ * invariance since the branches are dependent on the IV instead
+ * of the plaintext or key. This code is rarely executed in
+ * practice anyway.
+ */
+
+ /* Apply carry to outgoing counter. */
0: umov x8, vctr.d[0]
rev x8, x8
add x8, x8, #1
rev x8, x8
ins vctr.d[0], x8

- /* apply carry to N counter blocks for N := x12 */
- cbz x12, 2f
+ /*
+ * Apply carry to counter blocks if needed.
+ *
+ * Since the carry flag was set, we know 0 <= IV_PART <
+ * MAX_STRIDE. Using the value of IV_PART we can determine how
+ * many counter blocks need to be updated.
+ */
+ cbz IV_PART, 2f
adr x16, 1f
- sub x16, x16, x12, lsl #3
+ sub x16, x16, IV_PART, lsl #3
br x16
bti c
mov v0.d[0], vctr.d[0]
@@ -398,71 +441,88 @@ ST5( mov v4.d[0], vctr.d[0] )
1: b 2f
.previous

-2: rev x7, x12
+2: rev x7, IV_PART
ins vctr.d[1], x7
- sub x7, x12, #MAX_STRIDE - 1
- sub x8, x12, #MAX_STRIDE - 2
- sub x9, x12, #MAX_STRIDE - 3
+ sub x7, IV_PART, #MAX_STRIDE - 1
+ sub x8, IV_PART, #MAX_STRIDE - 2
+ sub x9, IV_PART, #MAX_STRIDE - 3
rev x7, x7
rev x8, x8
mov v1.d[1], x7
rev x9, x9
-ST5( sub x10, x12, #MAX_STRIDE - 4 )
+ST5( sub x10, IV_PART, #MAX_STRIDE - 4 )
mov v2.d[1], x8
ST5( rev x10, x10 )
mov v3.d[1], x9
ST5( mov v4.d[1], x10 )
.endif
- tbnz w4, #31, .Lctrtail\xctr
- ld1 {v5.16b-v7.16b}, [x1], #48
+
+ /*
+ * If there are at least MAX_STRIDE blocks left, XOR the data with
+ * keystream and store. Otherwise jump to tail handling.
+ */
+ tbnz BYTES_W, #31, .Lctrtail\xctr
+ ld1 {v5.16b-v7.16b}, [IN], #48
ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x )
eor v0.16b, v5.16b, v0.16b
-ST4( ld1 {v5.16b}, [x1], #16 )
+ST4( ld1 {v5.16b}, [IN], #16 )
eor v1.16b, v6.16b, v1.16b
-ST5( ld1 {v5.16b-v6.16b}, [x1], #32 )
+ST5( ld1 {v5.16b-v6.16b}, [IN], #32 )
eor v2.16b, v7.16b, v2.16b
eor v3.16b, v5.16b, v3.16b
ST5( eor v4.16b, v6.16b, v4.16b )
- st1 {v0.16b-v3.16b}, [x0], #64
-ST5( st1 {v4.16b}, [x0], #16 )
- cbz w4, .Lctrout\xctr
+ st1 {v0.16b-v3.16b}, [OUT], #64
+ST5( st1 {v4.16b}, [OUT], #16 )
+ cbz BYTES_W, .Lctrout\xctr
b .LctrloopNx\xctr

.Lctrout\xctr:
.if !\xctr
- st1 {vctr.16b}, [x5] /* return next CTR value */
+ st1 {vctr.16b}, [IV] /* return next CTR value */
.endif
ldp x29, x30, [sp], #16
ret

.Lctrtail\xctr:
+ /*
+ * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
+ *
+ * This code expects the last keystream block to be in v{MAX_STRIDE-1}.
+ * For example: if encrypting two blocks with MAX_STRIDE=5, then v3 and
+ * v4 should have the next two counter blocks.
+ *
+ * This allows us to store the ciphertext by writing to overlapping
+ * regions of memory. Any invalid ciphertext blocks get overwritten by
+ * correctly computed blocks. This approach greatly simplifies the
+ * logic for storing the ciphertext.
+ */
mov x16, #16
- ands x6, x4, #0xf
- csel x13, x6, x16, ne
+ ands w7, BYTES_W, #0xf
+ csel x13, x7, x16, ne

-ST5( cmp w4, #64 - (MAX_STRIDE << 4) )
+ST5( cmp BYTES_W, #64 - (MAX_STRIDE << 4))
ST5( csel x14, x16, xzr, gt )
- cmp w4, #48 - (MAX_STRIDE << 4)
+ cmp BYTES_W, #48 - (MAX_STRIDE << 4)
csel x15, x16, xzr, gt
- cmp w4, #32 - (MAX_STRIDE << 4)
+ cmp BYTES_W, #32 - (MAX_STRIDE << 4)
csel x16, x16, xzr, gt
- cmp w4, #16 - (MAX_STRIDE << 4)
+ cmp BYTES_W, #16 - (MAX_STRIDE << 4)

- adr_l x12, .Lcts_permute_table
- add x12, x12, x13
+ adr_l x9, .Lcts_permute_table
+ add x9, x9, x13
ble .Lctrtail1x\xctr

-ST5( ld1 {v5.16b}, [x1], x14 )
- ld1 {v6.16b}, [x1], x15
- ld1 {v7.16b}, [x1], x16
+ST5( ld1 {v5.16b}, [IN], x14 )
+ ld1 {v6.16b}, [IN], x15
+ ld1 {v7.16b}, [IN], x16

ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x )

- ld1 {v8.16b}, [x1], x13
- ld1 {v9.16b}, [x1]
- ld1 {v10.16b}, [x12]
+ ld1 {v8.16b}, [IN], x13
+ ld1 {v9.16b}, [IN]
+ ld1 {v10.16b}, [x9]

ST4( eor v6.16b, v6.16b, v0.16b )
ST4( eor v7.16b, v7.16b, v1.16b )
@@ -477,35 +537,70 @@ ST5( eor v7.16b, v7.16b, v2.16b )
ST5( eor v8.16b, v8.16b, v3.16b )
ST5( eor v9.16b, v9.16b, v4.16b )

-ST5( st1 {v5.16b}, [x0], x14 )
- st1 {v6.16b}, [x0], x15
- st1 {v7.16b}, [x0], x16
- add x13, x13, x0
+ST5( st1 {v5.16b}, [OUT], x14 )
+ st1 {v6.16b}, [OUT], x15
+ st1 {v7.16b}, [OUT], x16
+ add x13, x13, OUT
st1 {v9.16b}, [x13] // overlapping stores
- st1 {v8.16b}, [x0]
+ st1 {v8.16b}, [OUT]
b .Lctrout\xctr

.Lctrtail1x\xctr:
- sub x7, x6, #16
- csel x6, x6, x7, eq
- add x1, x1, x6
- add x0, x0, x6
- ld1 {v5.16b}, [x1]
- ld1 {v6.16b}, [x0]
+ /*
+ * Handle <= 16 bytes of plaintext
+ *
+ * This code always reads and writes 16 bytes. To avoid out of bounds
+ * accesses, XCTR and CTR modes must use a temporary buffer when
+ * encrypting/decrypting less than 16 bytes.
+ *
+ * This code is unusual in that it loads the input and stores the output
+ * relative to the end of the buffers rather than relative to the start.
+ * This causes unusual behaviour when encrypting/decrypting less than 16
+ * bytes; the end of the data is expected to be at the end of the
+ * temporary buffer rather than the start of the data being at the start
+ * of the temporary buffer.
+ */
+ sub x8, x7, #16
+ csel x7, x7, x8, eq
+ add IN, IN, x7
+ add OUT, OUT, x7
+ ld1 {v5.16b}, [IN]
+ ld1 {v6.16b}, [OUT]
ST5( mov v3.16b, v4.16b )
- encrypt_block v3, w3, x2, x8, w7
- ld1 {v10.16b-v11.16b}, [x12]
+ encrypt_block v3, ROUNDS_W, KEY, x8, w7
+ ld1 {v10.16b-v11.16b}, [x9]
tbl v3.16b, {v3.16b}, v10.16b
sshr v11.16b, v11.16b, #7
eor v5.16b, v5.16b, v3.16b
bif v5.16b, v6.16b, v11.16b
- st1 {v5.16b}, [x0]
+ st1 {v5.16b}, [OUT]
b .Lctrout\xctr
+
+ // Arguments
+ .unreq OUT
+ .unreq IN
+ .unreq KEY
+ .unreq ROUNDS_W
+ .unreq BYTES_W
+ .unreq IV
+ .unreq BYTE_CTR_W // XCTR only
+ // Intermediate values
+ .unreq CTR_W // XCTR only
+ .unreq CTR // XCTR only
+ .unreq IV_PART
+ .unreq BLOCKS
+ .unreq BLOCKS_W
.endm

/*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 ctr[])
+ *
+ * The input and output buffers must always be at least 16 bytes even if
+ * encrypting/decrypting less than 16 bytes. Otherwise out of bounds
+ * accesses will occur. The data to be encrypted/decrypted is expected
+ * to be at the end of this 16-byte temporary buffer rather than the
+ * start.
*/

AES_FUNC_START(aes_ctr_encrypt)
@@ -515,6 +610,12 @@ AES_FUNC_END(aes_ctr_encrypt)
/*
* aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 const iv[], int byte_ctr)
+ *
+ * The input and output buffers must always be at least 16 bytes even if
+ * encrypting/decrypting less than 16 bytes. Otherwise out of bounds
+ * accesses will occur. The data to be encrypted/decrypted is expected
+ * to be at the end of this 16-byte temporary buffer rather than the
+ * start.
*/

AES_FUNC_START(aes_xctr_encrypt)
--
2.36.1.124.g0e6072fb45-goog