基本概念
在 Rust 中我们除了使用消息传递在线程间通信之外,还可以通过共享内存的方式来进行通信。保证并发安全的前提是使用互斥器(mutex)提供的锁(lock)来保护数据,下面是一个使用互斥器的例子:
use std::sync::{Arc, Mutex};
use std::thread;
const THREAD_NUM: usize = 5;
const N: usize = 1_000_000;
fn create_thread(sum: &Arc<Mutex<i32>>, handles: &mut Vec<thread::JoinHandle<()>>, i: usize) {
let sum_ref = Arc::clone(sum);
let handle = thread::spawn(move || {
for _ in 0..N {
let mut guard = sum_ref.lock().unwrap();
for _ in 0..10 {
*guard += 1;
}
}
println!("Thread {}: sum = {}", i, *sum_ref.lock().unwrap());
});
handles.push(handle);
}
fn main() {
let sum = Arc::new(Mutex::new(0));
let mut handles = vec![];
for i in 0..THREAD_NUM {
create_thread(&sum, &mut handles, i);
}
for handle in handles {
handle.join().unwrap();
}
println!("sum = {}", *sum.lock().unwrap());
println!("{}*n = {}", THREAD_NUM * 10, (THREAD_NUM * 10 * N) as isize);
}在这个例子中创建了多个线程对同一个变量 sum 执行修改操作,输出如下:
Thread 2: sum = 40646100
Thread 0: sum = 47149680
Thread 4: sum = 49003000
Thread 1: sum = 49860580
Thread 3: sum = 50000000
sum = 50000000
50*n = 50000000条件变量
条件变量(condition variable)是一种万能的处理并发问题的解决方法,其形式可能不如信号量那样优雅,但适用性广泛。当线程之间必须发生某种信号时,如果一个线程在等待另一个线程继续执行某些操作,条件变量就很有用。线程可以使用条件变量,来等待一个条件变成真。条件变量是一个显式队列,当某些执行状态(即条件,condition)不满足时,线程可以把自己加入队列,等待(waiting)该条件。另外某个线程,当它改变了上述状态时,就可以唤醒一个或者多个等待线程(通过在该条件上发信号),让它们继续执行。
条件变量的使用形式很简单:
lock(mutex);
while (!condition) { // 可以是任意条件
wait(cvar, mutex) // 释放锁并阻塞当前线程
}
assert(condition);
// do something
unlock(mutex); // 释放锁在 Rust 中使用条件变量只需要 use std::sync::Condvar; 即可。
一个简单问题:括号序列
条件变量一个很常见的应用场景是解决“生产者/消费者(producer/consumer)问题”。
考虑这样一个问题:存在两种线程分别可以死循环打印字符“(”与字符“)”,括号嵌套深度不能超过规定值,打印合法的括号序列。
这是个显然的生产者/消费者问题,两种线程分别对应生产者与消费者,那么该问题下同步的条件则为当前打印序列的嵌套深度,当深度小于 0 时可以打印左括号,当深度大于 0 时可以打印右括号,在打印之后对状态进行更新。
现在来将其用 Rust 实现。首先定义规定的嵌套深度与线程数,将当前的嵌套深度抽象为一个结构体,再将它包装一下方便传参:
const MAX_DEPTH: usize = 2;
const THREAD_NUM: usize = 4;
struct SharedState {
current_depth: usize,
}
type SharedPair = Arc<(Mutex<SharedState>, Condvar)>;main 函数的写法是简单的,这里不赘述:
fn main() {
let initial_state = SharedState { current_depth: 0 };
let pair = Arc::new((Mutex::new(initial_state), Condvar::new()));
let mut handles = vec![];
for _ in 0..THREAD_NUM {
let pair_ref = Arc::clone(&pair);
let handle = thread::spawn(move || {
produce(pair_ref);
});
handles.push(handle);
let pair_ref = Arc::clone(&pair);
let handle = thread::spawn(move || {
consume(pair_ref);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}对于生产者线程函数的,可以发现只需要将标准库提供的条件变量 API 套上前文的条件变量使用方法即可,最后需要唤醒其他线程,锁在离开作用域时会自动释放:
fn produce(pair: SharedPair) {
let (lk, cvar) = &*pair;
loop {
let mut state = lk.lock().unwrap();
while !(state.current_depth < MAX_DEPTH) {
state = cvar.wait(state).unwrap();
}
assert!(state.current_depth < MAX_DEPTH);
print!("(");
io::stdout().flush().unwrap();
state.current_depth += 1;
cvar.notify_all();
}
}消费者函数同理:
fn consume(pair: SharedPair) {
let (lk, cvar) = &*pair;
loop {
let mut state = lk.lock().unwrap();
while !(state.current_depth > 0) {
state = cvar.wait(state).unwrap();
}
assert!(state.current_depth > 0);
print!(")");
io::stdout().flush().unwrap();
state.current_depth -= 1;
cvar.notify_all();
}
}完整代码:
use std::io::{self, Write};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
const MAX_DEPTH: usize = 2;
const THREAD_NUM: usize = 4;
struct SharedState {
current_depth: usize,
}
type SharedPair = Arc<(Mutex<SharedState>, Condvar)>;
fn produce(pair: SharedPair) {
let (lk, cvar) = &*pair;
loop {
let mut state = lk.lock().unwrap();
while !(state.current_depth < MAX_DEPTH) {
state = cvar.wait(state).unwrap();
}
assert!(state.current_depth < MAX_DEPTH);
print!("(");
io::stdout().flush().unwrap();
state.current_depth += 1;
cvar.notify_all();
}
}
fn consume(pair: SharedPair) {
let (lk, cvar) = &*pair;
loop {
let mut state = lk.lock().unwrap();
while !(state.current_depth > 0) {
state = cvar.wait(state).unwrap();
}
assert!(state.current_depth > 0);
print!(")");
io::stdout().flush().unwrap();
state.current_depth -= 1;
cvar.notify_all();
}
}
fn main() {
let initial_state = SharedState { current_depth: 0 };
let pair = Arc::new((Mutex::new(initial_state), Condvar::new()));
let mut handles = vec![];
for _ in 0..THREAD_NUM {
let pair_ref = Arc::clone(&pair);
let handle = thread::spawn(move || {
produce(pair_ref);
});
handles.push(handle);
let pair_ref = Arc::clone(&pair);
let handle = thread::spawn(move || {
consume(pair_ref);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}稍微复杂一点的问题:打印小鱼
当问题简单时使用条件变量需要考虑的状态很简单,现在让我们来看看复杂的状态下如何解决。
存在三种线程:
- 打印字符 “
<” - 打印字符 “
>” - 打印字符 “
_”
以下的字符序列是合法的:“<><_”或“><>_”,输出 个合法的序列。
从状态机的视角考虑,两种序列中的字符分别对应一种状态,根据当前的状态来进行状态转移打印对应的合法字符:
#[derive(Copy, Clone, PartialEq, Eq)]
enum State {
Start, // 初始状态
LeftStart, // 收到 <
LeftMiddle, // 收到 <>
RightStart, // 收到 >
RightMiddle, // 收到 ><
Separate, // 收到 <>< 或 ><>
}
impl State {
fn transition(self, ch: char) -> Option<State> {
use State::*;
match (self, ch) {
// Sequence 1: <><
(Start, '<') => Some(LeftStart),
(LeftStart, '>') => Some(LeftMiddle),
(LeftMiddle, '<') => Some(Separate),
// Sequence 2: ><>
(Start, '>') => Some(RightStart),
(RightStart, '<') => Some(RightMiddle),
(RightMiddle, '>') => Some(Separate),
// Common End: _
(Separate, '_') => Some(Start),
// 当前线程无法进行状态转移
_ => None,
}
}
}现在可以写出需要加锁保护的数据了:
struct JobState {
state: State,
remaining: usize,
}现在可以写出代码的主体框架:
type SyncState = Arc<(Mutex<JobState>, Condvar)>;
fn worker(sync_state: SyncState, ch: char) {
todo!();
}
fn main() {
let shared = JobState {
state: State::Start,
remaining: 1000,
};
let sync_state = Arc::new((Mutex::new(shared), Condvar::new()));
let mut handles = vec![];
for ch in ['<', '>', '_'] {
let sync_state_ref = Arc::clone(&sync_state);
let handle = thread::spawn(move || worker(sync_state_ref, ch));
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!();
}对于打印线程,我们可以用同样的方法写出对应的代码。
fn worker(sync_state: SyncState, ch: char) {
let (lock, cvar) = &*sync_state;
loop {
let mut shared = lock.lock().unwrap();
// 阻止当前线程,直到此条件变量接收到通知并且所提供的条件为 false 为止
// 条件:还有剩余任务 && 无法进行状态转移
shared = cvar
.wait_while(shared, |s| {
s.remaining > 0 && s.state.transition(ch).is_none()
})
.unwrap();
if shared.remaining == 0 {
break;
}
// 打印并更新状态
print!("{}", ch);
io::stdout().flush().unwrap();
// 这里的 unwrap 是安全的,因为 wait_while 保证了 next(ch) 是 Some
let next_state = shared.state.transition(ch).unwrap();
shared.state = next_state;
if next_state == State::Start {
shared.remaining -= 1;
}
cvar.notify_all();
}
}至此,我们已经可以成功实现打印合法的序列了,如果想要控制两种序列的数量只需要修改一下同步状态结构体、状态转移函数与维护当前状态的操作即可。
完整代码:
use std::io::{self, Write};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
#[derive(Copy, Clone, PartialEq, Eq)]
enum State {
Start, // 初始状态
LeftStart, // 收到 <
LeftMiddle, // 收到 <>
RightStart, // 收到 >
RightMiddle, // 收到 ><
Separate, // 收到 <>< 或 ><>
}
impl State {
fn transition(self, ch: char) -> Option<State> {
use State::*;
match (self, ch) {
// Sequence 1: <><
(Start, '<') => Some(LeftStart),
(LeftStart, '>') => Some(LeftMiddle),
(LeftMiddle, '<') => Some(Separate),
// Sequence 2: ><>
(Start, '>') => Some(RightStart),
(RightStart, '<') => Some(RightMiddle),
(RightMiddle, '>') => Some(Separate),
// Common End: _
(Separate, '_') => Some(Start),
// 当前线程无法进行状态转移
_ => None,
}
}
}
struct JobState {
state: State,
remaining: usize,
}
type SyncState = Arc<(Mutex<JobState>, Condvar)>;
fn worker(sync_state: SyncState, ch: char) {
let (lock, cvar) = &*sync_state;
loop {
let mut shared = lock.lock().unwrap();
// 阻止当前线程,直到此条件变量接收到通知并且所提供的条件为 false 为止
// 条件:还有剩余任务 && 无法进行状态转移
shared = cvar
.wait_while(shared, |s| {
s.remaining > 0 && s.state.transition(ch).is_none()
})
.unwrap();
if shared.remaining == 0 {
break;
}
// 打印并更新状态
print!("{}", ch);
io::stdout().flush().unwrap();
// 这里的 unwrap 是安全的,因为 wait_while 保证了 next(ch) 是 Some
let next_state = shared.state.transition(ch).unwrap();
shared.state = next_state;
if next_state == State::Start {
shared.remaining -= 1;
}
cvar.notify_all();
}
}
fn main() {
let shared = JobState {
state: State::Start,
remaining: 1000,
};
let sync_state = Arc::new((Mutex::new(shared), Condvar::new()));
let mut handles = vec![];
for ch in ['<', '>', '_'] {
let sync_state_ref = Arc::clone(&sync_state);
let handle = thread::spawn(move || worker(sync_state_ref, ch));
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!();
}同步变量实现信号量
相较于条件变量,信号量用来解决并发问题更加优雅,尤其是对计数型的资源上的管理。由于 Rust 标准库中没有提供信号量的实现,我们可以用同步变量实现一个简易的信号量。
首先是结构体定义,包含一个用 Mutex 保护的计数与条件变量:
struct Semaphore {
count: Mutex<usize>,
cvar: Condvar,
}接下来是构造函数以及 P 操作与 V 操作,这里可以发现在信号量的实现上与条件变量解决并发问题的方法几乎一样:
impl Semaphore {
fn new(count: usize) -> Self {
Semaphore {
count: Mutex::new(count),
cvar: Condvar::new(),
}
}
// P 操作
fn acquire(&self) {
let mut count = self.count.lock().unwrap();
while !(*count > 0) {
count = self.cvar.wait(count).unwrap();
}
*count -= 1;
}
// V 操作
fn release(&self) {
let mut count = self.count.lock().unwrap();
*count += 1;
self.cvar.notify_one();
}
}现在来使用信号量解决打印括号序列问题,同样的我们使用生产者/消费者模型:
fn produce(empty: Arc<Semaphore>, fill: Arc<Semaphore>) {
loop {
empty.acquire();
print!("(");
io::stdout().flush().unwrap();
fill.release();
}
}
fn consume(empty: Arc<Semaphore>, fill: Arc<Semaphore>) {
loop {
fill.acquire();
print!(")");
io::stdout().flush().unwrap();
empty.release();
}
}相较于条件变量,信号量将上锁的操作细节隐藏了起来。
完整代码:
use std::io::{self, Write};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
struct Semaphore {
count: Mutex<usize>,
cvar: Condvar,
}
impl Semaphore {
fn new(count: usize) -> Self {
Semaphore {
count: Mutex::new(count),
cvar: Condvar::new(),
}
}
fn acquire(&self) {
let mut count = self.count.lock().unwrap();
while !(*count > 0) {
count = self.cvar.wait(count).unwrap();
}
*count -= 1;
}
fn release(&self) {
let mut count = self.count.lock().unwrap();
*count += 1;
self.cvar.notify_one();
}
}
const MAX_DEPTH: usize = 2;
const THREAD_NUM: usize = 4;
fn produce(empty: Arc<Semaphore>, fill: Arc<Semaphore>) {
loop {
empty.acquire();
print!("(");
io::stdout().flush().unwrap();
fill.release();
}
}
fn consume(empty: Arc<Semaphore>, fill: Arc<Semaphore>) {
loop {
fill.acquire();
print!(")");
io::stdout().flush().unwrap();
empty.release();
}
}
fn main() {
let empty = Arc::new(Semaphore::new(MAX_DEPTH));
let fill = Arc::new(Semaphore::new(0));
let mut handles = vec![];
for _ in 0..THREAD_NUM {
let empty_ref = Arc::clone(&empty);
let fill_ref = Arc::clone(&fill);
let handle = thread::spawn(move || {
produce(empty_ref, fill_ref);
});
handles.push(handle);
let empty_ref = Arc::clone(&empty);
let fill_ref = Arc::clone(&fill);
let handle = thread::spawn(move || {
consume(empty_ref, fill_ref);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}需要注意的是这里的信号量的实现是简易的,甚至没有错误处理的实现,想要实现可靠、高效的信号量需要考虑的问题很多。
