Skip to content

Commit f11de5f

Browse files
Aviram Hassanijl
authored andcommitted
add numpy primitives encoding
1 parent 8ce5de8 commit f11de5f

5 files changed

Lines changed: 308 additions & 37 deletions

File tree

src/serialize/dict.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ impl NonStrKey {
215215
}
216216
}
217217
ObType::Tuple
218-
| ObType::Array
218+
| ObType::NumpyScalar
219+
| ObType::NumpyArray
219220
| ObType::Dict
220221
| ObType::List
221222
| ObType::Dataclass

src/serialize/encode.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
22

33
use crate::exc::*;
4-
use crate::ffi::PyDict_GET_SIZE;
54
use crate::ffi::*;
65
use crate::opt::*;
76
use crate::serialize::dataclass::*;
@@ -33,7 +32,7 @@ pub fn serialize(
3332
let mut buf = BytesWriter::new();
3433
let obtype = pyobject_to_obtype(ptr, opts);
3534
match obtype {
36-
ObType::List | ObType::Dict | ObType::Dataclass | ObType::Array => {
35+
ObType::List | ObType::Dict | ObType::Dataclass | ObType::NumpyArray => {
3736
buf.resize(1024);
3837
}
3938
_ => {}
@@ -75,7 +74,8 @@ pub enum ObType {
7574
Tuple,
7675
Uuid,
7776
Dataclass,
78-
Array,
77+
NumpyScalar,
78+
NumpyArray,
7979
Enum,
8080
StrSubclass,
8181
Unknown,
@@ -145,11 +145,10 @@ pub fn pyobject_to_obtype_unlikely(obj: *mut pyo3::ffi::PyObject, opts: Opt) ->
145145
ObType::Dict
146146
} else if ffi!(PyDict_Contains((*ob_type).tp_dict, DATACLASS_FIELDS_STR)) == 1 {
147147
ObType::Dataclass
148-
} else if opts & SERIALIZE_NUMPY != 0
149-
&& ARRAY_TYPE.is_some()
150-
&& ob_type == ARRAY_TYPE.unwrap().as_ptr()
151-
{
152-
ObType::Array
148+
} else if opts & SERIALIZE_NUMPY != 0 && is_numpy_scalar(ob_type) {
149+
ObType::NumpyScalar
150+
} else if opts & SERIALIZE_NUMPY != 0 && is_numpy_array(ob_type) {
151+
ObType::NumpyArray
153152
} else {
154153
ObType::Unknown
155154
}
@@ -433,7 +432,7 @@ impl<'p> Serialize for SerializePyObject {
433432
)
434433
.serialize(serializer)
435434
}
436-
ObType::Array => match PyArray::new(self.ptr) {
435+
ObType::NumpyArray => match PyArray::new(self.ptr) {
437436
Ok(val) => val.serialize(serializer),
438437
Err(PyArrayError::Malformed) => err!("numpy array is malformed"),
439438
Err(PyArrayError::NotContiguous) | Err(PyArrayError::UnsupportedDataType) => {
@@ -447,6 +446,18 @@ impl<'p> Serialize for SerializePyObject {
447446
.serialize(serializer)
448447
}
449448
},
449+
ObType::NumpyScalar => match pyobj_to_numpy_obj(self.ptr) {
450+
Ok(numpy_obj) => match numpy_obj {
451+
NumpyObjects::Float32(obj) => obj.serialize(serializer),
452+
NumpyObjects::Float64(obj) => obj.serialize(serializer),
453+
NumpyObjects::Int32(obj) => obj.serialize(serializer),
454+
NumpyObjects::Int64(obj) => obj.serialize(serializer),
455+
NumpyObjects::Uint32(obj) => obj.serialize(serializer),
456+
NumpyObjects::Uint64(obj) => obj.serialize(serializer),
457+
},
458+
Err(NumpyError::InvalidType) => err!("invalid numpy type"),
459+
Err(NumpyError::NotAvailable) => err!("numpy not available"),
460+
},
450461
ObType::Unknown => DefaultSerializer::new(
451462
self.ptr,
452463
self.opts,

src/serialize/numpy.rs

Lines changed: 195 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
2-
3-
use crate::typeref::ARRAY_STRUCT_STR;
1+
use crate::typeref::{ARRAY_STRUCT_STR, NUMPY_TYPES};
42
use pyo3::ffi::*;
53
use serde::ser::{Serialize, SerializeSeq, Serializer};
4+
use std::ops::DerefMut;
65
use std::os::raw::{c_char, c_int, c_void};
76

87
macro_rules! slice {
@@ -21,8 +20,6 @@ pub struct PyCapsule {
2120
pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *);
2221
}
2322

24-
// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__
25-
2623
#[repr(C)]
2724
pub struct PyArrayInterface {
2825
pub two: c_int,
@@ -53,6 +50,11 @@ pub enum PyArrayError {
5350
UnsupportedDataType,
5451
}
5552

53+
pub enum NumpyError {
54+
NotAvailable,
55+
InvalidType,
56+
}
57+
5658
// >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32)
5759
// >>> arr.ndim
5860
// 3
@@ -256,7 +258,7 @@ impl<'p> Serialize for DataTypeF32 {
256258
}
257259

258260
#[repr(transparent)]
259-
struct DataTypeF64 {
261+
pub struct DataTypeF64 {
260262
pub obj: f64,
261263
}
262264

@@ -270,7 +272,7 @@ impl<'p> Serialize for DataTypeF64 {
270272
}
271273

272274
#[repr(transparent)]
273-
struct DataTypeI32 {
275+
pub struct DataTypeI32 {
274276
pub obj: i32,
275277
}
276278

@@ -284,7 +286,7 @@ impl<'p> Serialize for DataTypeI32 {
284286
}
285287

286288
#[repr(transparent)]
287-
struct DataTypeI64 {
289+
pub struct DataTypeI64 {
288290
pub obj: i64,
289291
}
290292

@@ -298,7 +300,7 @@ impl<'p> Serialize for DataTypeI64 {
298300
}
299301

300302
#[repr(transparent)]
301-
struct DataTypeU32 {
303+
pub struct DataTypeU32 {
302304
pub obj: u32,
303305
}
304306

@@ -312,7 +314,7 @@ impl<'p> Serialize for DataTypeU32 {
312314
}
313315

314316
#[repr(transparent)]
315-
struct DataTypeU64 {
317+
pub struct DataTypeU64 {
316318
pub obj: u64,
317319
}
318320

@@ -326,7 +328,7 @@ impl<'p> Serialize for DataTypeU64 {
326328
}
327329

328330
#[repr(transparent)]
329-
struct DataTypeBOOL {
331+
pub struct DataTypeBOOL {
330332
pub obj: u8,
331333
}
332334

@@ -338,3 +340,185 @@ impl<'p> Serialize for DataTypeBOOL {
338340
serializer.serialize_bool(self.obj == 1)
339341
}
340342
}
343+
344+
#[repr(C)]
345+
#[derive(Copy, Clone)]
346+
pub struct NumpyInt32 {
347+
pub ob_refcnt: Py_ssize_t,
348+
pub ob_type: *mut PyTypeObject,
349+
pub value: i32,
350+
}
351+
352+
impl<'p> Serialize for NumpyInt32 {
353+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
354+
where
355+
S: Serializer,
356+
{
357+
serializer.serialize_i32(self.value)
358+
}
359+
}
360+
361+
#[repr(C)]
362+
#[derive(Copy, Clone)]
363+
pub struct NumpyInt64 {
364+
pub ob_refcnt: Py_ssize_t,
365+
pub ob_type: *mut PyTypeObject,
366+
pub value: i64,
367+
}
368+
369+
impl<'p> Serialize for NumpyInt64 {
370+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
371+
where
372+
S: Serializer,
373+
{
374+
serializer.serialize_i64(self.value)
375+
}
376+
}
377+
378+
#[repr(C)]
379+
#[derive(Copy, Clone)]
380+
pub struct NumpyUint32 {
381+
pub ob_refcnt: Py_ssize_t,
382+
pub ob_type: *mut PyTypeObject,
383+
pub value: u32,
384+
}
385+
386+
impl<'p> Serialize for NumpyUint32 {
387+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
388+
where
389+
S: Serializer,
390+
{
391+
serializer.serialize_u32(self.value)
392+
}
393+
}
394+
395+
#[repr(C)]
396+
#[derive(Copy, Clone)]
397+
pub struct NumpyUint64 {
398+
pub ob_refcnt: Py_ssize_t,
399+
pub ob_type: *mut PyTypeObject,
400+
pub value: u64,
401+
}
402+
403+
impl<'p> Serialize for NumpyUint64 {
404+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
405+
where
406+
S: Serializer,
407+
{
408+
serializer.serialize_u64(self.value)
409+
}
410+
}
411+
412+
#[repr(C)]
413+
#[derive(Copy, Clone)]
414+
pub struct NumpyFloat32 {
415+
pub ob_refcnt: Py_ssize_t,
416+
pub ob_type: *mut PyTypeObject,
417+
pub value: f32,
418+
}
419+
420+
impl<'p> Serialize for NumpyFloat32 {
421+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
422+
where
423+
S: Serializer,
424+
{
425+
serializer.serialize_f32(self.value)
426+
}
427+
}
428+
429+
#[repr(C)]
430+
#[derive(Copy, Clone)]
431+
pub struct NumpyFloat64 {
432+
pub ob_refcnt: Py_ssize_t,
433+
pub ob_type: *mut PyTypeObject,
434+
pub value: f64,
435+
}
436+
437+
impl<'p> Serialize for NumpyFloat64 {
438+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
439+
where
440+
S: Serializer,
441+
{
442+
serializer.serialize_f64(self.value)
443+
}
444+
}
445+
446+
pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
447+
let available_types;
448+
unsafe {
449+
match NUMPY_TYPES.deref_mut() {
450+
Some(v) => available_types = v,
451+
_ => return false,
452+
}
453+
}
454+
455+
let numpy_scalars = [
456+
available_types.float32,
457+
available_types.float64,
458+
available_types.int32,
459+
available_types.int64,
460+
available_types.uint32,
461+
available_types.uint64,
462+
];
463+
numpy_scalars.contains(&ob_type)
464+
}
465+
466+
pub fn is_numpy_array(ob_type: *mut PyTypeObject) -> bool {
467+
let available_types;
468+
unsafe {
469+
match NUMPY_TYPES.deref_mut() {
470+
Some(v) => available_types = v,
471+
_ => return false,
472+
}
473+
}
474+
available_types.array == ob_type
475+
}
476+
477+
// pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error>
478+
// where
479+
// S: Serializer,
480+
// {
481+
// let ob_type = ob_type!(obj);
482+
// let numpy = match ob_type {
483+
484+
// }
485+
// }
486+
487+
pub enum NumpyObjects {
488+
Float32(NumpyFloat32),
489+
Float64(NumpyFloat64),
490+
Int32(NumpyInt32),
491+
Int64(NumpyInt64),
492+
Uint32(NumpyUint32),
493+
Uint64(NumpyUint64),
494+
}
495+
496+
pub fn pyobj_to_numpy_obj(obj: *mut pyo3::ffi::PyObject) -> Result<NumpyObjects, NumpyError> {
497+
let available_types;
498+
unsafe {
499+
match NUMPY_TYPES.deref_mut() {
500+
Some(v) => available_types = v,
501+
_ => return Err(NumpyError::NotAvailable),
502+
}
503+
}
504+
505+
let ob_type = ob_type!(obj);
506+
507+
unsafe {
508+
if ob_type == available_types.float32 {
509+
return Ok(NumpyObjects::Float32(*(obj as *mut NumpyFloat32)));
510+
} else if ob_type == available_types.float64 {
511+
return Ok(NumpyObjects::Float64(*(obj as *mut NumpyFloat64)));
512+
} else if ob_type == available_types.int32 {
513+
return Ok(NumpyObjects::Int32(*(obj as *mut NumpyInt32)));
514+
} else if ob_type == available_types.int64 {
515+
return Ok(NumpyObjects::Int64(*(obj as *mut NumpyInt64)));
516+
} else if ob_type == available_types.uint32 {
517+
return Ok(NumpyObjects::Uint32(*(obj as *mut NumpyUint32)));
518+
} else if ob_type == available_types.uint64 {
519+
return Ok(NumpyObjects::Uint64(*(obj as *mut NumpyUint64)));
520+
} else {
521+
return Err(NumpyError::InvalidType);
522+
}
523+
}
524+
}

0 commit comments

Comments
 (0)