Skip to content

Commit 9938fff

Browse files
committed
Test and fix simple_query
1 parent 07e5930 commit 9938fff

4 files changed

Lines changed: 115 additions & 147 deletions

File tree

tokio-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ futures-preview = { version = "0.3.0-alpha.17", features = ["nightly", "async-aw
3737
log = "0.4"
3838
parking_lot = "0.9"
3939
percent-encoding = "1.0"
40+
pin-utils = "0.1.0-alpha.4"
4041
phf = "0.7.23"
4142
postgres-protocol = { version = "0.4.1", path = "../postgres-protocol" }
4243
tokio = { git = "https://github.com/tokio-rs/tokio", default-features = false, features = ["io", "codec"] }

tokio-postgres/src/connect.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::config::{Host, TargetSessionAttrs};
2+
use pin_utils::pin_mut;
23
use crate::connect_raw::connect_raw;
34
use crate::connect_socket::connect_socket;
45
use crate::tls::{MakeTlsConnect, TlsConnect};
@@ -55,7 +56,8 @@ where
5556
let (mut client, connection) = connect_raw(socket, tls, config, Some(idx)).await?;
5657

5758
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
58-
let mut rows = client.simple_query("SHOW transaction_read_only");
59+
let rows = client.simple_query("SHOW transaction_read_only");
60+
pin_mut!(rows);
5961

6062
loop {
6163
match rows.try_next().await? {

tokio-postgres/src/simple_query.rs

Lines changed: 47 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::{Error, SimpleQueryMessage, SimpleQueryRow};
55
use fallible_iterator::FallibleIterator;
6-
use futures::{ready, Stream};
6+
use futures::{ready, Stream, TryFutureExt};
77
use postgres_protocol::message::backend::Message;
88
use postgres_protocol::message::frontend;
99
use std::future::Future;
10-
use std::mem;
1110
use std::pin::Pin;
1211
use std::sync::Arc;
1312
use std::task::{Context, Poll};
@@ -17,7 +16,18 @@ pub fn simple_query(
1716
query: &str,
1817
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> {
1918
let buf = encode(query);
20-
SimpleQuery::Start { client, buf }
19+
20+
let start = async move {
21+
let buf = buf?;
22+
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
23+
24+
Ok(SimpleQuery {
25+
responses,
26+
columns: None,
27+
})
28+
};
29+
30+
start.try_flatten_stream()
2131
}
2232

2333
pub fn batch_execute(
@@ -49,84 +59,49 @@ fn encode(query: &str) -> Result<Vec<u8>, Error> {
4959
Ok(buf)
5060
}
5161

52-
enum SimpleQuery {
53-
Start {
54-
client: Arc<InnerClient>,
55-
buf: Result<Vec<u8>, Error>,
56-
},
57-
Reading {
58-
responses: Responses,
59-
columns: Option<Arc<[String]>>,
60-
},
61-
Done,
62+
struct SimpleQuery {
63+
responses: Responses,
64+
columns: Option<Arc<[String]>>
6265
}
6366

6467
impl Stream for SimpleQuery {
6568
type Item = Result<SimpleQueryMessage, Error>;
6669

6770
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
6871
loop {
69-
match mem::replace(&mut *self, SimpleQuery::Done) {
70-
SimpleQuery::Start { client, buf } => {
71-
let buf = buf?;
72-
let responses =
73-
client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
74-
75-
*self = SimpleQuery::Reading {
76-
responses,
77-
columns: None,
72+
match ready!(self.responses.poll_next(cx)?) {
73+
Message::CommandComplete(body) => {
74+
let rows = body
75+
.tag()
76+
.map_err(Error::parse)?
77+
.rsplit(' ')
78+
.next()
79+
.unwrap()
80+
.parse()
81+
.unwrap_or(0);
82+
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
83+
}
84+
Message::EmptyQueryResponse => {
85+
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
86+
}
87+
Message::RowDescription(body) => {
88+
let columns = body
89+
.fields()
90+
.map(|f| Ok(f.name().to_string()))
91+
.collect::<Vec<_>>()
92+
.map_err(Error::parse)?
93+
.into();
94+
self.columns = Some(columns);
95+
}
96+
Message::DataRow(body) => {
97+
let row = match &self.columns {
98+
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
99+
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
78100
};
101+
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
79102
}
80-
SimpleQuery::Reading {
81-
mut responses,
82-
columns,
83-
} => match ready!(responses.poll_next(cx)?) {
84-
Message::CommandComplete(body) => {
85-
let rows = body
86-
.tag()
87-
.map_err(Error::parse)?
88-
.rsplit(' ')
89-
.next()
90-
.unwrap()
91-
.parse()
92-
.unwrap_or(0);
93-
*self = SimpleQuery::Reading {
94-
responses,
95-
columns: None,
96-
};
97-
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
98-
}
99-
Message::EmptyQueryResponse => {
100-
*self = SimpleQuery::Reading {
101-
responses,
102-
columns: None,
103-
};
104-
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
105-
}
106-
Message::RowDescription(body) => {
107-
let columns = body
108-
.fields()
109-
.map(|f| Ok(f.name().to_string()))
110-
.collect::<Vec<_>>()
111-
.map_err(Error::parse)?
112-
.into();
113-
*self = SimpleQuery::Reading {
114-
responses,
115-
columns: Some(columns),
116-
};
117-
}
118-
Message::DataRow(body) => {
119-
let row = match &columns {
120-
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
121-
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
122-
};
123-
*self = SimpleQuery::Reading { responses, columns };
124-
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
125-
}
126-
Message::ReadyForQuery(_) => return Poll::Ready(None),
127-
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
128-
},
129-
SimpleQuery::Done => return Poll::Ready(None),
103+
Message::ReadyForQuery(_) => return Poll::Ready(None),
104+
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
130105
}
131106
}
132107
}

tokio-postgres/tests/test/main.rs

Lines changed: 64 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tokio::net::TcpStream;
66
use tokio_postgres::error::SqlState;
77
use tokio_postgres::tls::{NoTls, NoTlsStream};
88
use tokio_postgres::types::{Kind, Type};
9-
use tokio_postgres::{Client, Config, Connection, Error};
9+
use tokio_postgres::{Client, Config, Connection, Error, SimpleQueryMessage};
1010

1111
mod parse;
1212
#[cfg(feature = "runtime")]
@@ -114,12 +114,10 @@ async fn pipelined_prepare() {
114114
async fn insert_select() {
115115
let mut client = connect("user=postgres").await;
116116

117-
let setup = client
118-
.prepare("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
117+
client
118+
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
119119
.await
120120
.unwrap();
121-
client.execute(&setup, &[]).await.unwrap();
122-
drop(setup);
123121

124122
let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)");
125123
let select = client.prepare("SELECT id, name FROM foo ORDER BY id");
@@ -142,8 +140,8 @@ async fn insert_select() {
142140
async fn custom_enum() {
143141
let mut client = connect("user=postgres").await;
144142

145-
let create = client
146-
.prepare(
143+
client
144+
.batch_execute(
147145
"CREATE TYPE pg_temp.mood AS ENUM (
148146
'sad',
149147
'ok',
@@ -152,7 +150,6 @@ async fn custom_enum() {
152150
)
153151
.await
154152
.unwrap();
155-
client.execute(&create, &[]).await.unwrap();
156153

157154
let select = client.prepare("SELECT $1::mood").await.unwrap();
158155

@@ -172,11 +169,10 @@ async fn custom_enum() {
172169
async fn custom_domain() {
173170
let mut client = connect("user=postgres").await;
174171

175-
let create = client
176-
.prepare("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)")
172+
client
173+
.batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)")
177174
.await
178175
.unwrap();
179-
client.execute(&create, &[]).await.unwrap();
180176

181177
let select = client.prepare("SELECT $1::session_id").await.unwrap();
182178

@@ -206,17 +202,16 @@ async fn custom_array() {
206202
async fn custom_composite() {
207203
let mut client = connect("user=postgres").await;
208204

209-
let create = client
210-
.prepare(
205+
client
206+
.batch_execute(
211207
"CREATE TYPE pg_temp.inventory_item AS (
212-
name TEXT,
213-
supplier INTEGER,
214-
price NUMERIC
215-
)",
208+
name TEXT,
209+
supplier INTEGER,
210+
price NUMERIC
211+
)",
216212
)
217213
.await
218214
.unwrap();
219-
client.execute(&create, &[]).await.unwrap();
220215

221216
let select = client.prepare("SELECT $1::inventory_item").await.unwrap();
222217

@@ -239,16 +234,15 @@ async fn custom_composite() {
239234
async fn custom_range() {
240235
let mut client = connect("user=postgres").await;
241236

242-
let create = client
243-
.prepare(
237+
client
238+
.batch_execute(
244239
"CREATE TYPE pg_temp.floatrange AS RANGE (
245-
subtype = float8,
246-
subtype_diff = float8mi
247-
)",
240+
subtype = float8,
241+
subtype_diff = float8mi
242+
)",
248243
)
249244
.await
250245
.unwrap();
251-
client.execute(&create, &[]).await.unwrap();
252246

253247
let select = client.prepare("SELECT $1::floatrange").await.unwrap();
254248

@@ -257,6 +251,52 @@ async fn custom_range() {
257251
assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind());
258252
}
259253

254+
#[tokio::test]
255+
async fn simple_query() {
256+
let mut client = connect("user=postgres").await;
257+
258+
let messages = client
259+
.simple_query(
260+
"CREATE TEMPORARY TABLE foo (
261+
id SERIAL,
262+
name TEXT
263+
);
264+
INSERT INTO foo (name) VALUES ('steven'), ('joe');
265+
SELECT * FROM foo ORDER BY id;",
266+
)
267+
.try_collect::<Vec<_>>()
268+
.await
269+
.unwrap();
270+
271+
match messages[0] {
272+
SimpleQueryMessage::CommandComplete(0) => {}
273+
_ => panic!("unexpected message"),
274+
}
275+
match messages[1] {
276+
SimpleQueryMessage::CommandComplete(2) => {}
277+
_ => panic!("unexpected message"),
278+
}
279+
match &messages[2] {
280+
SimpleQueryMessage::Row(row) => {
281+
assert_eq!(row.get(0), Some("1"));
282+
assert_eq!(row.get(1), Some("steven"));
283+
}
284+
_ => panic!("unexpected message"),
285+
}
286+
match &messages[3] {
287+
SimpleQueryMessage::Row(row) => {
288+
assert_eq!(row.get(0), Some("2"));
289+
assert_eq!(row.get(1), Some("joe"));
290+
}
291+
_ => panic!("unexpected message"),
292+
}
293+
match messages[4] {
294+
SimpleQueryMessage::CommandComplete(2) => {}
295+
_ => panic!("unexpected message"),
296+
}
297+
assert_eq!(messages.len(), 5);
298+
}
299+
260300
/*
261301
#[test]
262302
fn query_portal() {
@@ -675,56 +715,6 @@ fn transaction_builder_around_moved_client() {
675715
runtime.run().unwrap();
676716
}
677717
678-
#[test]
679-
fn simple_query() {
680-
let _ = env_logger::try_init();
681-
let mut runtime = Runtime::new().unwrap();
682-
683-
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
684-
let connection = connection.map_err(|e| panic!("{}", e));
685-
runtime.handle().spawn(connection).unwrap();
686-
687-
let f = client
688-
.simple_query(
689-
"CREATE TEMPORARY TABLE foo (
690-
id SERIAL,
691-
name TEXT
692-
);
693-
INSERT INTO foo (name) VALUES ('steven'), ('joe');
694-
SELECT * FROM foo ORDER BY id;",
695-
)
696-
.collect();
697-
let messages = runtime.block_on(f).unwrap();
698-
699-
match messages[0] {
700-
SimpleQueryMessage::CommandComplete(0) => {}
701-
_ => panic!("unexpected message"),
702-
}
703-
match messages[1] {
704-
SimpleQueryMessage::CommandComplete(2) => {}
705-
_ => panic!("unexpected message"),
706-
}
707-
match &messages[2] {
708-
SimpleQueryMessage::Row(row) => {
709-
assert_eq!(row.get(0), Some("1"));
710-
assert_eq!(row.get(1), Some("steven"));
711-
}
712-
_ => panic!("unexpected message"),
713-
}
714-
match &messages[3] {
715-
SimpleQueryMessage::Row(row) => {
716-
assert_eq!(row.get(0), Some("2"));
717-
assert_eq!(row.get(1), Some("joe"));
718-
}
719-
_ => panic!("unexpected message"),
720-
}
721-
match messages[4] {
722-
SimpleQueryMessage::CommandComplete(2) => {}
723-
_ => panic!("unexpected message"),
724-
}
725-
assert_eq!(messages.len(), 5);
726-
}
727-
728718
#[test]
729719
fn poll_idle_running() {
730720
struct DelayStream(Delay);

0 commit comments

Comments
 (0)