Skip to content

Commit 331f81c

Browse files
committed
Turn Connection into a trait
This was pretty straightforward. Almost everything that was previously coupled to PG has been changed to be generic over the connection (mainly `LoadDsl` and `ExecuteDsl`). Our tests are coupled to PG. We'll want to change that as we further break up PG into it's own crate. Things that apply to Sqlite will need to be tested there as well. `SimpleConnection` exists because I didn't want to make migrations generic over the connection, but `Connection` is not object safe. We only ever use `batch_execute` in that module, so I'm just taking that as a trait object. CLI and Codegen are currently coupled to PG. I will need to write additional code for both of them in order to support SQLite in the future. It occurs to me that I don't have a great way to make these support third party adapters, and I'll need to rethink how the connection info is passed in.
1 parent c26aa7d commit 331f81c

30 files changed

Lines changed: 369 additions & 288 deletions

diesel/src/connection/mod.rs

Lines changed: 34 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,18 @@
11
extern crate libc;
22

3-
mod cursor;
4-
#[doc(hidden)]
5-
pub mod raw;
3+
pub mod pg;
64

7-
use std::cell::Cell;
8-
use std::ffi::{CString, CStr};
9-
use std::rc::Rc;
10-
use std::ptr;
5+
pub use self::pg::PgConnection;
116

12-
use backend::Pg;
13-
use db_result::DbResult;
7+
use backend::Backend;
148
use expression::{AsExpression, Expression, NonAggregate};
15-
use expression::expression_methods::*;
169
use expression::predicates::Eq;
1710
use helper_types::{FindBy, Limit};
1811
use expression::helper_types::AsExpr;
1912
use query_builder::{AsQuery, Query, QueryFragment};
20-
use query_builder::pg::PgQueryBuilder;
2113
use query_dsl::{FilterDsl, LimitDsl};
2214
use query_source::{Table, Queryable};
2315
use result::*;
24-
use self::cursor::Cursor;
25-
use self::raw::RawConnection;
26-
use types::{NativeSqlType, ToSql};
27-
28-
pub struct Connection {
29-
raw_connection: Rc<RawConnection>,
30-
transaction_depth: Cell<i32>,
31-
}
32-
33-
unsafe impl Send for Connection {}
3416

3517
#[doc(hidden)]
3618
pub type PrimaryKey<T> = <T as Table>::PrimaryKey;
@@ -39,129 +21,53 @@ pub type PkType<T> = <PrimaryKey<T> as Expression>::SqlType;
3921
#[doc(hidden)]
4022
pub type FindPredicate<T, PK> = Eq<PrimaryKey<T>, <PK as AsExpression<PkType<T>>>::Expression>;
4123

42-
impl Connection {
24+
pub trait SimpleConnection {
25+
#[doc(hidden)]
26+
fn batch_execute(&self, query: &str) -> QueryResult<()>;
27+
}
28+
29+
pub trait Connection: SimpleConnection + Sized {
30+
type Backend: Backend;
31+
4332
/// Establishes a new connection to the database at the given URL. The URL
44-
/// should be a PostgreSQL connection string, as documented at
45-
/// http://www.postgresql.org/docs/9.4/static/libpq-connect.html#LIBPQ-CONNSTRING
46-
pub fn establish(database_url: &str) -> ConnectionResult<Connection> {
47-
RawConnection::establish(database_url).map(|raw_conn| {
48-
Connection {
49-
raw_connection: Rc::new(raw_conn),
50-
transaction_depth: Cell::new(0),
51-
}
52-
})
53-
}
33+
/// should be a valid connection string for a given backend. See the
34+
/// documentation for the specific backend for specifics.
35+
fn establish(database_url: &str) -> ConnectionResult<Self>;
5436

5537
/// Executes the given function inside of a database transaction. When
56-
/// a transaction is already occurring,
57-
/// [savepoints](http://www.postgresql.org/docs/9.1/static/sql-savepoint.html)
58-
/// will be used to emulate a nested transaction.
38+
/// a transaction is already occurring, savepoints will be used to emulate a nested
39+
/// transaction.
5940
///
6041
/// If the function returns an `Ok`, that value will be returned. If the
6142
/// function returns an `Err`,
6243
/// [`TransactionError::UserReturnedError`](result/enum.TransactionError.html#variant.UserReturnedError)
6344
/// will be returned wrapping that value.
64-
pub fn transaction<T, E, F>(&self, f: F) -> TransactionResult<T, E> where
65-
F: FnOnce() -> Result<T, E>,
66-
{
67-
try!(self.begin_transaction());
68-
match f() {
69-
Ok(value) => {
70-
try!(self.commit_transaction());
71-
Ok(value)
72-
},
73-
Err(e) => {
74-
try!(self.rollback_transaction());
75-
Err(TransactionError::UserReturnedError(e))
76-
},
77-
}
78-
}
45+
fn transaction<T, E, F>(&self, f: F) -> TransactionResult<T, E> where
46+
F: FnOnce() -> Result<T, E>;
7947

8048
/// Creates a transaction that will never be committed. This is useful for
8149
/// tests. Panics if called while inside of a transaction.
82-
pub fn begin_test_transaction(&self) -> QueryResult<usize> {
83-
assert_eq!(self.transaction_depth.get(), 0);
84-
self.begin_transaction()
85-
}
50+
fn begin_test_transaction(&self) -> QueryResult<usize>;
8651

8752
/// Executes the given function inside a transaction, but does not commit
8853
/// it. Panics if the given function returns an `Err`.
89-
pub fn test_transaction<T, E, F>(&self, f: F) -> T where
90-
F: FnOnce() -> Result<T, E>,
91-
{
92-
let mut user_result = None;
93-
let _ = self.transaction::<(), _, _>(|| {
94-
user_result = f().ok();
95-
Err(())
96-
});
97-
user_result.expect("Transaction did not succeed")
98-
}
99-
100-
#[doc(hidden)]
101-
pub fn execute(&self, query: &str) -> QueryResult<usize> {
102-
self.execute_inner(query).map(|res| res.rows_affected())
103-
}
54+
fn test_transaction<T, E, F>(&self, f: F) -> T where
55+
F: FnOnce() -> Result<T, E>;
10456

10557
#[doc(hidden)]
106-
pub fn batch_execute(&self, query: &str) -> QueryResult<()> {
107-
let query = try!(CString::new(query));
108-
let inner_result = unsafe {
109-
self.raw_connection.exec(query.as_ptr())
110-
};
111-
try!(DbResult::new(self, inner_result));
112-
Ok(())
113-
}
58+
fn execute(&self, query: &str) -> QueryResult<usize>;
11459

11560
#[doc(hidden)]
116-
pub fn query_one<T, U>(&self, source: T) -> QueryResult<U> where
61+
fn query_one<T, U>(&self, source: T) -> QueryResult<U> where
11762
T: AsQuery,
118-
T::Query: QueryFragment<Pg>,
119-
U: Queryable<T::SqlType>,
120-
{
121-
self.query_all(source)
122-
.and_then(|mut e| e.nth(0).map(Ok).unwrap_or(Err(Error::NotFound)))
123-
}
63+
T::Query: QueryFragment<Self::Backend>,
64+
U: Queryable<T::SqlType>;
12465

12566
#[doc(hidden)]
126-
pub fn query_all<'a, T, U: 'a>(&self, source: T) -> QueryResult<Box<Iterator<Item=U> + 'a>> where
67+
fn query_all<'a, T, U: 'a>(&self, source: T) -> QueryResult<Box<Iterator<Item=U> + 'a>> where
12768
T: AsQuery,
128-
T::Query: QueryFragment<Pg>,
129-
U: Queryable<T::SqlType>,
130-
{
131-
let (sql, params, types) = self.prepare_query(&source.as_query());
132-
self.exec_sql_params(&sql, &params, &Some(types))
133-
.map(|r| Box::new(Cursor::new(r)) as Box<Iterator<Item=U>>)
134-
}
135-
136-
fn exec_sql_params(&self, query: &str, param_data: &Vec<Option<Vec<u8>>>, param_types: &Option<Vec<u32>>) -> QueryResult<DbResult> {
137-
let query = try!(CString::new(query));
138-
let params_pointer = param_data.iter()
139-
.map(|data| data.as_ref().map(|d| d.as_ptr() as *const libc::c_char)
140-
.unwrap_or(ptr::null()))
141-
.collect::<Vec<_>>();
142-
let param_types_ptr = param_types.as_ref()
143-
.map(|types| types.as_ptr())
144-
.unwrap_or(ptr::null());
145-
let param_lengths = param_data.iter()
146-
.map(|data| data.as_ref().map(|d| d.len() as libc::c_int)
147-
.unwrap_or(0))
148-
.collect::<Vec<_>>();
149-
let param_formats = vec![1; param_data.len()];
150-
151-
let internal_res = unsafe {
152-
self.raw_connection.exec_params(
153-
query.as_ptr(),
154-
params_pointer.len() as libc::c_int,
155-
param_types_ptr,
156-
params_pointer.as_ptr(),
157-
param_lengths.as_ptr(),
158-
param_formats.as_ptr(),
159-
1,
160-
)
161-
};
162-
163-
DbResult::new(self, internal_res)
164-
}
69+
T::Query: QueryFragment<Self::Backend>,
70+
U: Queryable<T::SqlType>;
16571

16672
/// Attempts to find a single record from the given table by primary key.
16773
///
@@ -189,94 +95,18 @@ impl Connection {
18995
/// assert_eq!(Err::<(i32, String), _>(NotFound), connection.find(users, 3));
19096
/// # }
19197
/// ```
192-
pub fn find<T, U, PK>(&self, source: T, id: PK) -> QueryResult<U> where
98+
fn find<T, U, PK>(&self, source: T, id: PK) -> QueryResult<U> where
19399
T: Table + FilterDsl<FindPredicate<T, PK>>,
194100
FindBy<T, T::PrimaryKey, PK>: LimitDsl,
195-
Limit<FindBy<T, T::PrimaryKey, PK>>: QueryFragment<Pg>,
101+
Limit<FindBy<T, T::PrimaryKey, PK>>: QueryFragment<Self::Backend>,
196102
U: Queryable<<Limit<FindBy<T, T::PrimaryKey, PK>> as Query>::SqlType>,
197103
PK: AsExpression<PkType<T>>,
198-
AsExpr<PK, T::PrimaryKey>: NonAggregate,
199-
{
200-
let pk = source.primary_key();
201-
self.query_one(source.filter(pk.eq(id)).limit(1))
202-
}
104+
AsExpr<PK, T::PrimaryKey>: NonAggregate;
203105

204106
#[doc(hidden)]
205-
pub fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize> where
206-
T: QueryFragment<Pg>,
207-
{
208-
let (sql, params, param_types) = self.prepare_query(source);
209-
self.exec_sql_params(&sql, &params, &Some(param_types))
210-
.map(|r| r.rows_affected())
211-
}
212-
213-
fn prepare_query<T: QueryFragment<Pg>>(&self, source: &T)
214-
-> (String, Vec<Option<Vec<u8>>>, Vec<u32>)
215-
{
216-
let mut query_builder = PgQueryBuilder::new(&self.raw_connection);
217-
source.to_sql(&mut query_builder).unwrap();
218-
(query_builder.sql, query_builder.binds, query_builder.bind_types)
219-
}
220-
221-
fn execute_inner(&self, query: &str) -> QueryResult<DbResult> {
222-
self.exec_sql_params(query, &Vec::new(), &None)
223-
}
224-
225-
#[doc(hidden)]
226-
pub fn last_error_message(&self) -> String {
227-
self.raw_connection.last_error_message()
228-
}
229-
230-
fn begin_transaction(&self) -> QueryResult<usize> {
231-
let transaction_depth = self.transaction_depth.get();
232-
self.change_transaction_depth(1, if transaction_depth == 0 {
233-
self.execute("BEGIN")
234-
} else {
235-
self.execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
236-
})
237-
}
238-
239-
fn rollback_transaction(&self) -> QueryResult<usize> {
240-
let transaction_depth = self.transaction_depth.get();
241-
self.change_transaction_depth(-1, if transaction_depth == 1 {
242-
self.execute("ROLLBACK")
243-
} else {
244-
self.execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
245-
transaction_depth - 1))
246-
})
247-
}
248-
249-
fn commit_transaction(&self) -> QueryResult<usize> {
250-
let transaction_depth = self.transaction_depth.get();
251-
self.change_transaction_depth(-1, if transaction_depth <= 1 {
252-
self.execute("COMMIT")
253-
} else {
254-
self.execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}",
255-
transaction_depth - 1))
256-
})
257-
}
258-
259-
fn change_transaction_depth(&self, by: i32, query: QueryResult<usize>) -> QueryResult<usize> {
260-
if query.is_ok() {
261-
self.transaction_depth.set(self.transaction_depth.get() + by);
262-
}
263-
query
264-
}
107+
fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize> where
108+
T: QueryFragment<Self::Backend>;
265109

266110
#[doc(hidden)]
267-
pub fn silence_notices<F: FnOnce() -> T, T>(&self, f: F) -> T {
268-
self.raw_connection.set_notice_processor(noop_notice_processor);
269-
let result = f();
270-
self.raw_connection.set_notice_processor(default_notice_processor);
271-
result
272-
}
273-
}
274-
275-
extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) {
276-
}
277-
278-
extern "C" fn default_notice_processor(_: *mut libc::c_void, message: *const libc::c_char) {
279-
use std::io::Write;
280-
let c_str = unsafe { CStr::from_ptr(message) };
281-
::std::io::stderr().write(c_str.to_bytes()).unwrap();
111+
fn silence_notices<F: FnOnce() -> T, T>(&self, f: F) -> T;
282112
}

0 commit comments

Comments
 (0)