Skip to content

Commit 7eaac1c

Browse files
committed
Sync copy_in support
1 parent 1fdfefb commit 7eaac1c

3 files changed

Lines changed: 131 additions & 2 deletions

File tree

postgres/src/client.rs

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use futures::{Future, Stream};
1+
use futures::sync::mpsc;
2+
use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream};
3+
use std::io::{self, Read};
24
use tokio_postgres::types::{ToSql, Type};
35
use tokio_postgres::{Error, Row};
46
#[cfg(feature = "runtime")]
@@ -52,6 +54,29 @@ impl Client {
5254
self.0.query(&statement.0, params).collect().wait()
5355
}
5456

57+
pub fn copy_in<T, R>(
58+
&mut self,
59+
query: &T,
60+
params: &[&dyn ToSql],
61+
reader: R,
62+
) -> Result<u64, Error>
63+
where
64+
T: ?Sized + Query,
65+
R: Read,
66+
{
67+
let statement = query.__statement(self)?;
68+
let (sender, receiver) = mpsc::channel(1);
69+
let future = self.0.copy_in(&statement.0, params, CopyInStream(receiver));
70+
71+
CopyInFuture {
72+
future,
73+
sender,
74+
reader,
75+
pending: None,
76+
}
77+
.wait()
78+
}
79+
5580
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
5681
self.0.batch_execute(query).wait()
5782
}
@@ -71,3 +96,80 @@ impl From<tokio_postgres::Client> for Client {
7196
Client(c)
7297
}
7398
}
99+
100+
enum CopyData {
101+
Data(Vec<u8>),
102+
Error(io::Error),
103+
Done,
104+
}
105+
106+
struct CopyInStream(mpsc::Receiver<CopyData>);
107+
108+
impl Stream for CopyInStream {
109+
type Item = Vec<u8>;
110+
type Error = io::Error;
111+
112+
fn poll(&mut self) -> Poll<Option<Vec<u8>>, io::Error> {
113+
match self.0.poll().expect("mpsc::Receiver can't error") {
114+
Async::Ready(Some(CopyData::Data(buf))) => Ok(Async::Ready(Some(buf))),
115+
Async::Ready(Some(CopyData::Error(e))) => Err(e),
116+
Async::Ready(Some(CopyData::Done)) => Ok(Async::Ready(None)),
117+
Async::Ready(None) => Err(io::Error::new(io::ErrorKind::Other, "writer disconnected")),
118+
Async::NotReady => Ok(Async::NotReady),
119+
}
120+
}
121+
}
122+
123+
struct CopyInFuture<R> {
124+
future: tokio_postgres::CopyIn<CopyInStream>,
125+
sender: mpsc::Sender<CopyData>,
126+
reader: R,
127+
pending: Option<CopyData>,
128+
}
129+
130+
impl<R> CopyInFuture<R> {
131+
fn poll_send_data(&mut self, data: CopyData) -> Poll<(), Error> {
132+
match self.sender.start_send(data) {
133+
Ok(AsyncSink::Ready) => Ok(Async::Ready(())),
134+
Ok(AsyncSink::NotReady(pending)) => {
135+
self.pending = Some(pending);
136+
return Ok(Async::NotReady);
137+
}
138+
// the future's hung up on its end of the channel, so we'll wait for it to report an error
139+
Err(_) => {
140+
self.pending = Some(CopyData::Done);
141+
return Ok(Async::NotReady);
142+
}
143+
}
144+
}
145+
}
146+
147+
impl<R> Future for CopyInFuture<R>
148+
where
149+
R: Read,
150+
{
151+
type Item = u64;
152+
type Error = Error;
153+
154+
fn poll(&mut self) -> Poll<u64, Error> {
155+
if let Async::Ready(n) = self.future.poll()? {
156+
return Ok(Async::Ready(n));
157+
}
158+
159+
loop {
160+
let data = match self.pending.take() {
161+
Some(pending) => pending,
162+
None => {
163+
let mut buf = vec![];
164+
match self.reader.by_ref().take(4096).read_to_end(&mut buf) {
165+
Ok(0) => CopyData::Done,
166+
Ok(_) => CopyData::Data(buf),
167+
Err(e) => CopyData::Error(e),
168+
}
169+
}
170+
};
171+
172+
try_ready!(self.poll_send_data(data));
173+
}
174+
}
175+
}

postgres/src/test.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,30 @@ fn nested_transactions() {
144144
assert_eq!(rows[1].get::<_, i32>(0), 3);
145145
assert_eq!(rows[2].get::<_, i32>(0), 4);
146146
}
147+
148+
#[test]
149+
fn copy_in() {
150+
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
151+
152+
client
153+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
154+
.unwrap();
155+
156+
client
157+
.copy_in(
158+
"COPY foo FROM stdin",
159+
&[],
160+
&mut &b"1\tsteven\n2\ttimothy"[..],
161+
)
162+
.unwrap();
163+
164+
let rows = client
165+
.query("SELECT id, name FROM foo ORDER BY id", &[])
166+
.unwrap();
167+
168+
assert_eq!(rows.len(), 2);
169+
assert_eq!(rows[0].get::<_, i32>(0), 1);
170+
assert_eq!(rows[0].get::<_, &str>(1), "steven");
171+
assert_eq!(rows[1].get::<_, i32>(0), 2);
172+
assert_eq!(rows[1].get::<_, &str>(1), "timothy");
173+
}

tokio-postgres/src/proto/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ impl Client {
189189
<S::Item as IntoBuf>::Buf: Send,
190190
S::Error: Into<Box<dyn StdError + Sync + Send>>,
191191
{
192-
let (mut sender, receiver) = mpsc::channel(0);
192+
let (mut sender, receiver) = mpsc::channel(1);
193193
let pending = PendingRequest(self.excecute_message(statement, params).map(|buf| {
194194
match sender.start_send(CopyMessage::Data(buf)) {
195195
Ok(AsyncSink::Ready) => {}

0 commit comments

Comments
 (0)