How to get callback/update when using tokio::io::copy? - rust

tokio::io::copy doesn't provide progress update callbacks. When copy takes much time to complete, it is good to report progress of the copy operation. Is there a way to get status/progress when copy happens?

The tokio::io::copy function simply works off the AsyncRead and AsyncWrite traits. I'm not sure if there's anything pre-built, but its not too bad to write your own adapter to display progress:
use std::io::Result;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use pin_project::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::time::{interval, Interval};
#[pin_project]
pub struct ProgressReadAdapter<R: AsyncRead> {
#[pin]
inner: R,
interval: Interval,
interval_bytes: usize,
}
impl<R: AsyncRead> ProgressReadAdapter<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
interval: interval(Duration::from_millis(100)),
interval_bytes: 0,
}
}
}
impl<R: AsyncRead> AsyncRead for ProgressReadAdapter<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let this = self.project();
let before = buf.filled().len();
let result = this.inner.poll_read(cx, buf);
let after = buf.filled().len();
*this.interval_bytes += after - before;
match this.interval.poll_tick(cx) {
Poll::Pending => {}
Poll::Ready(_) => {
println!("reading at {} bytes per second", *this.interval_bytes * 10);
*this.interval_bytes = 0;
}
};
result
}
}
Which can then be used like this:
use tokio::fs::File;
#[tokio::main]
async fn main() -> Result<()> {
let mut file1 = File::open("foo.bin").await?;
let mut file2 = File::create("bar.bin").await?;
let mut file1 = ProgressReadAdapter::new(file1);
tokio::io::copy(&mut file1, &mut file2).await?;
Ok(())
}
reading at 38338560 bytes per second
reading at 315146240 bytes per second
reading at 422625280 bytes per second
reading at 497827840 bytes per second
reading at 428605440 bytes per second
You could probably make this look nicer by hooking it up to indicatif or at least make it more configurable, but I'll leave that as an exercise to the reader.

Related

how to implement trait futures::stream::Stream?

So I'm getting a Response from the reqwest crate and passing it to a HttpResponseBuilder from the actix_web create. However I've tried and failed to understand how to implement the Stream trait from the futures create on a custom struct to act as a middleman and copy the contents down to a file.
I've tried doing this so far, but I'm not sure what to put inside that poll_next function to make it all work.
struct FileCache {
stream: Box<dyn futures::Stream<Item = reqwest::Result<bytes::Bytes>>>,
}
impl FileCache {
fn new(stream: Box<dyn futures::Stream<Item = reqwest::Result<bytes::Bytes>>>) -> Self {
FileCache { stream }
}
}
impl Stream for FileCache {
type Item = reqwest::Result<bytes::Bytes>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
}
}
This is possible but requires you to understand what pinning is and how to use it safely.
Basically, we just need to proxy to self.stream.poll_next(), but this method accepts Pin<&mut Self> (as you can see in your own implementation). Storing the box as Pin<Box<T>> instead of Box<T> will give us a way to obtain this Pin relatively easily, without requiring unsafe. Making this change is straightforward, since there is a From implementation allowing conversion of Box<T> to Pin<Box<T>> directly:
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Stream;
struct FileCache {
stream: Pin<Box<dyn Stream<Item = reqwest::Result<bytes::Bytes>>>>,
}
impl FileCache {
fn new(stream: Box<dyn Stream<Item = reqwest::Result<bytes::Bytes>>>) -> FileCache {
FileCache { stream: stream.into() }
}
}
Now we have to figure out how to go from Pin<&mut FileCache> to Pin<&mut dyn Stream<...>>. The correct incantation here is self.get_mut().stream.as_mut():
impl Stream for FileCache {
type Item = reqwest::Result<bytes::Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut().stream.as_mut().poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(v) => {
// Do what you need to do with v here.
Poll::Ready(v)
}
}
}
}
The catch is that poll_next isn't async and so you can't asynchronously wait for whatever you're doing with v. bytes::Bytes is atomically-refcounted, though, so you could clone the inner bytes::Bytes value and spawn a separate task on your executor, which is probably what you want to do anyway so that whoever is waiting for FileCache doesn't have to wait for that task to complete before using the data. So you'd do something like:
Poll::Ready(v) => {
if let Some(Ok(ref bytes)) = &v {
let bytes = bytes.clone();
spawn_new_task(async move {
// Do something with bytes
});
}
Poll::Ready(v)
}
Where spawn_new_task() is the function your executor provides, e.g. tokio::spawn().
Now that we can see what we're doing here, we can simplify this down and eliminate the match by pushing Poll::Ready into our pattern, and unconditionally returning whatever the inner poll_next() call did:
impl Stream for FileCache {
type Item = reqwest::Result<bytes::Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let r = self.get_mut().stream.as_mut().poll_next(cx);
if let Poll::Ready(Some(Ok(ref bytes))) = &r {
let bytes = bytes.clone();
spawn_new_task(async move {
// Do something with bytes
});
}
r
}
}

try_lock on futures::lock::Mutex outside of async?

I'm trying to implement Async read for a struct that has a futures::lock::Mutex:
pub struct SmolSocket<'a> {
stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}
impl<'a> AsyncRead for SmolSocket<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>
) -> Poll<std::io::Result<()>> {
block_on(self.stack).read(...)
}
}
The problem is that, since poll_read is not async, I cannot call await. But I also don't want to, as it'd block. I could call try_lock to try and if not, I'd register a Waker to be called by SmolSocket in the future.
Since I cannot do that either because it's not async, is there a version of block_on that does the same as try_lock for futures::lock::Mutex outside of async?
You probably mean to poll the MutexLockFuture instead, this can for example be done with the core::task::ready! macro, which desugars as following:
let num = match fut.poll(cx) {
Poll::Ready(t) => t,
Poll::Pending => return Poll::Pending,
};
To poll a future, you also need to pin it (ensure it doesn't get moved). This can be done on the stack with tokio::pin!, or Pin::new if the type is already Unpin (MutexLockFuture is), or by moving onto the heap with Box::pin.
Below is a runnable example.
⚠️ KEEP READING TO SEE WHY YOU DON'T WANT TO DO THIS!
#![feature(ready_macro)]
use core::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct SmolStackWithDevice<'a> {
counter: usize,
data: &'a [u8],
}
impl<'a> AsyncRead for SmolStackWithDevice<'a> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.counter % 2 == 0 {
self.counter += 1;
cx.waker().wake_by_ref();
println!("read nothing");
return Poll::Pending;
}
buf.put_slice(&[self.data[self.counter / 2]]);
self.counter += 1;
println!("read something");
Poll::Ready(Ok(()))
}
}
pub struct SmolSocket<'a> {
stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}
impl<'a> AsyncRead for SmolSocket<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut lock_fut = self.stack.lock();
let pinned_lock_fut = Pin::new(&mut lock_fut);
let mut guard = ready!(pinned_lock_fut.poll(cx));
println!("acquired lock");
let pinned_inner = Pin::new(&mut *guard);
pinned_inner.poll_read(cx, buf)
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let data = b"HORSE";
let mut buf = [0; 5];
let mut s = SmolSocket {
stack: Arc::new(
SmolStackWithDevice {
counter: 0,
data: &data[..],
}
.into(),
),
};
s.read_exact(&mut buf).await.unwrap();
println!("{}", String::from_utf8_lossy(&buf));
}
Look at it go! (in Rust Playground)
⚠️ KEEP READING TO SEE WHY YOU DON'T WANT TO DO THIS!
So, what is the problem?
Well, as you can see from the output, whenever we succeed in acquiring the lock, but the underlying source is not ready to read, or only gives us a small read, we drop the lock, and on the next poll we will have to acquire it again.
This is a good point to remember that async flavors of Mutex are only recommended over std or parking_lot when it is expected that the Guard from a successful locking will be held across an await, or explicitly stored in a Future data structure.
We are not doing that here, we are only ever exercising the fast path equivalent to Mutex::try_lock, because whenever the lock is not immediately available, we drop the MutexLockFuture instead of waiting to be waked to poll it again.
However, storing the lock in the data structure would make it easy to accidentally deadlock. So a good design might be creating an awkward-to-store(borrowing) AsyncRead adapter that wraps the lock:
pub struct SmolSocket<'a> {
stack: Arc<futures::lock::Mutex<SmolStackWithDevice<'a>>>,
}
impl<'a> SmolSocket<'a> {
fn read(&'a self) -> Reader<'a> {
Reader::Locking(self.stack.lock())
}
}
pub enum Reader<'a> {
Locking(futures::lock::MutexLockFuture<'a, SmolStackWithDevice<'a>>),
Locked(futures::lock::MutexGuard<'a, SmolStackWithDevice<'a>>),
}
impl<'a> AsyncRead for Reader<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
match this {
Reader::Locking(f) => {
*this = Reader::Locked(ready!(Pin::new(f).poll(cx)));
println!("acquired lock");
Pin::new(this).poll_read(cx, buf)
}
Reader::Locked(l) => Pin::new(&mut **l).poll_read(cx, buf),
}
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
let data = b"HORSE";
let mut buf = [0; 5];
let s = SmolSocket {
stack: Arc::new(
SmolStackWithDevice {
counter: 0,
data: &data[..],
}
.into(),
),
};
s.read().read_exact(&mut buf).await.unwrap();
println!("{}", String::from_utf8_lossy(&buf));
}
Look at it go! (executable Playground link)
This works out, because both the LockFuture and our SmolStackWithDevice are Unpin (non-self-referential) and so we don't have to guarantee we aren't moving them.
In a general case, for example if your SmolStackWithDevice is not Unpin, you'd have to project the Pin like this:
unsafe {
let this = self.get_unchecked_mut();
match this {
Reader::Locking(f) => {
*this = Reader::Locked(ready!(Pin::new_unchecked(f).poll(cx)));
println!("acquired lock");
Pin::new_unchecked(this).poll_read(cx, buf)
}
Reader::Locked(l) => Pin::new_unchecked(&mut **l).poll_read(cx, buf),
}
}
Not sure how to encapsulate the unsafety, pin_project isn't enough here, as we also need to dereference the guard.
But this only acquires the lock once, and drops it when the Reader is dropped, so, great success.
You can also see that it doesn't deadlock if you do
let mut r1 = s.read();
let mut r2 = s.read();
r1.read_exact(&mut buf[..3]).await.unwrap();
drop(r1);
r2.read_exact(&mut buf[3..]).await.unwrap();
println!("{}", String::from_utf8_lossy(&buf));
This is only possible because we deferred locking until polling.

Creating a stream of values while calling async fns?

I can't figure out how to provide a Stream where I await async functions to get the data needed for the values of the stream.
I've tried to implement the the Stream trait directly, but I run into issues because I'd like to use async things like awaiting, the compiler does not want me to call async functions.
I assume that I'm missing some background on what the goal of Stream is and I'm just attacking this incorrectly and perhaps I shouldn't be looking at Stream at all, but I don't know where else to turn. I've seen the other functions in the stream module that could be useful, but I'm unsure how I could store any state and use these functions.
As a slightly simplified version of my actual goal, I want to provide a stream of 64-byte Vecs from an AsyncRead object (i.e. TCP stream), but also store a little state inside whatever logic ends up producing values for the stream, in this example, a counter.
pub struct Receiver<T>
where
T: AsyncRead + Unpin,
{
readme: T,
num: u64,
}
// ..code for a simple `new() -> Self` function..
impl<T> Stream for Receiver<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buf: [u8; 64] = [0; 64];
match self.readme.read_exact(&mut buf).await {
Ok(()) => {
self.num += 1;
Poll::Ready(Some(Ok(buf.to_vec())))
}
Err(e) => Poll::Ready(Some(Err(e))),
}
}
}
This fails to build, saying
error[E0728]: `await` is only allowed inside `async` functions and blocks
I'm using rustc 1.36.0-nightly (d35181ad8 2019-05-20) and my Cargo.toml looks like this:
[dependencies]
futures-preview = { version = "0.3.0-alpha.16", features = ["compat", "io-compat"] }
pin-utils = "0.1.0-alpha.4"
Answer copy/pasted from the reddit post by user Matthias247:
It's unfortunately not possible at the moment - Streams have to be implemented by hand and can not utilize async fn. Whether it's possible to change this in the future is unclear.
You can work around it by defining a different Stream trait which makes use of Futures like:
trait Stream<T> {
type NextFuture: Future<Output=T>;
fn next(&mut self) -> Self::NextFuture;
}
This article and this futures-rs issue have more information around it.
You can do it with gen-stream crate:
#![feature(generators, generator_trait, gen_future)]
use {
futures::prelude::*,
gen_stream::{gen_await, GenTryStream},
pin_utils::unsafe_pinned,
std::{
io,
marker::PhantomData,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
},
};
pub type Inner = Pin<Box<dyn Stream<Item = Result<Vec<u8>, io::Error>> + Send>>;
pub struct Receiver<T> {
inner: Inner,
pub num: Arc<AtomicU64>,
_marker: PhantomData<T>,
}
impl<T> Receiver<T> {
unsafe_pinned!(inner: Inner);
}
impl<T> From<T> for Receiver<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
fn from(mut readme: T) -> Self {
let num = Arc::new(AtomicU64::new(0));
Self {
inner: Box::pin(GenTryStream::from({
let num = num.clone();
static move || loop {
let mut buf: [u8; 64] = [0; 64];
match gen_await!(readme.read_exact(&mut buf)) {
Ok(()) => {
num.fetch_add(1, Ordering::Relaxed);
yield Poll::Ready(buf.to_vec())
}
Err(e) => return Err(e),
}
}
})),
num,
_marker: PhantomData,
}
}
}
impl<T> Stream for Receiver<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner().poll_next(cx)
}
}

How do I apply a limit to the number of bytes read by futures::Stream::concat2?

An answer to How do I read the entire body of a Tokio-based Hyper request? suggests:
you may wish to establish some kind of cap on the number of bytes read [when using futures::Stream::concat2]
How can I actually achieve this? For example, here's some code that mimics a malicious user who is sending my service an infinite amount of data:
extern crate futures; // 0.1.25
use futures::{prelude::*, stream};
fn some_bytes() -> impl Stream<Item = Vec<u8>, Error = ()> {
stream::repeat(b"0123456789ABCDEF".to_vec())
}
fn limited() -> impl Future<Item = Vec<u8>, Error = ()> {
some_bytes().concat2()
}
fn main() {
let v = limited().wait().unwrap();
println!("{}", v.len());
}
One solution is to create a stream combinator that ends the stream once some threshold of bytes has passed. Here's one possible implementation:
struct TakeBytes<S> {
inner: S,
seen: usize,
limit: usize,
}
impl<S> Stream for TakeBytes<S>
where
S: Stream<Item = Vec<u8>>,
{
type Item = Vec<u8>;
type Error = S::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.seen >= self.limit {
return Ok(Async::Ready(None)); // Stream is over
}
let inner = self.inner.poll();
if let Ok(Async::Ready(Some(ref v))) = inner {
self.seen += v.len();
}
inner
}
}
trait TakeBytesExt: Sized {
fn take_bytes(self, limit: usize) -> TakeBytes<Self>;
}
impl<S> TakeBytesExt for S
where
S: Stream<Item = Vec<u8>>,
{
fn take_bytes(self, limit: usize) -> TakeBytes<Self> {
TakeBytes {
inner: self,
limit,
seen: 0,
}
}
}
This can then be chained onto the stream before concat2:
fn limited() -> impl Future<Item = Vec<u8>, Error = ()> {
some_bytes().take_bytes(999).concat2()
}
This implementation has caveats:
it only works for Vec<u8>. You can introduce generics to make it more broadly applicable, of course.
it allows for more bytes than the limit to come in, it just stops the stream after that point. Those types of decisions are application-dependent.
Another thing to keep in mind is that you want to attempt to tackle this problem as low as you can — if the source of the data has already allocated a gigabyte of memory, placing a limit won't help as much.

How to keep track of how many bytes written when using 'std::io::Write'?

When writing to a binary file-format, its useful to be able to check how many bytes have been written (for alignment for example), or just to ensure nested functions wrote the correct amount of data.
Is there a way to inspect std::io::Write to know how much has been written? If not, what would be a good approach to wrap the writer so it could track how many bytes have been written?
Write has two required methods: write and flush. Since write already returns the number of bytes written, you just track that:
use std::io::{self, Write};
struct ByteCounter<W> {
inner: W,
count: usize,
}
impl<W> ByteCounter<W>
where W: Write
{
fn new(inner: W) -> Self {
ByteCounter {
inner: inner,
count: 0,
}
}
fn into_inner(self) -> W {
self.inner
}
fn bytes_written(&self) -> usize {
self.count
}
}
impl<W> Write for ByteCounter<W>
where W: Write
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let res = self.inner.write(buf);
if let Ok(size) = res {
self.count += size
}
res
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
fn main() {
let out = std::io::stdout();
let mut out = ByteCounter::new(out);
writeln!(&mut out, "Hello, world! {}", 42).unwrap();
println!("Wrote {} bytes", out.bytes_written());
}
It's important to not delegate write_all or write_fmt because these do not return the count of bytes. Delegating them would allow bytes to be written and not tracked.
If the type you write to implements std::io::Seek, you can use seek to get the current position:
pos = f.seek(SeekFrom::Current(0))?;
Seek is implemented by std::fs::File (and std::io::BufWriter if the wrapped type implements Seek too).
So the function signature:
use ::std::io::{Write, Seek, SeekFrom, Error};
fn my_write<W: Write>(f: &mut W) -> Result<(), Error> { ... }
Needs to have the Seek trait added:
fn my_write<W: Write + Seek>(f: &mut W) -> Result<(), Error> { ... }

Resources