Skip to content

Commit 53d7596

Browse files
authored
Merge pull request diesel-rs#1466 from diesel-rs/sg-r2d2
Merge r2d2-diesel into Diesel itself
2 parents 1285f53 + df6d43d commit 53d7596

5 files changed

Lines changed: 244 additions & 16 deletions

File tree

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
88

99
### Added
1010

11+
* `r2d2-diesel` has been merged into Diesel proper. You should no longer rely
12+
directly on `r2d2-diesel` or `r2d2`. The functionality of both is exposed from
13+
`diesel::r2d2`.
14+
15+
* `r2d2::PooledConnection` now implements `Connection`. This means that you
16+
should no longer need to write `&*connection` when using `r2d2`.
17+
1118
* The `BINARY` column type name is now supported for SQLite.
1219

1320
* The `QueryId` trait can now be derived.

diesel/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ num-traits = { version = "0.1.35", optional = true }
3131
num-integer = { version = "0.1.32", optional = true }
3232
bigdecimal = { version = "0.0.10", optional = true }
3333
bitflags = { version = "1.0", optional = true }
34+
r2d2 = { version = ">= 0.7, < 0.9", optional = true }
3435

3536
[dev-dependencies]
3637
cfg-if = "0.1.0"
@@ -40,7 +41,7 @@ tempdir = "^0.3.4"
4041

4142
[features]
4243
default = ["with-deprecated"]
43-
extras = ["chrono", "serde_json", "uuid", "deprecated-time", "network-address", "numeric"]
44+
extras = ["chrono", "serde_json", "uuid", "deprecated-time", "network-address", "numeric", "r2d2"]
4445
unstable = []
4546
lint = ["clippy"]
4647
large-tables = []

diesel/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ pub mod insertable;
146146
pub mod query_builder;
147147
pub mod query_dsl;
148148
pub mod query_source;
149+
#[cfg(feature = "r2d2")]
150+
pub mod r2d2;
149151
pub mod result;
150152
pub mod serialize;
151153
#[macro_use]

diesel/src/r2d2.rs

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//! Connection pooling via r2d2
2+
3+
extern crate r2d2;
4+
5+
pub use self::r2d2::*;
6+
7+
use std::convert::Into;
8+
use std::fmt;
9+
use std::marker::PhantomData;
10+
11+
use backend::UsesAnsiSavepointSyntax;
12+
use deserialize::QueryableByName;
13+
use prelude::*;
14+
use connection::{AnsiTransactionManager, SimpleConnection};
15+
use query_builder::{AsQuery, QueryFragment, QueryId};
16+
use sql_types::HasSqlType;
17+
18+
/// An r2d2 connection manager for use with Diesel.
19+
///
20+
/// See the [r2d2 documentation] for usage examples.
21+
///
22+
/// [r2d2 documentation]: ../../r2d2
23+
#[derive(Debug, Clone)]
24+
pub struct ConnectionManager<T> {
25+
database_url: String,
26+
_marker: PhantomData<T>,
27+
}
28+
29+
unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
30+
31+
impl<T> ConnectionManager<T> {
32+
/// Returns a new connection manager,
33+
/// which establishes connections to the given database URL.
34+
pub fn new<S: Into<String>>(database_url: S) -> Self {
35+
ConnectionManager {
36+
database_url: database_url.into(),
37+
_marker: PhantomData,
38+
}
39+
}
40+
}
41+
42+
/// The error used when managing connections with `r2d2`.
43+
#[derive(Debug)]
44+
pub enum Error {
45+
/// An error occurred establishing the connection
46+
ConnectionError(ConnectionError),
47+
48+
/// An error occurred pinging the database
49+
QueryError(::result::Error),
50+
}
51+
52+
impl fmt::Display for Error {
53+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54+
match *self {
55+
Error::ConnectionError(ref e) => e.fmt(f),
56+
Error::QueryError(ref e) => e.fmt(f),
57+
}
58+
}
59+
}
60+
61+
impl ::std::error::Error for Error {
62+
fn description(&self) -> &str {
63+
match *self {
64+
Error::ConnectionError(ref e) => e.description(),
65+
Error::QueryError(ref e) => e.description(),
66+
}
67+
}
68+
}
69+
70+
impl<T> ManageConnection for ConnectionManager<T>
71+
where
72+
T: Connection + Send + 'static,
73+
{
74+
type Connection = T;
75+
type Error = Error;
76+
77+
fn connect(&self) -> Result<T, Error> {
78+
T::establish(&self.database_url).map_err(Error::ConnectionError)
79+
}
80+
81+
fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
82+
conn.execute("SELECT 1")
83+
.map(|_| ())
84+
.map_err(Error::QueryError)
85+
}
86+
87+
fn has_broken(&self, _conn: &mut T) -> bool {
88+
false
89+
}
90+
}
91+
92+
impl<T> SimpleConnection for PooledConnection<ConnectionManager<T>>
93+
where
94+
T: Connection + Send + 'static,
95+
{
96+
fn batch_execute(&self, query: &str) -> QueryResult<()> {
97+
(&**self).batch_execute(query)
98+
}
99+
}
100+
101+
impl<C> Connection for PooledConnection<ConnectionManager<C>>
102+
where
103+
C: Connection<TransactionManager = AnsiTransactionManager> + Send + 'static,
104+
C::Backend: UsesAnsiSavepointSyntax,
105+
{
106+
type Backend = C::Backend;
107+
type TransactionManager = C::TransactionManager;
108+
109+
fn establish(_: &str) -> ConnectionResult<Self> {
110+
Err(ConnectionError::BadConnection(String::from(
111+
"Cannot directly establish a pooled connection",
112+
)))
113+
}
114+
115+
fn execute(&self, query: &str) -> QueryResult<usize> {
116+
(&**self).execute(query)
117+
}
118+
119+
fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
120+
where
121+
T: AsQuery,
122+
T::Query: QueryFragment<Self::Backend> + QueryId,
123+
Self::Backend: HasSqlType<T::SqlType>,
124+
U: Queryable<T::SqlType, Self::Backend>,
125+
{
126+
(&**self).query_by_index(source)
127+
}
128+
129+
fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
130+
where
131+
T: QueryFragment<Self::Backend> + QueryId,
132+
U: QueryableByName<Self::Backend>,
133+
{
134+
(&**self).query_by_name(source)
135+
}
136+
137+
fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
138+
where
139+
T: QueryFragment<Self::Backend> + QueryId,
140+
{
141+
(&**self).execute_returning_count(source)
142+
}
143+
144+
fn transaction_manager(&self) -> &Self::TransactionManager {
145+
(&**self).transaction_manager()
146+
}
147+
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use std::sync::Arc;
152+
use std::sync::mpsc;
153+
use std::thread;
154+
155+
use r2d2::*;
156+
use test_helpers::*;
157+
158+
#[test]
159+
fn establish_basic_connection() {
160+
let manager = ConnectionManager::<TestConnection>::new(database_url());
161+
let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
162+
163+
let (s1, r1) = mpsc::channel();
164+
let (s2, r2) = mpsc::channel();
165+
166+
let pool1 = Arc::clone(&pool);
167+
let t1 = thread::spawn(move || {
168+
let conn = pool1.get().unwrap();
169+
s1.send(()).unwrap();
170+
r2.recv().unwrap();
171+
drop(conn);
172+
});
173+
174+
let pool2 = Arc::clone(&pool);
175+
let t2 = thread::spawn(move || {
176+
let conn = pool2.get().unwrap();
177+
s2.send(()).unwrap();
178+
r1.recv().unwrap();
179+
drop(conn);
180+
});
181+
182+
t1.join().unwrap();
183+
t2.join().unwrap();
184+
185+
pool.get().unwrap();
186+
}
187+
188+
#[test]
189+
fn is_valid() {
190+
let manager = ConnectionManager::<TestConnection>::new(database_url());
191+
let pool = Pool::builder()
192+
.max_size(1)
193+
.test_on_check_out(true)
194+
.build(manager)
195+
.unwrap();
196+
197+
pool.get().unwrap();
198+
}
199+
200+
#[test]
201+
fn pooled_connection_impls_connection() {
202+
use select;
203+
use sql_types::Text;
204+
205+
let manager = ConnectionManager::<TestConnection>::new(database_url());
206+
let pool = Pool::builder()
207+
.max_size(1)
208+
.test_on_check_out(true)
209+
.build(manager)
210+
.unwrap();
211+
let conn = pool.get().unwrap();
212+
213+
let query = select("foo".into_sql::<Text>());
214+
assert_eq!("foo", query.get_result::<String>(&conn).unwrap());
215+
}
216+
}

diesel/src/test_helpers.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,29 @@ cfg_if! {
77
pub fn connection() -> TestConnection {
88
SqliteConnection::establish(":memory:").unwrap()
99
}
10+
11+
pub fn database_url() -> String {
12+
String::from(":memory:")
13+
}
1014
} else if #[cfg(feature = "postgres")] {
1115
extern crate dotenv;
1216

13-
use self::dotenv::dotenv;
14-
use std::env;
15-
1617
pub type TestConnection = PgConnection;
1718

1819
pub fn connection() -> TestConnection {
19-
dotenv().ok();
20-
let database_url = env::var("PG_DATABASE_URL")
21-
.or_else(|_| env::var("DATABASE_URL"))
22-
.expect("DATABASE_URL must be set in order to run tests");
23-
let conn = PgConnection::establish(&database_url).unwrap();
20+
let conn = PgConnection::establish(&database_url()).unwrap();
2421
conn.begin_test_transaction().unwrap();
2522
conn
2623
}
24+
25+
pub fn database_url() -> String {
26+
dotenv::var("PG_DATABASE_URL")
27+
.or_else(|_| dotenv::var("DATABASE_URL"))
28+
.expect("DATABASE_URL must be set in order to run tests")
29+
}
2730
} else if #[cfg(feature = "mysql")] {
2831
extern crate dotenv;
2932

30-
use self::dotenv::dotenv;
31-
use std::env;
32-
3333
pub type TestConnection = MysqlConnection;
3434

3535
pub fn connection() -> TestConnection {
@@ -39,11 +39,13 @@ cfg_if! {
3939
}
4040

4141
pub fn connection_no_transaction() -> TestConnection {
42-
dotenv().ok();
43-
let database_url = env::var("MYSQL_UNIT_TEST_DATABASE_URL")
44-
.or_else(|_| env::var("DATABASE_URL"))
42+
MysqlConnection::establish(&database_url()).unwrap()
43+
}
44+
45+
fn database_url() -> String {
46+
dotenv::var("MYSQL_UNIT_TEST_DATABASE_URL")
47+
.or_else(|_| dotenv::var("DATABASE_URL"))
4548
.expect("DATABASE_URL must be set in order to run tests");
46-
MysqlConnection::establish(&database_url).unwrap()
4749
}
4850
} else {
4951
compile_error!(

0 commit comments

Comments
 (0)