Skip to content

Commit a4a625a

Browse files
committed
Detect direct queries to COPY FROM
1 parent 945714b commit a4a625a

3 files changed

Lines changed: 68 additions & 1 deletion

File tree

src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ use message::{AuthenticationCleartextPassword,
108108
BackendMessage,
109109
BindComplete,
110110
CommandComplete,
111+
CopyInResponse,
111112
DataRow,
112113
EmptyQueryResponse,
113114
ErrorResponse,
@@ -124,6 +125,7 @@ use message::{AuthenticationCleartextPassword,
124125
use message::{Bind,
125126
CancelRequest,
126127
Close,
128+
CopyFail,
127129
Describe,
128130
Execute,
129131
FrontendMessage,
@@ -1116,7 +1118,7 @@ impl<'conn> PostgresStatement<'conn> {
11161118
}
11171119
_ => {
11181120
conn.desynchronized = true;
1119-
return Err(PgBadResponse);
1121+
Err(PgBadResponse)
11201122
}
11211123
}
11221124
}
@@ -1190,6 +1192,13 @@ impl<'conn> PostgresStatement<'conn> {
11901192
num = 0;
11911193
break;
11921194
}
1195+
CopyInResponse { .. } => {
1196+
try_pg!(conn.write_messages([
1197+
CopyFail {
1198+
message: "COPY queries cannot be directly executed",
1199+
},
1200+
Sync]));
1201+
}
11931202
_ => {
11941203
conn.desynchronized = true;
11951204
return Err(PgBadResponse);
@@ -1305,6 +1314,14 @@ impl<'stmt> PostgresRows<'stmt> {
13051314
try!(conn.wait_for_ready());
13061315
return Err(PgDbError(PostgresDbError::new(fields)));
13071316
}
1317+
CopyInResponse { .. } => {
1318+
try_pg!(conn.write_messages([
1319+
CopyFail {
1320+
message: "COPY queries cannot be directly executed",
1321+
},
1322+
Sync]));
1323+
continue;
1324+
}
13081325
_ => {
13091326
conn.desynchronized = true;
13101327
return Err(PgBadResponse);

src/message.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ pub enum BackendMessage {
2626
CommandComplete {
2727
pub tag: String,
2828
},
29+
CopyInResponse {
30+
pub format: u8,
31+
pub column_formats: Vec<u16>,
32+
},
2933
DataRow {
3034
pub row: Vec<Option<Vec<u8>>>
3135
},
@@ -86,6 +90,13 @@ pub enum FrontendMessage<'a> {
8690
pub variant: u8,
8791
pub name: &'a str
8892
},
93+
CopyData {
94+
pub data: &'a [u8],
95+
},
96+
CopyDone,
97+
CopyFail {
98+
pub message: &'a str
99+
},
89100
Describe {
90101
pub variant: u8,
91102
pub name: &'a str
@@ -177,6 +188,17 @@ impl<W: Writer> WriteMessage for W {
177188
try!(buf.write_u8(variant));
178189
try!(buf.write_cstr(name));
179190
}
191+
CopyData { data } => {
192+
ident = Some(b'd');
193+
try!(buf.write(data));
194+
}
195+
CopyDone => {
196+
ident = Some(b'C');
197+
}
198+
CopyFail { message } => {
199+
ident = Some(b'f');
200+
try!(buf.write_cstr(message));
201+
}
180202
Describe { variant, name } => {
181203
ident = Some(b'D');
182204
try!(buf.write_u8(variant));
@@ -276,6 +298,17 @@ impl<R: Reader> ReadMessage for R {
276298
b'C' => CommandComplete { tag: try!(buf.read_cstr()) },
277299
b'D' => try!(read_data_row(&mut buf)),
278300
b'E' => ErrorResponse { fields: try!(read_fields(&mut buf)) },
301+
b'G' => {
302+
let format = try!(buf.read_u8());
303+
let mut column_formats = vec![];
304+
for _ in range(0, try!(buf.read_be_u16())) {
305+
column_formats.push(try!(buf.read_be_u16()));
306+
}
307+
CopyInResponse {
308+
format: format,
309+
column_formats: column_formats,
310+
}
311+
}
279312
b'I' => EmptyQueryResponse,
280313
b'K' => BackendKeyData {
281314
process_id: try!(buf.read_be_u32()),

tests/test.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,20 @@ fn test_md5_pass_wrong_pass() {
691691
_ => fail!("Expected error")
692692
}
693693
}
694+
695+
#[test]
696+
fn test_execute_copy_from_err() {
697+
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
698+
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", []));
699+
let stmt = or_fail!(conn.prepare("COPY foo (id) FROM STDIN"));
700+
match stmt.execute([]) {
701+
Err(PgDbError(ref err)) if err.message.as_slice().contains("COPY") => {}
702+
Err(err) => fail!("Unexptected error {}", err),
703+
_ => fail!("Expected error"),
704+
}
705+
match stmt.query([]) {
706+
Err(PgDbError(ref err)) if err.message.as_slice().contains("COPY") => {}
707+
Err(err) => fail!("Unexptected error {}", err),
708+
_ => fail!("Expected error"),
709+
}
710+
}

0 commit comments

Comments
 (0)