Skip to content

Commit f07ebc7

Browse files
committed
Support nested transactions
1 parent 2311cea commit f07ebc7

4 files changed

Lines changed: 58 additions & 17 deletions

File tree

postgres/src/client.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,20 @@ impl Client {
319319
Ok(Iter::new(self.0.simple_query(query)))
320320
}
321321

322+
/// Executes a sequence of SQL statements using the simple query protocol.
323+
///
324+
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
325+
/// point. This is intended for use when, for example, initializing a database schema.
326+
///
327+
/// # Warning
328+
///
329+
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
330+
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
331+
/// them to this method!
332+
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
333+
executor::block_on(self.0.batch_execute(query))
334+
}
335+
322336
/// Begins a new database transaction.
323337
///
324338
/// The transaction will roll back by default - use the `commit` method to commit it.

postgres/src/test.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,12 @@ fn transaction_drop() {
9696
assert_eq!(rows.len(), 0);
9797
}
9898

99-
/*
10099
#[test]
101100
fn nested_transactions() {
102101
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
103102

104103
client
105-
.simple_query("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
104+
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
106105
.unwrap();
107106

108107
let mut transaction = client.transaction().unwrap();
@@ -147,7 +146,6 @@ fn nested_transactions() {
147146
assert_eq!(rows[1].get::<_, i32>(0), 3);
148147
assert_eq!(rows[2].get::<_, i32>(0), 4);
149148
}
150-
*/
151149

152150
#[test]
153151
fn copy_in() {

postgres/src/transaction.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,14 @@ impl<'a> Transaction<'a> {
150150
Ok(Iter::new(self.0.simple_query(query)))
151151
}
152152

153-
// /// Like `Client::transaction`.
154-
// pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
155-
// let depth = self.depth + 1;
156-
// self.client
157-
// .simple_query(&format!("SAVEPOINT sp{}", depth))?;
158-
// Ok(Transaction {
159-
// client: self.client,
160-
// depth,
161-
// done: false,
162-
// })
163-
// }
153+
/// Like `Client::batch_execute`.
154+
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
155+
executor::block_on(self.0.batch_execute(query))
156+
}
157+
158+
/// Like `Client::transaction`.
159+
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
160+
let transaction = executor::block_on(self.0.transaction())?;
161+
Ok(Transaction(transaction))
162+
}
164163
}

tokio-postgres/src/transaction.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
2020
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
2121
pub struct Transaction<'a> {
2222
client: &'a mut Client,
23+
depth: u32,
2324
done: bool,
2425
}
2526

@@ -30,7 +31,12 @@ impl<'a> Drop for Transaction<'a> {
3031
}
3132

3233
let mut buf = vec![];
33-
frontend::query("ROLLBACK", &mut buf).unwrap();
34+
let query = if self.depth == 0 {
35+
"ROLLBACK".to_string()
36+
} else {
37+
format!("ROLLBACK TO sp{}", self.depth)
38+
};
39+
frontend::query(&query, &mut buf).unwrap();
3440
let _ = self
3541
.client
3642
.inner()
@@ -42,22 +48,33 @@ impl<'a> Transaction<'a> {
4248
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
4349
Transaction {
4450
client,
51+
depth: 0,
4552
done: false,
4653
}
4754
}
4855

4956
/// Consumes the transaction, committing all changes made within it.
5057
pub async fn commit(mut self) -> Result<(), Error> {
5158
self.done = true;
52-
self.client.batch_execute("COMMIT").await
59+
let query = if self.depth == 0 {
60+
"COMMIT".to_string()
61+
} else {
62+
format!("RELEASE sp{}", self.depth)
63+
};
64+
self.client.batch_execute(&query).await
5365
}
5466

5567
/// Rolls the transaction back, discarding all changes made within it.
5668
///
5769
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
5870
pub async fn rollback(mut self) -> Result<(), Error> {
5971
self.done = true;
60-
self.client.batch_execute("ROLLBACK").await
72+
let query = if self.depth == 0 {
73+
"ROLLBACK".to_string()
74+
} else {
75+
format!("ROLLBACK TO sp{}", self.depth)
76+
};
77+
self.client.batch_execute(&query).await
6178
}
6279

6380
/// Like `Client::prepare`.
@@ -227,4 +244,17 @@ impl<'a> Transaction<'a> {
227244
{
228245
self.client.cancel_query_raw(stream, tls)
229246
}
247+
248+
/// Like `Client::transaction`.
249+
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
250+
let depth = self.depth + 1;
251+
let query = format!("SAVEPOINT sp{}", depth);
252+
self.batch_execute(&query).await?;
253+
254+
Ok(Transaction {
255+
client: self.client,
256+
depth,
257+
done: false,
258+
})
259+
}
230260
}

0 commit comments

Comments
 (0)