Skip to content

Commit da4447b

Browse files
committed
Fix the custom sql function implementation for Sqlite
1 parent e581285 commit da4447b

7 files changed

Lines changed: 235 additions & 128 deletions

File tree

diesel/src/sqlite/connection/functions.rs

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
extern crate libsqlite3_sys as ffi;
22

33
use super::raw::RawConnection;
4+
use super::row::PrivateSqliteRow;
45
use super::serialized_value::SerializedValue;
56
use super::{Sqlite, SqliteAggregateFunction};
67
use crate::deserialize::{FromSqlRow, StaticallySizedRow};
78
use crate::result::{DatabaseErrorKind, Error, QueryResult};
89
use crate::row::{Field, PartialRow, Row, RowIndex};
910
use crate::serialize::{IsNull, Output, ToSql};
1011
use crate::sql_types::HasSqlType;
12+
use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
13+
use crate::sqlite::SqliteValue;
14+
use std::cell::{Ref, RefCell};
1115
use std::marker::PhantomData;
16+
use std::mem::ManuallyDrop;
17+
use std::ops::DerefMut;
18+
use std::rc::Rc;
1219

1320
pub fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
1421
conn: &RawConnection,
@@ -85,7 +92,7 @@ where
8592
}
8693

8794
pub(crate) fn build_sql_function_args<ArgsSqlType, Args>(
88-
args: &[*mut ffi::sqlite3_value],
95+
args: &mut [*mut ffi::sqlite3_value],
8996
) -> Result<Args, Error>
9097
where
9198
Args: FromSqlRow<ArgsSqlType, Sqlite>,
@@ -117,14 +124,67 @@ where
117124
})
118125
}
119126

120-
#[derive(Clone)]
121127
struct FunctionRow<'a> {
122-
args: &'a [*mut ffi::sqlite3_value],
128+
// we use `ManuallyDrop` to prevent dropping the content of the internal vector
129+
// as this buffer is owned by sqlite not by diesel
130+
args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a>>>>,
131+
field_count: usize,
132+
marker: PhantomData<&'a ffi::sqlite3_value>,
133+
}
134+
135+
impl<'a> Drop for FunctionRow<'a> {
136+
fn drop(&mut self) {
137+
if let Some(args) = Rc::get_mut(&mut self.args) {
138+
if let PrivateSqliteRow::Duplicated { column_names, .. } =
139+
DerefMut::deref_mut(RefCell::get_mut(args))
140+
{
141+
if let Some(inner) = Rc::get_mut(column_names) {
142+
// an empty Vector does not allocate according to the documentation
143+
// so this prevents leaking memory
144+
std::mem::drop(std::mem::replace(inner, Vec::new()));
145+
}
146+
}
147+
}
148+
}
123149
}
124150

125151
impl<'a> FunctionRow<'a> {
126-
fn new(args: &'a [*mut ffi::sqlite3_value]) -> Self {
127-
Self { args }
152+
fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
153+
let lenghts = args.len();
154+
let args = unsafe {
155+
Vec::from_raw_parts(
156+
// This cast is safe because:
157+
// * Casting from a pointer to an arry to a pointer to the first array
158+
// element is safe
159+
// * Casting from a raw pointer to `NonNull<T>` is safe,
160+
// because `NonNull` is #[repr(transparent)]
161+
// * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
162+
// as the struct is `#[repr(transparent)]
163+
// * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
164+
// states: "This is so that enums may use this forbidden value as a discriminant –
165+
// Option<NonNull<T>> has the same size as *mut T"
166+
// * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
167+
// guarantees the same layout as the inner type
168+
// * It's unsafe to drop the vector (and the vector elements)
169+
// because of this we wrap the vector (or better the Row)
170+
// Into `ManualDrop` to prevent the dropping
171+
args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
172+
as *mut Option<OwnedSqliteValue>,
173+
lenghts,
174+
lenghts,
175+
)
176+
};
177+
178+
Self {
179+
field_count: lenghts,
180+
args: Rc::new(RefCell::new(ManuallyDrop::new(
181+
PrivateSqliteRow::Duplicated {
182+
values: args,
183+
column_names: Rc::new(vec![None; lenghts]),
184+
},
185+
))),
186+
marker: PhantomData,
187+
}
128188
}
129189
}
130190

@@ -133,18 +193,17 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
133193
type InnerPartialRow = Self;
134194

135195
fn field_count(&self) -> usize {
136-
self.args.len()
196+
self.field_count
137197
}
138198

139199
fn get<I>(&self, idx: I) -> Option<Self::Field>
140200
where
141201
Self: crate::row::RowIndex<I>,
142202
{
143203
let idx = self.idx(idx)?;
144-
145-
self.args.get(idx).map(|arg| FunctionArgument {
146-
arg: *arg,
147-
p: PhantomData,
204+
Some(FunctionArgument {
205+
args: self.args.clone(),
206+
col_idx: idx as i32,
148207
})
149208
}
150209

@@ -155,7 +214,7 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
155214

156215
impl<'a> RowIndex<usize> for FunctionRow<'a> {
157216
fn idx(&self, idx: usize) -> Option<usize> {
158-
if idx < self.args.len() {
217+
if idx < self.field_count() {
159218
Some(idx)
160219
} else {
161220
None
@@ -170,8 +229,8 @@ impl<'a, 'b> RowIndex<&'a str> for FunctionRow<'b> {
170229
}
171230

172231
struct FunctionArgument<'a> {
173-
arg: *mut ffi::sqlite3_value,
174-
p: PhantomData<&'a ()>,
232+
args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a>>>>,
233+
col_idx: i32,
175234
}
176235

177236
impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
@@ -187,7 +246,9 @@ impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
187246
where
188247
'a: 'b,
189248
{
190-
todo!()
191-
// unsafe { SqliteValue::new(self.arg) }
249+
SqliteValue::new(
250+
Ref::map(self.args.borrow(), |drop| std::ops::Deref::deref(drop)),
251+
self.col_idx,
252+
)
192253
}
193254
}

diesel/src/sqlite/connection/raw.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl RawConnection {
9191
f: F,
9292
) -> QueryResult<()>
9393
where
94-
F: FnMut(&Self, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
94+
F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
9595
+ std::panic::UnwindSafe
9696
+ Send
9797
+ 'static,
@@ -269,7 +269,7 @@ extern "C" fn run_custom_function<F>(
269269
num_args: libc::c_int,
270270
value_ptr: *mut *mut ffi::sqlite3_value,
271271
) where
272-
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
272+
F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
273273
+ std::panic::UnwindSafe
274274
+ Send
275275
+ 'static,
@@ -278,7 +278,6 @@ extern "C" fn run_custom_function<F>(
278278
static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
279279
static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
280280

281-
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
282281
let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
283282
// We use `ManuallyDrop` here because we do not want to run the
284283
// Drop impl of `RawConnection` as this would close the connection
@@ -306,13 +305,16 @@ extern "C" fn run_custom_function<F>(
306305
// this is sound as `F` itself and the stored string is `UnwindSafe`
307306
let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
308307

309-
let result =
310-
std::panic::catch_unwind(move || Ok((callback.0)(&*conn, args)?)).unwrap_or_else(|p| {
311-
Err(SqliteCallbackError::Panic(
312-
p,
313-
data_ptr.function_name.clone(),
314-
))
315-
});
308+
let result = std::panic::catch_unwind(move || {
309+
let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
310+
Ok((callback.0)(&*conn, args)?)
311+
})
312+
.unwrap_or_else(|p| {
313+
Err(SqliteCallbackError::Panic(
314+
p,
315+
data_ptr.function_name.clone(),
316+
))
317+
});
316318
match result {
317319
Ok(value) => value.result_of(ctx),
318320
Err(e) => {
@@ -342,15 +344,16 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
342344
Ret: ToSql<RetSqlType, Sqlite>,
343345
Sqlite: HasSqlType<RetSqlType>,
344346
{
345-
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
346-
let result =
347-
std::panic::catch_unwind(move || run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args))
348-
.unwrap_or_else(|e| {
349-
Err(SqliteCallbackError::Panic(
350-
e,
351-
format!("{}::step() paniced", std::any::type_name::<A>()),
352-
))
353-
});
347+
let result = std::panic::catch_unwind(move || {
348+
let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
349+
run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
350+
})
351+
.unwrap_or_else(|e| {
352+
Err(SqliteCallbackError::Panic(
353+
e,
354+
format!("{}::step() paniced", std::any::type_name::<A>()),
355+
))
356+
});
354357

355358
match result {
356359
Ok(()) => {}
@@ -360,7 +363,7 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
360363

361364
fn run_aggregator_step<A, Args, ArgsSqlType>(
362365
ctx: *mut ffi::sqlite3_context,
363-
args: &[*mut ffi::sqlite3_value],
366+
args: &mut [*mut ffi::sqlite3_value],
364367
) -> Result<(), SqliteCallbackError>
365368
where
366369
A: SqliteAggregateFunction<Args>,

diesel/src/sqlite/connection/row.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,119 @@ impl<'a> Field<'a, Sqlite> for SqliteField<'a> {
164164
SqliteValue::new(self.row.inner.borrow(), self.col_idx)
165165
}
166166
}
167+
168+
#[test]
169+
fn fun_with_row_iters() {
170+
crate::table! {
171+
#[allow(unused_parens)]
172+
users(id) {
173+
id -> Integer,
174+
name -> Text,
175+
}
176+
}
177+
178+
use crate::deserialize::{FromSql, FromSqlRow};
179+
use crate::prelude::*;
180+
use crate::row::{Field, Row};
181+
use crate::sql_types;
182+
183+
let conn = &mut crate::test_helpers::connection();
184+
185+
crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
186+
.execute(conn)
187+
.unwrap();
188+
189+
crate::insert_into(users::table)
190+
.values(vec![
191+
(users::id.eq(1), users::name.eq("Sean")),
192+
(users::id.eq(2), users::name.eq("Tess")),
193+
])
194+
.execute(conn)
195+
.unwrap();
196+
197+
let query = users::table.select((users::id, users::name));
198+
199+
let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
200+
201+
let row_iter = conn.load(&query).unwrap();
202+
for (row, expected) in row_iter.zip(&expected) {
203+
let row = row.unwrap();
204+
205+
let deserialized = <(i32, String) as FromSqlRow<
206+
(sql_types::Integer, sql_types::Text),
207+
_,
208+
>>::build_from_row(&row)
209+
.unwrap();
210+
211+
assert_eq!(&deserialized, expected);
212+
}
213+
214+
{
215+
let collected_rows = conn.load(&query).unwrap().collect::<Vec<_>>();
216+
217+
for (row, expected) in collected_rows.iter().zip(&expected) {
218+
let deserialized = row
219+
.as_ref()
220+
.map(|row| {
221+
<(i32, String) as FromSqlRow<
222+
(sql_types::Integer, sql_types::Text),
223+
_,
224+
>>::build_from_row(row).unwrap()
225+
})
226+
.unwrap();
227+
228+
assert_eq!(&deserialized, expected);
229+
}
230+
}
231+
232+
let mut row_iter = conn.load(&query).unwrap();
233+
234+
let first_row = row_iter.next().unwrap().unwrap();
235+
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
236+
let first_values = (first_fields.0.value(), first_fields.1.value());
237+
238+
assert!(row_iter.next().unwrap().is_err());
239+
std::mem::drop(first_values);
240+
241+
let second_row = row_iter.next().unwrap().unwrap();
242+
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
243+
let second_values = (second_fields.0.value(), second_fields.1.value());
244+
245+
assert!(row_iter.next().unwrap().is_err());
246+
std::mem::drop(second_values);
247+
248+
assert!(row_iter.next().is_none());
249+
250+
let first_values = (first_fields.0.value(), first_fields.1.value());
251+
let second_values = (second_fields.0.value(), second_fields.1.value());
252+
253+
assert_eq!(
254+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
255+
expected[0].0
256+
);
257+
assert_eq!(
258+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
259+
expected[0].1
260+
);
261+
262+
assert_eq!(
263+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0).unwrap(),
264+
expected[1].0
265+
);
266+
assert_eq!(
267+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1).unwrap(),
268+
expected[1].1
269+
);
270+
271+
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
272+
let first_values = (first_fields.0.value(), first_fields.1.value());
273+
274+
assert_eq!(
275+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
276+
expected[0].0
277+
);
278+
assert_eq!(
279+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
280+
expected[0].1
281+
);
282+
}

diesel/src/sqlite/connection/sqlite_value.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct SqliteValue<'a, 'b> {
2323
col_idx: i32,
2424
}
2525

26+
#[repr(transparent)]
2627
pub struct OwnedSqliteValue {
2728
pub(super) value: NonNull<ffi::sqlite3_value>,
2829
}

diesel/src/sqlite/connection/statement_iterator.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,23 @@ impl<'a> Iterator for StatementIterator<'a> {
8888
// a user stored the row in some long time container before calling next another time
8989
// In this case we copy out the current values into a temporary store and advance
9090
// the statement iterator internally afterwards
91-
if let PrivateSqliteRow::Direct(stmt) =
92-
last_row.replace_with(|inner| inner.duplicate(&mut self.column_names))
93-
{
91+
let last_row = {
92+
let mut last_row = match last_row.try_borrow_mut() {
93+
Ok(o) => o,
94+
Err(_e) => {
95+
self.inner = Started(last_row.clone());
96+
return Some(Err(crate::result::Error::DeserializationError(
97+
"Failed to reborrow row. Try to release any `SqliteValue` \
98+
that exists at this point"
99+
.into(),
100+
)));
101+
}
102+
};
103+
let last_row = &mut *last_row;
104+
let duplicated = last_row.duplicate(&mut self.column_names);
105+
std::mem::replace(last_row, duplicated)
106+
};
107+
if let PrivateSqliteRow::Direct(stmt) = last_row {
94108
match stmt.step() {
95109
Err(e) => Some(Err(e)),
96110
Ok(None) => None,

0 commit comments

Comments
 (0)