Skip to content

Commit 58b6982

Browse files
authored
Merge pull request diesel-rs#1691 from diesel-rs/sg-sqlite-custom-functions
Support custom functions on SQLite
2 parents 2ed128d + fbf9002 commit 58b6982

8 files changed

Lines changed: 517 additions & 107 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
2323
* Added `sqlite-bundled` feature to `diesel_cli` to make installing on
2424
some platforms easier.
2525

26+
* Custom SQL functions can now be used with SQLite. See [the
27+
docs][sql-function-sqlite-1-3-0] for details.
28+
2629
* All functions and operators provided by Diesel can now be used with numeric
2730
operators if the SQL type supports it.
2831

diesel/src/expression/functions/mod.rs

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,94 @@ macro_rules! __diesel_sql_function_body {
260260
}
261261
}
262262
}
263+
264+
__diesel_sqlite_register_fn! {
265+
type_args = ($($type_args)*),
266+
fn_name = $fn_name,
267+
args = ($($arg_name,)+),
268+
sql_args = ($($arg_type,)+),
269+
ret = $return_type,
270+
}
263271
}
264272
}
265273
}
266274

275+
#[macro_export]
276+
#[doc(hidden)]
277+
#[cfg(feature = "sqlite")]
278+
macro_rules! __diesel_sqlite_register_fn {
279+
// We can't handle generic functions for SQLite
280+
(
281+
type_args = ($($type_args:tt)+),
282+
$($rest:tt)*
283+
) => {
284+
};
285+
286+
(
287+
type_args = (),
288+
fn_name = $fn_name:ident,
289+
args = ($($args:ident,)+),
290+
sql_args = $sql_args:ty,
291+
ret = $ret:ty,
292+
) => {
293+
#[allow(dead_code)]
294+
/// Registers an implementation for this function on the given connection
295+
///
296+
/// This function must be called for every `SqliteConnection` before
297+
/// this SQL function can be used on SQLite. The implementation must be
298+
/// deterministic (returns the same result given the same arguments). If
299+
/// the function is nondeterministic, call
300+
/// `register_nondeterministic_impl` instead.
301+
pub fn register_impl<F, Ret, $($args,)+>(
302+
conn: &$crate::SqliteConnection,
303+
f: F,
304+
) -> $crate::QueryResult<()>
305+
where
306+
F: Fn($($args,)+) -> Ret + Send + 'static,
307+
($($args,)+): $crate::deserialize::Queryable<$sql_args, $crate::sqlite::Sqlite>,
308+
Ret: $crate::serialize::ToSql<$ret, $crate::sqlite::Sqlite>,
309+
{
310+
conn.register_sql_function::<$sql_args, $ret, _, _, _>(
311+
stringify!($fn_name),
312+
true,
313+
move |($($args,)+)| f($($args),+),
314+
)
315+
}
316+
317+
#[allow(dead_code)]
318+
/// Registers an implementation for this function on the given connection
319+
///
320+
/// This function must be called for every `SqliteConnection` before
321+
/// this SQL function can be used on SQLite.
322+
/// `register_nondeterministic_impl` should only be used if your
323+
/// function can return different results with the same arguments (e.g.
324+
/// `random`). If your function is deterministic, you should call
325+
/// `register_impl` instead.
326+
pub fn register_nondeterministic_impl<F, Ret, $($args,)+>(
327+
conn: &$crate::SqliteConnection,
328+
mut f: F,
329+
) -> $crate::QueryResult<()>
330+
where
331+
F: FnMut($($args,)+) -> Ret + Send + 'static,
332+
($($args,)+): $crate::deserialize::Queryable<$sql_args, $crate::sqlite::Sqlite>,
333+
Ret: $crate::serialize::ToSql<$ret, $crate::sqlite::Sqlite>,
334+
{
335+
conn.register_sql_function::<$sql_args, $ret, _, _, _>(
336+
stringify!($fn_name),
337+
false,
338+
move |($($args,)+)| f($($args),+),
339+
)
340+
}
341+
};
342+
}
343+
344+
#[macro_export]
345+
#[doc(hidden)]
346+
#[cfg(not(feature = "sqlite"))]
347+
macro_rules! __diesel_sqlite_register_fn {
348+
($($token:tt)*) => {};
349+
}
350+
267351
#[macro_export]
268352
/// Declare a sql function for use in your code.
269353
///
@@ -325,7 +409,7 @@ macro_rules! __diesel_sql_function_body {
325409
/// Most attributes given to this macro will be put on the generated function
326410
/// (including doc comments).
327411
///
328-
/// # Example
412+
/// # Adding Doc Comments
329413
///
330414
/// ```no_run
331415
/// # #[macro_use] extern crate diesel;
@@ -350,6 +434,8 @@ macro_rules! __diesel_sql_function_body {
350434
/// # }
351435
/// ```
352436
///
437+
/// # Special Attributes
438+
///
353439
/// There are a handful of special attributes that Diesel will recognize. They
354440
/// are:
355441
///
@@ -382,6 +468,49 @@ macro_rules! __diesel_sql_function_body {
382468
/// crates.select(sum(id));
383469
/// # }
384470
/// ```
471+
///
472+
/// # Use with SQLite
473+
///
474+
/// On most backends, the implementation of the function is defined in a
475+
/// migration using `CREATE FUNCTION`. On SQLite, the function is implemented in
476+
/// Rust instead. You must call `register_impl` or
477+
/// `register_nondeterministic_impl` with every connection before you can use
478+
/// the function.
479+
///
480+
/// These functions will only be generated if the `sqlite` feature is enabled,
481+
/// and the function is not generic. Generic functions and variadic functions
482+
/// are not supported on SQLite.
483+
///
484+
/// ```rust
485+
/// # #[macro_use] extern crate diesel;
486+
/// # use diesel::*;
487+
/// #
488+
/// # #[cfg(feature = "sqlite")]
489+
/// # fn main() {
490+
/// # run_test().unwrap();
491+
/// # }
492+
/// #
493+
/// # #[cfg(not(feature = "sqlite"))]
494+
/// # fn main() {
495+
/// # }
496+
/// #
497+
/// use diesel::sql_types::{Integer, Double};
498+
/// sql_function!(fn add_mul(x: Integer, y: Integer, z: Double) -> Double);
499+
///
500+
/// # #[cfg(feature = "sqlite")]
501+
/// # fn run_test() -> Result<(), Box<::std::error::Error>> {
502+
/// let connection = SqliteConnection::establish(":memory:")?;
503+
///
504+
/// add_mul::register_impl(&connection, |x: i32, y: i32, z: f64| {
505+
/// (x + y) as f64 * z
506+
/// })?;
507+
///
508+
/// let result = select(add_mul(1, 2, 1.5))
509+
/// .get_result::<f64>(&connection)?;
510+
/// assert_eq!(4.5, result);
511+
/// # Ok(())
512+
/// # }
513+
/// ```
385514
macro_rules! sql_function {
386515
($(#$meta:tt)* fn $fn_name:ident $args:tt $(;)*) => {
387516
sql_function!($(#[$meta])* fn $fn_name $args -> ());
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
extern crate libsqlite3_sys as ffi;
2+
3+
use deserialize::{FromSqlRow, Queryable};
4+
use result::{DatabaseErrorKind, Error, QueryResult};
5+
use row::Row;
6+
use serialize::{IsNull, Output, ToSql};
7+
use sql_types::HasSqlType;
8+
use super::raw::RawConnection;
9+
use super::serialized_value::SerializedValue;
10+
use super::{Sqlite, SqliteValue};
11+
12+
pub fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
13+
conn: &RawConnection,
14+
fn_name: &str,
15+
deterministic: bool,
16+
mut f: F,
17+
) -> QueryResult<()>
18+
where
19+
F: FnMut(Args) -> Ret + Send + 'static,
20+
Args: Queryable<ArgsSqlType, Sqlite>,
21+
Ret: ToSql<RetSqlType, Sqlite>,
22+
Sqlite: HasSqlType<RetSqlType>,
23+
{
24+
let fields_needed = Args::Row::FIELDS_NEEDED;
25+
if fields_needed > 127 {
26+
return Err(Error::DatabaseError(
27+
DatabaseErrorKind::UnableToSendCommand,
28+
Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
29+
));
30+
}
31+
32+
conn.register_sql_function(fn_name, fields_needed, deterministic, move |args| {
33+
let mut row = FunctionRow { args };
34+
let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?;
35+
let args = Args::build(args_row);
36+
37+
let result = f(args);
38+
39+
let mut buf = Output::new(Vec::new(), &());
40+
let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
41+
42+
let bytes = if let IsNull::Yes = is_null {
43+
None
44+
} else {
45+
Some(buf.into_inner())
46+
};
47+
48+
Ok(SerializedValue {
49+
ty: Sqlite::metadata(&()),
50+
data: bytes,
51+
})
52+
})?;
53+
Ok(())
54+
}
55+
56+
struct FunctionRow<'a> {
57+
args: &'a [*mut ffi::sqlite3_value],
58+
}
59+
60+
impl<'a> Row<Sqlite> for FunctionRow<'a> {
61+
fn take(&mut self) -> Option<&SqliteValue> {
62+
self.args.split_first().and_then(|(&first, rest)| {
63+
self.args = rest;
64+
unsafe { SqliteValue::new(first) }
65+
})
66+
}
67+
68+
fn next_is_null(&self, count: usize) -> bool {
69+
self.args[..count]
70+
.iter()
71+
.all(|&p| unsafe { SqliteValue::new(p) }.is_none())
72+
}
73+
}

diesel/src/sqlite/connection/mod.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
extern crate libsqlite3_sys as ffi;
22

3+
mod functions;
34
#[doc(hidden)]
45
pub mod raw;
6+
mod serialized_value;
57
mod stmt;
68
mod statement_iterator;
79
mod sqlite_value;
@@ -18,6 +20,7 @@ use result::*;
1820
use self::raw::RawConnection;
1921
use self::statement_iterator::*;
2022
use self::stmt::{Statement, StatementUse};
23+
use serialize::ToSql;
2124
use sql_types::HasSqlType;
2225
use sqlite::Sqlite;
2326

@@ -208,6 +211,22 @@ impl SqliteConnection {
208211
Statement::prepare(&self.raw_connection, sql)
209212
})
210213
}
214+
215+
#[doc(hidden)]
216+
pub fn register_sql_function<ArgsSqlType, RetSqlType, Args, Ret, F>(
217+
&self,
218+
fn_name: &str,
219+
deterministic: bool,
220+
f: F,
221+
) -> QueryResult<()>
222+
where
223+
F: FnMut(Args) -> Ret + Send + 'static,
224+
Args: Queryable<ArgsSqlType, Sqlite>,
225+
Ret: ToSql<RetSqlType, Sqlite>,
226+
Sqlite: HasSqlType<RetSqlType>,
227+
{
228+
functions::register(&self.raw_connection, fn_name, deterministic, f)
229+
}
211230
}
212231

213232
fn error_message(err_code: libc::c_int) -> &'static str {
@@ -269,4 +288,56 @@ mod tests {
269288
assert_eq!(Ok(true), query.get_result(&connection));
270289
assert_eq!(1, connection.statement_cache.len());
271290
}
291+
292+
use sql_types::Text;
293+
sql_function!(fn fun_case(x: Text) -> Text);
294+
295+
#[test]
296+
fn register_custom_function() {
297+
let connection = SqliteConnection::establish(":memory:").unwrap();
298+
fun_case::register_impl(&connection, |x: String| {
299+
x.chars()
300+
.enumerate()
301+
.map(|(i, c)| {
302+
if i % 2 == 0 {
303+
c.to_lowercase().to_string()
304+
} else {
305+
c.to_uppercase().to_string()
306+
}
307+
})
308+
.collect::<String>()
309+
}).unwrap();
310+
311+
let mapped_string = ::select(fun_case("foobar"))
312+
.get_result::<String>(&connection)
313+
.unwrap();
314+
assert_eq!("fOoBaR", mapped_string);
315+
}
316+
317+
sql_function!(fn my_add(x: Integer, y: Integer) -> Integer);
318+
319+
#[test]
320+
fn register_multiarg_function() {
321+
let connection = SqliteConnection::establish(":memory:").unwrap();
322+
my_add::register_impl(&connection, |x: i32, y: i32| x + y).unwrap();
323+
324+
let added = ::select(my_add(1, 2)).get_result::<i32>(&connection);
325+
assert_eq!(Ok(3), added);
326+
}
327+
328+
sql_function!(fn add_counter(x: Integer) -> Integer);
329+
330+
#[test]
331+
fn register_nondeterministic_function() {
332+
let connection = SqliteConnection::establish(":memory:").unwrap();
333+
let mut y = 0;
334+
add_counter::register_nondeterministic_impl(&connection, move |x: i32| {
335+
y += 1;
336+
x + y
337+
}).unwrap();
338+
339+
let added = ::select((add_counter(1), add_counter(1), add_counter(1)))
340+
.get_result::<(i32, i32, i32)>(&connection);
341+
assert_eq!(Ok((2, 3, 4)), added);
342+
}
272343
}

0 commit comments

Comments
 (0)