Skip to content

Commit ec7dd96

Browse files
committed
Disallow use of wrong transactions
We don't want to allow statements to outlive the active transaction in which they were prepared.
1 parent 04ba539 commit ec7dd96

3 files changed

Lines changed: 45 additions & 29 deletions

File tree

src/error.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ pub enum PostgresError {
535535
PgInvalidColumn,
536536
/// A value was NULL but converted to a non-nullable Rust type
537537
PgWasNull,
538+
/// An attempt was made to prepare a statement or start a transaction on an
539+
/// object other than the active transaction
540+
PgWrongTransaction,
538541
}
539542

540543
impl fmt::Show for PostgresError {
@@ -554,6 +557,10 @@ impl fmt::Show for PostgresError {
554557
PgWrongType(ref ty) => write!(fmt, "Unexpected type {}", ty),
555558
PgInvalidColumn => write!(fmt, "Invalid column"),
556559
PgWasNull => write!(fmt, "The value was NULL"),
560+
PgWrongTransaction =>
561+
write!(fmt, "An attempt was made to prepare a statement or \
562+
start a transaction on an object other than the \
563+
active transaction"),
557564
}
558565
}
559566
}

src/lib.rs

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ use error::{InvalidUrl,
105105
PostgresDbError,
106106
PostgresError,
107107
UnsupportedAuthentication,
108-
PgWrongConnection};
108+
PgWrongConnection,
109+
PgWrongTransaction};
109110
use io::{MaybeSslStream,
110111
InternalStream};
111112
use message::{AuthenticationCleartextPassword,
@@ -415,6 +416,7 @@ struct InnerPostgresConnection {
415416
unknown_types: HashMap<Oid, StrBuf>,
416417
desynchronized: bool,
417418
finished: bool,
419+
trans_depth: u32,
418420
canary: u32,
419421
}
420422

@@ -442,6 +444,7 @@ impl InnerPostgresConnection {
442444
unknown_types: HashMap::new(),
443445
desynchronized: false,
444446
finished: false,
447+
trans_depth: 0,
445448
canary: CANARY,
446449
};
447450

@@ -805,7 +808,11 @@ impl PostgresConnection {
805808
/// };
806809
pub fn prepare<'a>(&'a self, query: &str)
807810
-> PostgresResult<PostgresStatement<'a>> {
808-
self.conn.borrow_mut().prepare(query, self)
811+
let mut conn = self.conn.borrow_mut();
812+
if conn.trans_depth != 0 {
813+
return Err(PgWrongTransaction);
814+
}
815+
conn.prepare(query, self)
809816
}
810817

811818
/// Begins a new transaction.
@@ -837,11 +844,15 @@ impl PostgresConnection {
837844
pub fn transaction<'a>(&'a self)
838845
-> PostgresResult<PostgresTransaction<'a>> {
839846
check_desync!(self);
847+
if self.conn.borrow().trans_depth != 0 {
848+
return Err(PgWrongTransaction);
849+
}
840850
try!(self.quick_query("BEGIN"));
851+
self.conn.borrow_mut().trans_depth += 1;
841852
Ok(PostgresTransaction {
842853
conn: self,
843854
commit: Cell::new(true),
844-
nested: false,
855+
depth: 1,
845856
finished: false,
846857
})
847858
}
@@ -921,7 +932,7 @@ pub enum SslMode {
921932
pub struct PostgresTransaction<'conn> {
922933
conn: &'conn PostgresConnection,
923934
commit: Cell<bool>,
924-
nested: bool,
935+
depth: u32,
925936
finished: bool,
926937
}
927938

@@ -936,51 +947,50 @@ impl<'conn> Drop for PostgresTransaction<'conn> {
936947

937948
impl<'conn> PostgresTransaction<'conn> {
938949
fn finish_inner(&mut self) -> PostgresResult<()> {
950+
debug_assert!(self.depth == self.conn.conn.borrow().trans_depth);
939951
let rollback = task::failing() || !self.commit.get();
940-
let query = match (rollback, self.nested) {
952+
let query = match (rollback, self.depth != 1) {
941953
(true, true) => "ROLLBACK TO sp",
942954
(true, false) => "ROLLBACK",
943955
(false, true) => "RELEASE sp",
944956
(false, false) => "COMMIT",
945957
};
958+
self.conn.conn.borrow_mut().trans_depth -= 1;
946959
self.conn.quick_query(query).map(|_| ())
947960
}
948961

949962
/// Like `PostgresConnection::prepare`.
950963
pub fn prepare<'a>(&'a self, query: &str)
951964
-> PostgresResult<PostgresStatement<'a>> {
952-
self.conn.prepare(query)
965+
if self.conn.conn.borrow().trans_depth != self.depth {
966+
return Err(PgWrongTransaction);
967+
}
968+
self.conn.conn.borrow_mut().prepare(query, self.conn)
953969
}
954970

955971
/// Like `PostgresConnection::execute`.
956972
pub fn execute(&self, query: &str, params: &[&ToSql])
957973
-> PostgresResult<uint> {
958-
self.conn.execute(query, params)
974+
self.prepare(query).and_then(|s| s.execute(params))
959975
}
960976

961977
/// Like `PostgresConnection::transaction`.
962978
pub fn transaction<'a>(&'a self)
963979
-> PostgresResult<PostgresTransaction<'a>> {
964980
check_desync!(self.conn);
981+
if self.conn.conn.borrow().trans_depth != self.depth {
982+
return Err(PgWrongTransaction);
983+
}
965984
try!(self.conn.quick_query("SAVEPOINT sp"));
985+
self.conn.conn.borrow_mut().trans_depth += 1;
966986
Ok(PostgresTransaction {
967987
conn: self.conn,
968988
commit: Cell::new(true),
969-
nested: true,
989+
depth: self.depth + 1,
970990
finished: false,
971991
})
972992
}
973993

974-
/// Like `PostgresConnection::notifications`.
975-
pub fn notifications<'a>(&'a self) -> PostgresNotifications<'a> {
976-
self.conn.notifications()
977-
}
978-
979-
/// Like `PostgresConnection::is_desynchronized`.
980-
pub fn is_desynchronized(&self) -> bool {
981-
self.conn.is_desynchronized()
982-
}
983-
984994
/// Determines if the transaction is currently set to commit or roll back.
985995
pub fn will_commit(&self) -> bool {
986996
self.commit.get()
@@ -1447,12 +1457,8 @@ impl RowIndex for int {
14471457
impl<'a> RowIndex for &'a str {
14481458
#[inline]
14491459
fn idx(&self, stmt: &PostgresStatement) -> Option<uint> {
1450-
for (i, desc) in stmt.result_descriptions().iter().enumerate() {
1451-
if desc.name.as_slice() == *self {
1452-
return Some(i);
1453-
}
1454-
}
1455-
None
1460+
stmt.result_descriptions().iter()
1461+
.position(|d| d.name.as_slice() == *self)
14561462
}
14571463
}
14581464

src/test.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ fn test_nested_transactions() {
225225
}
226226
}
227227

228-
let stmt = or_fail!(conn.prepare("SELECT * FROM foo ORDER BY id"));
228+
let stmt = or_fail!(trans1.prepare("SELECT * FROM foo ORDER BY id"));
229229
let result = or_fail!(stmt.query([]));
230230

231231
assert_eq!(vec![1i32, 2, 4, 6], result.map(|row| row[1]).collect());
@@ -277,10 +277,13 @@ fn test_nested_transactions_finish() {
277277
assert!(trans2.finish().is_ok());
278278
}
279279

280-
let stmt = or_fail!(conn.prepare("SELECT * FROM foo ORDER BY id"));
281-
let result = or_fail!(stmt.query([]));
280+
// in a block to unborrow trans1 for the finish call
281+
{
282+
let stmt = or_fail!(trans1.prepare("SELECT * FROM foo ORDER BY id"));
283+
let result = or_fail!(stmt.query([]));
282284

283-
assert_eq!(vec![1i32, 2, 4, 6], result.map(|row| row[1]).collect());
285+
assert_eq!(vec![1i32, 2, 4, 6], result.map(|row| row[1]).collect());
286+
}
284287

285288
trans1.set_rollback();
286289
assert!(trans1.finish().is_ok());
@@ -332,7 +335,7 @@ fn test_lazy_query() {
332335
for value in values.iter() {
333336
or_fail!(stmt.execute([value as &ToSql]));
334337
}
335-
let stmt = or_fail!(conn.prepare("SELECT id FROM foo ORDER BY id"));
338+
let stmt = or_fail!(trans.prepare("SELECT id FROM foo ORDER BY id"));
336339
let result = or_fail!(trans.lazy_query(&stmt, [], 2));
337340
assert_eq!(values, result.map(|row| row.unwrap()[1]).collect());
338341

0 commit comments

Comments
 (0)