Skip to content

Commit b7d3e3b

Browse files
committed
Handle empty dict in encode
1 parent cab086e commit b7d3e3b

2 files changed

Lines changed: 162 additions & 144 deletions

File tree

src/dict.rs

Lines changed: 151 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,36 @@ use crate::exc::*;
77
use crate::typeref::*;
88
use crate::unicode::*;
99
use crate::uuid::*;
10+
use core::ffi::c_void;
1011
use inlinable_string::InlinableString;
12+
use pyo3::ffi::*;
1113
use serde::ser::{Serialize, SerializeMap, Serializer};
1214
use smallvec::SmallVec;
1315
use std::ptr::NonNull;
1416

17+
#[repr(C)]
18+
pub struct PyDictObject {
19+
pub ob_refcnt: Py_ssize_t,
20+
pub ob_type: *mut PyTypeObject,
21+
pub ma_used: Py_ssize_t,
22+
pub ma_version_tag: u64,
23+
pub ma_keys: c_void,
24+
pub ma_values: c_void,
25+
}
26+
27+
#[allow(non_snake_case)]
28+
#[inline(always)]
29+
pub unsafe fn PyDict_GET_SIZE(op: *mut PyObject) -> Py_ssize_t {
30+
(*op.cast::<PyDictObject>()).ma_used
31+
}
32+
1533
pub struct DictSortedKey {
1634
ptr: *mut pyo3::ffi::PyObject,
1735
opts: u16,
1836
default_calls: u8,
1937
recursion: u8,
2038
default: Option<NonNull<pyo3::ffi::PyObject>>,
39+
len: usize,
2140
}
2241

2342
impl DictSortedKey {
@@ -27,13 +46,15 @@ impl DictSortedKey {
2746
default_calls: u8,
2847
recursion: u8,
2948
default: Option<NonNull<pyo3::ffi::PyObject>>,
49+
len: usize,
3050
) -> Self {
3151
DictSortedKey {
3252
ptr: ptr,
3353
opts: opts,
3454
default_calls: default_calls,
3555
recursion: recursion,
3656
default: default,
57+
len: len,
3758
}
3859
}
3960
}
@@ -44,45 +65,41 @@ impl<'p> Serialize for DictSortedKey {
4465
where
4566
S: Serializer,
4667
{
47-
let len = ffi!(PyDict_Size(self.ptr)) as usize;
48-
if len == 0 {
49-
serializer.serialize_map(Some(0)).unwrap().end()
50-
} else {
51-
let mut items: SmallVec<[(&str, *mut pyo3::ffi::PyObject); 8]> =
52-
SmallVec::with_capacity(len);
53-
let mut pos = 0isize;
54-
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
55-
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
56-
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
57-
while unsafe { pyo3::ffi::PyDict_Next(self.ptr, &mut pos, &mut key, &mut value) != 0 } {
58-
if unlikely!((*key).ob_type != STR_TYPE) {
59-
err!("Dict key must be str")
60-
}
61-
let data = read_utf8_from_str(key, &mut str_size);
62-
if unlikely!(data.is_null()) {
63-
err!(INVALID_STR)
64-
}
65-
items.push((str_from_slice!(data, str_size), value));
68+
let mut items: SmallVec<[(&str, *mut pyo3::ffi::PyObject); 8]> =
69+
SmallVec::with_capacity(self.len);
70+
let mut pos = 0isize;
71+
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
72+
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
73+
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
74+
for _ in 0..=self.len - 1 {
75+
unsafe { pyo3::ffi::PyDict_Next(self.ptr, &mut pos, &mut key, &mut value) };
76+
if unlikely!((*key).ob_type != STR_TYPE) {
77+
err!("Dict key must be str")
78+
}
79+
let data = read_utf8_from_str(key, &mut str_size);
80+
if unlikely!(data.is_null()) {
81+
err!(INVALID_STR)
6682
}
83+
items.push((str_from_slice!(data, str_size), value));
84+
}
6785

68-
items.sort_unstable_by(|a, b| a.0.cmp(b.0));
86+
items.sort_unstable_by(|a, b| a.0.cmp(b.0));
6987

70-
let mut map = serializer.serialize_map(None).unwrap();
71-
for (key, val) in items.iter() {
72-
map.serialize_entry(
73-
key,
74-
&SerializePyObject::new(
75-
*val,
76-
None,
77-
self.opts,
78-
self.default_calls,
79-
self.recursion + 1,
80-
self.default,
81-
),
82-
)?;
83-
}
84-
map.end()
88+
let mut map = serializer.serialize_map(None).unwrap();
89+
for (key, val) in items.iter() {
90+
map.serialize_entry(
91+
key,
92+
&SerializePyObject::new(
93+
*val,
94+
None,
95+
self.opts,
96+
self.default_calls,
97+
self.recursion + 1,
98+
self.default,
99+
),
100+
)?;
85101
}
102+
map.end()
86103
}
87104
}
88105

@@ -92,6 +109,7 @@ pub struct NonStrKey {
92109
default_calls: u8,
93110
recursion: u8,
94111
default: Option<NonNull<pyo3::ffi::PyObject>>,
112+
len: usize,
95113
}
96114

97115
impl NonStrKey {
@@ -101,13 +119,15 @@ impl NonStrKey {
101119
default_calls: u8,
102120
recursion: u8,
103121
default: Option<NonNull<pyo3::ffi::PyObject>>,
122+
len: usize,
104123
) -> Self {
105124
NonStrKey {
106125
ptr: ptr,
107126
opts: opts,
108127
default_calls: default_calls,
109128
recursion: recursion,
110129
default: default,
130+
len: len,
111131
}
112132
}
113133
}
@@ -118,125 +138,119 @@ impl<'p> Serialize for NonStrKey {
118138
where
119139
S: Serializer,
120140
{
121-
let len = ffi!(PyDict_Size(self.ptr)) as usize;
122-
if len == 0 {
123-
serializer.serialize_map(Some(0)).unwrap().end()
124-
} else {
125-
let mut items: SmallVec<[(InlinableString, *mut pyo3::ffi::PyObject); 8]> =
126-
SmallVec::with_capacity(len);
127-
let mut pos = 0isize;
128-
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
129-
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
130-
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
131-
while unsafe { pyo3::ffi::PyDict_Next(self.ptr, &mut pos, &mut key, &mut value) != 0 } {
132-
if unsafe { (*key).ob_type == STR_TYPE } {
133-
let data = read_utf8_from_str(key, &mut str_size);
134-
if unlikely!(data.is_null()) {
135-
err!(INVALID_STR)
141+
let mut items: SmallVec<[(InlinableString, *mut pyo3::ffi::PyObject); 8]> =
142+
SmallVec::with_capacity(self.len);
143+
let mut pos = 0isize;
144+
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
145+
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
146+
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
147+
for _ in 0..=self.len - 1 {
148+
unsafe { pyo3::ffi::PyDict_Next(self.ptr, &mut pos, &mut key, &mut value) };
149+
if unsafe { (*key).ob_type == STR_TYPE } {
150+
let data = read_utf8_from_str(key, &mut str_size);
151+
if unlikely!(data.is_null()) {
152+
err!(INVALID_STR)
153+
}
154+
items.push((
155+
InlinableString::from(str_from_slice!(data, str_size)),
156+
value,
157+
));
158+
} else {
159+
match pyobject_to_obtype(key, self.opts | SERIALIZE_UUID) {
160+
ObType::NONE => {
161+
items.push((InlinableString::from("null"), value));
136162
}
137-
items.push((
138-
InlinableString::from(str_from_slice!(data, str_size)),
139-
value,
140-
));
141-
} else if self.opts & NON_STR_KEYS != NON_STR_KEYS {
142-
err!(KEY_MUST_BE_STR)
143-
} else {
144-
match pyobject_to_obtype(key, self.opts | SERIALIZE_UUID) {
145-
ObType::NONE => {
146-
items.push((InlinableString::from("null"), value));
163+
ObType::BOOL => {
164+
let key_as_str: &str;
165+
if unsafe { key == TRUE } {
166+
key_as_str = "true";
167+
} else {
168+
key_as_str = "false";
147169
}
148-
ObType::BOOL => {
149-
let key_as_str: &str;
150-
if unsafe { key == TRUE } {
151-
key_as_str = "true";
152-
} else {
153-
key_as_str = "false";
154-
}
155-
items.push((InlinableString::from(key_as_str), value));
170+
items.push((InlinableString::from(key_as_str), value));
171+
}
172+
ObType::INT => {
173+
let val = ffi!(PyLong_AsLongLong(key));
174+
if unlikely!(val == -1 && !pyo3::ffi::PyErr_Occurred().is_null()) {
175+
err!("Dict integer key must be within 64-bit range")
156176
}
157-
ObType::INT => {
158-
let val = ffi!(PyLong_AsLongLong(key));
159-
if unlikely!(val == -1 && !pyo3::ffi::PyErr_Occurred().is_null()) {
160-
err!("Dict integer key must be within 64-bit range")
161-
}
177+
items.push((
178+
InlinableString::from(itoa::Buffer::new().format(val)),
179+
value,
180+
));
181+
}
182+
ObType::FLOAT => {
183+
let val = ffi!(PyFloat_AS_DOUBLE(key));
184+
if !val.is_finite() {
185+
items.push((InlinableString::from("null"), value));
186+
} else {
162187
items.push((
163-
InlinableString::from(itoa::Buffer::new().format(val)),
188+
InlinableString::from(ryu::Buffer::new().format_finite(val)),
164189
value,
165190
));
166191
}
167-
ObType::FLOAT => {
168-
let val = ffi!(PyFloat_AS_DOUBLE(key));
169-
if !val.is_finite() {
170-
items.push((InlinableString::from("null"), value));
171-
} else {
172-
items.push((
173-
InlinableString::from(ryu::Buffer::new().format_finite(val)),
174-
value,
175-
));
176-
}
177-
}
178-
ObType::DATETIME => {
179-
let mut buf: DateTimeBuffer = heapless::Vec::new();
180-
let dt = DateTime::new(key, self.opts);
181-
if dt.write_buf(&mut buf).is_err() {
182-
err!(DATETIME_LIBRARY_UNSUPPORTED)
183-
}
184-
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
185-
items.push((InlinableString::from(key_as_str), value));
192+
}
193+
ObType::DATETIME => {
194+
let mut buf: DateTimeBuffer = heapless::Vec::new();
195+
let dt = DateTime::new(key, self.opts);
196+
if dt.write_buf(&mut buf).is_err() {
197+
err!(DATETIME_LIBRARY_UNSUPPORTED)
186198
}
187-
ObType::DATE => {
199+
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
200+
items.push((InlinableString::from(key_as_str), value));
201+
}
202+
ObType::DATE => {
203+
let mut buf: DateTimeBuffer = heapless::Vec::new();
204+
Date::new(key).write_buf(&mut buf);
205+
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
206+
items.push((InlinableString::from(key_as_str), value));
207+
}
208+
ObType::TIME => match Time::new(key, self.opts) {
209+
Ok(val) => {
188210
let mut buf: DateTimeBuffer = heapless::Vec::new();
189-
Date::new(key).write_buf(&mut buf);
211+
val.write_buf(&mut buf);
190212
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
191213
items.push((InlinableString::from(key_as_str), value));
192214
}
193-
ObType::TIME => match Time::new(key, self.opts) {
194-
Ok(val) => {
195-
let mut buf: DateTimeBuffer = heapless::Vec::new();
196-
val.write_buf(&mut buf);
197-
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
198-
items.push((InlinableString::from(key_as_str), value));
199-
}
200-
Err(TimeError::HasTimezone) => err!(TIME_HAS_TZINFO),
201-
},
202-
ObType::UUID => {
203-
let mut buf: UUIDBuffer = heapless::Vec::new();
204-
UUID::new(key).write_buf(&mut buf);
205-
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
206-
items.push((InlinableString::from(key_as_str), value));
207-
}
208-
ObType::TUPLE
209-
| ObType::ARRAY
210-
| ObType::DICT
211-
| ObType::LIST
212-
| ObType::DATACLASS
213-
| ObType::UNKNOWN => {
214-
err!("Dict key must a type serializable with NON_STR_KEYS")
215-
}
216-
ObType::STR => unsafe { std::hint::unreachable_unchecked() },
215+
Err(TimeError::HasTimezone) => err!(TIME_HAS_TZINFO),
216+
},
217+
ObType::UUID => {
218+
let mut buf: UUIDBuffer = heapless::Vec::new();
219+
UUID::new(key).write_buf(&mut buf);
220+
let key_as_str = str_from_slice!(buf.as_ptr(), buf.len());
221+
items.push((InlinableString::from(key_as_str), value));
222+
}
223+
ObType::TUPLE
224+
| ObType::ARRAY
225+
| ObType::DICT
226+
| ObType::LIST
227+
| ObType::DATACLASS
228+
| ObType::UNKNOWN => {
229+
err!("Dict key must a type serializable with NON_STR_KEYS")
217230
}
231+
ObType::STR => unsafe { std::hint::unreachable_unchecked() },
218232
}
219233
}
234+
}
220235

221-
if self.opts & SORT_KEYS == SORT_KEYS {
222-
items.sort_unstable_by(|a, b| a.0.cmp(&b.0));
223-
}
236+
if self.opts & SORT_KEYS == SORT_KEYS {
237+
items.sort_unstable_by(|a, b| a.0.cmp(&b.0));
238+
}
224239

225-
let mut map = serializer.serialize_map(None).unwrap();
226-
for (key, val) in items.iter() {
227-
map.serialize_entry(
228-
str_from_slice!(key.as_ptr(), key.len()),
229-
&SerializePyObject::new(
230-
*val,
231-
None,
232-
self.opts,
233-
self.default_calls,
234-
self.recursion + 1,
235-
self.default,
236-
),
237-
)?;
238-
}
239-
map.end()
240+
let mut map = serializer.serialize_map(None).unwrap();
241+
for (key, val) in items.iter() {
242+
map.serialize_entry(
243+
str_from_slice!(key.as_ptr(), key.len()),
244+
&SerializePyObject::new(
245+
*val,
246+
None,
247+
self.opts,
248+
self.default_calls,
249+
self.recursion + 1,
250+
self.default,
251+
),
252+
)?;
240253
}
254+
map.end()
241255
}
242256
}

0 commit comments

Comments
 (0)