Skip to content

Commit 2e61f13

Browse files
committed
Initial COPY FROM support!
cc rust-postgres#51
1 parent a4a625a commit 2e61f13

4 files changed

Lines changed: 233 additions & 23 deletions

File tree

src/lib.rs

Lines changed: 202 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ use serialize::hex::ToHex;
7373
use std::cell::{Cell, RefCell};
7474
use std::collections::HashMap;
7575
use std::from_str::FromStr;
76-
use std::io::{BufferedStream, IoResult};
76+
use std::io::{BufferedStream, IoResult, MemWriter};
7777
use std::io::net::ip::Port;
7878
use std::mem;
7979
use std::fmt;
@@ -125,6 +125,8 @@ use message::{AuthenticationCleartextPassword,
125125
use message::{Bind,
126126
CancelRequest,
127127
Close,
128+
CopyData,
129+
CopyDone,
128130
CopyFail,
129131
Describe,
130132
Execute,
@@ -146,6 +148,7 @@ mod io;
146148
pub mod pool;
147149
mod message;
148150
pub mod types;
151+
mod util;
149152

150153
static CANARY: u32 = 0xdeadbeef;
151154

@@ -520,8 +523,8 @@ impl InnerPostgresConnection {
520523
mem::replace(&mut self.notice_handler, handler)
521524
}
522525

523-
fn prepare<'a>(&mut self, query: &str, conn: &'a PostgresConnection)
524-
-> PostgresResult<PostgresStatement<'a>> {
526+
fn raw_prepare(&mut self, query: &str)
527+
-> PostgresResult<(String, Vec<PostgresType>, Vec<ResultDescription>)> {
525528
let stmt_name = format!("s{}", self.next_stmt_id);
526529
self.next_stmt_id += 1;
527530

@@ -572,6 +575,12 @@ impl InnerPostgresConnection {
572575
try!(self.set_type_names(param_types.iter_mut()));
573576
try!(self.set_type_names(result_desc.iter_mut().map(|d| &mut d.ty)));
574577

578+
Ok((stmt_name, param_types, result_desc))
579+
}
580+
581+
fn prepare<'a>(&mut self, query: &str, conn: &'a PostgresConnection)
582+
-> PostgresResult<PostgresStatement<'a>> {
583+
let (stmt_name, param_types, result_desc) = try!(self.raw_prepare(query));
575584
Ok(PostgresStatement {
576585
conn: conn,
577586
name: stmt_name,
@@ -582,6 +591,54 @@ impl InnerPostgresConnection {
582591
})
583592
}
584593

594+
fn prepare_copy_in<'a>(&mut self, table: &str, rows: &[&str], conn: &'a PostgresConnection)
595+
-> PostgresResult<PostgresCopyInStatement<'a>> {
596+
let mut query = MemWriter::new();
597+
let _ = write!(query, "SELECT ");
598+
let _ = util::comma_join(&mut query, rows.iter().map(|&e| e));
599+
let _ = write!(query, " FROM {}", table);
600+
let query = String::from_utf8(query.unwrap()).unwrap();
601+
let (stmt_name, _, result_desc) = try!(self.raw_prepare(query.as_slice()));
602+
603+
let column_types = result_desc.iter().map(|desc| desc.ty.clone()).collect();
604+
try!(self.close_statement(stmt_name.as_slice()));
605+
606+
let mut query = MemWriter::new();
607+
let _ = write!(query, "COPY {} (", table);
608+
let _ = util::comma_join(&mut query, rows.iter().map(|&e| e));
609+
let _ = write!(query, ") FROM STDIN WITH (FORMAT binary)");
610+
let query = String::from_utf8(query.unwrap()).unwrap();
611+
let (stmt_name, _, _) = try!(self.raw_prepare(query.as_slice()));
612+
613+
Ok(PostgresCopyInStatement {
614+
conn: conn,
615+
name: stmt_name,
616+
column_types: column_types,
617+
next_portal_id: Cell::new(0),
618+
finished: false,
619+
})
620+
}
621+
622+
fn close_statement(&mut self, stmt_name: &str) -> PostgresResult<()> {
623+
try_pg!(self.write_messages([
624+
Close {
625+
variant: b'S',
626+
name: stmt_name,
627+
},
628+
Sync]));
629+
loop {
630+
match try_pg!(self.read_message_()) {
631+
ReadyForQuery { .. } => break,
632+
ErrorResponse { fields } => {
633+
try!(self.wait_for_ready());
634+
return Err(PgDbError(PostgresDbError::new(fields)));
635+
}
636+
_ => {}
637+
}
638+
}
639+
Ok(())
640+
}
641+
585642
fn set_type_names<'a, I>(&mut self, mut it: I) -> PostgresResult<()>
586643
where I: Iterator<&'a mut PostgresType> {
587644
for ty in it {
@@ -759,6 +816,15 @@ impl PostgresConnection {
759816
conn.prepare(query, self)
760817
}
761818

819+
pub fn prepare_copy_in<'a>(&'a self, table: &str, rows: &[&str])
820+
-> PostgresResult<PostgresCopyInStatement<'a>> {
821+
let mut conn = self.conn.borrow_mut();
822+
if conn.trans_depth != 0 {
823+
return Err(PgWrongTransaction);
824+
}
825+
conn.prepare_copy_in(table, rows, self)
826+
}
827+
762828
/// Begins a new transaction.
763829
///
764830
/// Returns a `PostgresTransaction` object which should be used instead of
@@ -1057,24 +1123,9 @@ impl<'conn> Drop for PostgresStatement<'conn> {
10571123

10581124
impl<'conn> PostgresStatement<'conn> {
10591125
fn finish_inner(&mut self) -> PostgresResult<()> {
1060-
check_desync!(self.conn);
1061-
try_pg!(self.conn.write_messages([
1062-
Close {
1063-
variant: b'S',
1064-
name: self.name.as_slice()
1065-
},
1066-
Sync]));
1067-
loop {
1068-
match try_pg!(self.conn.read_message_()) {
1069-
ReadyForQuery { .. } => break,
1070-
ErrorResponse { fields } => {
1071-
try!(self.conn.wait_for_ready());
1072-
return Err(PgDbError(PostgresDbError::new(fields)));
1073-
}
1074-
_ => {}
1075-
}
1076-
}
1077-
Ok(())
1126+
let mut conn = self.conn.conn.borrow_mut();
1127+
check_desync!(conn);
1128+
conn.close_statement(self.name.as_slice())
10781129
}
10791130

10801131
fn inner_execute(&self, portal_name: &str, row_limit: i32, params: &[&ToSql])
@@ -1495,3 +1546,133 @@ impl<'trans, 'stmt> Iterator<PostgresResult<PostgresRow<'stmt>>>
14951546
self.result.size_hint()
14961547
}
14971548
}
1549+
1550+
pub struct PostgresCopyInStatement<'a> {
1551+
conn: &'a PostgresConnection,
1552+
name: String,
1553+
column_types: Vec<PostgresType>,
1554+
next_portal_id: Cell<uint>,
1555+
finished: bool,
1556+
}
1557+
1558+
#[unsafe_destructor]
1559+
impl<'a> Drop for PostgresCopyInStatement<'a> {
1560+
fn drop(&mut self) {
1561+
if !self.finished {
1562+
let _ = self.finish_inner();
1563+
}
1564+
}
1565+
}
1566+
1567+
impl<'a> PostgresCopyInStatement<'a> {
1568+
fn finish_inner(&mut self) -> PostgresResult<()> {
1569+
let mut conn = self.conn.conn.borrow_mut();
1570+
check_desync!(conn);
1571+
conn.close_statement(self.name.as_slice())
1572+
}
1573+
1574+
pub fn execute<'b, I, J>(&self, mut rows: I) -> PostgresResult<()>
1575+
where I: Iterator<J>, J: Iterator<&'b ToSql + 'b> {
1576+
let mut conn = self.conn.conn.borrow_mut();
1577+
1578+
try_pg!(conn.write_messages([
1579+
Bind {
1580+
portal: "",
1581+
statement: self.name.as_slice(),
1582+
formats: [],
1583+
values: [],
1584+
result_formats: []
1585+
},
1586+
Execute {
1587+
portal: "",
1588+
max_rows: 0,
1589+
},
1590+
Sync]));
1591+
1592+
match try_pg!(conn.read_message_()) {
1593+
BindComplete => {},
1594+
ErrorResponse { fields } => {
1595+
try!(conn.wait_for_ready());
1596+
return Err(PgDbError(PostgresDbError::new(fields)));
1597+
}
1598+
_ => {
1599+
conn.desynchronized = true;
1600+
return Err(PgBadResponse);
1601+
}
1602+
}
1603+
1604+
match try_pg!(conn.read_message_()) {
1605+
CopyInResponse { .. } => {}
1606+
_ => {
1607+
conn.desynchronized = true;
1608+
return Err(PgBadResponse);
1609+
}
1610+
}
1611+
1612+
let mut buf = MemWriter::new();
1613+
let _ = buf.write(b"PGCOPY\n\xff\r\n\x00");
1614+
let _ = buf.write_be_i32(0);
1615+
let _ = buf.write_be_i32(0);
1616+
1617+
for mut row in rows {
1618+
let _ = buf.write_be_i16(self.column_types.len() as i16);
1619+
1620+
let mut count = 0;
1621+
for (i, (val, ty)) in row.by_ref().zip(self.column_types.iter()).enumerate() {
1622+
match try!(val.to_sql(ty)) {
1623+
(_, None) => {
1624+
let _ = buf.write_be_i32(-1);
1625+
}
1626+
(_, Some(val)) => {
1627+
let _ = buf.write_be_i32(val.len() as i32);
1628+
let _ = buf.write(val.as_slice());
1629+
}
1630+
}
1631+
count = i+1;
1632+
}
1633+
1634+
if row.next().is_some() || count != self.column_types.len() {
1635+
// FIXME
1636+
fail!()
1637+
}
1638+
1639+
try_pg!(conn.write_messages([
1640+
CopyData {
1641+
data: buf.unwrap().as_slice()
1642+
}]));
1643+
buf = MemWriter::new();
1644+
}
1645+
1646+
let _ = buf.write_be_i16(-1);
1647+
try_pg!(conn.write_messages([
1648+
CopyData {
1649+
data: buf.unwrap().as_slice(),
1650+
},
1651+
CopyDone,
1652+
Sync]));
1653+
1654+
match try_pg!(conn.read_message_()) {
1655+
CommandComplete { .. } => {},
1656+
ErrorResponse { fields } => {
1657+
try!(conn.wait_for_ready());
1658+
return Err(PgDbError(PostgresDbError::new(fields)));
1659+
}
1660+
_ => {
1661+
conn.desynchronized = true;
1662+
return Err(PgBadResponse);
1663+
}
1664+
}
1665+
1666+
conn.wait_for_ready()
1667+
}
1668+
1669+
/// Consumes the statement, clearing it from the Postgres session.
1670+
///
1671+
/// Functionally identical to the `Drop` implementation of the
1672+
/// `PostgresCopyInStatement` except that it returns any error to the
1673+
/// caller.
1674+
pub fn finish(mut self) -> PostgresResult<()> {
1675+
self.finished = true;
1676+
self.finish_inner()
1677+
}
1678+
}

src/message.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ impl<W: Writer> WriteMessage for W {
193193
try!(buf.write(data));
194194
}
195195
CopyDone => {
196-
ident = Some(b'C');
196+
ident = Some(b'c');
197197
}
198198
CopyFail { message } => {
199199
ident = Some(b'f');

src/util.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use std::io::IoResult;
2+
3+
pub fn comma_join<'a, W, I>(writer: &mut W, mut strs: I) -> IoResult<()>
4+
where W: Writer, I: Iterator<&'a str> {
5+
let mut first = true;
6+
for str_ in strs {
7+
if !first {
8+
try!(write!(writer, ", "));
9+
}
10+
first = false;
11+
try!(write!(writer, "{}", str_));
12+
}
13+
Ok(())
14+
}

tests/test.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use postgres::error::{PgConnectDbError,
3535
InvalidCatalogName,
3636
PgWrongTransaction,
3737
CardinalityViolation};
38-
use postgres::types::{PgInt4, PgVarchar};
38+
use postgres::types::{PgInt4, PgVarchar, ToSql};
3939

4040
macro_rules! or_fail(
4141
($e:expr) => (
@@ -708,3 +708,18 @@ fn test_execute_copy_from_err() {
708708
_ => fail!("Expected error"),
709709
}
710710
}
711+
712+
#[test]
713+
fn test_copy_in() {
714+
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
715+
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", []));
716+
717+
let stmt = or_fail!(conn.prepare_copy_in("foo", ["id", "name"]));
718+
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32, &None::<String>]];
719+
720+
or_fail!(stmt.execute(data.iter().map(|r| r.iter().map(|&e| e))));
721+
722+
let stmt = or_fail!(conn.prepare("SELECT id, name FROM foo ORDER BY id"));
723+
assert_eq!(vec![(0i32, Some("Steven".to_string())), (1, None)],
724+
or_fail!(stmt.query([])).map(|r| (r.get(0u), r.get(1u))).collect());
725+
}

0 commit comments

Comments
 (0)