Skip to content

Commit 85e400e

Browse files
committed
Refactor type serializers
1 parent 72f568b commit 85e400e

8 files changed

Lines changed: 324 additions & 271 deletions

File tree

src/datetime.rs

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
22

3+
use crate::exc::*;
34
use crate::typeref::*;
4-
use serde::ser::{Error, Serialize, Serializer};
5+
use serde::ser::{Serialize, Serializer};
56

67
pub const NAIVE_UTC: u16 = 1 << 1;
78
pub const OMIT_MICROSECONDS: u16 = 1 << 2;
@@ -15,12 +16,6 @@ const COLON: u8 = 58; // ":"
1516
const PERIOD: u8 = 46; // ":"
1617
const Z: u8 = 90; // "Z"
1718

18-
macro_rules! err {
19-
($msg:expr) => {
20-
return Err(Error::custom($msg));
21-
};
22-
}
23-
2419
pub type DateTimeBuffer = heapless::Vec<u8, heapless::consts::U32>;
2520

2621
macro_rules! write_double_digit {
@@ -55,13 +50,7 @@ impl Date {
5550
pub fn new(ptr: *mut pyo3::ffi::PyObject) -> Self {
5651
Date { ptr: ptr }
5752
}
58-
}
59-
impl<'p> Serialize for Date {
60-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61-
where
62-
S: Serializer,
63-
{
64-
let mut buf: DateTimeBuffer = heapless::Vec::new();
53+
pub fn write_buf(&self, buf: &mut DateTimeBuffer) {
6554
{
6655
let year = ffi!(PyDateTime_GET_YEAR(self.ptr)) as i32;
6756
buf.extend_from_slice(itoa::Buffer::new().format(year).as_bytes())
@@ -77,33 +66,39 @@ impl<'p> Serialize for Date {
7766
let day = ffi!(PyDateTime_GET_DAY(self.ptr)) as u32;
7867
write_double_digit!(buf, day);
7968
}
69+
}
70+
}
71+
impl<'p> Serialize for Date {
72+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73+
where
74+
S: Serializer,
75+
{
76+
let mut buf: DateTimeBuffer = heapless::Vec::new();
77+
self.write_buf(&mut buf);
8078
serializer.serialize_str(str_from_slice!(buf.as_ptr(), buf.len()))
8179
}
8280
}
8381

82+
pub enum TimeError {
83+
HasTimezone,
84+
}
85+
8486
pub struct Time {
8587
ptr: *mut pyo3::ffi::PyObject,
8688
opts: u16,
8789
}
8890

8991
impl Time {
90-
pub fn new(ptr: *mut pyo3::ffi::PyObject, opts: u16) -> Self {
91-
Time {
92+
pub fn new(ptr: *mut pyo3::ffi::PyObject, opts: u16) -> Result<Self, TimeError> {
93+
if unsafe { (*(ptr as *mut pyo3::ffi::PyDateTime_Time)).hastzinfo == 1 } {
94+
return Err(TimeError::HasTimezone);
95+
}
96+
Ok(Time {
9297
ptr: ptr,
9398
opts: opts,
94-
}
99+
})
95100
}
96-
}
97-
98-
impl<'p> Serialize for Time {
99-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100-
where
101-
S: Serializer,
102-
{
103-
if unsafe { (*(self.ptr as *mut pyo3::ffi::PyDateTime_Time)).hastzinfo == 1 } {
104-
err!("datetime.time must not have tzinfo set")
105-
}
106-
let mut buf: DateTimeBuffer = heapless::Vec::new();
101+
pub fn write_buf(&self, buf: &mut DateTimeBuffer) {
107102
{
108103
let hour = ffi!(PyDateTime_TIME_GET_HOUR(self.ptr)) as u8;
109104
write_double_digit!(buf, hour);
@@ -122,10 +117,24 @@ impl<'p> Serialize for Time {
122117
let microsecond = ffi!(PyDateTime_TIME_GET_MICROSECOND(self.ptr)) as u32;
123118
write_microsecond!(buf, microsecond);
124119
}
120+
}
121+
}
122+
123+
impl<'p> Serialize for Time {
124+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
125+
where
126+
S: Serializer,
127+
{
128+
let mut buf: DateTimeBuffer = heapless::Vec::new();
129+
self.write_buf(&mut buf);
125130
serializer.serialize_str(str_from_slice!(buf.as_ptr(), buf.len()))
126131
}
127132
}
128133

134+
pub enum DateTimeError {
135+
LibraryUnsupported,
136+
}
137+
129138
pub struct DateTime {
130139
ptr: *mut pyo3::ffi::PyObject,
131140
opts: u16,
@@ -138,14 +147,7 @@ impl DateTime {
138147
opts: opts,
139148
}
140149
}
141-
}
142-
143-
impl<'p> Serialize for DateTime {
144-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145-
where
146-
S: Serializer,
147-
{
148-
let mut buf: DateTimeBuffer = heapless::Vec::new();
150+
pub fn write_buf(&self, buf: &mut DateTimeBuffer) -> Result<(), DateTimeError> {
149151
let has_tz = unsafe { (*(self.ptr as *mut pyo3::ffi::PyDateTime_DateTime)).hastzinfo == 1 };
150152
let offset_day: i32;
151153
let mut offset_second: i32;
@@ -195,7 +197,7 @@ impl<'p> Serialize for DateTime {
195197
offset_second = ffi!(PyDateTime_DELTA_GET_SECONDS(offset)) as i32;
196198
offset_day = ffi!(PyDateTime_DELTA_GET_DAYS(offset));
197199
} else {
198-
err!("datetime's timezone library is not supported: use datetime.timezone.utc, pendulum, pytz, or dateutil")
200+
return Err(DateTimeError::LibraryUnsupported);
199201
}
200202
} else {
201203
offset_second = 0;
@@ -285,6 +287,19 @@ impl<'p> Serialize for DateTime {
285287
}
286288
}
287289
}
290+
Ok(())
291+
}
292+
}
293+
294+
impl<'p> Serialize for DateTime {
295+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296+
where
297+
S: Serializer,
298+
{
299+
let mut buf: DateTimeBuffer = heapless::Vec::new();
300+
if self.write_buf(&mut buf).is_err() {
301+
err!(DATETIME_LIBRARY_UNSUPPORTED)
302+
}
288303
serializer.serialize_str(str_from_slice!(buf.as_ptr(), buf.len()))
289304
}
290305
}

src/default.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
2+
3+
use crate::encode::*;
4+
5+
use serde::ser::{Serialize, Serializer};
6+
use std::ffi::CStr;
7+
8+
use std::ptr::NonNull;
9+
10+
macro_rules! obj_name {
11+
($obj:ident) => {
12+
unsafe { CStr::from_ptr((*$obj).tp_name).to_string_lossy() }
13+
};
14+
}
15+
16+
pub struct DefaultSerializer {
17+
ptr: *mut pyo3::ffi::PyObject,
18+
opts: u16,
19+
default_calls: u8,
20+
recursion: u8,
21+
default: Option<NonNull<pyo3::ffi::PyObject>>,
22+
}
23+
24+
impl DefaultSerializer {
25+
pub fn new(
26+
ptr: *mut pyo3::ffi::PyObject,
27+
opts: u16,
28+
default_calls: u8,
29+
recursion: u8,
30+
default: Option<NonNull<pyo3::ffi::PyObject>>,
31+
) -> Self {
32+
DefaultSerializer {
33+
ptr: ptr,
34+
opts: opts,
35+
default_calls: default_calls,
36+
recursion: recursion,
37+
default: default,
38+
}
39+
}
40+
}
41+
42+
impl<'p> Serialize for DefaultSerializer {
43+
#[inline(never)]
44+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
45+
where
46+
S: Serializer,
47+
{
48+
match self.default {
49+
Some(callable) => {
50+
if unlikely!(self.default_calls == RECURSION_LIMIT) {
51+
err!("default serializer exceeds recursion limit")
52+
}
53+
let obj_ptr = unsafe { (*self.ptr).ob_type };
54+
let default_obj = unsafe {
55+
pyo3::ffi::PyObject_CallFunctionObjArgs(
56+
callable.as_ptr(),
57+
self.ptr,
58+
std::ptr::null_mut() as *mut pyo3::ffi::PyObject,
59+
)
60+
};
61+
if default_obj.is_null() {
62+
err!(format_args!(
63+
"Type is not JSON serializable: {}",
64+
obj_name!(obj_ptr)
65+
))
66+
} else if !ffi!(PyErr_Occurred()).is_null() {
67+
err!(format_args!(
68+
"Type raised exception in default function: {}",
69+
obj_name!(obj_ptr)
70+
))
71+
} else {
72+
let res = SerializePyObject::new(
73+
default_obj,
74+
None,
75+
self.opts,
76+
self.default_calls + 1,
77+
self.recursion,
78+
self.default,
79+
)
80+
.serialize(serializer);
81+
ffi!(Py_DECREF(default_obj));
82+
res
83+
}
84+
}
85+
None => {
86+
let obj_ptr = unsafe { (*self.ptr).ob_type };
87+
err!(format_args!(
88+
"Type is not JSON serializable: {}",
89+
obj_name!(obj_ptr)
90+
))
91+
}
92+
}
93+
}
94+
}

src/dict.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
2+
3+
use crate::encode::*;
4+
use crate::exc::*;
5+
use crate::typeref::*;
6+
use crate::unicode::*;
7+
use serde::ser::{Serialize, SerializeMap, Serializer};
8+
use smallvec::SmallVec;
9+
use std::ptr::NonNull;
10+
11+
pub struct DictSortedKey {
12+
ptr: *mut pyo3::ffi::PyObject,
13+
opts: u16,
14+
default_calls: u8,
15+
recursion: u8,
16+
default: Option<NonNull<pyo3::ffi::PyObject>>,
17+
}
18+
19+
impl DictSortedKey {
20+
pub fn new(
21+
ptr: *mut pyo3::ffi::PyObject,
22+
opts: u16,
23+
default_calls: u8,
24+
recursion: u8,
25+
default: Option<NonNull<pyo3::ffi::PyObject>>,
26+
) -> Self {
27+
DictSortedKey {
28+
ptr: ptr,
29+
opts: opts,
30+
default_calls: default_calls,
31+
recursion: recursion,
32+
default: default,
33+
}
34+
}
35+
}
36+
37+
impl<'p> Serialize for DictSortedKey {
38+
#[inline(never)]
39+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40+
where
41+
S: Serializer,
42+
{
43+
let len = ffi!(PyDict_Size(self.ptr)) as usize;
44+
if len == 0 {
45+
serializer.serialize_map(Some(0)).unwrap().end()
46+
} else {
47+
let mut items: SmallVec<[(&str, *mut pyo3::ffi::PyObject); 8]> =
48+
SmallVec::with_capacity(len);
49+
let mut pos = 0isize;
50+
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
51+
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
52+
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
53+
while unsafe { pyo3::ffi::PyDict_Next(self.ptr, &mut pos, &mut key, &mut value) != 0 } {
54+
if unlikely!((*key).ob_type != STR_TYPE) {
55+
err!("Dict key must be str")
56+
}
57+
let data = read_utf8_from_str(key, &mut str_size);
58+
if unlikely!(data.is_null()) {
59+
err!(INVALID_STR)
60+
}
61+
items.push((str_from_slice!(data, str_size), value));
62+
}
63+
64+
items.sort_unstable_by(|a, b| a.0.cmp(b.0));
65+
66+
let mut map = serializer.serialize_map(None).unwrap();
67+
for (key, val) in items.iter() {
68+
map.serialize_entry(
69+
key,
70+
&SerializePyObject::new(
71+
*val,
72+
None,
73+
self.opts,
74+
self.default_calls,
75+
self.recursion + 1,
76+
self.default,
77+
),
78+
)?;
79+
}
80+
map.end()
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)