Skip to content

Commit 560f650

Browse files
committed
Make AstPass an opaque struct
I'm moving towards merging `to_sql` with `walk_ast`. The code will look *really* bad if there's a ton of conditionals on which pass it is. Instead I'd like to have the code look the same as it would if I were able to make `AstPass` a trait (see 97e5cba for context on why we can't do that) In order to enforce that only the methods are used, and nothing is actually destructuring, I've had to do a bit of funky indirection.
1 parent 1fc74f4 commit 560f650

14 files changed

Lines changed: 136 additions & 135 deletions

File tree

diesel/src/expression/array_comparison.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,9 @@ impl<T, DB> QueryFragment<DB> for Many<T> where
203203
}
204204

205205
fn walk_ast(&self, mut pass: AstPass<DB>) -> QueryResult<()> {
206-
if let AstPass::IsSafeToCachePrepared(result) = pass {
207-
*result = false;
208-
} else {
209-
for value in &self.0 {
210-
value.walk_ast(pass.reborrow())?;
211-
}
206+
pass.unsafe_to_cache_prepared();
207+
for value in &self.0 {
208+
value.walk_ast(pass.reborrow())?;
212209
}
213210
Ok(())
214211
}

diesel/src/expression/bound.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ impl<T, U, DB> QueryFragment<DB> for Bound<T, U> where
3434
Ok(())
3535
}
3636

37-
fn walk_ast(&self, pass: AstPass<DB>) -> QueryResult<()> {
38-
if let AstPass::CollectBinds(out) = pass {
39-
out.push_bound_value(&self.item)?;
40-
}
37+
fn walk_ast(&self, mut pass: AstPass<DB>) -> QueryResult<()> {
38+
pass.push_bind_param(&self.item)?;
4139
Ok(())
4240
}
4341
}

diesel/src/expression/predicates.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ impl<T, U, DB> Changeset<DB> for Eq<T, U> where
252252
QueryFragment::to_sql(&self.right, out)
253253
}
254254

255-
fn collect_binds(&self, out: &mut DB::BindCollector) -> QueryResult<()> {
256-
QueryFragment::collect_binds(&self.right, out)
255+
fn walk_ast(&self, out: AstPass<DB>) -> QueryResult<()> {
256+
QueryFragment::walk_ast(&self.right, out)
257257
}
258258
}
259259

diesel/src/expression/sql_literal.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ impl<ST, DB> QueryFragment<DB> for SqlLiteral<ST> where
4040
Ok(())
4141
}
4242

43-
fn walk_ast(&self, pass: AstPass<DB>) -> QueryResult<()> {
44-
if let AstPass::IsSafeToCachePrepared(result) = pass {
45-
*result = false;
46-
}
43+
fn walk_ast(&self, mut pass: AstPass<DB>) -> QueryResult<()> {
44+
pass.unsafe_to_cache_prepared();
4745
Ok(())
4846
}
4947
}

diesel/src/insertable.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::iter;
33
use backend::{Backend, SupportsDefaultKeyword};
44
use expression::Expression;
55
use result::QueryResult;
6-
use query_builder::{QueryBuilder, BuildQueryResult};
6+
use query_builder::{QueryBuilder, BuildQueryResult, AstPass};
77
use query_source::{Table, Column};
88
use types::IntoNullable;
99

@@ -27,7 +27,7 @@ pub trait Insertable<T: Table, DB: Backend> {
2727
pub trait InsertValues<DB: Backend> {
2828
fn column_names(&self, out: &mut DB::QueryBuilder) -> BuildQueryResult;
2929
fn values_clause(&self, out: &mut DB::QueryBuilder) -> BuildQueryResult;
30-
fn values_bind_params(&self, out: &mut DB::BindCollector) -> QueryResult<()>;
30+
fn walk_ast(&self, out: AstPass<DB>) -> QueryResult<()>;
3131
}
3232

3333
#[derive(Debug, Copy, Clone)]
@@ -85,9 +85,9 @@ impl<T, DB> InsertValues<DB> for BatchInsertValues<T> where
8585
Ok(())
8686
}
8787

88-
fn values_bind_params(&self, out: &mut DB::BindCollector) -> QueryResult<()> {
88+
fn walk_ast(&self, mut out: AstPass<DB>) -> QueryResult<()> {
8989
for values in self.0.clone() {
90-
try!(values.values_bind_params(out));
90+
values.walk_ast(out.reborrow())?;
9191
}
9292
Ok(())
9393
}

diesel/src/pg/upsert/on_conflict_actions.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,9 @@ impl<T> QueryFragment<Pg> for DoUpdate<T> where
163163
Ok(())
164164
}
165165

166-
fn walk_ast(&self, pass: AstPass<Pg>) -> QueryResult<()> {
167-
match pass {
168-
AstPass::CollectBinds(out) => self.changeset.collect_binds(out)?,
169-
AstPass::IsSafeToCachePrepared(result) => *result = false,
170-
}
166+
fn walk_ast(&self, mut pass: AstPass<Pg>) -> QueryResult<()> {
167+
pass.unsafe_to_cache_prepared();
168+
self.changeset.walk_ast(pass)?;
171169
Ok(())
172170
}
173171
}

diesel/src/pg/upsert/on_conflict_clause.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ impl<Values, Target, Action> InsertValues<Pg> for OnConflictValues<Values, Targe
116116
Ok(())
117117
}
118118

119-
fn values_bind_params(&self, out: &mut <Pg as Backend>::BindCollector) -> QueryResult<()> {
120-
try!(self.values.values_bind_params(out));
121-
try!(self.target.collect_binds(out));
122-
try!(self.action.collect_binds(out));
119+
fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
120+
self.values.walk_ast(out.reborrow())?;
121+
self.target.walk_ast(out.reborrow())?;
122+
self.action.walk_ast(out.reborrow())?;
123123
Ok(())
124124
}
125125
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use backend::Backend;
2+
use query_builder::BindCollector;
3+
use result::QueryResult;
4+
use types::{ToSql, HasSqlType};
5+
6+
#[doc(hidden)]
7+
#[allow(missing_debug_implementations)]
8+
pub struct AstPass<'a, DB> where
9+
DB: Backend,
10+
DB::BindCollector: 'a,
11+
{
12+
internals: AstPassInternals<'a, DB>,
13+
}
14+
15+
impl<'a, DB> AstPass<'a, DB> where
16+
DB: Backend,
17+
DB::BindCollector: 'a,
18+
{
19+
pub fn collect_binds(collector: &'a mut DB::BindCollector) -> Self {
20+
AstPass {
21+
internals: AstPassInternals::CollectBinds(collector),
22+
}
23+
}
24+
25+
pub fn is_safe_to_cache_prepared(result: &'a mut bool) -> Self {
26+
AstPass {
27+
internals: AstPassInternals::IsSafeToCachePrepared(result),
28+
}
29+
}
30+
31+
/// Effectively copies `self`, with a narrower lifetime. This method
32+
/// matches the semantics of the implicit reborrow that occurs when passing
33+
/// a reference by value in Rust.
34+
pub fn reborrow(&mut self) -> AstPass<DB> {
35+
use self::AstPassInternals::*;
36+
let internals = match self.internals {
37+
CollectBinds(ref mut collector) => CollectBinds(&mut **collector),
38+
IsSafeToCachePrepared(ref mut result) => IsSafeToCachePrepared(&mut **result),
39+
};
40+
AstPass { internals }
41+
}
42+
43+
pub fn unsafe_to_cache_prepared(&mut self) {
44+
if let AstPassInternals::IsSafeToCachePrepared(ref mut result) = self.internals {
45+
**result = false
46+
}
47+
}
48+
49+
pub fn push_bind_param<T, U>(&mut self, bind: &U) -> QueryResult<()> where
50+
DB: HasSqlType<T>,
51+
U: ToSql<T, DB>,
52+
{
53+
if let AstPassInternals::CollectBinds(ref mut out) = self.internals {
54+
out.push_bound_value(bind)?;
55+
}
56+
Ok(())
57+
}
58+
}
59+
60+
#[allow(missing_debug_implementations)]
61+
/// This is separate from the struct to cause the enum to be opaque, forcing
62+
/// usage of the methods provided rather than matching on the enum directly.
63+
/// This essentially mimics the capabilities that would be available if
64+
/// `AstPass` were a trait.
65+
enum AstPassInternals<'a, DB> where
66+
DB: Backend,
67+
DB::BindCollector: 'a,
68+
{
69+
CollectBinds(&'a mut DB::BindCollector),
70+
IsSafeToCachePrepared(&'a mut bool),
71+
}
72+

diesel/src/query_builder/insert_statement.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,11 @@ impl<T, U, Op, Ret, DB> QueryFragment<DB> for InsertStatement<T, U, Op, Ret> whe
117117
}
118118

119119
fn walk_ast(&self, mut pass: AstPass<DB>) -> QueryResult<()> {
120-
if let AstPass::IsSafeToCachePrepared(result) = pass {
121-
*result = false;
122-
} else {
123-
let values = self.records.values();
124-
self.operator.walk_ast(pass.reborrow())?;
125-
self.target.from_clause().walk_ast(pass.reborrow())?;
126-
if let AstPass::CollectBinds(ref mut out) = pass {
127-
values.values_bind_params(out)?;
128-
}
129-
self.returning.walk_ast(pass.reborrow())?;
130-
}
120+
pass.unsafe_to_cache_prepared();
121+
self.operator.walk_ast(pass.reborrow())?;
122+
self.target.from_clause().walk_ast(pass.reborrow())?;
123+
self.records.values().walk_ast(pass.reborrow())?;
124+
self.returning.walk_ast(pass.reborrow())?;
131125
Ok(())
132126
}
133127
}

diesel/src/query_builder/mod.rs

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod query_id;
66
#[macro_use]
77
mod clause_macro;
88

9+
mod ast_pass;
910
pub mod bind_collector;
1011
mod delete_statement;
1112
#[doc(hidden)]
@@ -24,6 +25,7 @@ pub mod where_clause;
2425
pub mod insert_statement;
2526
pub mod update_statement;
2627

28+
pub use self::ast_pass::AstPass;
2729
pub use self::bind_collector::BindCollector;
2830
pub use self::query_id::QueryId;
2931
#[doc(hidden)]
@@ -73,32 +75,6 @@ impl<'a, T: Query> Query for &'a T {
7375
type SqlType = T::SqlType;
7476
}
7577

76-
#[doc(hidden)]
77-
#[allow(missing_debug_implementations)]
78-
pub enum AstPass<'a, DB> where
79-
DB: Backend,
80-
DB::BindCollector: 'a,
81-
{
82-
CollectBinds(&'a mut DB::BindCollector),
83-
IsSafeToCachePrepared(&'a mut bool),
84-
}
85-
86-
impl<'a, DB> AstPass<'a, DB> where
87-
DB: Backend,
88-
DB::BindCollector: 'a,
89-
{
90-
/// Effectively copies `self`, with a narrower lifetime. This method
91-
/// matches the semantics of the implicit reborrow that occurs when passing
92-
/// a reference by value in Rust.
93-
pub fn reborrow(&mut self) -> AstPass<DB> {
94-
use self::AstPass::*;
95-
match *self {
96-
CollectBinds(ref mut collector) => CollectBinds(&mut **collector),
97-
IsSafeToCachePrepared(ref mut result) => IsSafeToCachePrepared(&mut **result),
98-
}
99-
}
100-
}
101-
10278
/// An untyped fragment of SQL. This may be a complete SQL command (such as
10379
/// an update statement without a `RETURNING` clause), or a subsection (such as
10480
/// our internal types used to represent a `WHERE` clause). All methods on
@@ -109,12 +85,12 @@ pub trait QueryFragment<DB: Backend> {
10985
fn walk_ast(&self, pass: AstPass<DB>) -> QueryResult<()>;
11086

11187
fn collect_binds(&self, out: &mut DB::BindCollector) -> QueryResult<()> {
112-
self.walk_ast(AstPass::CollectBinds(out))
88+
self.walk_ast(AstPass::collect_binds(out))
11389
}
11490

11591
fn is_safe_to_cache_prepared(&self) -> QueryResult<bool> {
11692
let mut result = true;
117-
self.walk_ast(AstPass::IsSafeToCachePrepared(&mut result))?;
93+
self.walk_ast(AstPass::is_safe_to_cache_prepared(&mut result))?;
11894
Ok(result)
11995
}
12096
}

0 commit comments

Comments
 (0)