Skip to content

Commit eaef62c

Browse files
authored
Merge pull request #451 from sfackler/less-copy-copies
Avoid copies in copy_in
2 parents bcb4ca0 + 9dbeb84 commit eaef62c

File tree

10 files changed

+180
-75
lines changed

10 files changed

+180
-75
lines changed

postgres-protocol/src/message/frontend.rs

+36
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#![allow(missing_docs)]
33

44
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
5+
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
6+
use std::convert::TryFrom;
57
use std::error::Error;
68
use std::io;
79
use std::marker;
@@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
263265
})
264266
}
265267

268+
pub struct CopyData<T> {
269+
buf: T,
270+
len: i32,
271+
}
272+
273+
impl<T> CopyData<T>
274+
where
275+
T: Buf,
276+
{
277+
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
278+
where
279+
U: IntoBuf<Buf = T>,
280+
{
281+
let buf = buf.into_buf();
282+
283+
let len = buf
284+
.remaining()
285+
.checked_add(4)
286+
.and_then(|l| i32::try_from(l).ok())
287+
.ok_or_else(|| {
288+
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
289+
})?;
290+
291+
Ok(CopyData { buf, len })
292+
}
293+
294+
pub fn write(self, out: &mut BytesMut) {
295+
out.reserve(self.len as usize + 1);
296+
out.put_u8(b'd');
297+
out.put_i32_be(self.len);
298+
out.put(self.buf);
299+
}
300+
}
301+
266302
#[inline]
267303
pub fn copy_done(buf: &mut Vec<u8>) {
268304
buf.push(b'c');

tokio-postgres/src/impls.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
170170
where
171171
S: Stream,
172172
S::Item: IntoBuf,
173-
<S::Item as IntoBuf>::Buf: Send,
173+
<S::Item as IntoBuf>::Buf: 'static + Send,
174174
S::Error: Into<Box<dyn error::Error + Sync + Send>>;
175175

176176
impl<S> Future for CopyIn<S>

tokio-postgres/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ impl Client {
242242
where
243243
S: Stream,
244244
S::Item: IntoBuf,
245-
<S::Item as IntoBuf>::Buf: Send,
245+
<S::Item as IntoBuf>::Buf: 'static + Send,
246246
// FIXME error type?
247247
S::Error: Into<Box<dyn StdError + Sync + Send>>,
248248
{

tokio-postgres/src/proto/client.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
1111
use tokio_io::{AsyncRead, AsyncWrite};
1212

1313
use crate::proto::bind::BindFuture;
14+
use crate::proto::codec::FrontendMessage;
1415
use crate::proto::connection::{Request, RequestMessages};
1516
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
1617
use crate::proto::copy_out::CopyOutStream;
@@ -185,8 +186,12 @@ impl Client {
185186
if let Ok(ref mut buf) = buf {
186187
frontend::sync(buf);
187188
}
188-
let pending =
189-
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
189+
let pending = PendingRequest(buf.map(|m| {
190+
(
191+
RequestMessages::Single(FrontendMessage::Raw(m)),
192+
self.0.idle.guard(),
193+
)
194+
}));
190195
BindFuture::new(self.clone(), pending, name, statement.clone())
191196
}
192197

@@ -208,12 +213,12 @@ impl Client {
208213
where
209214
S: Stream,
210215
S::Item: IntoBuf,
211-
<S::Item as IntoBuf>::Buf: Send,
216+
<S::Item as IntoBuf>::Buf: 'static + Send,
212217
S::Error: Into<Box<dyn StdError + Sync + Send>>,
213218
{
214219
let (mut sender, receiver) = mpsc::channel(1);
215220
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
216-
match sender.start_send(CopyMessage { data, done: false }) {
221+
match sender.start_send(CopyMessage::Message(data)) {
217222
Ok(AsyncSink::Ready) => {}
218223
_ => unreachable!("channel should have capacity"),
219224
}
@@ -278,7 +283,7 @@ impl Client {
278283
frontend::sync(&mut buf);
279284
let (sender, _) = mpsc::channel(0);
280285
let _ = self.0.sender.unbounded_send(Request {
281-
messages: RequestMessages::Single(buf),
286+
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
282287
sender,
283288
idle: None,
284289
});
@@ -326,20 +331,23 @@ impl Client {
326331
&self,
327332
statement: &Statement,
328333
params: &[&dyn ToSql],
329-
) -> Result<Vec<u8>, Error> {
334+
) -> Result<FrontendMessage, Error> {
330335
let mut buf = self.bind_message(statement, "", params)?;
331336
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
332337
frontend::sync(&mut buf);
333-
Ok(buf)
338+
Ok(FrontendMessage::Raw(buf))
334339
}
335340

336341
fn pending<F>(&self, messages: F) -> PendingRequest
337342
where
338343
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
339344
{
340345
let mut buf = vec![];
341-
PendingRequest(
342-
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
343-
)
346+
PendingRequest(messages(&mut buf).map(|()| {
347+
(
348+
RequestMessages::Single(FrontendMessage::Raw(buf)),
349+
self.0.idle.guard(),
350+
)
351+
}))
344352
}
345353
}

tokio-postgres/src/proto/codec.rs

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
1-
use bytes::BytesMut;
1+
use bytes::{Buf, BytesMut};
22
use postgres_protocol::message::backend;
3+
use postgres_protocol::message::frontend::CopyData;
34
use std::io;
45
use tokio_codec::{Decoder, Encoder};
56

7+
pub enum FrontendMessage {
8+
Raw(Vec<u8>),
9+
CopyData(CopyData<Box<dyn Buf + Send>>),
10+
}
11+
612
pub struct PostgresCodec;
713

814
impl Encoder for PostgresCodec {
9-
type Item = Vec<u8>;
15+
type Item = FrontendMessage;
1016
type Error = io::Error;
1117

12-
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
13-
dst.extend_from_slice(&item);
18+
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
19+
match item {
20+
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
21+
FrontendMessage::CopyData(data) => data.write(dst),
22+
}
23+
1424
Ok(())
1525
}
1626
}

tokio-postgres/src/proto/connect_raw.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::collections::HashMap;
1111
use tokio_codec::Framed;
1212
use tokio_io::{AsyncRead, AsyncWrite};
1313

14-
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
14+
use crate::proto::{Client, Connection, FrontendMessage, MaybeTlsStream, PostgresCodec, TlsFuture};
1515
use crate::tls::ChannelBinding;
1616
use crate::{Config, Error, TlsConnect};
1717

@@ -111,7 +111,7 @@ where
111111
let stream = Framed::new(stream, PostgresCodec);
112112

113113
transition!(SendingStartup {
114-
future: stream.send(buf),
114+
future: stream.send(FrontendMessage::Raw(buf)),
115115
config: state.config,
116116
idx: state.idx,
117117
channel_binding,
@@ -156,7 +156,7 @@ where
156156
let mut buf = vec![];
157157
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
158158
transition!(SendingPassword {
159-
future: state.stream.send(buf),
159+
future: state.stream.send(FrontendMessage::Raw(buf)),
160160
config: state.config,
161161
idx: state.idx,
162162
})
@@ -178,7 +178,7 @@ where
178178
let mut buf = vec![];
179179
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
180180
transition!(SendingPassword {
181-
future: state.stream.send(buf),
181+
future: state.stream.send(FrontendMessage::Raw(buf)),
182182
config: state.config,
183183
idx: state.idx,
184184
})
@@ -235,7 +235,7 @@ where
235235
.map_err(Error::encode)?;
236236

237237
transition!(SendingSasl {
238-
future: state.stream.send(buf),
238+
future: state.stream.send(FrontendMessage::Raw(buf)),
239239
scram,
240240
config: state.config,
241241
idx: state.idx,
@@ -293,7 +293,7 @@ where
293293
let mut buf = vec![];
294294
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
295295
transition!(SendingSasl {
296-
future: state.stream.send(buf),
296+
future: state.stream.send(FrontendMessage::Raw(buf)),
297297
scram: state.scram,
298298
config: state.config,
299299
idx: state.idx,

tokio-postgres/src/proto/connection.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ use std::io;
88
use tokio_codec::Framed;
99
use tokio_io::{AsyncRead, AsyncWrite};
1010

11-
use crate::proto::codec::PostgresCodec;
11+
use crate::proto::codec::{FrontendMessage, PostgresCodec};
1212
use crate::proto::copy_in::CopyInReceiver;
1313
use crate::proto::idle::IdleGuard;
1414
use crate::{AsyncMessage, Notification};
1515
use crate::{DbError, Error};
1616

1717
pub enum RequestMessages {
18-
Single(Vec<u8>),
18+
Single(FrontendMessage),
1919
CopyIn {
2020
receiver: CopyInReceiver,
21-
pending_message: Option<Vec<u8>>,
21+
pending_message: Option<FrontendMessage>,
2222
},
2323
}
2424

@@ -188,7 +188,7 @@ where
188188
self.state = State::Terminating;
189189
let mut request = vec![];
190190
frontend::terminate(&mut request);
191-
RequestMessages::Single(request)
191+
RequestMessages::Single(FrontendMessage::Raw(request))
192192
}
193193
Async::Ready(None) => {
194194
trace!(

0 commit comments

Comments
 (0)