Skip to content

Commit a5fb4a7

Browse files
marcoffeeijl
authored andcommitted
Fixed bug on empty multidimensional numpy array
Extended empty array fix for any zero dimension Added numpy tests for any dimension equal to 0 Joined all numpy tests related to dimensions equal to 0 Added new test for multidimensional array with last dimension equal to 0
1 parent ac88b2b commit a5fb4a7

2 files changed

Lines changed: 79 additions & 42 deletions

File tree

src/serialize/numpy.rs

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,54 +242,58 @@ impl<'p> Serialize for NumpyArray {
242242
S: Serializer,
243243
{
244244
let mut seq = serializer.serialize_seq(None).unwrap();
245-
if !self.children.is_empty() {
246-
for child in &self.children {
247-
seq.serialize_element(child).unwrap();
248-
}
249-
} else {
250-
let data_ptr = self.data();
251-
let num_items = self.num_items();
252-
match self.kind().unwrap() {
253-
ItemType::F64 => {
254-
let slice: &[f64] = slice!(data_ptr as *const f64, num_items);
255-
for &each in slice.iter() {
256-
seq.serialize_element(&DataTypeF64 { obj: each }).unwrap();
257-
}
245+
246+
if self.depth >= self.shape().len() || self.shape()[self.depth] != 0 {
247+
if !self.children.is_empty() {
248+
for child in &self.children {
249+
seq.serialize_element(child).unwrap();
258250
}
259-
ItemType::F32 => {
260-
let slice: &[f32] = slice!(data_ptr as *const f32, num_items);
261-
for &each in slice.iter() {
262-
seq.serialize_element(&DataTypeF32 { obj: each }).unwrap();
251+
252+
} else {
253+
let data_ptr = self.data();
254+
let num_items = self.num_items();
255+
match self.kind().unwrap() {
256+
ItemType::F64 => {
257+
let slice: &[f64] = slice!(data_ptr as *const f64, num_items);
258+
for &each in slice.iter() {
259+
seq.serialize_element(&DataTypeF64 { obj: each }).unwrap();
260+
}
263261
}
264-
}
265-
ItemType::I64 => {
266-
let slice: &[i64] = slice!(data_ptr as *const i64, num_items);
267-
for &each in slice.iter() {
268-
seq.serialize_element(&DataTypeI64 { obj: each }).unwrap();
262+
ItemType::F32 => {
263+
let slice: &[f32] = slice!(data_ptr as *const f32, num_items);
264+
for &each in slice.iter() {
265+
seq.serialize_element(&DataTypeF32 { obj: each }).unwrap();
266+
}
269267
}
270-
}
271-
ItemType::I32 => {
272-
let slice: &[i32] = slice!(data_ptr as *const i32, num_items);
273-
for &each in slice.iter() {
274-
seq.serialize_element(&DataTypeI32 { obj: each }).unwrap();
268+
ItemType::I64 => {
269+
let slice: &[i64] = slice!(data_ptr as *const i64, num_items);
270+
for &each in slice.iter() {
271+
seq.serialize_element(&DataTypeI64 { obj: each }).unwrap();
272+
}
275273
}
276-
}
277-
ItemType::U64 => {
278-
let slice: &[u64] = slice!(data_ptr as *const u64, num_items);
279-
for &each in slice.iter() {
280-
seq.serialize_element(&DataTypeU64 { obj: each }).unwrap();
274+
ItemType::I32 => {
275+
let slice: &[i32] = slice!(data_ptr as *const i32, num_items);
276+
for &each in slice.iter() {
277+
seq.serialize_element(&DataTypeI32 { obj: each }).unwrap();
278+
}
281279
}
282-
}
283-
ItemType::U32 => {
284-
let slice: &[u32] = slice!(data_ptr as *const u32, num_items);
285-
for &each in slice.iter() {
286-
seq.serialize_element(&DataTypeU32 { obj: each }).unwrap();
280+
ItemType::U64 => {
281+
let slice: &[u64] = slice!(data_ptr as *const u64, num_items);
282+
for &each in slice.iter() {
283+
seq.serialize_element(&DataTypeU64 { obj: each }).unwrap();
284+
}
287285
}
288-
}
289-
ItemType::BOOL => {
290-
let slice: &[u8] = slice!(data_ptr as *const u8, num_items);
291-
for &each in slice.iter() {
292-
seq.serialize_element(&DataTypeBOOL { obj: each }).unwrap();
286+
ItemType::U32 => {
287+
let slice: &[u32] = slice!(data_ptr as *const u32, num_items);
288+
for &each in slice.iter() {
289+
seq.serialize_element(&DataTypeU32 { obj: each }).unwrap();
290+
}
291+
}
292+
ItemType::BOOL => {
293+
let slice: &[u8] = slice!(data_ptr as *const u8, num_items);
294+
for &each in slice.iter() {
295+
seq.serialize_element(&DataTypeBOOL { obj: each }).unwrap();
296+
}
293297
}
294298
}
295299
}

test/test_numpy.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,39 @@ def test_numpy_array_dimension_zero(self):
235235
with self.assertRaises(orjson.JSONEncodeError):
236236
orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY)
237237

238+
array = numpy.empty((0, 4, 2))
239+
self.assertEqual(
240+
orjson.loads(
241+
orjson.dumps(
242+
array,
243+
option=orjson.OPT_SERIALIZE_NUMPY,
244+
)
245+
),
246+
array.tolist(),
247+
)
248+
249+
array = numpy.empty((4, 0, 2))
250+
self.assertEqual(
251+
orjson.loads(
252+
orjson.dumps(
253+
array,
254+
option=orjson.OPT_SERIALIZE_NUMPY,
255+
)
256+
),
257+
array.tolist(),
258+
)
259+
260+
array = numpy.empty((2, 4, 0))
261+
self.assertEqual(
262+
orjson.loads(
263+
orjson.dumps(
264+
array,
265+
option=orjson.OPT_SERIALIZE_NUMPY,
266+
)
267+
),
268+
array.tolist(),
269+
)
270+
238271
def test_numpy_array_dimension_max(self):
239272
array = numpy.random.rand(
240273
1,

0 commit comments

Comments
 (0)