I'm trying to make client program that communicates with a server using a TcpStream wrapped by a openssl::ssl::SslStream (from crates.io). It should wait for read, and process data sent from the server if it was received without delay. At the same time, it should be able to send messages to the server regardless of reading.
I tried some methods such as
Passing single stream to both read and write threads. Both read and write methods require a mutable reference, so I couldn't pass a single stream to two threads.
I followed In Rust how do I handle parallel read writes on a TcpStream, but TcpStream doesn't seem to have clone method, and neither does SslStream.
I tried making copy of TcpStream with as_raw_fd and from_raw_fd :
fn irc_read(mut stream: SslStream<TcpStream>) {
loop {
let mut buf = vec![0; 2048];
let resp = stream.ssl_read(&mut buf);
match resp {
// Process Message
}
}
}
fn irc_write(mut stream: SslStream<TcpStream>) {
thread::sleep(Duration::new(3, 0));
let msg = "QUIT\n";
let res = stream.ssl_write(msg.as_bytes());
let _ = stream.flush();
match res {
// Process
}
}
fn main() {
let ctx = SslContext::new(SslMethod::Sslv23).unwrap();
let read_ssl = Ssl::new(&ctx).unwrap();
let write_ssl = Ssl::new(&ctx).unwrap();
let raw_stream = TcpStream::connect((SERVER, PORT)).unwrap();
let mut fd_stream: TcpStream;
unsafe {
fd_stream = TcpStream::from_raw_fd(raw_stream.as_raw_fd());
}
let mut read_stream = SslStream::connect(read_ssl, raw_stream).unwrap();
let mut write_stream = SslStream::connect(write_ssl, fd_stream).unwrap();
let read_thread = thread::spawn(move || {
irc_read(read_stream);
});
let write_thread = thread::spawn(move || {
irc_write(write_stream);
});
let _ = read_thread.join();
let _ = write_thread.join();
}
this code compiles, but panics on the second SslStream::connect
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Failure(Ssl(ErrorStack([Error { library: "SSL routines", function: "SSL23_GET_SERVER_HELLO", reason: "unknown protocol" }])))', ../src/libcore/result.rs:788
stack backtrace:
1: 0x556d719c6069 - std::sys::backtrace::tracing::imp::write::h00e948915d1e4c72
2: 0x556d719c9d3c - std::panicking::default_hook::_{{closure}}::h7b8a142818383fb8
3: 0x556d719c8f89 - std::panicking::default_hook::h41cf296f654245d7
4: 0x556d719c9678 - std::panicking::rust_panic_with_hook::h4cbd7ca63ce1aee9
5: 0x556d719c94d2 - std::panicking::begin_panic::h93672d0313d5e8e9
6: 0x556d719c9440 - std::panicking::begin_panic_fmt::hd0daa02942245d81
7: 0x556d719c93c1 - rust_begin_unwind
8: 0x556d719ffcbf - core::panicking::panic_fmt::hbfc935564d134c1b
9: 0x556d71899f02 - core::result::unwrap_failed::h66f79b2edc69ddfd
at /buildslave/rust-buildbot/slave/stable-dist-rustc-linux/build/obj/../src/libcore/result.rs:29
10: 0x556d718952cb - _<core..result..Result<T, E>>::unwrap::h49a140af593bc4fa
at /buildslave/rust-buildbot/slave/stable-dist-rustc-linux/build/obj/../src/libcore/result.rs:726
11: 0x556d718a5e3d - dbrust::main::h24a50e631826915e
at /home/lastone817/dbrust/src/main.rs:87
12: 0x556d719d1826 - __rust_maybe_catch_panic
13: 0x556d719c8702 - std::rt::lang_start::h53bf99b0829cc03c
14: 0x556d718a6b83 - main
15: 0x7f40a0b5082f - __libc_start_main
16: 0x556d7188d038 - _start
17: 0x0 - <unknown>
error: Process didn't exit successfully: `target/debug/dbrust` (exit code: 101)
The best solution I've found so far is to use nonblocking. I used Mutex on the stream and passed it to both threads. Then the reading thread acquires a lock and calls read. If there is no message it releases the lock so that the writing thread can use the stream. With this method, the reading thread does busy waiting, resulting in 100% CPU consumption. This is not the best solution, I think.
Is there a safe way to separate the read and write aspects of the stream?
I accomplished the split of an SSL stream into a read and a write part by using Rust's std::cell::UnsafeCell.
extern crate native_tls;
use native_tls::TlsConnector;
use std::cell::UnsafeCell;
use std::error::Error;
use std::io::Read;
use std::io::Write;
use std::marker::Sync;
use std::net::TcpStream;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
struct UnsafeMutator<T> {
value: UnsafeCell<T>,
}
impl<T> UnsafeMutator<T> {
fn new(value: T) -> UnsafeMutator<T> {
return UnsafeMutator {
value: UnsafeCell::new(value),
};
}
fn mut_value(&self) -> &mut T {
return unsafe { &mut *self.value.get() };
}
}
unsafe impl<T> Sync for UnsafeMutator<T> {}
struct ReadWrapper<R>
where
R: Read,
{
inner: Arc<UnsafeMutator<R>>,
}
impl<R: Read> Read for ReadWrapper<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
return self.inner.mut_value().read(buf);
}
}
struct WriteWrapper<W>
where
W: Write,
{
inner: Arc<UnsafeMutator<W>>,
}
impl<W: Write> Write for WriteWrapper<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
return self.inner.mut_value().write(buf);
}
fn flush(&mut self) -> Result<(), std::io::Error> {
return self.inner.mut_value().flush();
}
}
pub struct Socket {
pub output_stream: Arc<Mutex<Write + Send>>,
pub input_stream: Arc<Mutex<Read + Send>>,
}
impl Socket {
pub fn bind(host: &str, port: u16, secure: bool) -> Result<Socket, Box<Error>> {
let tcp_stream = match TcpStream::connect((host, port)) {
Ok(x) => x,
Err(e) => return Err(Box::new(e)),
};
if secure {
let tls_connector = TlsConnector::builder().build().unwrap();
let tls_stream = match tls_connector.connect(host, tcp_stream) {
Ok(x) => x,
Err(e) => return Err(Box::new(e)),
};
let mutator = Arc::new(UnsafeMutator::new(tls_stream));
let input_stream = Arc::new(Mutex::new(ReadWrapper {
inner: mutator.clone(),
}));
let output_stream = Arc::new(Mutex::new(WriteWrapper { inner: mutator }));
let socket = Socket {
output_stream,
input_stream,
};
return Ok(socket);
} else {
let mutator = Arc::new(UnsafeMutator::new(tcp_stream));
let input_stream = Arc::new(Mutex::new(ReadWrapper {
inner: mutator.clone(),
}));
let output_stream = Arc::new(Mutex::new(WriteWrapper { inner: mutator }));
let socket = Socket {
output_stream,
input_stream,
};
return Ok(socket);
}
}
}
fn main() {
let socket = Arc::new(Socket::bind("google.com", 443, true).unwrap());
let socket_clone = Arc::clone(&socket);
let reader_thread = thread::spawn(move || {
let mut res = vec![];
let mut input_stream = socket_clone.input_stream.lock().unwrap();
input_stream.read_to_end(&mut res).unwrap();
println!("{}", String::from_utf8_lossy(&res));
});
let writer_thread = thread::spawn(move || {
let mut output_stream = socket.output_stream.lock().unwrap();
output_stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
});
writer_thread.join().unwrap();
reader_thread.join().unwrap();
}
Related
I'm trying to make a simple host which can handle multiple streams (similar to the example given in https://tokio.rs/tokio/tutorial):
use std::{error::Error, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
time::sleep,
};
type GenericResult<T> = Result<T, Box<dyn Error>>;
#[tokio::main]
async fn main() {
const ADDRESS: &str = "127.0.0.1:8080";
let listener = TcpListener::bind(ADDRESS).await.unwrap();
tokio::spawn(async { host(listener) });
let mut stream = TcpStream::connect(ADDRESS).await.unwrap();
stream.write_all(b"testing").await.unwrap();
}
async fn host(listener: TcpListener) -> GenericResult<()> {
loop {
let (stream, _) = listener.accept().await?;
println!("new connection");
tokio::spawn(async { process(stream).await.unwrap() });
}
async fn process(mut stream: TcpStream) -> GenericResult<()> {
// Reads from stream
let mut buffer = Vec::with_capacity(128);
let mut position = 0;
loop {
// Read from stream into buffer
let n = stream.read(&mut buffer[position..]).await?;
// Advance position
position += n;
// Print buffer
println!("buffer: {:?}", buffer);
sleep(Duration::from_millis(100)).await;
}
}
}
But when running this I encounter:
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 10061, kind: ConnectionRefused, message: "No connection could be made because the target machine actively refused it." }', src\main.rs:12:56
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Any help would be really appreciated here.
Specifically why my implementation here differs (in functionality) from a working implementation in https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=5b944a8c4703d438d48ef5e556f1fc08
cargo --version --verbose:
cargo 1.60.0 (d1fd9fe2c 2022-03-01)
release: 1.60.0
commit-hash: d1fd9fe2c40a1a56af9132b5c92ab963ac7ae422
commit-date: 2022-03-01
host: x86_64-pc-windows-msvc
libgit2: 1.3.0 (sys:0.13.23 vendored)
libcurl: 7.80.0-DEV (sys:0.4.51+curl-7.80.0 vendored ssl:Schannel)
os: Windows 10.0.22000 (Windows 10 Pro) [64-bit]
With tokio = {version="1.18.0",features=["full"]}
Define the stream before you start listening on it.
use std::{error::Error, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
time::sleep,
};
type GenericResult<T> = Result<T, Box<dyn Error>>;
#[tokio::main]
async fn main() {
const ADDRESS: &str = "127.0.0.1:8080";
let listener = TcpListener::bind(ADDRESS).await.unwrap();
let mut stream = TcpStream::connect(ADDRESS).await.unwrap();
tokio::spawn(async { host(listener) });
stream.write_all(b"testing").await.unwrap();
}
async fn host(listener: TcpListener) -> GenericResult<()> {
loop {
let (stream, _) = listener.accept().await?;
println!("new connection");
tokio::spawn(async { process(stream).await.unwrap() });
}
async fn process(mut stream: TcpStream) -> GenericResult<()> {
// Reads from stream
let mut buffer = Vec::with_capacity(128);
let mut position = 0;
loop {
// Read from stream into buffer
let n = stream.read(&mut buffer[position..]).await?;
// Advance position
position += n;
// Print buffer
println!("buffer: {:?}", buffer);
sleep(Duration::from_millis(100)).await;
}
}
}
In java I would use an ExecutorService that is a thread pool with a fixed size and submit(Callable)s to it and then get() the results.
What is the idiom that matches this in Rust? I could thread::spawn() a bunch of tasks and join() them, but it would create a thread per task, and I want to limit the number of concurrent threads.
In order to make things a little more concrete, here is a rough sketch of some code:
let a4 = thread_pool.spawn(|| svg.compute_stitches("path674"));
let a1 = thread_pool.spawn(|| svg.compute_stitches("path653"));
let a2 = thread_pool.spawn(|| svg.compute_stitches("g659"));
let a3 = thread_pool.spawn(|| svg.compute_stitches("path664"));
let a5 = thread_pool.spawn(|| svg.compute_stitches("path679"));
stitcher.stitch(a1.join());
stitcher.stitch(a2.join());
stitcher.next_color();
stitcher.stitch(a3.join());
stitcher.next_color();
stitcher.stitch(a4.join());
stitcher.next_color();
stitcher.stitch(a5.join());
I have rolled my own solution for the time being. It looks like this:
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, RecvError};
use std::{panic, thread};
pub struct ThreadPool {
channel: spmc::Sender<Mission>,
}
impl ThreadPool {
pub fn new(thread_count: usize) -> ThreadPool {
let (tx, rx) = spmc::channel();
for _ in 0..thread_count {
let rx2 = rx.clone();
thread::spawn(move || Self::work_loop(rx2));
}
ThreadPool { channel: tx }
}
pub fn spawn<F, T: 'static>(&mut self, task: F) -> Answer<T>
where
F: FnOnce() -> T + std::panic::UnwindSafe + Send + 'static,
{
let (tx, rx) = mpsc::channel();
let mission = Mission {
go: Box::new(move || {
let tmp = panic::catch_unwind(task);
tx.send(tmp).unwrap()
}),
};
self.channel.send(mission).unwrap();
Answer { channel: rx }
}
fn work_loop(channel: spmc::Receiver<Mission>) {
while let Ok(mission) = channel.recv() {
(mission.go)();
}
}
}
struct Mission {
go: Box<dyn FnOnce()>,
}
unsafe impl Send for Mission {}
pub struct Answer<T> {
channel: Receiver<std::thread::Result<T>>,
}
impl<T> Answer<T> {
pub fn get(self) -> Result<T, RecvError> {
let tmp = self.channel.recv();
match tmp {
Ok(rval) => match rval {
Ok(rval) => Ok(rval),
Err(explosion) => panic::resume_unwind(explosion),
},
Err(e) => Err(e),
}
}
}
This question already has an answer here:
How to send a pointer to another thread?
(1 answer)
Closed 5 months ago.
I was able to proceed forward to implement my asynchronous udp server. However I have this error showing up twice because my variable data has type *mut u8 which is not Send:
error: future cannot be sent between threads safely
help: within `impl std::future::Future`, the trait `std::marker::Send` is not implemented for `*mut u8`
note: captured value is not `Send`
And the code (MRE):
use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc, task, time}; // 1.4.0
use std::alloc::{alloc, Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;
const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;
/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8, src_ptr: *const u8, len: usize) {
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
}
// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
return 32 - (n - 1).leading_zeros();
}
async fn run_server(socket: UdpSocket) {
let mut missing_indexes: Vec<u16> = Vec::new();
let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut len: usize = 0; // total len of bytes that will be written
let mut layout = MaybeUninit::<Layout>::uninit();
let mut buf = [0u8; MAX_DATA_LENGTH];
let mut start = false;
let (debounce_tx, mut debounce_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
let (network_tx, mut network_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
loop {
// Listen for events
let debouncer = task::spawn(async move {
let duration = Duration::from_millis(3300);
loop {
match time::timeout(duration, debounce_rx.recv()).await {
Ok(Some((size, peer))) => {
eprintln!("Network activity");
}
Ok(None) => {
if start == true {
eprintln!("Debounce finished");
break;
}
}
Err(_) => {
eprintln!("{:?} since network activity", duration);
}
}
}
});
// Listen for network activity
let server = task::spawn({
// async{
let debounce_tx = debounce_tx.clone();
async move {
while let Some((size, peer)) = network_rx.recv().await {
// Received a new packet
debounce_tx.send((size, peer)).await.expect("Unable to talk to debounce");
eprintln!("Received a packet {} from: {}", size, peer);
let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;
if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
start = true;
let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
unsafe {
layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n, mem::align_of::<u8>()));
// /!\ data has type `*mut u8` which is not `Send`
data = alloc(layout.assume_init());
peer_addr.as_mut_ptr().write(peer);
}
let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
missing_indexes = a;
}
missing_indexes[packet_index as usize] = 1;
unsafe {
let dst_ptr = data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
memcpy(dst_ptr, &buf[AG_HEADER], size - AG_HEADER);
};
println!("receiving packet {} from: {}", packet_index, peer);
}
}
});
// Prevent deadlocks
drop(debounce_tx);
match socket.recv_from(&mut buf).await {
Ok((size, src)) => {
network_tx.send((size, src)).await.expect("Unable to talk to network");
}
Err(e) => {
eprintln!("couldn't recieve a datagram: {}", e);
}
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
let socket = UdpSocket::bind(&addr).await?;
println!("Listening on: {}", socket.local_addr()?);
run_server(socket);
Ok(())
}
Since I was converting from synchronous to asynchronous code I know that, potentially, multiple thread would be writing to data, and that is probably why I encounter such error. But I don't know which syntax I could use to "clone" the mut ptr and make it unique for each thread (and same for the buffer).
As suggested by user4815162342 I think the best would be
to make pointer Send by wrapping it in a struct and declaring unsafe impl Send for NewStruct {}.
Any help strongly appreciated!
PS: Full code can be found on my github repository
Short version
Thanks to the comment of user4815162342 I decided to add an implementation for the mut ptr to be able to use it with Send and Sync, which allowed me to solve this part (there are still other issues, but beyond the scope of this question):
pub struct FileBuffer {
data: *mut u8
}
unsafe impl Send for FileBuffer {}
unsafe impl Sync for FileBuffer {}
//let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut fileBuffer: FileBuffer = FileBuffer { data: std::ptr::null_mut() };
Long version
use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc, task, time}; // 1.4.0
use std::alloc::{alloc, Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;
const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;
/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8, src_ptr: *const u8, len: usize) {
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, len);
}
// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
return 32 - (n - 1).leading_zeros();
}
pub struct FileBuffer {
data: *mut u8
}
unsafe impl Send for FileBuffer {}
unsafe impl Sync for FileBuffer {}
async fn run_server(socket: UdpSocket) {
let mut missing_indexes: Vec<u16> = Vec::new();
let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
//let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut fileBuffer: FileBuffer = FileBuffer { data: std::ptr::null_mut() };
let mut len: usize = 0; // total len of bytes that will be written
let mut layout = MaybeUninit::<Layout>::uninit();
let mut buf = [0u8; MAX_DATA_LENGTH];
let mut start = false;
let (debounce_tx, mut debounce_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
let (network_tx, mut network_rx) = mpsc::channel::<(usize, SocketAddr)>(3300);
loop {
// Listen for events
let debouncer = task::spawn(async move {
let duration = Duration::from_millis(3300);
loop {
match time::timeout(duration, debounce_rx.recv()).await {
Ok(Some((size, peer))) => {
eprintln!("Network activity");
}
Ok(None) => {
if start == true {
eprintln!("Debounce finished");
break;
}
}
Err(_) => {
eprintln!("{:?} since network activity", duration);
}
}
}
});
// Listen for network activity
let server = task::spawn({
// async{
let debounce_tx = debounce_tx.clone();
async move {
while let Some((size, peer)) = network_rx.recv().await {
// Received a new packet
debounce_tx.send((size, peer)).await.expect("Unable to talk to debounce");
eprintln!("Received a packet {} from: {}", size, peer);
let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;
if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
start = true;
let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
unsafe {
layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n, mem::align_of::<u8>()));
// /!\ data has type `*mut u8` which is not `Send`
fileBuffer.data = alloc(layout.assume_init());
peer_addr.as_mut_ptr().write(peer);
}
let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
missing_indexes = a;
}
missing_indexes[packet_index as usize] = 1;
unsafe {
let dst_ptr = fileBuffer.data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
memcpy(dst_ptr, &buf[AG_HEADER], size - AG_HEADER);
};
println!("receiving packet {} from: {}", packet_index, peer);
}
}
});
// Prevent deadlocks
drop(debounce_tx);
match socket.recv_from(&mut buf).await {
Ok((size, src)) => {
network_tx.send((size, src)).await.expect("Unable to talk to network");
}
Err(e) => {
eprintln!("couldn't recieve a datagram: {}", e);
}
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
let socket = UdpSocket::bind(&addr).await?;
println!("Listening on: {}", socket.local_addr()?);
run_server(socket);
Ok(())
}
I am attempting to create a simple event loop with Mio. Every event invokes an associated callback depending on the Token of the event. The event loop runs on a separate thread to the rest of the code.
Events can be registered via the register function, however in order to register the event it must add it to the same HashMap of callbacks (and access the same Mio Poll struct) that the event loop is iterating.
The issue is that callbacks themselves may register events, in this case it is impossible to take the mutex as the event loop has the mutex. Is it possible to somehow drop the Mutex in the start() function whilst the callback is being invoked? Although, this does not seem performant.
Is there a better way to handle this in Rust?
use mio::event::Event;
use mio::net::{TcpListener, TcpStream};
use mio::{event, Events, Interest, Poll, Token};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use std::{io, thread};
pub trait HookCb: Send {
fn call(&self, event: &Event);
}
impl<F: Send> HookCb for F
where
F: Fn(&Event),
{
fn call(&self, event: &Event) {
self(event)
}
}
struct EventLoopInner {
handlers: HashMap<Token, Box<dyn HookCb>>,
poll: Poll,
}
pub struct EventLoop {
inner: Mutex<EventLoopInner>,
}
impl EventLoop {
pub fn new() -> io::Result<Self> {
Ok(Self {
inner: Mutex::new(EventLoopInner {
handlers: HashMap::new(),
poll: Poll::new()?,
}),
})
}
pub fn start(&self) -> io::Result<()> {
let mut events = Events::with_capacity(1024);
let inner = &mut *self.inner.lock().unwrap(); // Inner mutex taken
loop {
inner.poll.poll(&mut events, None)?;
for event in events.iter() {
match event.token() {
Token(v) => {
if let Some(cb) = inner.handlers.get(&Token(v)) {
// TODO release the inner mutex before here so that the callback can invoke register
cb.call(event)
}
}
}
}
}
}
pub fn register<S>(
&self,
source: &mut S,
interest: Interest,
cb: impl HookCb + 'static,
) -> io::Result<Token>
where
S: event::Source + std::marker::Send,
{
let mut inner = self.inner.lock().unwrap(); // Cannot get this lock after start() has been invoked
let token = Token(inner.handlers.len());
inner.poll.registry().register(source, token, interest)?;
inner.handlers.insert(token, Box::new(cb));
Ok(token)
}
}
struct ServerConn {
listener: Option<TcpListener>,
connections: HashMap<Token, TcpStream>,
}
struct Server {
eventloop: Arc<EventLoop>,
thread: Option<JoinHandle<()>>,
conn: Arc<Mutex<ServerConn>>,
}
impl Server {
pub fn new() -> Self {
Self {
eventloop: Arc::new(EventLoop::new().unwrap()),
thread: None,
conn: Arc::new(Mutex::new(ServerConn {
listener: None,
connections: HashMap::new(),
})),
}
}
pub fn listen(&mut self, addr: &str) {
{
let mut conn = self.conn.lock().unwrap();
conn.listener = Some(TcpListener::bind(addr.parse().unwrap()).unwrap());
let cb_conn = Arc::clone(&self.conn);
let cb_eventloop = Arc::clone(&self.eventloop);
self.eventloop
.register(
conn.listener.as_mut().unwrap(),
Interest::READABLE,
move |e: &Event| {
Self::accept_cb(e, &cb_conn, &cb_eventloop);
},
)
.unwrap();
} // Unlock conn
let t_eventloop = Arc::clone(&self.eventloop);
self.thread = Some(thread::spawn(move || {
t_eventloop.start().unwrap();
}));
self.thread.take().unwrap().join().unwrap(); // Temp fix to block main thread so application does not exit
}
fn accept_cb(_e: &Event, conn: &Arc<Mutex<ServerConn>>, evloop: &Arc<EventLoop>) {
let mut conn_lock = conn.lock().unwrap();
loop {
let (mut stream, addr) = match conn_lock.listener.as_ref().unwrap().accept() {
Ok((stream, addr)) => (stream, addr),
Err(_) => return,
};
println!("Accepted connection from: {}", addr);
// TODO can this clone be avoided?
let cb_conn = Arc::clone(conn);
let cb_evloop = Arc::clone(evloop);
let token = evloop
.register(
&mut stream,
Interest::READABLE.add(Interest::WRITABLE),
move |e: &Event| {
Self::conn_cb(e, &cb_conn, &cb_evloop);
},
)
.unwrap();
conn_lock.connections.insert(token, stream);
}
}
pub fn conn_cb(e: &Event, conn: &Arc<Mutex<ServerConn>>, _evloop: &Arc<EventLoop>) {
let conn_lock = conn.lock().unwrap();
let mut _connection = conn_lock.connections.get(&e.token()).unwrap();
if e.is_writable() {
// TODO write logic -- connection.write(b"Hello World!\n").unwrap();
// TODO evloop.reregister(&mut connection, event.token(), Interest::READABLE)?
}
if e.is_readable() {
// TODO read logic -- connection.read(&mut received_data);
}
}
}
fn main() {
let mut s = Server::new();
s.listen("127.0.0.1:8000");
}
I implemented the future and made a request of it, but it blocked my curl and the log shows that poll was only invoked once.
Did I implement anything wrong?
use failure::{format_err, Error};
use futures::{future, Async};
use hyper::rt::Future;
use hyper::service::{service_fn, service_fn_ok};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use log::{debug, error, info};
use std::{
sync::{Arc, Mutex},
task::Waker,
thread,
};
pub struct TimerFuture {
shared_state: Arc<Mutex<SharedState>>,
}
struct SharedState {
completed: bool,
resp: String,
}
impl Future for TimerFuture {
type Item = Response<Body>;
type Error = hyper::Error;
fn poll(&mut self) -> futures::Poll<Response<Body>, hyper::Error> {
let mut shared_state = self.shared_state.lock().unwrap();
if shared_state.completed {
return Ok(Async::Ready(Response::new(Body::from(
shared_state.resp.clone(),
))));
} else {
return Ok(Async::NotReady);
}
}
}
impl TimerFuture {
pub fn new(instance: String) -> Self {
let shared_state = Arc::new(Mutex::new(SharedState {
completed: false,
resp: String::new(),
}));
let thread_shared_state = shared_state.clone();
thread::spawn(move || {
let res = match request_health(instance) {
Ok(status) => status.clone(),
Err(err) => {
error!("{:?}", err);
format!("{}", err)
}
};
let mut shared_state = thread_shared_state.lock().unwrap();
shared_state.completed = true;
shared_state.resp = res;
});
TimerFuture { shared_state }
}
}
fn request_health(instance_name: String) -> Result<String, Error> {
std::thread::sleep(std::time::Duration::from_secs(1));
Ok("health".to_string())
}
type BoxFut = Box<dyn Future<Item = Response<Body>, Error = hyper::Error> + Send>;
fn serve_health(req: Request<Body>) -> BoxFut {
let mut response = Response::new(Body::empty());
let path = req.uri().path().to_owned();
match (req.method(), path) {
(&Method::GET, path) => {
return Box::new(TimerFuture::new(path.clone()));
}
_ => *response.status_mut() = StatusCode::NOT_FOUND,
}
Box::new(future::ok(response))
}
fn main() {
let endpoint_addr = "0.0.0.0:8080";
match std::thread::spawn(move || {
let addr = endpoint_addr.parse().unwrap();
info!("Server is running on {}", addr);
hyper::rt::run(
Server::bind(&addr)
.serve(move || service_fn(serve_health))
.map_err(|e| eprintln!("server error: {}", e)),
);
})
.join()
{
Ok(e) => e,
Err(e) => println!("{:?}", e),
}
}
After compile and run this code, a server with port 8080 is running. Call the server with curl and it will block:
curl 127.0.0.1:8080/my-health-scope
Did I implement anything wrong?
Yes, you did not read and follow the documentation for the method you are implementing (emphasis mine):
When a future is not ready yet, the Async::NotReady value will be returned. In this situation the future will also register interest of the current task in the value being produced. This is done by calling task::park to retrieve a handle to the current Task. When the future is then ready to make progress (e.g. it should be polled again) the unpark method is called on the Task.
As a minimal, reproducible example, let's use this:
use futures::{future::Future, Async};
use std::{
mem,
sync::{Arc, Mutex},
thread,
time::Duration,
};
pub struct Timer {
data: Arc<Mutex<String>>,
}
impl Timer {
pub fn new(instance: String) -> Self {
let data = Arc::new(Mutex::new(String::new()));
thread::spawn({
let data = data.clone();
move || {
thread::sleep(Duration::from_secs(1));
*data.lock().unwrap() = instance;
}
});
Timer { data }
}
}
impl Future for Timer {
type Item = String;
type Error = ();
fn poll(&mut self) -> futures::Poll<Self::Item, Self::Error> {
let mut data = self.data.lock().unwrap();
eprintln!("poll was called");
if data.is_empty() {
Ok(Async::NotReady)
} else {
let data = mem::replace(&mut *data, String::new());
Ok(Async::Ready(data))
}
}
}
fn main() {
let v = Timer::new("Some text".into()).wait();
println!("{:?}", v);
}
It only prints out "poll was called" once.
You can call task::current (previously task::park) in the implementation of Future::poll, save the resulting value, then use the value with Task::notify (previously Task::unpark) whenever the future may be polled again:
use futures::{
future::Future,
task::{self, Task},
Async,
};
use std::{
mem,
sync::{Arc, Mutex},
thread,
time::Duration,
};
pub struct Timer {
data: Arc<Mutex<(String, Option<Task>)>>,
}
impl Timer {
pub fn new(instance: String) -> Self {
let data = Arc::new(Mutex::new((String::new(), None)));
let me = Timer { data };
thread::spawn({
let data = me.data.clone();
move || {
thread::sleep(Duration::from_secs(1));
let mut data = data.lock().unwrap();
data.0 = instance;
if let Some(task) = data.1.take() {
task.notify();
}
}
});
me
}
}
impl Future for Timer {
type Item = String;
type Error = ();
fn poll(&mut self) -> futures::Poll<Self::Item, Self::Error> {
let mut data = self.data.lock().unwrap();
eprintln!("poll was called");
if data.0.is_empty() {
let v = task::current();
data.1 = Some(v);
Ok(Async::NotReady)
} else {
let data = mem::replace(&mut data.0, String::new());
Ok(Async::Ready(data))
}
}
}
fn main() {
let v = Timer::new("Some text".into()).wait();
println!("{:?}", v);
}
See also:
Why does Future::select choose the future with a longer sleep period first?
Why is `Future::poll` not called repeatedly after returning `NotReady`?
What is the best approach to encapsulate blocking I/O in future-rs?