2021-07-17 10:20:21

by Xiyu Yang

[permalink] [raw]
Subject: [PATCH] sched: Convert from atomic_t to refcount_t on root_domain->refcount

refcount_t type and corresponding API can protect refcounters from
accidental underflow and overflow and further use-after-free situations.

Signed-off-by: Xiyu Yang <[email protected]>
Signed-off-by: Xin Tan <[email protected]>
---
kernel/sched/sched.h | 3 ++-
kernel/sched/topology.c | 12 ++++++------
2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index 14a41a243f7b..8197738a615a 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -3,6 +3,7 @@
* Scheduler internal types and methods:
*/
#include <linux/sched.h>
+#include <linux/refcount.h>

#include <linux/sched/autogroup.h>
#include <linux/sched/clock.h>
@@ -784,7 +785,7 @@ struct perf_domain {
*
*/
struct root_domain {
- atomic_t refcount;
+ refcount_t refcount;
atomic_t rto_count;
struct rcu_head rcu;
cpumask_var_t span;
diff --git a/kernel/sched/topology.c b/kernel/sched/topology.c
index b77ad49dc14f..5d7d767e62ed 100644
--- a/kernel/sched/topology.c
+++ b/kernel/sched/topology.c
@@ -482,11 +482,11 @@ void rq_attach_root(struct rq *rq, struct root_domain *rd)
* set old_rd to NULL to skip the freeing later
* in this function:
*/
- if (!atomic_dec_and_test(&old_rd->refcount))
+ if (!refcount_dec_and_test(&old_rd->refcount))
old_rd = NULL;
}

- atomic_inc(&rd->refcount);
+ refcount_inc(&rd->refcount);
rq->rd = rd;

cpumask_set_cpu(rq->cpu, rd->span);
@@ -501,12 +501,12 @@ void rq_attach_root(struct rq *rq, struct root_domain *rd)

void sched_get_rd(struct root_domain *rd)
{
- atomic_inc(&rd->refcount);
+ refcount_inc(&rd->refcount);
}

void sched_put_rd(struct root_domain *rd)
{
- if (!atomic_dec_and_test(&rd->refcount))
+ if (!refcount_dec_and_test(&rd->refcount))
return;

call_rcu(&rd->rcu, free_rootdomain);
@@ -562,7 +562,7 @@ void init_defrootdomain(void)
{
init_rootdomain(&def_root_domain);

- atomic_set(&def_root_domain.refcount, 1);
+ refcount_set(&def_root_domain.refcount, 1);
}

static struct root_domain *alloc_rootdomain(void)
@@ -1419,7 +1419,7 @@ static void __free_domain_allocs(struct s_data *d, enum s_alloc what,
{
switch (what) {
case sa_rootdomain:
- if (!atomic_read(&d->rd->refcount))
+ if (!refcount_read(&d->rd->refcount))
free_rootdomain(&d->rd->rcu);
fallthrough;
case sa_sd:
--
2.7.4