Skip to content

Commit 2480fef

Browse files
committed
Connection IO logic
1 parent 32fe524 commit 2480fef

4 files changed

Lines changed: 339 additions & 23 deletions

File tree

tokio-postgres/src/connection.rs

Lines changed: 276 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1-
use crate::codec::{BackendMessages, FrontendMessage, PostgresCodec};
1+
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2+
use crate::error::DbError;
23
use crate::maybe_tls_stream::MaybeTlsStream;
4+
use crate::{AsyncMessage, Error, Notification};
5+
use fallible_iterator::FallibleIterator;
36
use futures::channel::mpsc;
4-
use std::collections::HashMap;
7+
use futures::{ready, Sink, Stream, StreamExt};
8+
use log::trace;
9+
use postgres_protocol::message::backend::Message;
10+
use postgres_protocol::message::frontend;
11+
use std::collections::{HashMap, VecDeque};
12+
use std::future::Future;
13+
use std::io;
14+
use std::pin::Pin;
15+
use std::task::{Context, Poll};
516
use tokio::codec::Framed;
17+
use tokio::io::{AsyncRead, AsyncWrite};
618

719
pub enum RequestMessages {
820
Single(FrontendMessage),
@@ -13,13 +25,40 @@ pub struct Request {
1325
pub sender: mpsc::Sender<BackendMessages>,
1426
}
1527

28+
pub struct Response {
29+
sender: mpsc::Sender<BackendMessages>,
30+
}
31+
32+
#[derive(PartialEq, Debug)]
33+
enum State {
34+
Active,
35+
Terminating,
36+
Closing,
37+
}
38+
39+
/// A connection to a PostgreSQL database.
40+
///
41+
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
42+
/// server, and should generally be spawned off onto an executor to run in the background.
43+
///
44+
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
45+
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
46+
#[must_use = "futures do nothing unless polled"]
1647
pub struct Connection<S, T> {
1748
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
1849
parameters: HashMap<String, String>,
1950
receiver: mpsc::UnboundedReceiver<Request>,
51+
pending_request: Option<RequestMessages>,
52+
pending_response: Option<BackendMessage>,
53+
responses: VecDeque<Response>,
54+
state: State,
2055
}
2156

22-
impl<S, T> Connection<S, T> {
57+
impl<S, T> Connection<S, T>
58+
where
59+
S: AsyncRead + AsyncWrite + Unpin,
60+
T: AsyncRead + AsyncWrite + Unpin,
61+
{
2362
pub(crate) fn new(
2463
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
2564
parameters: HashMap<String, String>,
@@ -29,6 +68,240 @@ impl<S, T> Connection<S, T> {
2968
stream,
3069
parameters,
3170
receiver,
71+
pending_request: None,
72+
pending_response: None,
73+
responses: VecDeque::new(),
74+
state: State::Active,
75+
}
76+
}
77+
78+
/// Returns the value of a runtime parameter for this connection.
79+
pub fn parameter(&self, name: &str) -> Option<&str> {
80+
self.parameters.get(name).map(|s| &**s)
81+
}
82+
83+
fn poll_response(
84+
&mut self,
85+
cx: &mut Context<'_>,
86+
) -> Poll<Option<Result<BackendMessage, Error>>> {
87+
if let Some(message) = self.pending_response.take() {
88+
trace!("retrying pending response");
89+
return Poll::Ready(Some(Ok(message)));
90+
}
91+
92+
Pin::new(&mut self.stream)
93+
.poll_next(cx)
94+
.map(|o| o.map(|r| r.map_err(Error::io)))
95+
}
96+
97+
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
98+
if self.state != State::Active {
99+
trace!("poll_read: done");
100+
return Ok(None);
101+
}
102+
103+
loop {
104+
let message = match self.poll_response(cx)? {
105+
Poll::Ready(Some(message)) => message,
106+
Poll::Ready(None) => return Err(Error::closed()),
107+
Poll::Pending => {
108+
trace!("poll_read: waiting on response");
109+
return Ok(None);
110+
}
111+
};
112+
113+
let (mut messages, request_complete) = match message {
114+
BackendMessage::Async(Message::NoticeResponse(body)) => {
115+
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
116+
return Ok(Some(AsyncMessage::Notice(error)));
117+
}
118+
BackendMessage::Async(Message::NotificationResponse(body)) => {
119+
let notification = Notification {
120+
process_id: body.process_id(),
121+
channel: body.channel().map_err(Error::parse)?.to_string(),
122+
payload: body.message().map_err(Error::parse)?.to_string(),
123+
};
124+
return Ok(Some(AsyncMessage::Notification(notification)));
125+
}
126+
BackendMessage::Async(Message::ParameterStatus(body)) => {
127+
self.parameters.insert(
128+
body.name().map_err(Error::parse)?.to_string(),
129+
body.value().map_err(Error::parse)?.to_string(),
130+
);
131+
continue;
132+
}
133+
BackendMessage::Async(_) => unreachable!(),
134+
BackendMessage::Normal {
135+
messages,
136+
request_complete,
137+
} => (messages, request_complete),
138+
};
139+
140+
let mut response = match self.responses.pop_front() {
141+
Some(response) => response,
142+
None => match messages.next().map_err(Error::parse)? {
143+
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
144+
_ => return Err(Error::unexpected_message()),
145+
},
146+
};
147+
148+
match response.sender.poll_ready(cx) {
149+
Poll::Ready(Ok(())) => {
150+
let _ = response.sender.start_send(messages);
151+
if !request_complete {
152+
self.responses.push_front(response);
153+
}
154+
}
155+
Poll::Ready(Err(_)) => {
156+
// we need to keep paging through the rest of the messages even if the receiver's hung up
157+
if !request_complete {
158+
self.responses.push_front(response);
159+
}
160+
}
161+
Poll::Pending => {
162+
self.responses.push_front(response);
163+
self.pending_response = Some(BackendMessage::Normal {
164+
messages,
165+
request_complete,
166+
});
167+
trace!("poll_read: waiting on sender");
168+
return Ok(None);
169+
}
170+
}
171+
}
172+
}
173+
174+
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
175+
if let Some(messages) = self.pending_request.take() {
176+
trace!("retrying pending request");
177+
return Poll::Ready(Some(messages));
178+
}
179+
180+
match self.receiver.poll_next_unpin(cx) {
181+
Poll::Ready(Some(request)) => {
182+
trace!("polled new request");
183+
self.responses.push_back(Response {
184+
sender: request.sender,
185+
});
186+
Poll::Ready(Some(request.messages))
187+
}
188+
Poll::Ready(None) => Poll::Ready(None),
189+
Poll::Pending => Poll::Pending,
32190
}
33191
}
192+
193+
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
194+
loop {
195+
if self.state == State::Closing {
196+
trace!("poll_write: done");
197+
return Ok(false);
198+
}
199+
200+
let request = match self.poll_request(cx) {
201+
Poll::Ready(Some(request)) => request,
202+
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
203+
trace!("poll_write: at eof, terminating");
204+
self.state = State::Terminating;
205+
let mut request = vec![];
206+
frontend::terminate(&mut request);
207+
RequestMessages::Single(FrontendMessage::Raw(request))
208+
}
209+
Poll::Ready(None) => {
210+
trace!(
211+
"poll_write: at eof, pending responses {}",
212+
self.responses.len()
213+
);
214+
return Ok(true);
215+
}
216+
Poll::Pending => {
217+
trace!("poll_write: waiting on request");
218+
return Ok(true);
219+
}
220+
};
221+
222+
if let Poll::Pending = Pin::new(&mut self.stream)
223+
.poll_ready(cx)
224+
.map_err(Error::io)?
225+
{
226+
trace!("poll_write: waiting on socket");
227+
self.pending_request = Some(request);
228+
return Ok(false);
229+
}
230+
231+
match request {
232+
RequestMessages::Single(request) => {
233+
Pin::new(&mut self.stream)
234+
.start_send(request)
235+
.map_err(Error::io)?;
236+
if self.state == State::Terminating {
237+
trace!("poll_write: sent eof, closing");
238+
self.state = State::Closing;
239+
}
240+
}
241+
}
242+
}
243+
}
244+
245+
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
246+
match Pin::new(&mut self.stream)
247+
.poll_flush(cx)
248+
.map_err(Error::io)?
249+
{
250+
Poll::Ready(()) => trace!("poll_flush: flushed"),
251+
Poll::Pending => trace!("poll_flush: waiting on socket"),
252+
}
253+
Ok(())
254+
}
255+
256+
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
257+
if self.state != State::Closing {
258+
return Poll::Pending;
259+
}
260+
261+
match Pin::new(&mut self.stream)
262+
.poll_close(cx)
263+
.map_err(Error::io)?
264+
{
265+
Poll::Ready(()) => {
266+
trace!("poll_shutdown: complete");
267+
Poll::Ready(Ok(()))
268+
}
269+
Poll::Pending => {
270+
trace!("poll_shutdown: waiting on socket");
271+
Poll::Pending
272+
}
273+
}
274+
}
275+
276+
pub fn poll_message(
277+
mut self: Pin<&mut Self>,
278+
cx: &mut Context<'_>,
279+
) -> Poll<Option<Result<AsyncMessage, Error>>> {
280+
let message = self.poll_read(cx)?;
281+
let want_flush = self.poll_write(cx)?;
282+
if want_flush {
283+
self.poll_flush(cx)?;
284+
}
285+
match message {
286+
Some(message) => Poll::Ready(Some(Ok(message))),
287+
None => match self.poll_shutdown(cx) {
288+
Poll::Ready(Ok(())) => Poll::Ready(None),
289+
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
290+
Poll::Pending => Poll::Pending,
291+
},
292+
}
293+
}
294+
}
295+
296+
impl<S, T> Future for Connection<S, T>
297+
where
298+
S: AsyncRead + AsyncWrite + Unpin,
299+
T: AsyncRead + AsyncWrite + Unpin,
300+
{
301+
type Output = Result<(), Error>;
302+
303+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
304+
while let Some(_) = ready!(Pin::as_mut(&mut self).poll_message(cx)?) {}
305+
Poll::Ready(Ok(()))
306+
}
34307
}

tokio-postgres/src/lib.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
pub use crate::client::Client;
116116
pub use crate::config::Config;
117117
pub use crate::connection::Connection;
118+
use crate::error::DbError;
118119
pub use crate::error::Error;
119120
#[cfg(feature = "runtime")]
120121
pub use crate::socket::Socket;
@@ -157,3 +158,43 @@ where
157158
let config = config.parse::<Config>()?;
158159
config.connect(tls).await
159160
}
161+
162+
/// An asynchronous notification.
163+
#[derive(Clone, Debug)]
164+
pub struct Notification {
165+
process_id: i32,
166+
channel: String,
167+
payload: String,
168+
}
169+
170+
/// An asynchronous message from the server.
171+
#[allow(clippy::large_enum_variant)]
172+
pub enum AsyncMessage {
173+
/// A notice.
174+
///
175+
/// Notices use the same format as errors, but aren't "errors" per-se.
176+
Notice(DbError),
177+
/// A notification.
178+
///
179+
/// Connections can subscribe to notifications with the `LISTEN` command.
180+
Notification(Notification),
181+
#[doc(hidden)]
182+
__NonExhaustive,
183+
}
184+
185+
impl Notification {
186+
/// The process ID of the notifying backend process.
187+
pub fn process_id(&self) -> i32 {
188+
self.process_id
189+
}
190+
191+
/// The name of the channel that the notify has been raised on.
192+
pub fn channel(&self) -> &str {
193+
&self.channel
194+
}
195+
196+
/// The "payload" string passed from the notifying process.
197+
pub fn payload(&self) -> &str {
198+
&self.payload
199+
}
200+
}

0 commit comments

Comments
 (0)