Skip to content

Commit 247e320

Browse files
committed
Add explicit statement caching
Also remove transaction depth checks to preparation methods on Connection since lifetimes of statements are tied to the connection, not any transaction that may be active. cc rust-postgres#84
1 parent 05b2a09 commit 247e320

2 files changed

Lines changed: 103 additions & 24 deletions

File tree

src/lib.rs

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,21 @@ pub fn cancel_query<T>(params: T, ssl: &SslMode, data: CancelData)
383383
Ok(())
384384
}
385385

386+
#[derive(Clone)]
387+
struct CachedStatement {
388+
name: String,
389+
param_types: Vec<Type>,
390+
result_desc: Vec<ResultDescription>,
391+
}
392+
386393
struct InnerConnection {
387394
stream: BufferedStream<MaybeSslStream<InternalStream>>,
388395
next_stmt_id: usize,
389396
notice_handler: Box<NoticeHandler>,
390397
notifications: RingBuf<Notification>,
391398
cancel_data: CancelData,
392399
unknown_types: HashMap<Oid, Type>,
400+
cached_statements: HashMap<String, CachedStatement>,
393401
desynchronized: bool,
394402
finished: bool,
395403
trans_depth: u32,
@@ -421,6 +429,7 @@ impl InnerConnection {
421429
notifications: RingBuf::new(),
422430
cancel_data: CancelData { process_id: 0, secret_key: 0 },
423431
unknown_types: HashMap::new(),
432+
cached_statements: HashMap::new(),
424433
desynchronized: false,
425434
finished: false,
426435
trans_depth: 0,
@@ -647,6 +656,34 @@ impl InnerConnection {
647656
})
648657
}
649658

659+
fn prepare_cached<'a>(&mut self, query: &str, conn: &'a Connection) -> Result<Statement<'a>> {
660+
let stmt = self.cached_statements.get(query).map(|e| e.clone());
661+
662+
let CachedStatement { name, param_types, result_desc } = match stmt {
663+
Some(stmt) => stmt,
664+
None => {
665+
let stmt_name = self.make_stmt_name();
666+
let (param_types, result_desc) = try!(self.raw_prepare(&*stmt_name, query));
667+
let stmt = CachedStatement {
668+
name: stmt_name,
669+
param_types: param_types,
670+
result_desc: result_desc,
671+
};
672+
self.cached_statements.insert(query.to_owned(), stmt.clone());
673+
stmt
674+
}
675+
};
676+
677+
Ok(Statement {
678+
conn: conn,
679+
name: name,
680+
param_types: param_types,
681+
result_desc: result_desc,
682+
next_portal_id: Cell::new(0),
683+
finished: true, // << !
684+
})
685+
}
686+
650687
fn prepare_copy_in<'a>(&mut self, table: &str, rows: &[&str], conn: &'a Connection)
651688
-> Result<CopyInStatement<'a>> {
652689
let mut query = vec![];
@@ -924,11 +961,31 @@ impl Connection {
924961
/// Err(err) => panic!("Error preparing statement: {:?}", err)
925962
/// };
926963
pub fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
927-
let mut conn = self.conn.borrow_mut();
928-
if conn.trans_depth != 0 {
929-
return Err(Error::WrongTransaction);
930-
}
931-
conn.prepare(query, self)
964+
self.conn.borrow_mut().prepare(query, self)
965+
}
966+
967+
/// Creates cached prepared statement.
968+
///
969+
/// Like `prepare`, except that the statement is only prepared once and
970+
/// then cached. If the same statement is going to be used frequently,
971+
/// caching it can improve performance by reducing the number of round
972+
/// trips to the Postgres backend.
973+
///
974+
/// ## Example
975+
///
976+
/// ```rust,no_run
977+
/// # use postgres::{Connection, SslMode};
978+
/// # fn f() -> postgres::Result<()> {
979+
/// # let x = 10i32;
980+
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
981+
/// let stmt = try!(conn.prepare_cached("SELECT foo FROM bar WHERE baz = $1"));
982+
/// for row in try!(stmt.query(&[&x])) {
983+
/// println!("foo: {}", row.get::<_, String>(0));
984+
/// }
985+
/// # Ok(()) };
986+
/// ```
987+
pub fn prepare_cached<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
988+
self.conn.borrow_mut().prepare_cached(query, self)
932989
}
933990

934991
/// Creates a new COPY FROM STDIN prepared statement.
@@ -937,11 +994,7 @@ impl Connection {
937994
/// the database.
938995
pub fn prepare_copy_in<'a>(&'a self, table: &str, rows: &[&str])
939996
-> Result<CopyInStatement<'a>> {
940-
let mut conn = self.conn.borrow_mut();
941-
if conn.trans_depth != 0 {
942-
return Err(Error::WrongTransaction);
943-
}
944-
conn.prepare_copy_in(table, rows, self)
997+
self.conn.borrow_mut().prepare_copy_in(table, rows, self)
945998
}
946999

9471000
/// Begins a new transaction.
@@ -1039,11 +1092,7 @@ impl Connection {
10391092
/// }
10401093
/// ```
10411094
pub fn batch_execute(&self, query: &str) -> Result<()> {
1042-
let mut conn = self.conn.borrow_mut();
1043-
if conn.trans_depth != 0 {
1044-
return Err(Error::WrongTransaction);
1045-
}
1046-
conn.quick_query(query).map(|_| ())
1095+
self.conn.borrow_mut().quick_query(query).map(|_| ())
10471096
}
10481097

10491098
/// Returns information used to cancel pending queries.
@@ -1135,12 +1184,20 @@ impl<'conn> Transaction<'conn> {
11351184

11361185
/// Like `Connection::prepare`.
11371186
pub fn prepare(&self, query: &str) -> Result<Statement<'conn>> {
1138-
self.conn.conn.borrow_mut().prepare(query, self.conn)
1187+
self.conn.prepare(query)
1188+
}
1189+
1190+
/// Like `Connection::prepare_cached`.
1191+
///
1192+
/// Note that the statement will be cached for the duration of the
1193+
/// connection, not just the duration of this transaction.
1194+
pub fn prepare_cached(&self, query: &str) -> Result<Statement<'conn>> {
1195+
self.conn.prepare_cached(query)
11391196
}
11401197

11411198
/// Like `Connection::prepare_copy_in`.
11421199
pub fn prepare_copy_in(&self, table: &str, cols: &[&str]) -> Result<CopyInStatement<'conn>> {
1143-
self.conn.conn.borrow_mut().prepare_copy_in(table, cols, self.conn)
1200+
self.conn.prepare_copy_in(table, cols)
11441201
}
11451202

11461203
/// Like `Connection::execute`.
@@ -1150,7 +1207,7 @@ impl<'conn> Transaction<'conn> {
11501207

11511208
/// Like `Connection::batch_execute`.
11521209
pub fn batch_execute(&self, query: &str) -> Result<()> {
1153-
self.conn.conn.borrow_mut().quick_query(query).map(|_| ())
1210+
self.conn.batch_execute(query)
11541211
}
11551212

11561213
/// Like `Connection::transaction`.
@@ -1895,6 +1952,9 @@ pub trait GenericConnection {
18951952
/// Like `Connection::prepare`.
18961953
fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>>;
18971954

1955+
/// Like `Connection::prepare_cached`.
1956+
fn prepare_cached<'a>(&'a self, query: &str) -> Result<Statement<'a>>;
1957+
18981958
/// Like `Connection::execute`.
18991959
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<usize>;
19001960

@@ -1914,6 +1974,10 @@ impl GenericConnection for Connection {
19141974
self.prepare(query)
19151975
}
19161976

1977+
fn prepare_cached<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
1978+
self.prepare_cached(query)
1979+
}
1980+
19171981
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<usize> {
19181982
self.execute(query, params)
19191983
}
@@ -1937,6 +2001,10 @@ impl<'a> GenericConnection for Transaction<'a> {
19372001
self.prepare(query)
19382002
}
19392003

2004+
fn prepare_cached<'b>(&'b self, query: &str) -> Result<Statement<'b>> {
2005+
self.prepare_cached(query)
2006+
}
2007+
19402008
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<usize> {
19412009
self.execute(query, params)
19422010
}

tests/test.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,9 @@ fn test_nested_transactions_finish() {
281281
}
282282

283283
#[test]
284-
fn test_conn_prepare_with_trans() {
284+
fn test_conn_trans_when_nested() {
285285
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
286286
let _trans = or_panic!(conn.transaction());
287-
match conn.prepare("") {
288-
Err(Error::WrongTransaction) => {}
289-
Err(r) => panic!("Unexpected error {:?}", r),
290-
Ok(_) => panic!("Unexpected success"),
291-
}
292287
match conn.transaction() {
293288
Err(Error::WrongTransaction) => {}
294289
Err(r) => panic!("Unexpected error {:?}", r),
@@ -902,3 +897,19 @@ fn test_custom_range_element_type() {
902897
t => panic!("Unexpected type {:?}", t)
903898
}
904899
}
900+
901+
#[test]
902+
fn test_prepare_cached() {
903+
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
904+
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]));
905+
or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1), (2)", &[]));
906+
907+
let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id"));
908+
assert_eq!(&[1, 2][], or_panic!(stmt.query(&[])).map(|r| r.get(0)).collect::<Vec<i32>>());
909+
910+
let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id"));
911+
assert_eq!(&[1, 2][], or_panic!(stmt.query(&[])).map(|r| r.get(0)).collect::<Vec<i32>>());
912+
913+
let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id DESC"));
914+
assert_eq!(&[2, 1][], or_panic!(stmt.query(&[])).map(|r| r.get(0)).collect::<Vec<i32>>());
915+
}

0 commit comments

Comments
 (0)