Skip to content

Commit 2c204b3

Browse files
authored
Merge pull request #452 from sfackler/block-response
Send response messages in blocks
2 parents eaef62c + 2d2a5de commit 2c204b3

File tree

15 files changed

+324
-102
lines changed

15 files changed

+324
-102
lines changed

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ members = [
77
"tokio-postgres-native-tls",
88
"tokio-postgres-openssl",
99
]
10+
11+
[profile.release]
12+
debug = 2

postgres-protocol/src/message/backend.rs

+82-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![allow(missing_docs)]
22

3-
use byteorder::{BigEndian, ReadBytesExt};
3+
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
44
use bytes::{Bytes, BytesMut};
55
use fallible_iterator::FallibleIterator;
66
use memchr::memchr;
@@ -11,6 +11,66 @@ use std::str;
1111

1212
use crate::Oid;
1313

14+
pub const PARSE_COMPLETE_TAG: u8 = b'1';
15+
pub const BIND_COMPLETE_TAG: u8 = b'2';
16+
pub const CLOSE_COMPLETE_TAG: u8 = b'3';
17+
pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A';
18+
pub const COPY_DONE_TAG: u8 = b'c';
19+
pub const COMMAND_COMPLETE_TAG: u8 = b'C';
20+
pub const COPY_DATA_TAG: u8 = b'd';
21+
pub const DATA_ROW_TAG: u8 = b'D';
22+
pub const ERROR_RESPONSE_TAG: u8 = b'E';
23+
pub const COPY_IN_RESPONSE_TAG: u8 = b'G';
24+
pub const COPY_OUT_RESPONSE_TAG: u8 = b'H';
25+
pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
26+
pub const BACKEND_KEY_DATA_TAG: u8 = b'K';
27+
pub const NO_DATA_TAG: u8 = b'n';
28+
pub const NOTICE_RESPONSE_TAG: u8 = b'N';
29+
pub const AUTHENTICATION_TAG: u8 = b'R';
30+
pub const PORTAL_SUSPENDED_TAG: u8 = b's';
31+
pub const PARAMETER_STATUS_TAG: u8 = b'S';
32+
pub const PARAMETER_DESCRIPTION_TAG: u8 = b't';
33+
pub const ROW_DESCRIPTION_TAG: u8 = b'T';
34+
pub const READY_FOR_QUERY_TAG: u8 = b'Z';
35+
36+
#[derive(Debug, Copy, Clone)]
37+
pub struct Header {
38+
tag: u8,
39+
len: i32,
40+
}
41+
42+
#[allow(clippy::len_without_is_empty)]
43+
impl Header {
44+
#[inline]
45+
pub fn parse(buf: &[u8]) -> io::Result<Option<Header>> {
46+
if buf.len() < 5 {
47+
return Ok(None);
48+
}
49+
50+
let tag = buf[0];
51+
let len = BigEndian::read_i32(&buf[1..]);
52+
53+
if len < 4 {
54+
return Err(io::Error::new(
55+
io::ErrorKind::InvalidData,
56+
"invalid message length",
57+
));
58+
}
59+
60+
Ok(Some(Header { tag, len }))
61+
}
62+
63+
#[inline]
64+
pub fn tag(self) -> u8 {
65+
self.tag
66+
}
67+
68+
#[inline]
69+
pub fn len(self) -> i32 {
70+
self.len
71+
}
72+
}
73+
1474
/// An enum representing Postgres backend messages.
1575
pub enum Message {
1676
AuthenticationCleartextPassword,
@@ -80,10 +140,10 @@ impl Message {
80140
};
81141

82142
let message = match tag {
83-
b'1' => Message::ParseComplete,
84-
b'2' => Message::BindComplete,
85-
b'3' => Message::CloseComplete,
86-
b'A' => {
143+
PARSE_COMPLETE_TAG => Message::ParseComplete,
144+
BIND_COMPLETE_TAG => Message::BindComplete,
145+
CLOSE_COMPLETE_TAG => Message::CloseComplete,
146+
NOTIFICATION_RESPONSE_TAG => {
87147
let process_id = buf.read_i32::<BigEndian>()?;
88148
let channel = buf.read_cstr()?;
89149
let message = buf.read_cstr()?;
@@ -93,25 +153,25 @@ impl Message {
93153
message,
94154
})
95155
}
96-
b'c' => Message::CopyDone,
97-
b'C' => {
156+
COPY_DONE_TAG => Message::CopyDone,
157+
COMMAND_COMPLETE_TAG => {
98158
let tag = buf.read_cstr()?;
99159
Message::CommandComplete(CommandCompleteBody { tag })
100160
}
101-
b'd' => {
161+
COPY_DATA_TAG => {
102162
let storage = buf.read_all();
103163
Message::CopyData(CopyDataBody { storage })
104164
}
105-
b'D' => {
165+
DATA_ROW_TAG => {
106166
let len = buf.read_u16::<BigEndian>()?;
107167
let storage = buf.read_all();
108168
Message::DataRow(DataRowBody { storage, len })
109169
}
110-
b'E' => {
170+
ERROR_RESPONSE_TAG => {
111171
let storage = buf.read_all();
112172
Message::ErrorResponse(ErrorResponseBody { storage })
113173
}
114-
b'G' => {
174+
COPY_IN_RESPONSE_TAG => {
115175
let format = buf.read_u8()?;
116176
let len = buf.read_u16::<BigEndian>()?;
117177
let storage = buf.read_all();
@@ -121,7 +181,7 @@ impl Message {
121181
storage,
122182
})
123183
}
124-
b'H' => {
184+
COPY_OUT_RESPONSE_TAG => {
125185
let format = buf.read_u8()?;
126186
let len = buf.read_u16::<BigEndian>()?;
127187
let storage = buf.read_all();
@@ -131,21 +191,21 @@ impl Message {
131191
storage,
132192
})
133193
}
134-
b'I' => Message::EmptyQueryResponse,
135-
b'K' => {
194+
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
195+
BACKEND_KEY_DATA_TAG => {
136196
let process_id = buf.read_i32::<BigEndian>()?;
137197
let secret_key = buf.read_i32::<BigEndian>()?;
138198
Message::BackendKeyData(BackendKeyDataBody {
139199
process_id,
140200
secret_key,
141201
})
142202
}
143-
b'n' => Message::NoData,
144-
b'N' => {
203+
NO_DATA_TAG => Message::NoData,
204+
NOTICE_RESPONSE_TAG => {
145205
let storage = buf.read_all();
146206
Message::NoticeResponse(NoticeResponseBody { storage })
147207
}
148-
b'R' => match buf.read_i32::<BigEndian>()? {
208+
AUTHENTICATION_TAG => match buf.read_i32::<BigEndian>()? {
149209
0 => Message::AuthenticationOk,
150210
2 => Message::AuthenticationKerberosV5,
151211
3 => Message::AuthenticationCleartextPassword,
@@ -180,23 +240,23 @@ impl Message {
180240
));
181241
}
182242
},
183-
b's' => Message::PortalSuspended,
184-
b'S' => {
243+
PORTAL_SUSPENDED_TAG => Message::PortalSuspended,
244+
PARAMETER_STATUS_TAG => {
185245
let name = buf.read_cstr()?;
186246
let value = buf.read_cstr()?;
187247
Message::ParameterStatus(ParameterStatusBody { name, value })
188248
}
189-
b't' => {
249+
PARAMETER_DESCRIPTION_TAG => {
190250
let len = buf.read_u16::<BigEndian>()?;
191251
let storage = buf.read_all();
192252
Message::ParameterDescription(ParameterDescriptionBody { storage, len })
193253
}
194-
b'T' => {
254+
ROW_DESCRIPTION_TAG => {
195255
let len = buf.read_u16::<BigEndian>()?;
196256
let storage = buf.read_all();
197257
Message::RowDescription(RowDescriptionBody { storage, len })
198258
}
199-
b'Z' => {
259+
READY_FOR_QUERY_TAG => {
200260
let status = buf.read_u8()?;
201261
Message::ReadyForQuery(ReadyForQueryBody { status })
202262
}

tokio-postgres/src/proto/bind.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use futures::sync::mpsc;
2-
use futures::{Poll, Stream};
1+
use futures::{try_ready, Poll, Stream};
32
use postgres_protocol::message::backend::Message;
43
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
54

65
use crate::proto::client::{Client, PendingRequest};
76
use crate::proto::portal::Portal;
7+
use crate::proto::responses::Responses;
88
use crate::proto::statement::Statement;
99
use crate::Error;
1010

@@ -19,7 +19,7 @@ pub enum Bind {
1919
},
2020
#[state_machine_future(transitions(Finished))]
2121
ReadBindComplete {
22-
receiver: mpsc::Receiver<Message>,
22+
receiver: Responses,
2323
client: Client,
2424
name: String,
2525
statement: Statement,
@@ -46,7 +46,7 @@ impl PollBind for Bind {
4646
fn poll_read_bind_complete<'a>(
4747
state: &'a mut RentToOwn<'a, ReadBindComplete>,
4848
) -> Poll<AfterReadBindComplete, Error> {
49-
let message = try_ready_receive!(state.receiver.poll());
49+
let message = try_ready!(state.receiver.poll());
5050
let state = state.take();
5151

5252
match message {

tokio-postgres/src/proto/client.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use bytes::IntoBuf;
33
use futures::sync::mpsc;
44
use futures::{AsyncSink, Poll, Sink, Stream};
55
use postgres_protocol;
6-
use postgres_protocol::message::backend::Message;
76
use postgres_protocol::message::frontend;
87
use std::collections::HashMap;
98
use std::error::Error as StdError;
@@ -20,6 +19,7 @@ use crate::proto::idle::{IdleGuard, IdleState};
2019
use crate::proto::portal::Portal;
2120
use crate::proto::prepare::PrepareFuture;
2221
use crate::proto::query::QueryStream;
22+
use crate::proto::responses::{self, Responses};
2323
use crate::proto::simple_query::SimpleQueryStream;
2424
use crate::proto::statement::Statement;
2525
#[cfg(feature = "runtime")]
@@ -130,9 +130,9 @@ impl Client {
130130
self.0.state.lock().typeinfo_composite_query = Some(statement.clone());
131131
}
132132

133-
pub fn send(&self, request: PendingRequest) -> Result<mpsc::Receiver<Message>, Error> {
133+
pub fn send(&self, request: PendingRequest) -> Result<Responses, Error> {
134134
let (messages, idle) = request.0?;
135-
let (sender, receiver) = mpsc::channel(1);
135+
let (sender, receiver) = responses::channel();
136136
self.0
137137
.sender
138138
.unbounded_send(Request {

tokio-postgres/src/proto/codec.rs

+67-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use bytes::{Buf, BytesMut};
2+
use fallible_iterator::FallibleIterator;
23
use postgres_protocol::message::backend;
34
use postgres_protocol::message::frontend::CopyData;
45
use std::io;
@@ -9,6 +10,31 @@ pub enum FrontendMessage {
910
CopyData(CopyData<Box<dyn Buf + Send>>),
1011
}
1112

13+
pub enum BackendMessage {
14+
Normal {
15+
messages: BackendMessages,
16+
request_complete: bool,
17+
},
18+
Async(backend::Message),
19+
}
20+
21+
pub struct BackendMessages(BytesMut);
22+
23+
impl BackendMessages {
24+
pub fn empty() -> BackendMessages {
25+
BackendMessages(BytesMut::new())
26+
}
27+
}
28+
29+
impl FallibleIterator for BackendMessages {
30+
type Item = backend::Message;
31+
type Error = io::Error;
32+
33+
fn next(&mut self) -> io::Result<Option<backend::Message>> {
34+
backend::Message::parse(&mut self.0)
35+
}
36+
}
37+
1238
pub struct PostgresCodec;
1339

1440
impl Encoder for PostgresCodec {
@@ -26,10 +52,48 @@ impl Encoder for PostgresCodec {
2652
}
2753

2854
impl Decoder for PostgresCodec {
29-
type Item = backend::Message;
55+
type Item = BackendMessage;
3056
type Error = io::Error;
3157

32-
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<backend::Message>, io::Error> {
33-
backend::Message::parse(src)
58+
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
59+
let mut idx = 0;
60+
let mut request_complete = false;
61+
62+
while let Some(header) = backend::Header::parse(&src[idx..])? {
63+
let len = header.len() as usize + 1;
64+
if src[idx..].len() < len {
65+
break;
66+
}
67+
68+
match header.tag() {
69+
backend::NOTICE_RESPONSE_TAG
70+
| backend::NOTIFICATION_RESPONSE_TAG
71+
| backend::PARAMETER_STATUS_TAG => {
72+
if idx == 0 {
73+
let message = backend::Message::parse(src)?.unwrap();
74+
return Ok(Some(BackendMessage::Async(message)));
75+
} else {
76+
break;
77+
}
78+
}
79+
_ => {}
80+
}
81+
82+
idx += len;
83+
84+
if header.tag() == backend::READY_FOR_QUERY_TAG {
85+
request_complete = true;
86+
break;
87+
}
88+
}
89+
90+
if idx == 0 {
91+
Ok(None)
92+
} else {
93+
Ok(Some(BackendMessage::Normal {
94+
messages: BackendMessages(src.split_to(idx)),
95+
request_complete,
96+
}))
97+
}
3498
}
3599
}

0 commit comments

Comments
 (0)