Skip to content

Commit 45f3dc1

Browse files
committed
Omit under attributes, InitVar, ClassVar on dataclasses
1 parent 80299ba commit 45f3dc1

7 files changed

Lines changed: 103 additions & 31 deletions

File tree

ci/azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ jobs:
141141
pool:
142142
vmImage: windows-2019
143143
variables:
144-
interpreter: C:\hostedtoolcache\windows\Python\3.8.2\x64\python.exe
144+
interpreter: C:\hostedtoolcache\windows\Python\3.8.3\x64\python.exe
145145
target: x86_64-pc-windows-msvc
146146
steps:
147147
- task: UsePythonVersion@0

pydataclass

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing import List
1111
import orjson
1212
import rapidjson
1313
import simplejson
14+
import ujson
1415
from tabulate import tabulate
1516

1617
os.sched_setaffinity(os.getpid(), {0, 1})
@@ -63,16 +64,15 @@ table = []
6364
for lib_name in LIBRARIES:
6465
if lib_name == "json":
6566
as_dict = timeit(
66-
lambda: json.dumps(objects_as_dict, default=default).encode("utf-8"),
67-
number=ITERATIONS,
67+
lambda: json.dumps(objects_as_dict).encode("utf-8"), number=ITERATIONS,
6868
)
6969
as_dataclass = timeit(
7070
lambda: json.dumps(objects_as_dataclass, default=default).encode("utf-8"),
7171
number=ITERATIONS,
7272
)
7373
elif lib_name == "simplejson":
7474
as_dict = timeit(
75-
lambda: simplejson.dumps(objects_as_dict, default=default).encode("utf-8"),
75+
lambda: simplejson.dumps(objects_as_dict).encode("utf-8"),
7676
number=ITERATIONS,
7777
)
7878
as_dataclass = timeit(
@@ -82,12 +82,13 @@ for lib_name in LIBRARIES:
8282
number=ITERATIONS,
8383
)
8484
elif lib_name == "ujson":
85-
as_dict = None
85+
as_dict = timeit(
86+
lambda: ujson.dumps(objects_as_dict).encode("utf-8"), number=ITERATIONS,
87+
)
8688
as_dataclass = None
8789
elif lib_name == "rapidjson":
8890
as_dict = timeit(
89-
lambda: rapidjson.dumps(objects_as_dict, default=default).encode("utf-8"),
90-
number=ITERATIONS,
91+
lambda: rapidjson.dumps(objects_as_dict).encode("utf-8"), number=ITERATIONS,
9192
)
9293
as_dataclass = timeit(
9394
lambda: rapidjson.dumps(objects_as_dataclass, default=default).encode(
@@ -112,7 +113,7 @@ for lib_name in LIBRARIES:
112113

113114
if lib_name == "orjson":
114115
compared_to_orjson = 1
115-
elif as_dict:
116+
elif as_dataclass:
116117
compared_to_orjson = int(as_dataclass / orjson_as_dataclass)
117118
else:
118119
compared_to_orjson = None

src/dataclass.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,19 @@ impl<'p> Serialize for DataclassSerializer {
6161
std::ptr::null_mut(),
6262
)
6363
};
64+
if unsafe { ffi!(PyObject_GetAttr(field, FIELD_TYPE_STR)) != FIELD_TYPE.as_ptr() } {
65+
continue;
66+
}
6467
{
6568
let data = read_utf8_from_str(attr, &mut str_size);
6669
if unlikely!(data.is_null()) {
6770
err!(INVALID_STR);
6871
}
69-
map.serialize_key(str_from_slice!(data, str_size)).unwrap();
72+
let key_as_str = str_from_slice!(data, str_size);
73+
if key_as_str.as_bytes()[0] == b'_' {
74+
continue;
75+
}
76+
map.serialize_key(key_as_str).unwrap();
7077
}
7178

7279
let value = ffi!(PyObject_GetAttr(self.ptr, attr));

src/encode.rs

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,7 @@ impl<'p> Serialize for SerializePyObject {
239239
let len = unsafe { PyDict_GET_SIZE(self.ptr) as usize };
240240
if unlikely!(len == 0) {
241241
serializer.serialize_map(Some(0)).unwrap().end()
242-
} else if likely!(
243-
self.opts & SORT_OR_NON_STR_KEYS == 0 || self.opts & DATACLASS_DICT_PATH != 0
244-
) {
245-
let opts = self.opts & !DATACLASS_DICT_PATH;
242+
} else if likely!(self.opts & SORT_OR_NON_STR_KEYS == 0) {
246243
let mut map = serializer.serialize_map(None).unwrap();
247244
let mut pos = 0isize;
248245
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
@@ -271,7 +268,7 @@ impl<'p> Serialize for SerializePyObject {
271268

272269
map.serialize_value(&SerializePyObject::new(
273270
value,
274-
opts,
271+
self.opts,
275272
self.default_calls,
276273
self.recursion + 1,
277274
self.default,
@@ -354,15 +351,46 @@ impl<'p> Serialize for SerializePyObject {
354351
let dict = ffi!(PyObject_GetAttr(self.ptr, DICT_STR));
355352
if !dict.is_null() {
356353
ffi!(Py_DECREF(dict));
357-
SerializePyObject::with_obtype(
358-
dict,
359-
ObType::Dict,
360-
self.opts | DATACLASS_DICT_PATH,
361-
self.default_calls,
362-
self.recursion,
363-
self.default,
364-
)
365-
.serialize(serializer)
354+
let len = unsafe { PyDict_GET_SIZE(dict) as usize };
355+
let mut map = serializer.serialize_map(None).unwrap();
356+
let mut pos = 0isize;
357+
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
358+
let mut key: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
359+
let mut value: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
360+
for _ in 0..=len - 1 {
361+
unsafe {
362+
pyo3::ffi::_PyDict_Next(
363+
dict,
364+
&mut pos,
365+
&mut key,
366+
&mut value,
367+
std::ptr::null_mut(),
368+
)
369+
};
370+
if unlikely!(ob_type!(key) != STR_TYPE) {
371+
err!(KEY_MUST_BE_STR)
372+
}
373+
{
374+
let data = read_utf8_from_str(key, &mut str_size);
375+
if unlikely!(data.is_null()) {
376+
err!(INVALID_STR)
377+
}
378+
let key_as_str = str_from_slice!(data, str_size);
379+
if unlikely!(key_as_str.as_bytes()[0] == b'_') {
380+
continue;
381+
}
382+
map.serialize_key(key_as_str).unwrap();
383+
}
384+
385+
map.serialize_value(&SerializePyObject::new(
386+
value,
387+
self.opts,
388+
self.default_calls,
389+
self.recursion + 1,
390+
self.default,
391+
))?;
392+
}
393+
map.end()
366394
} else {
367395
unsafe { pyo3::ffi::PyErr_Clear() };
368396
DataclassSerializer::new(

src/opt.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
22

3-
pub type Opt = u16;
3+
pub type Opt = u8;
44

55
pub const INDENT_2: Opt = 1;
66
pub const NAIVE_UTC: Opt = 1 << 1;
@@ -15,7 +15,6 @@ pub const UTC_Z: Opt = 1 << 7;
1515
pub const SERIALIZE_DATACLASS: Opt = 0;
1616
pub const SERIALIZE_UUID: Opt = 0;
1717

18-
pub const DATACLASS_DICT_PATH: Opt = 1 << 8;
1918
pub const SORT_OR_NON_STR_KEYS: Opt = SORT_KEYS | NON_STR_KEYS;
2019

2120
pub const MAX_OPT: i32 = (INDENT_2

src/typeref.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub static mut UUID_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
2727
pub static mut ENUM_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
2828
pub static mut ARRAY_TYPE: Lazy<Option<NonNull<PyTypeObject>>> =
2929
Lazy::new(|| unsafe { look_up_array_type() });
30+
pub static mut FIELD_TYPE: Lazy<NonNull<PyObject>> = Lazy::new(|| unsafe { look_up_field_type() });
3031

3132
pub static mut BYTES_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
3233
pub static mut BYTEARRAY_TYPE: *mut PyTypeObject = 0 as *mut PyTypeObject;
@@ -38,6 +39,7 @@ pub static mut CONVERT_METHOD_STR: *mut PyObject = 0 as *mut PyObject;
3839
pub static mut DST_STR: *mut PyObject = 0 as *mut PyObject;
3940
pub static mut DICT_STR: *mut PyObject = 0 as *mut PyObject;
4041
pub static mut DATACLASS_FIELDS_STR: *mut PyObject = 0 as *mut PyObject;
42+
pub static mut FIELD_TYPE_STR: *mut PyObject = 0 as *mut PyObject;
4143
pub static mut ARRAY_STRUCT_STR: *mut PyObject = 0 as *mut PyObject;
4244
pub static mut VALUE_STR: *mut PyObject = 0 as *mut PyObject;
4345
pub static mut STR_HASH_FUNCTION: Option<hashfunc> = None;
@@ -85,6 +87,7 @@ pub fn init_typerefs() {
8587
DICT_STR = PyUnicode_InternFromString("__dict__\0".as_ptr() as *const c_char);
8688
DATACLASS_FIELDS_STR =
8789
PyUnicode_InternFromString("__dataclass_fields__\0".as_ptr() as *const c_char);
90+
FIELD_TYPE_STR = PyUnicode_InternFromString("_field_type\0".as_ptr() as *const c_char);
8891
ARRAY_STRUCT_STR =
8992
pyo3::ffi::PyUnicode_InternFromString("__array_struct__\0".as_ptr() as *const c_char);
9093
VALUE_STR = pyo3::ffi::PyUnicode_InternFromString("value\0".as_ptr() as *const c_char);
@@ -127,6 +130,16 @@ unsafe fn look_up_array_type() -> Option<NonNull<PyTypeObject>> {
127130
}
128131
}
129132

133+
unsafe fn look_up_field_type() -> NonNull<PyObject> {
134+
let module = PyImport_ImportModule("dataclasses\0".as_ptr() as *const c_char);
135+
let module_dict = PyModule_GetDict(module);
136+
let ptr = PyMapping_GetItemString(module_dict, "_FIELD\0".as_ptr() as *const c_char)
137+
as *mut PyTypeObject;
138+
Py_DECREF(module_dict);
139+
Py_DECREF(module);
140+
NonNull::new_unchecked(ptr as *mut PyObject)
141+
}
142+
130143
unsafe fn look_up_enum_type() -> *mut PyTypeObject {
131144
let module = PyImport_ImportModule("enum\0".as_ptr() as *const c_char);
132145
let module_dict = PyModule_GetDict(module);

test/test_dataclass.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import unittest
44
import uuid
5-
from dataclasses import dataclass, field
5+
from dataclasses import InitVar, dataclass, field
66
from enum import Enum
7-
from typing import Dict, Optional
7+
from typing import ClassVar, Dict, Optional
88

99
import orjson
1010

@@ -51,9 +51,12 @@ class Datasubclass(Dataclass1):
5151

5252
@dataclass
5353
class Slotsdataclass:
54-
__slots__ = ("a", "b")
54+
__slots__ = ("a", "b", "_c", "d")
5555
a: str
5656
b: int
57+
_c: str
58+
d: InitVar[str]
59+
cls_var: ClassVar[str] = "cls"
5760

5861

5962
@dataclass
@@ -70,6 +73,18 @@ class UnsortedDataclass:
7073
d: Optional[Dict]
7174

7275

76+
@dataclass
77+
class InitDataclass:
78+
a: InitVar[str]
79+
b: InitVar[str]
80+
cls_var: ClassVar[str] = "cls"
81+
ab: str = ""
82+
83+
def __post_init__(self, a: str, b: str):
84+
self._other = 1
85+
self.ab = f"{a} {b}"
86+
87+
7388
class DataclassTests(unittest.TestCase):
7489
def test_dataclass(self):
7590
"""
@@ -146,9 +161,9 @@ def test_dataclass_subclass(self):
146161

147162
def test_dataclass_slots(self):
148163
"""
149-
dumps() dataclass with __slots__
164+
dumps() dataclass with __slots__ does not include under attributes, InitVar, or ClassVar
150165
"""
151-
obj = Slotsdataclass("a", 1)
166+
obj = Slotsdataclass("a", 1, "c", "d")
152167
assert "__dict__" not in dir(obj)
153168
self.assertEqual(orjson.dumps(obj), b'{"a":"a","b":1}')
154169

@@ -191,9 +206,18 @@ def test_dataclass_sort_sub(self):
191206
b'{"c":1,"b":2,"a":3,"d":{"e":1,"f":2}}',
192207
)
193208

209+
def test_dataclass_under(self):
210+
"""
211+
dumps() does not include under attributes, InitVar, or ClassVar
212+
"""
213+
obj = InitDataclass("zxc", "vbn")
214+
self.assertEqual(
215+
orjson.dumps(obj), b'{"ab":"zxc vbn"}',
216+
)
217+
194218
def test_dataclass_option(self):
195219
"""
196-
dumps() accepts deprecated OPT_SERIALIZE_DATACALSS
220+
dumps() accepts deprecated OPT_SERIALIZE_DATACLASS
197221
"""
198222
obj = Dataclass1("a", 1, None)
199223
self.assertEqual(

0 commit comments

Comments
 (0)