Skip to content

Commit

Permalink
Enable thread-safety marker traits for structs
Browse files Browse the repository at this point in the history
- Array: Send, Sync
- Features: Send, Sync
- Event: Send
- RandomEngine: Send
- Indexer: Send

Added a new threading tutorial with code examples illustrating
how to share Array across threads.

Added unit tests in corresponding modules
  • Loading branch information
9prady9 committed Oct 6, 2020
1 parent 653ef75 commit f3b6c03
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 0 deletions.
249 changes: 249 additions & 0 deletions src/core/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ extern "C" {
///
/// Currently, Array objects can store only data until four dimensions
///
/// ## Sharing Across Threads
///
/// While sharing an Array with other threads, there is no need to wrap
/// this in an Arc object unless only one such object is required to exist.
/// The reason being that ArrayFire's internal Array is appropriately reference
/// counted in thread safe manner. However, if you need to modify Array object,
/// then please do wrap the object using a Mutex or Read-Write lock.
///
/// Examples on how to share Array across threads is illustrated in our
/// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
///
/// ### NOTE
///
/// All operators(traits) from std::ops module implemented for Array object
Expand All @@ -156,6 +167,11 @@ pub struct Array<T: HasAfEnum> {
_marker: PhantomData<T>,
}

/// Enable safely moving Array objects across threads
unsafe impl<T: HasAfEnum> Send for Array<T> {}

unsafe impl<T: HasAfEnum> Sync for Array<T> {}

macro_rules! is_func {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
#[doc=$doc_str]
Expand Down Expand Up @@ -834,3 +850,236 @@ pub fn is_eval_manual() -> bool {
ret_val > 0
}
}

#[cfg(test)]
mod tests {
use super::super::array::print;
use super::super::data::constant;
use super::super::device::{info, set_device, sync};
use crate::dim4;
use std::sync::{mpsc, Arc, RwLock};
use std::thread;

#[test]
fn thread_move_array() {
// ANCHOR: move_array_to_thread
set_device(0);
info();
let mut a = constant(1, dim4!(3, 3));

let handle = thread::spawn(move || {
//set_device to appropriate device id is required in each thread
set_device(0);

println!("\nFrom thread {:?}", thread::current().id());

a += constant(2, dim4!(3, 3));
print(&a);
});

//Need to join other threads as main thread holds arrayfire context
handle.join().unwrap();
// ANCHOR_END: move_array_to_thread
}

#[test]
fn thread_borrow_array() {
set_device(0);
info();
let a = constant(1i32, dim4!(3, 3));

let handle = thread::spawn(move || {
set_device(0); //set_device to appropriate device id is required in each thread
println!("\nFrom thread {:?}", thread::current().id());
print(&a);
});
//Need to join other threads as main thread holds arrayfire context
handle.join().unwrap();
}

// ANCHOR: multiple_threads_enum_def
#[derive(Debug, Copy, Clone)]
enum Op {
Add,
Sub,
Div,
Mul,
}
// ANCHOR_END: multiple_threads_enum_def

#[test]
fn read_from_multiple_threads() {
// ANCHOR: read_from_multiple_threads
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];

// Set active GPU/device on main thread on which
// subsequent Array objects are created
set_device(0);

// ArrayFire Array's are internally maintained via atomic reference counting
// Thus, they need no Arc wrapping while moving to another thread.
// Just call clone method on the object and share the resulting clone object
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));

let threads: Vec<_> = ops
.into_iter()
.map(|op| {
let x = a.clone();
let y = b.clone();
thread::spawn(move || {
set_device(0); //Both of objects are created on device 0 earlier
match op {
Op::Add => {
let _c = x + y;
}
Op::Sub => {
let _c = x - y;
}
Op::Div => {
let _c = x / y;
}
Op::Mul => {
let _c = x * y;
}
}
sync(0);
thread::sleep(std::time::Duration::new(1, 0));
})
})
.collect();
for child in threads {
let _ = child.join();
}
// ANCHOR_END: read_from_multiple_threads
}

#[test]
fn access_using_rwlock() {
// ANCHOR: access_using_rwlock
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];

// Set active GPU/device on main thread on which
// subsequent Array objects are created
set_device(0);

let c = constant(0.0f32, dim4!(3, 3));
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));

// Move ownership to RwLock and wrap in Arc since same object is to be modified
let c_lock = Arc::new(RwLock::new(c));

// a and b are internally reference counted by ArrayFire. Unless there
// is prior known need that they may be modified, you can simply clone
// the objects pass them to threads

let threads: Vec<_> = ops
.into_iter()
.map(|op| {
let x = a.clone();
let y = b.clone();

let wlock = c_lock.clone();
thread::spawn(move || {
//Both of objects are created on device 0 in main thread
//Every thread needs to set the device that it is going to
//work on. Note that all Array objects must have been created
//on same device as of date this is written on.
set_device(0);
if let Ok(mut c_guard) = wlock.write() {
match op {
Op::Add => {
*c_guard += x + y;
}
Op::Sub => {
*c_guard += x - y;
}
Op::Div => {
*c_guard += x / y;
}
Op::Mul => {
*c_guard += x * y;
}
}
}
})
})
.collect();

for child in threads {
let _ = child.join();
}

//let read_guard = c_lock.read().unwrap();
//af_print!("C after threads joined", *read_guard);
//C after threads joined
//[3 3 1 1]
// 8.0000 8.0000 8.0000
// 8.0000 8.0000 8.0000
// 8.0000 8.0000 8.0000
// ANCHOR_END: access_using_rwlock
}

#[test]
fn accum_using_channel() {
// ANCHOR: accum_using_channel
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
let ops_len: usize = ops.len();

// Set active GPU/device on main thread on which
// subsequent Array objects are created
set_device(0);

let mut c = constant(0.0f32, dim4!(3, 3));
let a = constant(1.0f32, dim4!(3, 3));
let b = constant(2.0f32, dim4!(3, 3));

let (tx, rx) = mpsc::channel();

let threads: Vec<_> = ops
.into_iter()
.map(|op| {
// a and b are internally reference counted by ArrayFire. Unless there
// is prior known need that they may be modified, you can simply clone
// the objects pass them to threads
let x = a.clone();
let y = b.clone();

let tx_clone = tx.clone();

thread::spawn(move || {
//Both of objects are created on device 0 in main thread
//Every thread needs to set the device that it is going to
//work on. Note that all Array objects must have been created
//on same device as of date this is written on.
set_device(0);

let c = match op {
Op::Add => x + y,
Op::Sub => x - y,
Op::Div => x / y,
Op::Mul => x * y,
};
tx_clone.send(c).unwrap();
})
})
.collect();

for _i in 0..ops_len {
c += rx.recv().unwrap();
}

//Need to join other threads as main thread holds arrayfire context
for child in threads {
let _ = child.join();
}

//af_print!("C after accumulating results", &c);
//[3 3 1 1]
// 8.0000 8.0000 8.0000
// 8.0000 8.0000 8.0000
// 8.0000 8.0000 8.0000
// ANCHOR_END: accum_using_channel
}
}
85 changes: 85 additions & 0 deletions src/core/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ extern "C" {
}

/// RAII construct to manage ArrayFire events
///
/// ## Sharing Across Threads
///
/// While sharing an Event with other threads, just move it across threads.
pub struct Event {
event_handle: af_event,
}

unsafe impl Send for Event {}
// No borrowed references are to be shared for Events, hence no sync trait

impl Default for Event {
fn default() -> Self {
let mut temp: af_event = std::ptr::null_mut();
Expand Down Expand Up @@ -74,3 +81,81 @@ impl Drop for Event {
}
}
}

#[cfg(test)]
mod tests {
use super::super::arith::pow;
use super::super::device::{info, set_device};
use super::super::event::Event;
use crate::{af_print, randu};
use std::sync::mpsc;
use std::thread;

#[test]
fn event_block() {
// This code example will try to compute the following expression
// using data-graph approach using threads, evens for illustration.
//
// (a * (b + c))^(d - 2)
//
// ANCHOR: event_block

// Set active GPU/device on main thread on which
// subsequent Array objects are created
set_device(0);
info();

let a = randu!(10, 10);
let b = randu!(10, 10);
let c = randu!(10, 10);
let d = randu!(10, 10);

let (tx, rx) = mpsc::channel();

thread::spawn(move || {
set_device(0);

let add_event = Event::default();

let add_res = b + c;

add_event.mark();
tx.send((add_res, add_event)).unwrap();

thread::sleep(std::time::Duration::new(10, 0));

let sub_event = Event::default();

let sub_res = d - 2;

sub_event.mark();
tx.send((sub_res, sub_event)).unwrap();
});

let (add_res, add_event) = rx.recv().unwrap();

println!("Got first message, waiting for addition result ...");
thread::sleep(std::time::Duration::new(5, 0));
// Perhaps, do some other tasks
add_event.block();

println!("Got addition result, now scaling it ... ");
let scaled = a * add_res;

let (sub_res, sub_event) = rx.recv().unwrap();

println!("Got message, waiting for subtraction result ...");
thread::sleep(std::time::Duration::new(5, 0));
// Perhaps, do some other tasks
sub_event.block();

let fin_res = pow(&scaled, &sub_res, false);

af_print!(
"Final result of the expression: ((a * (b + c))^(d - 2))",
&fin_res
);

// ANCHOR_END: event_block
}
}
7 changes: 7 additions & 0 deletions src/core/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ extern "C" {

/// Struct to manage an array of resources of type `af_indexer_t`(ArrayFire C struct)
///
/// ## Sharing Across Threads
///
/// While sharing an Indexer object with other threads, just move it across threads. At the
/// moment, one cannot share borrowed references across threads.
///
/// # Examples
///
/// Given below are examples illustrating correct and incorrect usage of Indexer struct.
Expand Down Expand Up @@ -108,6 +113,8 @@ pub struct Indexer<'object> {
marker: PhantomData<&'object ()>,
}

unsafe impl<'object> Send for Indexer<'object> {}

/// Trait bound indicating indexability
///
/// Any object to be able to be passed on to [Indexer::set_index()](./struct.Indexer.html#method.set_index) method should implement this trait with appropriate implementation of `set` method.
Expand Down
Loading

0 comments on commit f3b6c03

Please sign in to comment.