Skip to content

Commit 4afd523

Browse files
committed
Transaction support
1 parent 88399a7 commit 4afd523

4 files changed

Lines changed: 233 additions & 130 deletions

File tree

tokio-postgres/src/client.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::tls::TlsConnect;
77
use crate::types::{Oid, ToSql, Type};
88
#[cfg(feature = "runtime")]
99
use crate::Socket;
10-
use crate::{cancel_query, cancel_query_raw, query};
10+
use crate::{cancel_query, cancel_query_raw, query, Transaction};
1111
use crate::{prepare, SimpleQueryMessage};
1212
use crate::{simple_query, Row};
1313
use crate::{Error, Statement};
@@ -274,12 +274,21 @@ impl Client {
274274
simple_query::batch_execute(self.inner(), query)
275275
}
276276

277+
/// Begins a new database transaction.
278+
///
279+
/// The transaction will roll back by default - use the `commit` method to commit it.
280+
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
281+
self.batch_execute("BEGIN").await?;
282+
Ok(Transaction::new(self))
283+
}
284+
277285
/// Attempts to cancel an in-progress query.
278286
///
279287
/// The server provides no information about whether a cancellation attempt was successful or not. An error will
280288
/// only be returned if the client was unable to connect to the database.
281289
///
282290
/// Requires the `runtime` Cargo feature (enabled by default).
291+
#[cfg(feature = "runtime")]
283292
pub fn cancel_query<T>(&mut self, tls: T) -> impl Future<Output = Result<(), Error>>
284293
where
285294
T: MakeTlsConnect<Socket>,

tokio-postgres/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108

109109
pub use crate::client::Client;
110110
pub use crate::config::Config;
111+
pub use crate::transaction::Transaction;
111112
pub use crate::connection::Connection;
112113
use crate::error::DbError;
113114
pub use crate::error::Error;
@@ -141,6 +142,7 @@ mod simple_query;
141142
#[cfg(feature = "runtime")]
142143
mod socket;
143144
mod statement;
145+
mod transaction;
144146
pub mod tls;
145147
pub mod types;
146148

tokio-postgres/src/transaction.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
use crate::codec::FrontendMessage;
2+
use crate::connection::RequestMessages;
3+
#[cfg(feature = "runtime")]
4+
use crate::tls::MakeTlsConnect;
5+
use crate::tls::TlsConnect;
6+
use crate::types::{ToSql, Type};
7+
#[cfg(feature = "runtime")]
8+
use crate::Socket;
9+
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
10+
use futures::Stream;
11+
use postgres_protocol::message::frontend;
12+
use std::future::Future;
13+
use tokio::io::{AsyncRead, AsyncWrite};
14+
15+
/// A representation of a PostgreSQL database transaction.
16+
///
17+
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
18+
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
19+
pub struct Transaction<'a> {
20+
client: &'a mut Client,
21+
done: bool,
22+
}
23+
24+
impl<'a> Drop for Transaction<'a> {
25+
fn drop(&mut self) {
26+
if self.done {
27+
return;
28+
}
29+
30+
let mut buf = vec![];
31+
frontend::query("ROLLBACK", &mut buf).unwrap();
32+
let _ = self
33+
.client
34+
.inner()
35+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
36+
}
37+
}
38+
39+
impl<'a> Transaction<'a> {
40+
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
41+
Transaction {
42+
client,
43+
done: false,
44+
}
45+
}
46+
47+
/// Consumes the transaction, committing all changes made within it.
48+
pub async fn commit(mut self) -> Result<(), Error> {
49+
self.done = true;
50+
self.client.batch_execute("COMMIT").await
51+
}
52+
53+
/// Rolls the transaction back, discarding all changes made within it.
54+
///
55+
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
56+
pub async fn rollback(mut self) -> Result<(), Error> {
57+
self.done = true;
58+
self.client.batch_execute("ROLLBACK").await
59+
}
60+
61+
/// Like `Client::prepare`.
62+
pub fn prepare(&mut self, query: &str) -> impl Future<Output = Result<Statement, Error>> {
63+
self.client.prepare(query)
64+
}
65+
66+
/// Like `Client::prepare_typed`.
67+
pub fn prepare_typed(
68+
&mut self,
69+
query: &str,
70+
parameter_types: &[Type],
71+
) -> impl Future<Output = Result<Statement, Error>> {
72+
self.client.prepare_typed(query, parameter_types)
73+
}
74+
75+
/// Like `Client::query`.
76+
pub fn query(
77+
&mut self,
78+
statement: &Statement,
79+
params: &[&dyn ToSql],
80+
) -> impl Stream<Item = Result<Row, Error>> {
81+
self.client.query(statement, params)
82+
}
83+
84+
/// Like `Client::query_iter`.
85+
pub fn query_iter<'b, I>(
86+
&mut self,
87+
statement: &Statement,
88+
params: I,
89+
) -> impl Stream<Item = Result<Row, Error>> + 'static
90+
where
91+
I: IntoIterator<Item = &'b dyn ToSql>,
92+
I::IntoIter: ExactSizeIterator,
93+
{
94+
// https://github.com/rust-lang/rust/issues/63032
95+
let buf = query::encode(statement, params);
96+
query::query(self.client.inner(), statement.clone(), buf)
97+
}
98+
99+
/// Like `Client::execute`.
100+
pub fn execute(
101+
&mut self,
102+
statement: &Statement,
103+
params: &[&dyn ToSql],
104+
) -> impl Future<Output = Result<u64, Error>> {
105+
self.client.execute(statement, params)
106+
}
107+
108+
/// Like `Client::execute_iter`.
109+
pub fn execute_iter<'b, I>(
110+
&mut self,
111+
statement: &Statement,
112+
params: I,
113+
) -> impl Future<Output = Result<u64, Error>>
114+
where
115+
I: IntoIterator<Item = &'b dyn ToSql>,
116+
I::IntoIter: ExactSizeIterator,
117+
{
118+
// https://github.com/rust-lang/rust/issues/63032
119+
let buf = query::encode(statement, params);
120+
query::execute(self.client.inner(), buf)
121+
}
122+
123+
/// Like `Client::simple_query`.
124+
pub fn simple_query(
125+
&mut self,
126+
query: &str,
127+
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> {
128+
self.client.simple_query(query)
129+
}
130+
131+
/// Like `Client::batch_execute`.
132+
pub fn batch_execute(&mut self, query: &str) -> impl Future<Output = Result<(), Error>> {
133+
self.client.batch_execute(query)
134+
}
135+
136+
/// Like `Client::cancel_query`.
137+
#[cfg(feature = "runtime")]
138+
pub fn cancel_query<T>(&mut self, tls: T) -> impl Future<Output = Result<(), Error>>
139+
where
140+
T: MakeTlsConnect<Socket>,
141+
{
142+
self.client.cancel_query(tls)
143+
}
144+
145+
/// Like `Client::cancel_query_raw`.
146+
pub fn cancel_query_raw<S, T>(
147+
&mut self,
148+
stream: S,
149+
tls: T,
150+
) -> impl Future<Output = Result<(), Error>>
151+
where
152+
S: AsyncRead + AsyncWrite + Unpin,
153+
T: TlsConnect<S>,
154+
{
155+
self.client.cancel_query_raw(stream, tls)
156+
}
157+
}

0 commit comments

Comments
 (0)