2022-09-11 08:38:03

by Yuanchu Xie

[permalink] [raw]
Subject: [RFC PATCH 0/2] mm: multi-gen LRU: per-process heatmaps

Today, the MGLRU debugfs interface (/sys/kernel/debug/lru_gen) provides
a histogram counting the number of pages belonging to each generation,
providing some data for memory coldness, but we don't actually know
where the memory actually is. However, since MGLRU revamps the page
reclaim mechanism to walk page tables, we can hook into MGLRU page table
access bit harvesting with a BPF program to collect information on
relative hotness and coldness, NUMA nodes, whether a page is anon/file,
etc.

Using BPF programs to collect and aggregate page access information
allows for the userspace agent to customize what to collect and how to
aggregate. It could focus on a particular region of interest and count a
moving average access frequency, or find allocations that are never
accessed that could be eliminated all together. Currently MGLRU relies
on heuristics with regards to what generation a page is assigned, for
example, pages accessed through page tables are always assigned to the
youngest generation. Exposing page access data can allow future work to
customize page generation assignments (with more BPF).

We demonstrate feasibility with a proof-of-concept that prints a live
heatmap of a process, with configurable MGLRU aging intervals and
aggregation intervals. This is a very rough PoC that still needs a lot
of work, but it shows a lot can be done by exposing page access
information from MGLRU. I will be presenting this work at the coming
LPC.

As an example. I ran the memtier benchmark[1] and captured a heatmap of
memcached being populated and running the benchmark (similar to the one
Yu posted for OpenWRT[2]):

$ cat ./run_memtier_benchmark.sh
run_memtier_benchmark()
{
# populate dataset
memtier_benchmark/memtier_benchmark -s 127.0.0.1 -p 11211 \
-P memcache_binary -n allkeys -t 1 -c 1 --ratio 1:0 --pipeline 8 \
--key-minimum=1 --key-maximum=$2 --key-pattern=P:P \
-d 1000

# access dataset using Guassian pattern
memtier_benchmark/memtier_benchmark -s 127.0.0.1 -p 11211 \
-P memcache_binary --test-time $1 -t 1 -c 1 --ratio 0:1 \
--pipeline 8 --key-minimum=1 --key-maximum=$2 \
--key-pattern=G:G --randomize --distinct-client-seed

# collect results
}

run_duration_secs=3600
max_key=8000000

run_memtier_benchmark $run_duration_secs $max_key

In the following screenshot we can see the process of populating the
dataset and accessing the dataset:
https://services.google.com/fh/files/events/memcached_memtier_startup.png

Patch 1 adds the infrastructure to enable BPF programs to monitor page
access bit harvesting

Patch 2 includes a proof-of-concept python TUI program displaying online
per-process heatmaps.

[1] https://github.com/RedisLabs/memtier_benchmark
[2] https://lore.kernel.org/all/[email protected]/

Yuanchu Xie (2):
mm: multi-gen LRU: support page access info harvesting with eBPF
mm: add a BPF-based per-process heatmap tool

include/linux/mmzone.h | 1 +
mm/vmscan.c | 154 ++++++++
tools/vm/heatmap/Makefile | 30 ++
tools/vm/heatmap/heatmap.bpf.c | 123 +++++++
tools/vm/heatmap/heatmap.user.c | 188 ++++++++++
tools/vm/heatmap/heatmap_tui.py | 600 ++++++++++++++++++++++++++++++++
6 files changed, 1096 insertions(+)
create mode 100644 tools/vm/heatmap/Makefile
create mode 100644 tools/vm/heatmap/heatmap.bpf.c
create mode 100644 tools/vm/heatmap/heatmap.user.c
create mode 100755 tools/vm/heatmap/heatmap_tui.py

--
2.37.2.789.g6183377224-goog


2022-09-11 08:48:49

by Yuanchu Xie

[permalink] [raw]
Subject: [RFC PATCH 2/2] mm: add a BPF-based per-process heatmap tool

The heatmap tool uses a bpf program integrated with a TUI as a proof of
concept for consuming page access information from MGLRU. It displays
heats, NUMA node, and anon/other with configurable aging intervals and
aggregation intervals.

Signed-off-by: Yuanchu Xie <[email protected]>
---
tools/vm/heatmap/Makefile | 30 ++
tools/vm/heatmap/heatmap.bpf.c | 123 +++++++
tools/vm/heatmap/heatmap.user.c | 188 ++++++++++
tools/vm/heatmap/heatmap_tui.py | 600 ++++++++++++++++++++++++++++++++
4 files changed, 941 insertions(+)
create mode 100644 tools/vm/heatmap/Makefile
create mode 100644 tools/vm/heatmap/heatmap.bpf.c
create mode 100644 tools/vm/heatmap/heatmap.user.c
create mode 100755 tools/vm/heatmap/heatmap_tui.py

diff --git a/tools/vm/heatmap/Makefile b/tools/vm/heatmap/Makefile
new file mode 100644
index 000000000000..43ae4af67781
--- /dev/null
+++ b/tools/vm/heatmap/Makefile
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: GPL-2.0
+include ../../build/Build.include
+
+MAKEFLAGS += --no-builtin-rules
+MAKEFLAGS += --no-builtin-variables
+
+CC := clang
+KERNEL_DIR := $(abspath ../../..)
+
+.PHONY: clean all
+
+all: heatmap.user heatmap.skel.h vmlinux.h
+
+%.user: %.user.o
+ $(CC) -g -O2 $^ -l:libbpf.a -lelf -lz -o $@ -static
+
+vmlinux.h: $(KERNEL_DIR)/vmlinux
+ bpftool btf dump file $(KERNEL_DIR)/vmlinux format c > vmlinux.h
+
+%.bpf.o: vmlinux.h %.bpf.c
+ $(CC) -g -D__TARGET_ARCH_x86 -O2 -target bpf -c $*.bpf.c -o $*.bpf.o
+
+%.skel.h: %.bpf.o
+ bpftool gen skeleton $*.bpf.o > $*.skel.h
+
+%.user.o: %.skel.h %.user.c
+ $(CC) -g -O2 -c $*.user.c -o $*.user.o
+
+clean:
+ rm -f *.o heatmap.user *.skel.h vmlinux.h
diff --git a/tools/vm/heatmap/heatmap.bpf.c b/tools/vm/heatmap/heatmap.bpf.c
new file mode 100644
index 000000000000..924d896a0c4f
--- /dev/null
+++ b/tools/vm/heatmap/heatmap.bpf.c
@@ -0,0 +1,123 @@
+// SPDX-License-Identifier: GPL-2.0
+#include "vmlinux.h"
+
+#include <bpf/bpf_helpers.h>
+#include <bpf/bpf_tracing.h>
+
+static const u8 one = 1;
+pid_t target_pid;
+#define MAP_SHIFT (12 + 9)
+#define MIXED_NODES -1
+#define MIXED_MEM -1
+
+struct region_stat {
+ u16 accesses;
+ s8 mem_type; /* NON_ANON, ANON */
+ s8 node_id;
+};
+
+struct heatmap_outer {
+ __uint(type, BPF_MAP_TYPE_HASH);
+ __uint(map_flags, BPF_F_NO_PREALLOC);
+ __uint(max_entries, 1000000);
+ __type(key, u64);
+ __type(value, struct region_stat);
+} heatmap SEC(".maps");
+
+int probe(unsigned int nid, unsigned long addr, unsigned long len, bool anon)
+{
+ u64 map_key = addr >> MAP_SHIFT;
+ struct region_stat *region = bpf_map_lookup_elem(&heatmap, &map_key);
+ int err;
+ struct region_stat to_insert;
+
+ if (!region) {
+ to_insert.accesses = len;
+ to_insert.mem_type = anon;
+ to_insert.node_id = nid;
+ err = bpf_map_update_elem(&heatmap, &map_key, &to_insert,
+ BPF_NOEXIST);
+ if (err)
+ return err;
+ } else {
+ region->accesses += len;
+ if (region->node_id != (int)nid)
+ region->node_id = MIXED_NODES;
+ if (region->mem_type != anon)
+ region->mem_type = MIXED_MEM;
+ err = bpf_map_update_elem(&heatmap, &map_key, region,
+ BPF_EXIST);
+ if (err)
+ return err;
+ }
+ return 0;
+}
+
+SEC("fentry/mglru_pte_probe")
+int BPF_PROG(fentry_mglru_pte_probe, pid_t pid, unsigned int nid,
+ unsigned long addr, unsigned long len, bool anon)
+{
+ int err;
+
+ if (pid != target_pid)
+ return 0;
+ err = probe(nid, addr, len, anon);
+ if (err)
+ bpf_printk("PTE called addr:0x%lx len:%lu error:%ld", addr, len,
+ err);
+ return 0;
+}
+
+SEC("fentry/mglru_pmd_probe")
+int BPF_PROG(fentry_mglru_pmd_probe, pid_t pid, unsigned int nid,
+ unsigned long addr, unsigned long len, bool anon)
+{
+ int err;
+
+ if (pid != target_pid)
+ return 0;
+ err = probe(nid, addr, len, anon);
+ if (err)
+ bpf_printk("PMD called addr:0x%lx len:%lu error:%ld", addr, len,
+ err);
+ return 0;
+}
+
+extern void
+bpf_set_skip_mm(struct bpf_mglru_should_skip_mm_control *should_skip) __ksym;
+
+SEC("fentry/bpf_mglru_should_skip_mm")
+int BPF_PROG(bpf_mglru_should_skip_mm,
+ struct bpf_mglru_should_skip_mm_control *ctl)
+{
+ if (ctl->pid != target_pid) {
+ bpf_printk("aging wrong pid");
+ bpf_set_skip_mm(ctl);
+ }
+ return 0;
+}
+
+extern int bpf_run_aging(int memcg_id, bool can_swap, bool force_scan) __ksym;
+
+struct args {
+ int memcg_id;
+};
+
+SEC("syscall")
+int memcg_run_aging(struct args *ctx)
+{
+ int err;
+
+ err = bpf_run_aging(ctx->memcg_id, true, true);
+
+ if (err != 0) {
+ bpf_printk("aging failed for memcg %ld with error %ld",
+ ctx->memcg_id, err);
+ return 0;
+ }
+
+ bpf_printk("aging succeeded for memcg %ld", ctx->memcg_id);
+ return 0;
+}
+
+char LICENSE[] SEC("license") = "GPL";
diff --git a/tools/vm/heatmap/heatmap.user.c b/tools/vm/heatmap/heatmap.user.c
new file mode 100644
index 000000000000..094ba1e49233
--- /dev/null
+++ b/tools/vm/heatmap/heatmap.user.c
@@ -0,0 +1,188 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+#include <fcntl.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "heatmap.skel.h"
+
+#define MAP_SHIFT (12 + 9)
+
+static bool terminate;
+
+struct args {
+ int memcg_id;
+};
+
+void handle_sigint(int sig)
+{
+ terminate = true;
+}
+
+static int libbpf_print_fn(enum libbpf_print_level level, const char *format,
+ va_list args)
+{
+ return vfprintf(stderr, format, args);
+}
+
+int run_aging(int aging_fd, int memcg_id)
+{
+ struct args ctx = {
+ .memcg_id = memcg_id,
+ };
+ LIBBPF_OPTS(bpf_test_run_opts, tattr, .ctx_in = &ctx,
+ .ctx_size_in = sizeof(ctx));
+ return bpf_prog_test_run_opts(aging_fd, &tattr);
+}
+
+int attach_progs(pid_t pid, struct heatmap_bpf **heatmap_obj, int *aging_fd,
+ int *heatmap_fd)
+{
+ int err;
+ int fd;
+ struct heatmap_bpf *obj;
+
+ obj = heatmap_bpf__open();
+ if (obj == NULL) {
+ perror("Error when opening heatmap bpf object");
+ return -1;
+ }
+ obj->bss->target_pid = pid;
+
+ err = heatmap_bpf__load(obj);
+ if (err) {
+ perror("Error loading heatmap bpf object");
+ goto cleanup;
+ }
+
+ fd = bpf_program__fd(obj->progs.memcg_run_aging);
+
+ err = heatmap_bpf__attach(obj);
+ if (err) {
+ perror("Error attaching heatmap bpf object");
+ goto cleanup;
+ }
+
+ *aging_fd = fd;
+ *heatmap_fd = bpf_map__fd(obj->maps.heatmap);
+ *heatmap_obj = obj;
+ return 0;
+
+cleanup:
+ heatmap_bpf__destroy(obj);
+ return err;
+}
+
+int bpf_map_delete_and_get_next_key(int fd, const void *key, void *next_key)
+{
+ int err = bpf_map_get_next_key(fd, key, next_key);
+
+ bpf_map_delete_elem(fd, key);
+ return err;
+}
+
+struct region_stat {
+ __u16 accesses;
+ __s8 mem_type; /* NON_ANON, ANON */
+ __s8 node_id;
+};
+
+void dump_map(int fd)
+{
+ __u64 prev_key, key;
+ struct region_stat value;
+ int err;
+
+ while (bpf_map_delete_and_get_next_key(fd, &prev_key, &key) == 0) {
+ err = bpf_map_lookup_elem(fd, &key, &value);
+ if (err < 0) {
+ /* impossible if we don't have racing deletions */
+ exit(-1);
+ }
+ printf("%llu %u %d %d\n", key << MAP_SHIFT, value.accesses,
+ value.mem_type, value.node_id);
+ prev_key = key;
+ }
+}
+
+void detach_progs(struct heatmap_bpf *heatmap_obj)
+{
+ heatmap_bpf__detach(heatmap_obj);
+ heatmap_bpf__destroy(heatmap_obj);
+}
+
+int main(void)
+{
+ struct heatmap_bpf *heatmap_obj = NULL;
+ int aging_fd = -1;
+ int heatmap_fd = -1;
+ int memcg_id = -1;
+ int err;
+
+ signal(SIGINT, handle_sigint);
+ setvbuf(stdout, NULL, _IONBF, BUFSIZ);
+ libbpf_set_print(libbpf_print_fn);
+
+ while (!terminate) {
+ char *buffer = NULL;
+
+ if (scanf("%ms", &buffer) == 1) {
+ if (strcmp(buffer, "exit") == 0) {
+ printf("No hard feelings.\n");
+ exit(0);
+
+ } else if (heatmap_obj == NULL &&
+ strcmp(buffer, "attach") == 0) {
+ pid_t pid_;
+ int memcg_id_;
+
+ if (scanf("%d %d", &pid_, &memcg_id_) == 2) {
+ err = attach_progs(pid_, &heatmap_obj,
+ &aging_fd,
+ &heatmap_fd);
+ if (err) {
+ printf("error: aging %d\n",
+ err);
+ goto next;
+ }
+ memcg_id = memcg_id_;
+ printf("success: attach\n");
+
+ } else
+ printf("error: invalid arguments\n");
+
+ } else if (heatmap_obj != NULL) {
+ if (strcmp(buffer, "map") == 0) {
+ dump_map(heatmap_fd);
+ printf("success: map\n");
+
+ } else if (strcmp(buffer, "age") == 0) {
+ err = run_aging(aging_fd, memcg_id);
+ if (err)
+ printf("error: age %d\n", err);
+ else
+ printf("success: age\n");
+
+ } else if (strcmp(buffer, "detach") == 0) {
+ detach_progs(heatmap_obj);
+ heatmap_obj = NULL;
+ heatmap_fd = -1;
+ aging_fd = -1;
+ memcg_id = -1;
+ printf("success: detach\n");
+ }
+
+ } else
+ printf("error: invalid command\n");
+
+next:
+ free(buffer);
+ } else
+ printf("error: invalid command\n");
+ }
+}
diff --git a/tools/vm/heatmap/heatmap_tui.py b/tools/vm/heatmap/heatmap_tui.py
new file mode 100755
index 000000000000..9be6b611bc24
--- /dev/null
+++ b/tools/vm/heatmap/heatmap_tui.py
@@ -0,0 +1,600 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: GPL-2.0
+import urwid
+import parse
+import subprocess
+import os
+
+from time import time
+from math import floor
+
+UPDATE_INTERVAL = 0.5
+MEM_MAX = 0
+MEM_MIN = 0x10000000000000000
+INVALID_NODE = -2
+MIXED_NODE = -1
+INVALID_MEM = -2
+MIXED_MEM = -1
+ANON_MEM = 1
+MAX_AGGREGATIONS = 20
+MIN_ROWS = 40
+MAX_COL = 7
+
+
+class HeatmapModel:
+ def __init__(self, max_nr_aggregations):
+ self.data_range = (MEM_MIN, MEM_MAX)
+ self.data = []
+ self.max_nr_aggregations = max_nr_aggregations
+
+ def get_nr_aggregations(self):
+ return len(self.data)
+
+ def get_max_nr_aggregations(self):
+ return self.max_nr_generations
+
+ def set_max_nr_aggregations(self, max_nr_aggregations):
+ self.max_nr_aggregations = max_nr_aggregations
+
+ def append_data(self, bpf_access_data):
+ region_min = MEM_MIN
+ region_max = MEM_MAX
+ for region in bpf_access_data:
+ if region["address"] > region_max:
+ region_max = region["address"]
+ if region["address"] < region_min:
+ region_min = region["address"]
+
+ self.data.append({
+ "min": region_min,
+ "max": region_max,
+ "data": sorted(bpf_access_data, key=lambda d: d["address"]),
+ })
+
+ if len(self.data) > self.max_nr_aggregations:
+ del self.data[:len(self.data) - self.max_nr_aggregations]
+
+
+ def get_display_sections(self, rows):
+ REGION_SIZE = 2 * 1024 * 1024 # region size (21 bits) 2 MB
+ ranges = []
+ for d in self.data:
+ for r in d["data"]:
+ addr = r["address"]
+ ranges.append(("start", addr - REGION_SIZE))
+ ranges.append(("end", addr + REGION_SIZE * 2))
+
+ ranges = sorted(ranges, key=lambda d: d[1])
+ sections = []
+ start_addr = None
+ nesting = 0
+ total_size = 0
+ for (tag, addr) in ranges:
+ if tag == "start":
+ nesting += 1
+ if start_addr == None:
+ start_addr = addr
+
+ if tag == "end":
+ nesting -= 1;
+ if nesting == 0:
+ total_size += addr - start_addr
+ sections.append((start_addr, addr))
+ start_addr = None
+
+ if len(sections) > rows:
+ # compact some sections
+ sections_with_idx = [((start, end), i, end - start) for (i, (start, end)) in enumerate(sections)]
+ sections_with_idx = sorted(sections_with_idx, key=lambda x:
+ x[-1] + abs(sections[x[1] + 1][0] - x[0][1]) + abs(sections[x[1] - 1][1] - x[0][0]))
+ for i in range(0, len(sections) - rows):
+ # natural number of regions is greater than the number of rows
+ ((start, end), section_i, size) = sections_with_idx[i]
+ (_, prev_end) = sections[section_i - 1]
+ (succ_start, _) = sections[section_i + 1]
+ if abs(succ_start - end) > abs(prev_end - start):
+ sections[section_i - 1][1] = end
+ else:
+ sections[section_i - 1][0] = start
+
+ new_sections = []
+ for i in range(len(sections) - rows, len(sections)):
+ (_, section_i, _) = sections_with_idx[i]
+ new_sections.append(sections[section_i])
+ sections = sorted(new_sections, key=lambda x: x[0])
+
+ else:
+ extra_rows = rows - len(sections)
+ split_sections = []
+ spill_over_factor = 0
+ for (start, end) in sections:
+ fraction_of_row = (end - start) / total_size * rows
+ if fraction_of_row > 1:
+ additional_rows_frac = (spill_over_factor + (fraction_of_row - 1))
+ spill_over_factor = additional_rows_frac - floor(additional_rows_frac)
+ additional_rows = min(floor(additional_rows_frac), extra_rows)
+ extra_rows -= additional_rows
+ new_rows = additional_rows + 1
+ # split current interval
+ inc = (end - start + new_rows - 1) // new_rows # round up
+ for i in range(0, new_rows):
+ split_sections.append((start + inc * i, min(start + inc * (i + 1), end)))
+ else:
+ split_sections.append((start, end))
+ while extra_rows > 0:
+ split_sections.append((0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF))
+ extra_rows -= 1
+ sections = split_sections
+
+ return sections
+
+ def format_data(self, rows):
+ cell_init = (0, INVALID_MEM, INVALID_NODE)
+ def cell_update(cell, r):
+ (acc, mem, node) = cell
+ acc += r["accesses"]
+ if mem == INVALID_MEM:
+ mem = r["mem"]
+ elif mem != r["mem"]:
+ mem = MIXED_MEM
+
+ if node == INVALID_NODE:
+ node = r["node"]
+ elif node != r["node"]:
+ node = MIXED_NODE
+
+ return (acc, mem, node)
+
+ return self.__format_data(rows, cell_init, cell_update)
+
+ def __format_data(self, rows, cell_init, cell_update):
+ row_ranges = self.get_display_sections(rows)
+ assert(len(row_ranges) == rows)
+ graph = []
+ for d in self.data:
+ (curr_start, curr_end) = row_ranges[0]
+ range_idx = 0
+
+ col = []
+ cell = cell_init
+ for r in d["data"]:
+ while r["address"] >= curr_end:
+ col.append(cell)
+ cell = cell_init
+ range_idx += 1
+ (curr_start, curr_end) = row_ranges[range_idx]
+
+ cell = cell_update(cell, r)
+
+ while range_idx < len(row_ranges):
+ (curr_start, curr_end) = row_ranges[range_idx]
+ col.append(cell)
+ cell = cell_init
+ range_idx += 1
+
+ assert(len(col) == rows)
+ graph.append(col)
+
+ col_labels = [start for (start, end) in row_ranges]
+ return (col_labels, graph)
+
+
+class GraphView(urwid.Widget):
+ def __init__(self, model, modes):
+ self.model = model
+ self.mode = modes[0]
+ urwid.Widget.__init__(self)
+
+ def set_mode(self, mode):
+ self.mode = mode
+ self.update()
+
+ def render_col(self, col):
+ l = []
+ if self.mode == "Anon/Other":
+ for (_, mem, _) in col:
+ if mem == MIXED_MEM:
+ l.append(("mixed", "M" * MAX_COL))
+ elif mem == INVALID_MEM:
+ l.append(("invalid", "." * MAX_COL))
+ elif mem == ANON_MEM:
+ l.append(("anon mem", "A" * MAX_COL))
+ else:
+ l.append(("other mem", "X" * MAX_COL))
+ elif self.mode == "NUMA Node":
+ for (_, _, node) in col:
+ if node == MIXED_NODE:
+ l.append(("mixed", "M" * MAX_COL))
+ elif node == INVALID_NODE:
+ l.append(("invalid", "." * MAX_COL))
+ else:
+ node_str = str(node)
+ left_pad = (MAX_COL - len(node_str)) // 2
+ right_pad = MAX_COL - left_pad - len(node_str)
+ node_style = "node"
+ if node < 4 and node >= 0:
+ node_style += " " + node_str
+ l.append((node_style, "_" * left_pad + node_str + "_" * right_pad))
+ else:
+ for (acc, _, _) in col:
+ acc_str = str(acc)
+ left_pad = (MAX_COL - len(acc_str)) // 2
+ right_pad = MAX_COL - left_pad - len(acc_str)
+ heat_style = "heat"
+ if acc >= 512:
+ heat_style += " mid"
+ elif acc >= 1024:
+ heat_style += " high"
+ l.append((heat_style, u"\u00a0" * left_pad + acc_str + u"\u00a0" * right_pad))
+
+ return (urwid.Text(l).render((MAX_COL,)), None, False, MAX_COL)
+
+ def render(self, size, focus=False):
+ (cols, rows) = size
+ LABEL_COLS = 16 + 2 + 4 # 0x and 16 chars of hex address, plus padding
+ self.model.set_max_nr_aggregations((cols - LABEL_COLS) // MAX_COL)
+ data_cols = cols - LABEL_COLS
+ (labels, data) = self.model.format_data(rows)[-(data_cols // MAX_COL):]
+ if len(labels) == 0:
+ label_col = urwid.SolidCanvas(" ", LABEL_COLS, rows)
+ else:
+ label_col = [("pg smooth", "0x{:016X}".format(addr) + u"\u00a0" * 4) for addr in labels]
+ label_col = urwid.Text(label_col).render((LABEL_COLS,))
+
+ label_col = (label_col, None, False, LABEL_COLS)
+ if len(data) > 0:
+ return urwid.CanvasJoin([label_col] + list(map(self.render_col, data)) + [
+ (urwid.SolidCanvas(" ", data_cols - len(data) * MAX_COL, rows),
+ None, False, data_cols - len(data) * MAX_COL)])
+ else:
+ return urwid.SolidCanvas(" ", cols, rows)
+
+ def update(self):
+ self._invalidate()
+
+ def rows(self, size, focus=False):
+ return MIN_ROWS
+
+ def keypress(self, size, key):
+ return key
+
+
+class HeatmapView(urwid.WidgetWrap):
+ palette = [
+ ('body', 'black', 'light gray', 'standout'),
+ ('header', 'white', 'dark red', 'bold'),
+ ('screen edge', 'light blue', 'dark cyan'),
+ ('main shadow', 'dark gray', 'black'),
+ ('line', 'black', 'light gray', 'standout'),
+ ('bg background','light gray', 'black'),
+ ('bg 1', 'black', 'dark blue', 'standout'),
+ ('bg 1 smooth', 'dark blue', 'black'),
+ ('bg 2', 'black', 'dark cyan', 'standout'),
+ ('bg 2 smooth', 'dark cyan', 'black'),
+ ('button normal','light gray', 'dark blue', 'standout'),
+ ('button select','white', 'dark green'),
+ ('line', 'black', 'light gray', 'standout'),
+ ('pg normal', 'white', 'black', 'standout'),
+ ('pg complete', 'white', 'dark magenta'),
+ ('pg smooth', 'dark magenta','black'),
+
+ ("mixed", "light blue", "black"),
+ ("invalid", "dark gray", "black"),
+ ("anon mem", "white", "dark cyan"),
+ ("other mem", "light blue", "dark red"),
+
+ ("node", "white", "black"),
+
+ ("node 0", "light green", "black"),
+ ("node 1", "light blue", "black"),
+ ("node 2", "light red", "black"),
+ ("node 3", "yellow", "black"),
+
+ ("heat", "yellow", "black"),
+ ("heat mid", "dark red", "brown"),
+ ("heat high", "light red", "dark red"),
+ ]
+
+ def __init__(self, controller):
+ self.controller = controller
+ urwid.WidgetWrap.__init__(self, self.draw_view())
+
+ def update_graph(self):
+ self.graph_view.update()
+ pass
+
+ def set_selected_mode(self, new_mode):
+ self.graph_view.set_mode(new_mode)
+ for b in self.mode_buttons:
+ if b.get_label() == new_mode:
+ b.set_state(True, do_callback=False)
+ break
+
+ def on_mode_button(self, button, state):
+ if state:
+ self.controller.on_mode_change(button.get_label())
+
+ def radio_button(self, group, label, state, on_state_change):
+ w = urwid.RadioButton(group, label, state, on_state_change=on_state_change)
+ w = urwid.AttrWrap(w, 'button normal', 'button select')
+ return w
+
+ def button(self, label, on_press):
+ w = urwid.Button(label, on_press)
+ w = urwid.AttrWrap(w, 'button normal', 'button select')
+ return w
+
+ def set_alert_message(self, new_alert_text):
+ self.alert_text.set_text(new_alert_text)
+ pass
+
+ def set_start_button_text(self, new_text):
+ self.start_button.set_label(new_text)
+ pass
+
+ def edit_box(self, label, text, on_change):
+ w = urwid.Edit(label, text)
+ urwid.connect_signal(w, 'change', on_change)
+ w = urwid.AttrWrap(w, 'edit')
+ return w
+
+ def draw_control_pane(self):
+ g = []
+ self.mode_buttons = [self.radio_button(g, mode, mode == self.controller.mode,
+ self.on_mode_button)
+ for mode in self.controller.overlay_modes]
+ self.pid_box = self.edit_box("PID: ", "1", self.controller.on_pid_change)
+ self.memcg_box = self.edit_box("memcg id: ", "1", self.controller.on_memcg_change)
+ self.start_button = self.button("Start", self.controller.on_start_button)
+ self.alert_text = urwid.Text("", align="center")
+
+ aging_text = urwid.Text("Aging interval", align="center")
+ self.aging_box = self.edit_box("Seconds: ", str(0.5), self.controller.on_aging_change)
+ aggregation_text = urwid.Text("Aggregation Interval", align="center")
+ self.aggregation_box = self.edit_box("Aging cycles: ", str(3), self.controller.on_aggregation_change)
+
+ self.quit_button = self.button("Quit", self.controller.on_quit_button)
+ l = [
+ urwid.Text("Overlay Mode", align="center")
+ ] + self.mode_buttons + [
+ urwid.Divider(),
+ self.pid_box,
+ self.memcg_box,
+ self.start_button,
+ urwid.Divider(),
+ aging_text,
+ self.aging_box,
+ aggregation_text,
+ self.aggregation_box,
+ urwid.Divider(),
+ self.alert_text,
+ urwid.Divider(),
+ self.quit_button
+ ]
+ return urwid.ListBox(urwid.SimpleListWalker(l))
+
+ def main_shadow(self, w):
+ # Wrap a shadow and background around widget w
+ bg = urwid.AttrWrap(urwid.SolidFill(u"\u2592"), 'screen edge')
+ shadow = urwid.AttrWrap(urwid.SolidFill(u" "), 'main shadow')
+
+ bg = urwid.Overlay( shadow, bg,
+ ('fixed left', 3), ('fixed right', 1),
+ ('fixed top', 2), ('fixed bottom', 1))
+ w = urwid.Overlay( w, bg,
+ ('fixed left', 2), ('fixed right', 3),
+ ('fixed top', 1), ('fixed bottom', 2))
+ return w
+
+ def draw_view(self):
+ control_pane = self.draw_control_pane()
+ self.graph_view = GraphView(self.controller.model, self.controller.overlay_modes)
+ vline = urwid.AttrWrap(urwid.SolidFill(u'\u2502'), 'line')
+ w = urwid.Columns([("weight", 4, self.graph_view), ("fixed", 1, vline), control_pane],
+ dividechars=1, focus_column=2)
+ w = urwid.Padding(w,('fixed left',1),('fixed right',0))
+ w = urwid.AttrWrap(w,'body')
+ w = urwid.LineBox(w)
+ w = urwid.AttrWrap(w,'line')
+ w = self.main_shadow(w)
+ return w
+
+class BpfBridge:
+ def __init__(self):
+ self.bpf_bridge = subprocess.Popen(os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ "heatmap.user"),
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ text=True)
+
+ def attach(self, pid, memcg_id):
+ self.bpf_bridge.stdin.write("attach\n")
+ self.bpf_bridge.stdin.write("{:d} {:d}\n".format(pid, memcg_id))
+ self.bpf_bridge.stdin.flush()
+ result = self.bpf_bridge.stdout.readline()
+ if "success" not in result:
+ return result
+
+ return "success"
+
+ def detach(self):
+ self.bpf_bridge.stdin.write("detach\n")
+ self.bpf_bridge.stdin.flush()
+ result = self.bpf_bridge.stdout.readline()
+ if "success" not in result:
+ return result
+
+ return "success"
+
+ def run_aging(self):
+ self.bpf_bridge.stdin.write("age\n")
+ self.bpf_bridge.stdin.flush()
+ result = self.bpf_bridge.stdout.readline()
+ if "success" not in result:
+ return result
+
+ return "success"
+
+ def get_map(self):
+ self.bpf_bridge.stdin.write("map\n")
+ self.bpf_bridge.stdin.flush()
+ result = ""
+ access_data = []
+ parser = parse.compile("{:d} {:d} {:d} {:d}\n")
+ result = self.bpf_bridge.stdout.readline()
+ while "success" not in result:
+ parsed_access = parser.parse(result)
+ if parsed_access is None:
+ raise Exception("woo " + result)
+ return result
+
+ access_data.append({
+ "address": parsed_access[0],
+ "accesses": parsed_access[1],
+ "mem": parsed_access[2],
+ "node": parsed_access[3],
+ })
+ result = self.bpf_bridge.stdout.readline()
+
+ return access_data
+
+
+class HeatmapController:
+ def __init__(self):
+ self.last_aged = 0
+ self.aging_interval = 0.5
+ self.aging_count = 0
+ self.aggregation_interval = 3
+ self.overlay_modes = ["Heat", "NUMA Node", "Anon/Other"]
+ self.monitoring_pid = 1
+ self.monitoring_memcg_id = 1
+ self.mode = self.overlay_modes[0]
+ self.model = HeatmapModel(MAX_AGGREGATIONS)
+ self.bpf_bridge = BpfBridge()
+ self.monitoring = False
+ self.view = HeatmapView(self)
+
+ def main(self):
+ self.loop = urwid.MainLoop(self.view, self.view.palette)
+ self.timer = self.loop.set_alarm_in(UPDATE_INTERVAL, self.on_timer)
+ # spawn the monitored process
+ self.loop.run()
+
+ def on_mode_change(self, new_mode):
+ mode = new_mode
+ self.view.set_selected_mode(new_mode)
+
+ def on_aging_change(self, w, new_value):
+ try:
+ aging_interval = float(new_value)
+ if aging_interval > 0:
+ self.aging_interval = aging_interval
+ self.view.set_alert_message("")
+ else:
+ self.view.set_alert_message("invalid aging interval")
+ except ValueError:
+ self.view.set_alert_message("invalid aging interval")
+
+ def on_aggregation_change(self, w, new_value):
+ try:
+ aggregation_interval = int(new_value)
+ if aggregation_interval > 0:
+ self.aggregation_interval = aggregation_interval
+ self.view.set_alert_message("")
+ else:
+ self.view.set_alert_message("invalid aggregation interval")
+ except ValueError:
+ self.view.set_alert_message("invalid aggregation interval")
+
+ def update_graph(self, new_data):
+ self.model.append_data(new_data)
+ self.view.update_graph()
+
+ def on_timer(self, loop=None, user_data=None):
+ # perform aging
+ # read data
+ delta_time = -time()
+ if self.last_aged <= -delta_time - self.aging_interval:
+ if self.monitoring:
+ # aging
+ self.last_aged = -delta_time
+ self.aging_count += 1
+ err = self.bpf_bridge.run_aging()
+ if "success" not in err:
+ self.disable_monitoring()
+ self.view.set_alert_message(err)
+
+ if self.aging_count % self.aggregation_interval == 0:
+ # aggregation
+ data = self.bpf_bridge.get_map()
+ if isinstance(data, str):
+ # get map failed
+ self.disable_monitoring()
+ self.view.set_alert_message(err)
+ else:
+ self.update_graph(data)
+
+ delta_time += time()
+ if delta_time > UPDATE_INTERVAL:
+ self.view.set_alert_message("timer running behind")
+ self.loop.set_alarm_in(0, self.on_timer)
+ else:
+ self.loop.set_alarm_in(UPDATE_INTERVAL - delta_time, self.on_timer)
+
+ def on_pid_change(self, widget, new_text):
+ pid = parse.parse("{:d}", new_text)
+ self.view.set_alert_message("")
+ if self.monitoring:
+ return
+
+ if pid is not None:
+ self.monitoring_pid = pid[0]
+ else:
+ self.monitoring_pid = -1
+
+ def on_memcg_change(self, widget, new_text):
+ memcg_id = parse.parse("{:d}", new_text)
+ self.view.set_alert_message("")
+ if self.monitoring:
+ return
+
+ if memcg_id is not None:
+ self.monitoring_memcg_id = memcg_id[0]
+ else:
+ self.monitoring_memcg_id = -1
+
+ def disable_monitoring(self):
+ self.monitoring = False
+ self.bpf_bridge.detach()
+ self.view.set_start_button_text("Start")
+
+ def on_start_button(self, w):
+ if self.monitoring:
+ err = self.bpf_bridge.detach()
+ self.monitoring = False
+ self.view.set_start_button_text("Start")
+ if "success" not in err:
+ self.view.set_alert_message(err)
+ self.view.pid_box.set_edit_text(str(self.monitoring_pid))
+ else:
+ if self.monitoring_pid != -1 and self.monitoring_memcg_id != -1:
+ err = self.bpf_bridge.attach(self.monitoring_pid, self.monitoring_memcg_id)
+ if "success" in err:
+ self.view.set_start_button_text("Stop")
+ self.monitoring = True
+ else:
+ self.view.set_alert_message(err)
+ else:
+ self.view.set_alert_message("invalid pid/memcg")
+
+
+
+ def on_quit_button(self, w):
+ raise urwid.ExitMainLoop()
+
+def main():
+ HeatmapController().main()
+
+if "__main__" == __name__:
+ main()
--
2.37.2.789.g6183377224-goog

2022-09-11 09:02:58

by Yuanchu Xie

[permalink] [raw]
Subject: [RFC PATCH 1/2] mm: multi-gen LRU: support page access info harvesting with eBPF

Add the infrastructure to enable bpf programs to hook into MGLRU and
capture the page access information as MGLRU walks page tables.

- Add empty functions as hook points to capture pte and pmd access bit
harvesting of MGLRU page table walks.

- Add a kfunc to invoke MGLRU aging.

- Add a kfunc and hook point to enable the filtering of MGLRU aging by
PIDs.

Signed-off-by: Yuanchu Xie <[email protected]>
---
include/linux/mmzone.h | 1 +
mm/vmscan.c | 154 +++++++++++++++++++++++++++++++++++++++++
2 files changed, 155 insertions(+)

diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h
index 710fc1d83bd0..f652b9473c6f 100644
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -481,6 +481,7 @@ struct lru_gen_mm_walk {
int mm_stats[NR_MM_STATS];
/* total batched items */
int batched;
+ pid_t pid;
bool can_swap;
bool force_scan;
};
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 762e7cb3d2d0..28499ba15e96 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -60,6 +60,10 @@
#include <linux/swapops.h>
#include <linux/balloon_compaction.h>
#include <linux/sched/sysctl.h>
+#include <linux/bpf.h>
+#include <linux/btf.h>
+#include <linux/btf_ids.h>
+#include <linux/rcupdate.h>

#include "internal.h"
#include "swap.h"
@@ -3381,12 +3385,41 @@ static void reset_mm_stats(struct lruvec *lruvec, struct lru_gen_mm_walk *walk,
}
}

+struct bpf_mglru_should_skip_mm_control {
+ pid_t pid;
+ bool should_skip;
+};
+
+void bpf_set_skip_mm(struct bpf_mglru_should_skip_mm_control *ctl)
+{
+ ctl->should_skip = true;
+}
+
+__weak noinline void
+bpf_mglru_should_skip_mm(struct bpf_mglru_should_skip_mm_control *ctl)
+{
+}
+
+static bool bpf_mglru_should_skip_mm_wrapper(pid_t pid)
+{
+ struct bpf_mglru_should_skip_mm_control ctl = {
+ .pid = pid,
+ .should_skip = false,
+ };
+
+ bpf_mglru_should_skip_mm(&ctl);
+ return ctl.should_skip;
+}
+
static bool should_skip_mm(struct mm_struct *mm, struct lru_gen_mm_walk *walk)
{
int type;
unsigned long size = 0;
struct pglist_data *pgdat = lruvec_pgdat(walk->lruvec);
int key = pgdat->node_id % BITS_PER_TYPE(mm->lru_gen.bitmap);
+#ifdef CONFIG_MEMCG
+ struct task_struct *task;
+#endif

if (!walk->force_scan && !test_bit(key, &mm->lru_gen.bitmap))
return true;
@@ -3402,6 +3435,16 @@ static bool should_skip_mm(struct mm_struct *mm, struct lru_gen_mm_walk *walk)
if (size < MIN_LRU_BATCH)
return true;

+#ifdef CONFIG_MEMCG
+ rcu_read_lock();
+ task = rcu_dereference(mm->owner);
+ if (task && bpf_mglru_should_skip_mm_wrapper(task->pid)) {
+ rcu_read_unlock();
+ return true;
+ }
+ rcu_read_unlock();
+#endif
+
return !mmget_not_zero(mm);
}

@@ -3842,6 +3885,22 @@ static bool suitable_to_scan(int total, int young)
return young * n >= total;
}

+/*
+ * __weak noinline guarantees that both the function and the callsite are
+ * preserved
+ */
+__weak noinline void mglru_pte_probe(pid_t pid, unsigned int nid, unsigned long addr,
+ unsigned long len, bool anon)
+{
+
+}
+
+__weak noinline void mglru_pmd_probe(pid_t pid, unsigned int nid, unsigned long addr,
+ unsigned long len, bool anon)
+{
+
+}
+
static bool walk_pte_range(pmd_t *pmd, unsigned long start, unsigned long end,
struct mm_walk *args)
{
@@ -3898,6 +3957,8 @@ static bool walk_pte_range(pmd_t *pmd, unsigned long start, unsigned long end,
folio_mark_dirty(folio);

old_gen = folio_update_gen(folio, new_gen);
+ mglru_pte_probe(walk->pid, pgdat->node_id, addr, folio_nr_pages(folio),
+ folio_test_anon(folio));
if (old_gen >= 0 && old_gen != new_gen)
update_batch_size(walk, folio, old_gen, new_gen);
}
@@ -3978,6 +4039,8 @@ static void walk_pmd_range_locked(pud_t *pud, unsigned long next, struct vm_area
folio_mark_dirty(folio);

old_gen = folio_update_gen(folio, new_gen);
+ mglru_pmd_probe(walk->pid, pgdat->node_id, addr, folio_nr_pages(folio),
+ folio_test_anon(folio));
if (old_gen >= 0 && old_gen != new_gen)
update_batch_size(walk, folio, old_gen, new_gen);
next:
@@ -4139,6 +4202,7 @@ static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct lru_gen_
int err;
struct mem_cgroup *memcg = lruvec_memcg(lruvec);

+ walk->pid = mm->owner->pid;
walk->next_addr = FIRST_USER_ADDRESS;

do {
@@ -5657,6 +5721,96 @@ static int run_cmd(char cmd, int memcg_id, int nid, unsigned long seq,
return err;
}

+int bpf_run_aging(int memcg_id, bool can_swap,
+ bool force_scan)
+{
+ struct scan_control sc = {
+ .may_writepage = true,
+ .may_unmap = true,
+ .may_swap = true,
+ .reclaim_idx = MAX_NR_ZONES - 1,
+ .gfp_mask = GFP_KERNEL,
+ };
+ int err = -EINVAL;
+ struct mem_cgroup *memcg = NULL;
+ struct blk_plug plug;
+ unsigned int flags;
+ unsigned int nid;
+
+ if (!mem_cgroup_disabled()) {
+ rcu_read_lock();
+ memcg = mem_cgroup_from_id(memcg_id);
+#ifdef CONFIG_MEMCG
+ if (memcg && !css_tryget(&memcg->css))
+ memcg = NULL;
+#endif
+ rcu_read_unlock();
+
+ if (!memcg)
+ return -EINVAL;
+ }
+
+ if (memcg_id != mem_cgroup_id(memcg)) {
+ mem_cgroup_put(memcg);
+ return err;
+ }
+
+ set_task_reclaim_state(current, &sc.reclaim_state);
+ flags = memalloc_noreclaim_save();
+ blk_start_plug(&plug);
+ if (!set_mm_walk(NULL)) {
+ err = -ENOMEM;
+ goto done;
+ }
+
+ for_each_online_node(nid) {
+ struct lruvec *lruvec = get_lruvec(memcg, nid);
+ DEFINE_MAX_SEQ(lruvec);
+
+ err = run_aging(lruvec, max_seq, &sc, can_swap, force_scan);
+ if (err)
+ goto done;
+ }
+done:
+ clear_mm_walk();
+ blk_finish_plug(&plug);
+ memalloc_noreclaim_restore(flags);
+ set_task_reclaim_state(current, NULL);
+ mem_cgroup_put(memcg);
+
+ return err;
+}
+
+BTF_SET8_START(bpf_lru_gen_trace_kfunc_ids)
+BTF_ID_FLAGS(func, bpf_set_skip_mm)
+BTF_SET8_END(bpf_lru_gen_trace_kfunc_ids)
+
+BTF_SET8_START(bpf_lru_gen_syscall_kfunc_ids)
+BTF_ID_FLAGS(func, bpf_run_aging)
+BTF_SET8_END(bpf_lru_gen_syscall_kfunc_ids)
+
+static const struct btf_kfunc_id_set bpf_lru_gen_trace_kfunc_set = {
+ .owner = THIS_MODULE,
+ .set = &bpf_lru_gen_trace_kfunc_ids,
+};
+
+static const struct btf_kfunc_id_set bpf_lru_gen_syscall_kfunc_set = {
+ .owner = THIS_MODULE,
+ .set = &bpf_lru_gen_syscall_kfunc_ids,
+};
+
+static int __init bpf_lru_gen_kfunc_init(void)
+{
+ int err = register_btf_kfunc_id_set(BPF_PROG_TYPE_TRACING,
+ &bpf_lru_gen_trace_kfunc_set);
+ if (err)
+ return err;
+ return register_btf_kfunc_id_set(BPF_PROG_TYPE_SYSCALL,
+ &bpf_lru_gen_syscall_kfunc_set);
+}
+late_initcall(bpf_lru_gen_kfunc_init);
+
+
/* see Documentation/admin-guide/mm/multigen_lru.rst for details */
static ssize_t lru_gen_seq_write(struct file *file, const char __user *src,
size_t len, loff_t *pos)
--
2.37.2.789.g6183377224-goog