Skip to content

Commit 8fb0429

Browse files
kivikakksgrif
authored andcommitted
Add a function for constructing PG array literals
1 parent b67624a commit 8fb0429

8 files changed

Lines changed: 227 additions & 1 deletion

File tree

diesel/src/expression/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,17 @@ impl<'a, QS, ST, DB> QueryId for BoxableExpression<QS, DB, SqlType = ST> + 'a {
339339

340340
const HAS_STATIC_QUERY_ID: bool = false;
341341
}
342+
343+
/// Converts a tuple of values into a tuple of Diesel expressions.
344+
///
345+
/// This trait is similar to [`AsExpression`], but it operates on tuples.
346+
/// The expressions must all be of the same SQL type.
347+
///
348+
/// [`AsExpression`]: trait.AsExpression.html
349+
pub trait AsExpressionList<ST> {
350+
/// The final output expression
351+
type Expression;
352+
353+
/// Perform the conversion
354+
fn as_expression_list(self) -> Self::Expression;
355+
}

diesel/src/pg/expression/array.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use std::marker::PhantomData;
2+
use backend::Backend;
3+
use expression::{AppearsOnTable, AsExpressionList, Expression, NonAggregate, SelectableExpression};
4+
use query_builder::{AstPass, QueryFragment};
5+
use types;
6+
7+
/// An ARRAY[...] literal.
8+
#[derive(Debug, Clone, Copy, QueryId)]
9+
pub struct ArrayLiteral<T, ST> {
10+
elements: T,
11+
_marker: PhantomData<ST>,
12+
}
13+
14+
/// Creates an `ARRAY[...]` expression.
15+
///
16+
/// The argument should be a tuple of expressions which can be represented by the
17+
/// same SQL type.
18+
///
19+
/// # Examples
20+
///
21+
/// ```rust
22+
/// # #[macro_use] extern crate diesel;
23+
/// # include!("../../doctest_setup.rs");
24+
/// #
25+
/// # fn main() {
26+
/// # run_test().unwrap();
27+
/// # }
28+
/// #
29+
/// # fn run_test() -> QueryResult<()> {
30+
/// # use schema::users::dsl::*;
31+
/// # use diesel::dsl::array;
32+
/// # use diesel::types::Integer;
33+
/// # let connection = establish_connection();
34+
/// let ints = diesel::select(array::<Integer, _>((1, 2)))
35+
/// .get_result::<Vec<i32>>(&connection)?;
36+
/// assert_eq!(vec![1, 2], ints);
37+
///
38+
/// let ids = users.select(array((id, id * 2)))
39+
/// .get_results::<Vec<i32>>(&connection)?;
40+
/// let expected = vec![
41+
/// vec![1, 2],
42+
/// vec![2, 4],
43+
/// ];
44+
/// assert_eq!(expected, ids);
45+
/// # Ok(())
46+
/// # }
47+
/// ```
48+
pub fn array<ST, T>(elements: T) -> ArrayLiteral<T::Expression, ST>
49+
where
50+
T: AsExpressionList<ST>,
51+
{
52+
ArrayLiteral {
53+
elements: elements.as_expression_list(),
54+
_marker: PhantomData,
55+
}
56+
}
57+
58+
impl<T, ST> Expression for ArrayLiteral<T, ST>
59+
where
60+
T: Expression,
61+
{
62+
type SqlType = types::Array<ST>;
63+
}
64+
65+
impl<T, ST, DB> QueryFragment<DB> for ArrayLiteral<T, ST>
66+
where
67+
DB: Backend,
68+
for<'a> (&'a T): QueryFragment<DB>,
69+
{
70+
fn walk_ast(&self, mut out: AstPass<DB>) -> ::result::QueryResult<()> {
71+
out.push_sql("ARRAY[");
72+
QueryFragment::walk_ast(&&self.elements, out.reborrow())?;
73+
out.push_sql("]");
74+
Ok(())
75+
}
76+
}
77+
78+
impl<T, ST, QS> SelectableExpression<QS> for ArrayLiteral<T, ST>
79+
where
80+
T: SelectableExpression<QS>,
81+
ArrayLiteral<T, ST>: AppearsOnTable<QS>,
82+
{
83+
}
84+
85+
impl<T, ST, QS> AppearsOnTable<QS> for ArrayLiteral<T, ST>
86+
where
87+
T: AppearsOnTable<QS>,
88+
ArrayLiteral<T, ST>: Expression,
89+
{
90+
}
91+
92+
impl<T, ST> NonAggregate for ArrayLiteral<T, ST>
93+
where
94+
T: NonAggregate,
95+
ArrayLiteral<T, ST>: Expression,
96+
{
97+
}

diesel/src/pg/expression/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
#[doc(hidden)]
88
pub mod array_comparison;
9+
pub(crate) mod array;
910
pub(crate) mod expression_methods;
1011
pub mod extensions;
1112
#[doc(hidden)]
@@ -22,5 +23,8 @@ pub mod dsl {
2223
#[doc(inline)]
2324
pub use super::array_comparison::{all, any};
2425

26+
#[doc(inline)]
27+
pub use super::array::array;
28+
2529
pub use super::extensions::*;
2630
}

diesel/src/types/impls/tuples.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::error::Error;
33
use associations::BelongsTo;
44
use backend::Backend;
55
use deserialize;
6-
use expression::{AppearsOnTable, Expression, NonAggregate, SelectableExpression};
6+
use expression::{AsExpression, AppearsOnTable, Expression, AsExpressionList, NonAggregate, SelectableExpression};
77
use insertable::{CanInsertInSingleQuery, InsertValues, Insertable};
88
use query_builder::*;
99
use query_source::*;
@@ -246,6 +246,16 @@ macro_rules! tuple_impls {
246246
($($T,)+ next)
247247
}
248248
}
249+
250+
impl<$($T,)+ ST> AsExpressionList<ST> for ($($T,)+) where
251+
$($T: AsExpression<ST>,)+
252+
{
253+
type Expression = ($($T::Expression,)+);
254+
255+
fn as_expression_list(self) -> Self::Expression {
256+
($(self.$idx.as_expression(),)+)
257+
}
258+
}
249259
)+
250260
}
251261
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#[macro_use]
2+
extern crate diesel;
3+
4+
use diesel::*;
5+
use diesel::dsl::*;
6+
7+
fn main() {
8+
let connection = PgConnection::establish("").unwrap();
9+
select(array((1, 3))).get_result::<Vec<i32>>(&connection);
10+
select(array((1f64, 3f64))).get_result::<Vec<i32>>(&connection);
11+
//~^ ERROR E0277
12+
//~| ERROR E0277
13+
//~| ERROR E0277
14+
//~| ERROR E0277
15+
//~| ERROR E0277
16+
//~| ERROR E0277
17+
//~| ERROR E0277
18+
//~| ERROR E0277
19+
select(array((1f64, 3f64))).get_result::<Vec<f64>>(&connection);
20+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#[macro_use]
2+
extern crate diesel;
3+
4+
use diesel::*;
5+
use diesel::dsl::*;
6+
7+
fn main() {
8+
let connection = PgConnection::establish("").unwrap();
9+
select(array((1, 3))).get_result::<Vec<i32>>(&connection).unwrap();
10+
select(array((1f64, 3f64))).get_result::<Vec<f64>>(&connection).unwrap();
11+
12+
select(array((1, 3f64))).get_result::<Vec<i32>>(&connection).unwrap();
13+
//~^ ERROR E0277
14+
//~| ERROR E0277
15+
//~| ERROR E0277
16+
//~| ERROR E0277
17+
//~| ERROR E0277
18+
//~| ERROR E0277
19+
//~| ERROR E0277
20+
21+
select(array((1, 3f64))).get_result::<Vec<f64>>(&connection).unwrap();
22+
//~^ ERROR E0277
23+
//~| ERROR E0277
24+
//~| ERROR E0277
25+
//~| ERROR E0277
26+
//~| ERROR E0277
27+
//~| ERROR E0277
28+
//~| ERROR E0277
29+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#[macro_use]
2+
extern crate diesel;
3+
4+
use diesel::*;
5+
use diesel::dsl::*;
6+
7+
fn main() {
8+
let connection = SqliteConnection::establish("").unwrap();
9+
select(array((1,))).get_result::<Vec<i32>>(&connection);
10+
//~^ ERROR E0271
11+
//~| ERROR E0277
12+
13+
let connection = MysqlConnection::establish("").unwrap();
14+
select(array((1,))).get_result::<Vec<i32>>(&connection);
15+
//~^ ERROR E0271
16+
//~| ERROR E0277
17+
}

diesel_tests/tests/expressions/mod.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,38 @@ fn test_avg_for_numeric() {
406406
};
407407
assert_eq!(Ok(Some(expected_result)), result);
408408
}
409+
410+
#[test]
411+
#[cfg(feature = "postgres")]
412+
fn test_arrays_a() {
413+
let connection = connection();
414+
415+
use diesel::types::Int4;
416+
let value = select(array::<Int4, _>((1, 2)))
417+
.get_result::<Vec<i32>>(&connection)
418+
.unwrap();
419+
420+
assert_eq!(value, vec![1, 2]);
421+
}
422+
423+
#[test]
424+
#[cfg(feature = "postgres")]
425+
fn test_arrays_b() {
426+
use diesel::types::{Array, Int4};
427+
sql_function!(unnest, unnest_t, (a: Array<Int4>) -> Int4);
428+
429+
use self::numbers::columns::*;
430+
use self::numbers::table as numbers;
431+
432+
let connection = connection();
433+
connection
434+
.execute("INSERT INTO numbers (n) VALUES (7)")
435+
.unwrap();
436+
437+
let value = numbers
438+
.select(unnest(array((n, n + n))))
439+
.load::<i32>(&connection)
440+
.unwrap();
441+
442+
assert_eq!(value, vec![7, 14]);
443+
}

0 commit comments

Comments
 (0)