Skip to content

Commit ae817ca

Browse files
committed
don't poll the reader again after eof while waiting for the writer to flush
1 parent 97a2fbe commit ae817ca

File tree

2 files changed

+100
-4
lines changed

2 files changed

+100
-4
lines changed

src/io/copy.rs

+28-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ use crate::io::{self, BufRead, BufReader, Read, Write};
77
use crate::task::{Context, Poll};
88
use crate::utils::Context as _;
99

10+
// Note: There are two otherwise-identical implementations of this
11+
// function because unstable has removed the `?Sized` bound for the
12+
// reader and writer and accepts `R` and `W` instead of `&mut R` and
13+
// `&mut W`. If making a change to either of the implementations,
14+
// ensure that you copy it into the other.
15+
1016
/// Copies the entire contents of a reader into a writer.
1117
///
1218
/// This function will continuously read data from `reader` and then
@@ -57,6 +63,7 @@ where
5763
#[pin]
5864
writer: W,
5965
amt: u64,
66+
reader_eof: bool
6067
}
6168
}
6269

@@ -69,13 +76,20 @@ where
6976

7077
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
7178
let mut this = self.project();
79+
7280
loop {
73-
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
74-
if buffer.is_empty() {
81+
if *this.reader_eof {
7582
futures_core::ready!(this.writer.as_mut().poll_flush(cx))?;
7683
return Poll::Ready(Ok(*this.amt));
7784
}
7885

86+
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
87+
88+
if buffer.is_empty() {
89+
*this.reader_eof = true;
90+
continue;
91+
}
92+
7993
let i = futures_core::ready!(this.writer.as_mut().poll_write(cx, buffer))?;
8094
if i == 0 {
8195
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -89,6 +103,7 @@ where
89103
let future = CopyFuture {
90104
reader: BufReader::new(reader),
91105
writer,
106+
reader_eof: false,
92107
amt: 0,
93108
};
94109
future.await.context(|| String::from("io::copy failed"))
@@ -144,6 +159,7 @@ where
144159
#[pin]
145160
writer: W,
146161
amt: u64,
162+
reader_eof: bool
147163
}
148164
}
149165

@@ -156,13 +172,20 @@ where
156172

157173
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158174
let mut this = self.project();
175+
159176
loop {
160-
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
161-
if buffer.is_empty() {
177+
if *this.reader_eof {
162178
futures_core::ready!(this.writer.as_mut().poll_flush(cx))?;
163179
return Poll::Ready(Ok(*this.amt));
164180
}
165181

182+
let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?;
183+
184+
if buffer.is_empty() {
185+
*this.reader_eof = true;
186+
continue;
187+
}
188+
166189
let i = futures_core::ready!(this.writer.as_mut().poll_write(cx, buffer))?;
167190
if i == 0 {
168191
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -176,6 +199,7 @@ where
176199
let future = CopyFuture {
177200
reader: BufReader::new(reader),
178201
writer,
202+
reader_eof: false,
179203
amt: 0,
180204
};
181205
future.await.context(|| String::from("io::copy failed"))

tests/io_copy.rs

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use std::{
2+
io::Result,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
};
6+
7+
struct ReaderThatPanicsAfterEof {
8+
read_count: usize,
9+
has_sent_eof: bool,
10+
max_read: usize,
11+
}
12+
13+
impl async_std::io::Read for ReaderThatPanicsAfterEof {
14+
fn poll_read(
15+
mut self: Pin<&mut Self>,
16+
_cx: &mut Context<'_>,
17+
buf: &mut [u8],
18+
) -> Poll<Result<usize>> {
19+
if self.has_sent_eof {
20+
panic!("this should be unreachable because we should not poll after eof (Ready(Ok(0)))")
21+
} else if self.read_count >= self.max_read {
22+
self.has_sent_eof = true;
23+
Poll::Ready(Ok(0))
24+
} else {
25+
self.read_count += 1;
26+
Poll::Ready(Ok(buf.len()))
27+
}
28+
}
29+
}
30+
31+
struct WriterThatTakesAWhileToFlush {
32+
max_flush: usize,
33+
flush_count: usize,
34+
}
35+
36+
impl async_std::io::Write for WriterThatTakesAWhileToFlush {
37+
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
38+
Poll::Ready(Ok(buf.len()))
39+
}
40+
41+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
42+
self.flush_count += 1;
43+
if self.flush_count >= self.max_flush {
44+
Poll::Ready(Ok(()))
45+
} else {
46+
cx.waker().wake_by_ref();
47+
Poll::Pending
48+
}
49+
}
50+
51+
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
52+
Poll::Ready(Ok(()))
53+
}
54+
}
55+
56+
#[test]
57+
fn io_copy_does_not_poll_after_eof() {
58+
async_std::task::block_on(async {
59+
let mut reader = ReaderThatPanicsAfterEof {
60+
has_sent_eof: false,
61+
max_read: 10,
62+
read_count: 0,
63+
};
64+
65+
let mut writer = WriterThatTakesAWhileToFlush {
66+
flush_count: 0,
67+
max_flush: 10,
68+
};
69+
70+
assert!(async_std::io::copy(&mut reader, &mut writer).await.is_ok());
71+
})
72+
}

0 commit comments

Comments
 (0)