2024-04-04 11:43:18

by Puranjay Mohan

[permalink] [raw]
Subject: [PATCH bpf-next v4 1/2] bpf,riscv: Implement PROBE_MEM32 pseudo instructions

Add support for [LDX | STX | ST], PROBE_MEM32, [B | H | W | DW]
instructions. They are similar to PROBE_MEM instructions with the
following differences:
- PROBE_MEM32 supports store.
- PROBE_MEM32 relies on the verifier to clear upper 32-bit of the
src/dst register
- PROBE_MEM32 adds 64-bit kern_vm_start address (which is stored in S7
in the prologue). Due to bpf_arena constructions such S7 + reg +
off16 access is guaranteed to be within arena virtual range, so no
address check at run-time.
- S11 is a free callee-saved register, so it is used to store kern_vm_start
- PROBE_MEM32 allows STX and ST. If they fault the store is a nop. When
LDX faults the destination register is zeroed.

To support these on riscv, we do tmp = S7 + src/dst reg and then use
tmp2 as the new src/dst register. This allows us to reuse most of the
code for normal [LDX | STX | ST].

Tested-by: Björn Töpel <[email protected]>
Tested-by: Pu Lehui <[email protected]>
Acked-by: Björn Töpel <[email protected]>
Reviewed-by: Pu Lehui <[email protected]>
Signed-off-by: Puranjay Mohan <[email protected]>
---
arch/riscv/net/bpf_jit.h | 1 +
arch/riscv/net/bpf_jit_comp64.c | 189 +++++++++++++++++++++++++++++++-
arch/riscv/net/bpf_jit_core.c | 1 +
3 files changed, 189 insertions(+), 2 deletions(-)

diff --git a/arch/riscv/net/bpf_jit.h b/arch/riscv/net/bpf_jit.h
index f4b6b3b9edda..8a47da08dd9c 100644
--- a/arch/riscv/net/bpf_jit.h
+++ b/arch/riscv/net/bpf_jit.h
@@ -81,6 +81,7 @@ struct rv_jit_context {
int nexentries;
unsigned long flags;
int stack_size;
+ u64 arena_vm_start;
};

/* Convert from ninsns to bytes. */
diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
index 1adf2f39ce59..a4c8e1e6c1e2 100644
--- a/arch/riscv/net/bpf_jit_comp64.c
+++ b/arch/riscv/net/bpf_jit_comp64.c
@@ -18,6 +18,7 @@

#define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
+#define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */

static const int regmap[] = {
[BPF_REG_0] = RV_REG_A5,
@@ -255,6 +256,10 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
store_offset -= 8;
}
+ if (ctx->arena_vm_start) {
+ emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
+ store_offset -= 8;
+ }

emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
/* Set return value. */
@@ -548,6 +553,7 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,

#define BPF_FIXUP_OFFSET_MASK GENMASK(26, 0)
#define BPF_FIXUP_REG_MASK GENMASK(31, 27)
+#define REG_DONT_CLEAR_MARKER 0 /* RV_REG_ZERO unused in pt_regmap */

bool ex_handler_bpf(const struct exception_table_entry *ex,
struct pt_regs *regs)
@@ -555,7 +561,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex,
off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);

- *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
+ if (regs_offset != REG_DONT_CLEAR_MARKER)
+ *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
regs->epc = (unsigned long)&ex->fixup - offset;

return true;
@@ -572,7 +579,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
off_t fixup_offset;

if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
- (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
+ (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
+ BPF_MODE(insn->code) != BPF_PROBE_MEM32))
return 0;

if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
@@ -1539,6 +1547,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
+ /* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
+ case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
+ case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
+ case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
+ case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
{
int insn_len, insns_start;
bool sign_ext;
@@ -1546,6 +1559,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
BPF_MODE(insn->code) == BPF_PROBE_MEMSX;

+ if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
+ emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
+ rs = RV_REG_T2;
+ }
+
switch (BPF_SIZE(code)) {
case BPF_B:
if (is_12b_int(off)) {
@@ -1682,6 +1700,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
break;

+ case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
+ case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
+ case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
+ case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
+ {
+ int insn_len, insns_start;
+
+ emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
+ rd = RV_REG_T3;
+
+ /* Load imm to a register then store it */
+ emit_imm(RV_REG_T1, imm, ctx);
+
+ switch (BPF_SIZE(code)) {
+ case BPF_B:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit(rv_sb(rd, off, RV_REG_T1), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T2, off, ctx);
+ emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_H:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit(rv_sh(rd, off, RV_REG_T1), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T2, off, ctx);
+ emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_W:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit_sw(rd, off, RV_REG_T1, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T2, off, ctx);
+ emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_DW:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit_sd(rd, off, RV_REG_T1, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T2, off, ctx);
+ emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
+ insn_len);
+ if (ret)
+ return ret;
+
+ break;
+ }
+
/* STX: *(size *)(dst + off) = src */
case BPF_STX | BPF_MEM | BPF_B:
if (is_12b_int(off)) {
@@ -1728,6 +1826,84 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
emit_atomic(rd, rs, off, imm,
BPF_SIZE(code) == BPF_DW, ctx);
break;
+
+ case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
+ case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
+ case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
+ case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
+ {
+ int insn_len, insns_start;
+
+ emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
+ rd = RV_REG_T2;
+
+ switch (BPF_SIZE(code)) {
+ case BPF_B:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit(rv_sb(rd, off, rs), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T1, off, ctx);
+ emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit(rv_sb(RV_REG_T1, 0, rs), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_H:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit(rv_sh(rd, off, rs), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T1, off, ctx);
+ emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit(rv_sh(RV_REG_T1, 0, rs), ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_W:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit_sw(rd, off, rs, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T1, off, ctx);
+ emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit_sw(RV_REG_T1, 0, rs, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ case BPF_DW:
+ if (is_12b_int(off)) {
+ insns_start = ctx->ninsns;
+ emit_sd(rd, off, rs, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ emit_imm(RV_REG_T1, off, ctx);
+ emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+ insns_start = ctx->ninsns;
+ emit_sd(RV_REG_T1, 0, rs, ctx);
+ insn_len = ctx->ninsns - insns_start;
+ break;
+ }
+
+ ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
+ insn_len);
+ if (ret)
+ return ret;
+
+ break;
+ }
+
default:
pr_err("bpf-jit: unknown opcode %02x\n", code);
return -EINVAL;
@@ -1759,6 +1935,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
stack_adjust += 8;
if (seen_reg(RV_REG_S6, ctx))
stack_adjust += 8;
+ if (ctx->arena_vm_start)
+ stack_adjust += 8;

stack_adjust = round_up(stack_adjust, 16);
stack_adjust += bpf_stack_adjust;
@@ -1810,6 +1988,10 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
store_offset -= 8;
}
+ if (ctx->arena_vm_start) {
+ emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
+ store_offset -= 8;
+ }

emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);

@@ -1823,6 +2005,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);

ctx->stack_size = stack_adjust;
+
+ if (ctx->arena_vm_start)
+ emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
}

void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
diff --git a/arch/riscv/net/bpf_jit_core.c b/arch/riscv/net/bpf_jit_core.c
index 6b3acac30c06..9ab739b9f9a2 100644
--- a/arch/riscv/net/bpf_jit_core.c
+++ b/arch/riscv/net/bpf_jit_core.c
@@ -80,6 +80,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
goto skip_init_ctx;
}

+ ctx->arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
ctx->prog = prog;
ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
if (!ctx->offset) {
--
2.40.1