Skip to content

Commit 102f570

Browse files
committed
Update FindDsl to work with composite primary keys
This introduces a new expression method, `eq_all` which is not in the public API. For single values, it is identical to `eq`. For tuples, it represents all of the tuple elements matching. I unfortunately could not implement `eq_all` for all types, as the base impl would always overlap with the tuple impls. Since this is meant to be used with `FindDsl` for now, which always has a column as the left hand side, we brute force it. Additionally, I couldn't easily implement this for arbitrarily large tuples with a macro. This might get a bit easier when macro types lands (but the code might still be complex enough to not warrant it). The overwhelming majority of cases here will be 2 element tuples. I've manually done it for 4 elements just to be safe.
1 parent d02618d commit 102f570

9 files changed

Lines changed: 102 additions & 14 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use expression::Expression;
2+
use expression::predicates::And;
3+
use expression::expression_methods::*;
4+
use types::Bool;
5+
6+
/// This method is used by `FindDsl` to work with tuples. Because we cannot
7+
/// express this without specialization or overlapping impls, it is brute force
8+
/// implemented on columns in the `column!` macro.
9+
#[doc(hidden)]
10+
pub trait EqAll<Rhs> {
11+
type Output: Expression<SqlType=Bool>;
12+
13+
fn eq_all(self, rhs: Rhs) -> Self::Output;
14+
}
15+
16+
// FIXME: This is much easier to represent with a macro once macro types are stable
17+
// which appears to be slated for 1.13
18+
impl<L1, L2, R1, R2> EqAll<(R1, R2)> for (L1, L2) where
19+
L1: EqAll<R1>,
20+
L2: EqAll<R2>,
21+
{
22+
type Output = And<<L1 as EqAll<R1>>::Output, <L2 as EqAll<R2>>::Output>;
23+
24+
fn eq_all(self, rhs: (R1, R2)) -> Self::Output {
25+
self.0.eq_all(rhs.0).and(self.1.eq_all(rhs.1))
26+
}
27+
}
28+
29+
impl<L1, L2, L3, R1, R2, R3> EqAll<(R1, R2, R3)> for (L1, L2, L3) where
30+
L1: EqAll<R1>,
31+
L2: EqAll<R2>,
32+
L3: EqAll<R3>,
33+
{
34+
type Output = And<<L1 as EqAll<R1>>::Output, And<<L2 as EqAll<R2>>::Output, <L3 as EqAll<R3>>::Output>>;
35+
36+
fn eq_all(self, rhs: (R1, R2, R3)) -> Self::Output {
37+
self.0.eq_all(rhs.0).and(
38+
self.1.eq_all(rhs.1).and(self.2.eq_all(rhs.2)))
39+
}
40+
}
41+
42+
impl<L1, L2, L3, L4, R1, R2, R3, R4> EqAll<(R1, R2, R3, R4)> for (L1, L2, L3, L4) where
43+
L1: EqAll<R1>,
44+
L2: EqAll<R2>,
45+
L3: EqAll<R3>,
46+
L4: EqAll<R4>,
47+
{
48+
type Output = And<<L1 as EqAll<R1>>::Output, And<<L2 as EqAll<R2>>::Output, And<<L3 as EqAll<R3>>::Output, <L4 as EqAll<R4>>::Output>>>;
49+
50+
fn eq_all(self, rhs: (R1, R2, R3, R4)) -> Self::Output {
51+
self.0.eq_all(rhs.0).and(
52+
self.1.eq_all(rhs.1).and(
53+
self.2.eq_all(rhs.2).and(
54+
self.3.eq_all(rhs.3)
55+
)))
56+
}
57+
}

diesel/src/expression/expression_methods/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@ pub mod bool_expression_methods;
88
pub mod escape_expression_methods;
99
pub mod global_expression_methods;
1010
pub mod text_expression_methods;
11+
#[doc(hidden)]
12+
pub mod eq_all;
1113

1214
pub use self::bool_expression_methods::BoolExpressionMethods;
1315
pub use self::escape_expression_methods::EscapeExpressionMethods;
1416
pub use self::global_expression_methods::ExpressionMethods;
1517
pub use self::text_expression_methods::TextExpressionMethods;
18+
#[doc(hidden)]
19+
pub use self::eq_all::EqAll;
1620

1721
#[cfg(feature = "postgres")]
1822
pub use pg::expression::expression_methods::*;

diesel/src/macros/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ macro_rules! __diesel_column {
6161
stringify!($column_name)
6262
}
6363
}
64+
65+
impl<T> $crate::EqAll<T> for $column_name where
66+
T: $crate::expression::AsExpression<$Type>,
67+
$crate::expression::helper_types::Eq<$column_name, T>: $crate::Expression<SqlType=$crate::types::Bool>,
68+
{
69+
type Output = $crate::expression::helper_types::Eq<Self, T>;
70+
71+
fn eq_all(self, rhs: T) -> Self::Output {
72+
$crate::ExpressionMethods::eq(self, rhs)
73+
}
74+
}
6475
}
6576
}
6677

diesel/src/query_dsl/filter_dsl.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ impl<T: Table> NotFiltered for T {}
5757
impl<Left, Right> NotFiltered for InnerJoinSource<Left, Right> {}
5858
impl<Left, Right> NotFiltered for LeftOuterJoinSource<Left, Right> {}
5959

60-
use expression::AsExpression;
6160
use expression::expression_methods::*;
62-
use expression::helper_types::Eq;
63-
use helper_types::FindBy;
61+
use helper_types::Filter;
6462

6563
/// Attempts to find a single record from the given table by primary key.
6664
///
@@ -95,13 +93,13 @@ pub trait FindDsl<PK>: AsQuery {
9593
}
9694

9795
impl<T, PK> FindDsl<PK> for T where
98-
T: Table + FilterDsl<Eq<<T as Table>::PrimaryKey, PK>>,
99-
PK: AsExpression<<T::PrimaryKey as Expression>::SqlType>,
96+
T: Table + FilterDsl<<<T as Table>::PrimaryKey as EqAll<PK>>::Output>,
97+
T::PrimaryKey: EqAll<PK>,
10098
{
101-
type Output = FindBy<Self, T::PrimaryKey, PK>;
99+
type Output = Filter<Self, <T::PrimaryKey as EqAll<PK>>::Output>;
102100

103101
fn find(self, id: PK) -> Self::Output {
104102
let primary_key = self.primary_key();
105-
self.filter(primary_key.eq(id))
103+
self.filter(primary_key.eq_all(id))
106104
}
107105
}

diesel_codegen_shared/src/schema_inference/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ pub fn get_primary_keys(
9696
if primary_keys.is_empty() {
9797
Err(format!("Diesel only supports tables with primary keys. \
9898
Table {} has no primary key", table_name).into())
99+
} else if primary_keys.len() > 4 {
100+
Err(format!("Diesel does not currently support tables with \
101+
primary keys consisting of more than 4 columns. \
102+
Table {} has {} columns in its primary key. \
103+
Please open an issue and we will increase the \
104+
limit.", table_name, primary_keys.len()).into())
99105
} else {
100106
Ok(primary_keys)
101107
}

diesel_compile_tests/tests/compile-fail/find_requires_correct_type.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ fn main() {
2929
//~| ERROR E0277
3030
//~| ERROR E0277
3131
//~| ERROR E0277
32-
//~| ERROR E0277
3332
}

diesel_tests/Cargo.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ stable_sqlite = ["with-syntex", "sqlite", "diesel_codegen_syntex/sqlite"]
3131
unstable_postgres = ["unstable", "postgres", "diesel_codegen/postgres"]
3232
unstable_sqlite = ["unstable", "sqlite", "diesel_codegen/sqlite"]
3333

34-
[lib]
35-
name = "integration_tests"
36-
path = "tests/lib.rs"
37-
3834
[[test]]
3935
name = "integration_tests"
4036
path = "tests/lib.rs"

diesel_tests/tests/find.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,20 @@ fn find_with_non_serial_pk() {
3535
assert_eq!(Ok(("Tess".to_string(),)), users.find("Tess".to_string()).first(&connection));
3636
assert_eq!(Ok(None::<(String,)>), users.find("Wibble").first(&connection).optional());
3737
}
38+
39+
#[test]
40+
fn find_with_composite_pk() {
41+
use schema::followings::dsl::*;
42+
43+
let first_following = Following { user_id: 1, post_id: 1, email_notifications: true };
44+
let second_following = Following { user_id: 1, post_id: 2, email_notifications: false };
45+
let third_following = Following { user_id: 2, post_id: 1, email_notifications: false };
46+
47+
let connection = connection();
48+
batch_insert(&[first_following, second_following, third_following], followings, &connection);
49+
50+
assert_eq!(Ok(first_following), followings.find((1, 1)).first(&connection));
51+
assert_eq!(Ok(second_following), followings.find((1, 2)).first(&connection));
52+
assert_eq!(Ok(third_following), followings.find((2, 1)).first(&connection));
53+
assert_eq!(Ok(None::<Following>), followings.find((2, 2)).first(&connection).optional());
54+
}

diesel_tests/tests/schema.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ impl Comment {
4949
}
5050
}
5151

52-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53-
#[allow(dead_code)]
52+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable, Insertable)]
53+
#[table_name="followings"]
5454
pub struct Following {
5555
pub user_id: i32,
5656
pub post_id: i32,

0 commit comments

Comments
 (0)