Skip to content

Commit 7929d07

Browse files
committed
Fix a desynchronization issue
1 parent 6484836 commit 7929d07

2 files changed

Lines changed: 30 additions & 4 deletions

File tree

src/lib.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ impl<'a> PostgresCopyInStatement<'a> {
15831583

15841584
/// Executes the prepared statement.
15851585
///
1586-
/// Each iterator retuned by the `rows` iterator will be interpreted as
1586+
/// Each iterator returned by the `rows` iterator will be interpreted as
15871587
/// providing a single result row.
15881588
///
15891589
/// Returns the number of rows copied.
@@ -1637,14 +1637,22 @@ impl<'a> PostgresCopyInStatement<'a> {
16371637
loop {
16381638
match (row.next(), types.next()) {
16391639
(Some(val), Some(ty)) => {
1640-
match try!(val.to_sql(ty)) {
1641-
None => {
1640+
match val.to_sql(ty) {
1641+
Ok(None) => {
16421642
let _ = buf.write_be_i32(-1);
16431643
}
1644-
Some(val) => {
1644+
Ok(Some(val)) => {
16451645
let _ = buf.write_be_i32(val.len() as i32);
16461646
let _ = buf.write(val[]);
16471647
}
1648+
Err(err) => {
1649+
// FIXME this is not the right way to handle this
1650+
try_pg_desync!(conn, conn.stream.write_message(
1651+
&CopyFail {
1652+
message: err.to_string()[],
1653+
}));
1654+
break 'l;
1655+
}
16481656
}
16491657
}
16501658
(Some(_), None) | (None, Some(_)) => {

tests/test.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,24 @@ fn test_copy_in_bad_column_count() {
752752
or_fail!(conn.execute("SELECT 1", []));
753753
}
754754

755+
#[test]
756+
fn test_copy_in_bad_type() {
757+
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
758+
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", []));
759+
760+
let stmt = or_fail!(conn.prepare_copy_in("foo", ["id", "name"]));
761+
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32, &2i32]];
762+
763+
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
764+
match res {
765+
Err(PgDbError(ref err)) if err.message[].contains("Unexpected type PgVarchar") => {}
766+
Err(err) => fail!("unexpected error {}", err),
767+
_ => fail!("Expected error"),
768+
}
769+
770+
or_fail!(conn.execute("SELECT 1", []));
771+
}
772+
755773
#[test]
756774
fn test_batch_execute_copy_from_err() {
757775
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));

0 commit comments

Comments
 (0)