Skip to content

Commit 249db6b

Browse files
committed
Correctly handle bad column counts in copy
1 parent f88f908 commit 249db6b

2 files changed

Lines changed: 49 additions & 15 deletions

File tree

src/lib.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,26 +1614,33 @@ impl<'a> PostgresCopyInStatement<'a> {
16141614
let _ = buf.write_be_i32(0);
16151615
let _ = buf.write_be_i32(0);
16161616

1617-
for mut row in rows {
1617+
for row in rows {
16181618
let _ = buf.write_be_i16(self.column_types.len() as i16);
16191619

1620-
let mut count = 0;
1621-
for (i, (val, ty)) in row.by_ref().zip(self.column_types.iter()).enumerate() {
1622-
match try!(val.to_sql(ty)) {
1623-
(_, None) => {
1624-
let _ = buf.write_be_i32(-1);
1620+
let mut row = row.fuse();
1621+
let mut types = self.column_types.iter();
1622+
loop {
1623+
match (row.next(), types.next()) {
1624+
(Some(val), Some(ty)) => {
1625+
match try!(val.to_sql(ty)) {
1626+
(_, None) => {
1627+
let _ = buf.write_be_i32(-1);
1628+
}
1629+
(_, Some(val)) => {
1630+
let _ = buf.write_be_i32(val.len() as i32);
1631+
let _ = buf.write(val.as_slice());
1632+
}
1633+
}
16251634
}
1626-
(_, Some(val)) => {
1627-
let _ = buf.write_be_i32(val.len() as i32);
1628-
let _ = buf.write(val.as_slice());
1635+
(Some(_), None) | (None, Some(_)) => {
1636+
try_pg!(conn.stream.write_message(
1637+
&CopyFail {
1638+
message: "Invalid column count",
1639+
}));
1640+
break;
16291641
}
1642+
(None, None) => break
16301643
}
1631-
count = i+1;
1632-
}
1633-
1634-
if row.next().is_some() || count != self.column_types.len() {
1635-
// FIXME
1636-
fail!()
16371644
}
16381645

16391646
try_pg!(conn.stream.write_message(

tests/test.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,30 @@ fn test_copy_in() {
723723
assert_eq!(vec![(0i32, Some("Steven".to_string())), (1, None)],
724724
or_fail!(stmt.query([])).map(|r| (r.get(0u), r.get(1u))).collect());
725725
}
726+
727+
#[test]
728+
fn test_copy_in_bad_column_count() {
729+
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
730+
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", []));
731+
732+
let stmt = or_fail!(conn.prepare_copy_in("foo", ["id", "name"]));
733+
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32]];
734+
735+
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
736+
match res {
737+
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
738+
Err(err) => fail!("unexpected error {}", err),
739+
_ => fail!("Expected error"),
740+
}
741+
742+
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32, &"Steven".to_string(), &1i32]];
743+
744+
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
745+
match res {
746+
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
747+
Err(err) => fail!("unexpected error {}", err),
748+
_ => fail!("Expected error"),
749+
}
750+
751+
or_fail!(conn.execute("SELECT 1", []));
752+
}

0 commit comments

Comments
 (0)