从用户空间抢占热循环线程



我有一个想法,它有点复杂,你必须熟悉汇编和多线程才能理解我为什么要这样做。

你有没有注意到当你正在运行的程序处于热循环中时,取消按钮永远不会做任何事情?但是如果你在热循环中加入if语句,就会大大减慢它的速度。

在Linux内核中,当时间片过期时,进程由一个定时器通过中断调度到CPU上。参见kernel/schedule/core.c

中的调度方法。现在在Java和C程序的用户空间中,你有一个处于热循环中的线程,除非你在热循环中放入if语句来检查终止标志,否则你不能从用户空间中断热循环。

只有内核可以抢占线程。

但是我有一个关于如何在C编程中从用户空间抢占线程的想法。

可以将函数地址反汇编为RET来获得它的程序集。因此,我们可以确定循环的条件跳变。我们可以在热循环之后引入一个go-to语句__sched_yield(),我们可以从反汇编中识别出sched_yield在内存中的相对地址。

从这里,我们可以从用户空间中断一个用户线程。需要madvise/memprotect可执行代码更新条件跳转语句以跳转到goto语句。

你觉得怎么样?这有多简单?

  • 如何了解函数的大小
  • 如何用libopcodes反汇编函数
  • 如何更新函数内存

在Java, C和Rust中,您可以通过直接更改循环变量以超越循环不变式来抢占热循环中的线程。您可以在任何线程中执行此操作。

如果您需要使用循环变量(您可能会这样做),请确保将其复制到像这样的变量中,否则将出现数据竞争:

for (initialVar(0, 0), int loopVal = 0; getValue(0) < getLimit(0); loopVal = increment(0)) {
Math.sqrt(loopVal);
}

for (int loopVal = 0; m->value[0] < m->limit[0]; loopVal = m->value[0]++) {

sqrt(loopVal);
}

你可以用这种方法创建可取消的api,你的取消令牌可以表示代码所在的所有循环索引。所以你可以创建深度响应的代码。即使您在加密、压缩或上传数据方面做得很深入。只要你不在系统调用中。

我还将该源代码的C版本粘贴在Java版本下面。Rust版本低于C版本

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
public class Scheduler {
public static class KernelThread extends Thread {
public Map<LightWeightThread, Boolean> scheduled = new HashMap<>();
public void preempt() {
for (LightWeightThread thread : scheduled.keySet()) {
scheduled.put(thread, false);
for (int loop = 0 ; loop < thread.getLoops(); loop++) {
thread.preempt(loop);
}
}
}
public void addLightWeightThread(LightWeightThread thread) {
scheduled.put(thread, false);
}
public boolean isScheduled(LightWeightThread lightWeightThread) {
return scheduled.get(lightWeightThread);
}
public void run() {
while (true) {
LightWeightThread previous = null;
for (LightWeightThread thread : scheduled.keySet()) {
scheduled.put(thread, false);
}
for (LightWeightThread thread : scheduled.keySet()) {
if (previous != null) {
scheduled.put(previous, false);
}
scheduled.put(thread, true);
thread.run();
previous = thread;
}
}
}
}
public interface Preemptible {
void registerLoop(int name, int defaultValue, int limit);
int increment(int name);
boolean isPreempted(int name);
int getLimit(int name);
int getValue(int name);
void preempt(int id);
int getLoops();
}
public static abstract class LightWeightThread implements Preemptible {
public int kernelThreadId;
public int threadId;
public KernelThread parent;
AtomicInteger[] values = new AtomicInteger[1];
int[] limits = new int[1];
boolean[] preempted = new boolean[1];
int[] remembered = new int[1];
public LightWeightThread(int kernelThreadId, int threadId, KernelThread parent) {
this.kernelThreadId = kernelThreadId;
this.threadId = threadId;
this.parent = parent;
for (int i = 0 ; i < values.length; i++) {
values[i] = new AtomicInteger();
}
}
public void run() {
}
public void registerLoop(int name, int defaultValue, int limit) {
if (preempted.length > name && remembered[name] < limit) {
values[name].set(remembered[name]);
limits[name] = limit;
} else {
values[name].set(defaultValue);
limits[name] = limit;
}
preempted[name] = false;
}
public int increment(int name) {
return values[name].incrementAndGet();
}
public boolean isPreempted(int name) {
return preempted[name];
}
public int getLimit(int name) {
return limits[name];
}
public int getValue(int name) {
return values[name].get();
}
public int initialVar(int name, int value) {
values[name].set(value);
return value;
}
public void preempt(int id) {
remembered[id] = values[id].get();
preempted[id] = true;
while (!values[id].compareAndSet(values[id].get(), limits[id])){};
}
public int getLoops() {
return values.length;
}
}
public static void main(String[] args) throws InterruptedException {
List<KernelThread> kernelThreads = new ArrayList<>();
for (int i = 0; i < 5; i++) {
KernelThread kt = new KernelThread();
for (int j = 0 ; j < 5; j++) {
LightWeightThread lightWeightThread = new LightWeightThread(i, j, kt) {
@Override
public void run() {
while (this.parent.isScheduled(this)) {
System.out.println(String.format("%d %d", this.kernelThreadId, this.threadId));
registerLoop(0, 0, 10000000);
for (initialVar(0, 0); getValue(0) < getLimit(0); increment(0)) {
Math.sqrt(getValue(0));
}
if (isPreempted(0)) {
System.out.println(String.format("%d %d: %d was preempted !%d < %d", this.kernelThreadId, this.threadId, 0, values[0].get(), limits[0]));
}
}
}
};
kt.addLightWeightThread(lightWeightThread);
}
kernelThreads.add(kt);
}
for (KernelThread kt : kernelThreads) {
kt.start();
}
Timer timer = new Timer();
timer.schedule(new TimerTask() {
@Override
public void run() {
for (KernelThread kt : kernelThreads) {
kt.preempt();
}
}
}, 10, 10);
for (KernelThread kt : kernelThreads) {
kt.join();
}
}
}

这是相同内容的C版本:

#include <pthread.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#define handle_error_en(en, msg) 
do { errno = en; perror(msg); exit(EXIT_FAILURE); } while (0)
#define handle_error(msg) 
do { perror(msg); exit(EXIT_FAILURE); } while (0)
struct lightweight_thread {
int thread_num;
volatile int preempted;
int num_loops;
int *limit;
volatile int *value;
int *remembered;
int kernel_thread_num;
struct lightweight_thread* (*user_function) (struct lightweight_thread*);
};
struct thread_info {    /* Used as argument to thread_start() */
pthread_t thread_id;        /* ID returned by pthread_create() */
int       thread_num;       /* Application-defined thread # */
char     *argv_string;      /* From command-line argument */
int lightweight_threads_num;
struct lightweight_thread *user_threads;
volatile int running;
};
struct timer_thread {
pthread_t thread_id;
struct thread_info *all_threads;
int num_threads;
int lightweight_threads_num;
int delay;
volatile int running;
};
static void *
timer_thread_start(void *arg) {
int iterations = 0;
struct timer_thread *timer_thread = arg;
int msec = 0, trigger = timer_thread->delay; /* 10ms */
clock_t before = clock();
while (timer_thread->running == 1 && iterations < 100000) {
do {
for (int i = 0 ; i < timer_thread->num_threads; i++) {
for (int j = 0 ; j < timer_thread->all_threads[i].lightweight_threads_num; j++) {
// printf("Preempting kernel thread %d user thread %dn", i, j);
timer_thread->all_threads[i].user_threads[j].preempted = 0;

}
}
for (int i = 0 ; i < timer_thread->num_threads; i++) {
for (int j = 0 ; j < timer_thread->all_threads[i].lightweight_threads_num; j++) {


// printf("Preempting kernel thread %d user thread %dn", i, j);
for (int loop = 0; loop < timer_thread->all_threads[i].user_threads[j].num_loops; loop++) {
timer_thread->all_threads[i].user_threads[j].remembered[loop] = timer_thread->all_threads[i].user_threads[j].value[loop];
timer_thread->all_threads[i].user_threads[j].value[loop] = timer_thread->all_threads[i].user_threads[j].limit[loop];
}

}
}
clock_t difference = clock() - before;
msec = difference * 1000 / CLOCKS_PER_SEC;
iterations++;
} while ( msec < trigger && iterations < 100000 );
// printf("Time taken %d seconds %d milliseconds (%d iterations)n",
//  msec/1000, msec%1000, iterations);
}
return 0;
}
/* Thread start function: display address near top of our stack,
and return upper-cased copy of argv_string. */
static void *
thread_start(void *arg)
{
struct thread_info *tinfo = arg;
char *uargv;
printf("Thread %d: top of stack near %p; argv_string=%sn",
tinfo->thread_num, (void *) &tinfo, tinfo->argv_string);
uargv = strdup(tinfo->argv_string);
if (uargv == NULL)
handle_error("strdup");
for (char *p = uargv; *p != ''; p++) {
*p = toupper(*p);

}
while (tinfo->running == 1) {
for (int i = 0 ; i < tinfo->lightweight_threads_num; i++) {
tinfo->user_threads[i].preempted = 0;

}
int previous = -1;
for (int i = 0 ; i < tinfo->lightweight_threads_num; i++) {
if (previous != -1) {
tinfo->user_threads[previous].preempted = 0;
}
tinfo->user_threads[i].preempted = 1;
tinfo->user_threads[i].user_function(&tinfo->user_threads[i]);
previous = i;
} 
}

return uargv;
}
void
register_loop(int index, int value, struct lightweight_thread* m, int limit) {
if (m->remembered[index] == -1) {
m->limit[index] = limit;
m->value[index] = value;
} else {
m->limit[index] = limit;
m->value[index] = m->remembered[index];
}
}
int
lightweight_thread_function(struct lightweight_thread* m)
{

while (m->preempted != 0) {
register_loop(0, 0, m, 100000000);
for (; m->value[0] < m->limit[0]; m->value[0]++) {

sqrt(m->value[0]);
}
printf("Kernel thread %d User thread %d rann", m->kernel_thread_num, m->thread_num);

}


return 0;
}
struct lightweight_thread*
create_lightweight_threads(int kernel_thread_num, int num_threads) {
struct lightweight_thread *lightweight_threads = 
calloc(num_threads, sizeof(*lightweight_threads));
if (lightweight_threads == NULL)
handle_error("calloc lightweight threads");
for (int i = 0 ; i < num_threads ; i++) {
lightweight_threads[i].kernel_thread_num = kernel_thread_num;
lightweight_threads[i].thread_num = i;
lightweight_threads[i].num_loops = 1;
lightweight_threads[i].user_function = lightweight_thread_function;
int *remembered = calloc(lightweight_threads[i].num_loops, sizeof(*remembered));
int *value = calloc(lightweight_threads[i].num_loops, sizeof(*value));
int *limit = calloc(lightweight_threads[i].num_loops, sizeof(*limit));
lightweight_threads[i].remembered = remembered;
lightweight_threads[i].value = value;
lightweight_threads[i].limit = limit;
for (int j = 0 ; j < lightweight_threads[i].num_loops ; j++) {
lightweight_threads[i].remembered[j] = -1;
}
}
return lightweight_threads;
}
int
main(int argc, char *argv[])
{
int s, timer_s, opt, num_threads;
pthread_attr_t attr;
pthread_attr_t timer_attr;
ssize_t stack_size;
void *res;
int timer_result;
/* The "-s" option specifies a stack size for our threads. */
stack_size = 16384ul;
num_threads = 5;

while ((opt = getopt(argc, argv, "t:")) != -1) {
switch (opt) {

case 't':
num_threads = strtoul(optarg, NULL, 0);
break;

default:
fprintf(stderr, "Usage: %s [-t thread-size] arg...n",
argv[0]);
exit(EXIT_FAILURE);
}
}

/* Initialize thread creation attributes. */
s = pthread_attr_init(&attr);
if (s != 0)
handle_error_en(s, "pthread_attr_init");
timer_s = pthread_attr_init(&timer_attr);
if (timer_s != 0)
handle_error_en(s, "pthread_attr_init timer_s");

if (stack_size > 0) {
s = pthread_attr_setstacksize(&attr, stack_size);
int t = pthread_attr_setstacksize(&timer_attr, stack_size);
if (t != 0)
handle_error_en(t, "pthread_attr_setstacksize timer");
if (s != 0)
handle_error_en(s, "pthread_attr_setstacksize");
}
/* Allocate memory for pthread_create() arguments. */
struct thread_info *tinfo = calloc(num_threads, sizeof(*tinfo));
if (tinfo == NULL)
handle_error("calloc");
for (int tnum = 0 ; tnum < num_threads; tnum++) {
tinfo[tnum].running = 1;
}
struct timer_thread *timer_info = calloc(1, sizeof(*timer_info));
timer_info->running = 1;
timer_info->delay = 10;
timer_info->num_threads = num_threads;

if (timer_info == NULL)
handle_error("calloc timer thread");
/* Create one thread for each command-line argument. */
timer_info->all_threads = tinfo;
for (int tnum = 0; tnum < num_threads; tnum++) {
tinfo[tnum].thread_num = tnum + 1;
tinfo[tnum].argv_string = argv[0];
struct lightweight_thread *lightweight_threads = create_lightweight_threads(tnum, num_threads);
tinfo[tnum].user_threads = lightweight_threads;
tinfo[tnum].lightweight_threads_num = num_threads;
/* The pthread_create() call stores the thread ID into
corresponding element of tinfo[]. */
s = pthread_create(&tinfo[tnum].thread_id, &attr,
&thread_start, &tinfo[tnum]);
if (s != 0)
handle_error_en(s, "pthread_create");
}
s = pthread_create(&timer_info[0].thread_id, &timer_attr,
&timer_thread_start, &timer_info[0]);
if (s != 0)
handle_error_en(s, "pthread_create");
/* Destroy the thread attributes object, since it is no
longer needed. */
s = pthread_attr_destroy(&attr);
if (s != 0)
handle_error_en(s, "pthread_attr_destroy");
s = pthread_attr_destroy(&timer_attr);
if (s != 0)
handle_error_en(s, "pthread_attr_destroy timer");
/* Now join with each thread, and display its returned value. */
s = pthread_join(timer_info->thread_id, &timer_result);
if (s != 0)
handle_error_en(s, "pthread_join");
printf("Joined timer thread");

for (int tnum = 0; tnum < num_threads; tnum++) {
tinfo[tnum].running = 0;
s = pthread_join(tinfo[tnum].thread_id, &res);
if (s != 0)
handle_error_en(s, "pthread_join");
printf("Joined with thread %d; returned value was %sn",
tinfo[tnum].thread_num, (char *) res);
free(res);      /* Free memory allocated by thread */
for (int user_thread_num = 0 ; user_thread_num < num_threads; user_thread_num++) {

free(tinfo[tnum].user_threads[user_thread_num].remembered);
free(tinfo[tnum].user_threads[user_thread_num].value);
free(tinfo[tnum].user_threads[user_thread_num].limit);

}
free(tinfo[tnum].user_threads);
}



free(timer_info);
free(tinfo);
exit(EXIT_SUCCESS);
}

Rust版本使用不安全。

extern crate timer;
extern crate chrono;
use std::sync::Arc;
use std::thread;
use std::sync::atomic::{AtomicI32, Ordering};
use std::{time};
struct LightweightThread {
thread_num: i32, 
preempted: AtomicI32,
num_loops: i32,
limit: Vec<AtomicI32>,
value: Vec<AtomicI32>,
remembered: Vec<AtomicI32>,
kernel_thread_num: i32,
lightweight_thread: fn(&mut LightweightThread)
}
fn register_loop(loopindex: usize, initialValue: i32, limit: i32, _thread: &mut LightweightThread) {
if _thread.remembered[loopindex].load(Ordering::Relaxed) < _thread.limit[loopindex].load(Ordering::Relaxed) {
_thread.value[loopindex].store( _thread.remembered[loopindex].load(Ordering::Relaxed), Ordering::Relaxed);
_thread.limit[loopindex].store(limit, Ordering::Relaxed);
} else {
_thread.value[loopindex].store(initialValue, Ordering::Relaxed);
_thread.limit[loopindex].store(limit, Ordering::Relaxed);
}

}
fn lightweight_thread(_thread: &mut LightweightThread) {
register_loop(0usize, 0, 1000000, _thread);
while _thread.preempted.load(Ordering::Relaxed) == 1 {
while _thread.value[0].load(Ordering::Relaxed) < _thread.limit[0].load(Ordering::Relaxed) {
let i = _thread.value[0].load(Ordering::Relaxed);
f64::sqrt(i.into());
_thread.value[0].fetch_add(1, Ordering::Relaxed);
}
}
println!("Kernel thread {} User thread {}", _thread.kernel_thread_num, _thread.thread_num)
}
fn main() {
println!("Hello, world!");
let timer = timer::Timer::new();
static mut threads:Vec<LightweightThread> = Vec::new();
let mut thread_handles = Vec::new();
for kernel_thread_num in 1..=5 {

let thread_join_handle = thread::spawn(move || {



for i in 1..=5 {
let mut lthread = LightweightThread {
thread_num: i,
preempted: AtomicI32::new(0),
num_loops: 1,
limit: Vec::new(),
value: Vec::new(),
remembered: Vec::new(),
kernel_thread_num: kernel_thread_num.clone(),
lightweight_thread: lightweight_thread
};

lthread.limit.push(AtomicI32::new(-1));

lthread.value.push(AtomicI32::new(-1));

lthread.remembered.push(AtomicI32::new(1));


unsafe {
threads.push(lthread);
}
}


loop {
let mut previous:Option<&mut LightweightThread> = None;
unsafe {
for (_pos, current_thread) in threads.iter_mut().enumerate() {

if current_thread.kernel_thread_num != kernel_thread_num {

continue;
}
if !previous.is_none() {
previous.unwrap().preempted.store(0, Ordering::Relaxed)
}
current_thread.preempted.store(1, Ordering::Relaxed);
(current_thread.lightweight_thread)(current_thread);
previous = Some(current_thread);
// println!("Running")
}
}
} // loop forever
}); // thread

thread_handles.push(thread_join_handle);
} // thread generation


let timer_handle = thread::spawn(move || {

unsafe {
loop {


for thread in threads.iter() {
thread.preempted.store(0, Ordering::Relaxed);
}
let mut previous:Option<usize> = None;
for (index,  thread) in threads.iter_mut().enumerate() {
if !previous.is_none() {
threads[previous.unwrap()].preempted.store(0, Ordering::Relaxed);
}
previous = Some(index);
for loopindex in 0..thread.num_loops {
thread.remembered[loopindex as usize].store(thread.value[loopindex as usize].load(Ordering::Relaxed), Ordering::Relaxed);
thread.value[loopindex as usize].store(thread.limit[loopindex as usize].load(Ordering::Relaxed), Ordering::Relaxed);

}
thread.preempted.store(1, Ordering::Relaxed);
}
let ten_millis = time::Duration::from_millis(10);

thread::sleep(ten_millis);

} // loop
} // unsafe

}); // end of thread
timer_handle.join();
for thread in thread_handles {
thread.join();
}

}

最新更新