1- // SPDX-License-Identifier: (Apache-2.0 OR MIT)
2-
3- use crate :: typeref:: ARRAY_STRUCT_STR ;
1+ use crate :: typeref:: { ARRAY_STRUCT_STR , NUMPY_TYPES } ;
42use pyo3:: ffi:: * ;
53use serde:: ser:: { Serialize , SerializeSeq , Serializer } ;
4+ use std:: ops:: DerefMut ;
65use std:: os:: raw:: { c_char, c_int, c_void} ;
76
87macro_rules! slice {
@@ -21,8 +20,6 @@ pub struct PyCapsule {
2120 pub destructor : * mut c_void , // should be typedef void (*PyCapsule_Destructor)(PyObject *);
2221}
2322
24- // https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__
25-
2623#[ repr( C ) ]
2724pub struct PyArrayInterface {
2825 pub two : c_int ,
@@ -53,6 +50,11 @@ pub enum PyArrayError {
5350 UnsupportedDataType ,
5451}
5552
53+ pub enum NumpyError {
54+ NotAvailable ,
55+ InvalidType ,
56+ }
57+
5658// >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32)
5759// >>> arr.ndim
5860// 3
@@ -256,7 +258,7 @@ impl<'p> Serialize for DataTypeF32 {
256258}
257259
258260#[ repr( transparent) ]
259- struct DataTypeF64 {
261+ pub struct DataTypeF64 {
260262 pub obj : f64 ,
261263}
262264
@@ -270,7 +272,7 @@ impl<'p> Serialize for DataTypeF64 {
270272}
271273
272274#[ repr( transparent) ]
273- struct DataTypeI32 {
275+ pub struct DataTypeI32 {
274276 pub obj : i32 ,
275277}
276278
@@ -284,7 +286,7 @@ impl<'p> Serialize for DataTypeI32 {
284286}
285287
286288#[ repr( transparent) ]
287- struct DataTypeI64 {
289+ pub struct DataTypeI64 {
288290 pub obj : i64 ,
289291}
290292
@@ -298,7 +300,7 @@ impl<'p> Serialize for DataTypeI64 {
298300}
299301
300302#[ repr( transparent) ]
301- struct DataTypeU32 {
303+ pub struct DataTypeU32 {
302304 pub obj : u32 ,
303305}
304306
@@ -312,7 +314,7 @@ impl<'p> Serialize for DataTypeU32 {
312314}
313315
314316#[ repr( transparent) ]
315- struct DataTypeU64 {
317+ pub struct DataTypeU64 {
316318 pub obj : u64 ,
317319}
318320
@@ -326,7 +328,7 @@ impl<'p> Serialize for DataTypeU64 {
326328}
327329
328330#[ repr( transparent) ]
329- struct DataTypeBOOL {
331+ pub struct DataTypeBOOL {
330332 pub obj : u8 ,
331333}
332334
@@ -338,3 +340,185 @@ impl<'p> Serialize for DataTypeBOOL {
338340 serializer. serialize_bool ( self . obj == 1 )
339341 }
340342}
343+
344+ #[ repr( C ) ]
345+ #[ derive( Copy , Clone ) ]
346+ pub struct NumpyInt32 {
347+ pub ob_refcnt : Py_ssize_t ,
348+ pub ob_type : * mut PyTypeObject ,
349+ pub value : i32 ,
350+ }
351+
352+ impl < ' p > Serialize for NumpyInt32 {
353+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
354+ where
355+ S : Serializer ,
356+ {
357+ serializer. serialize_i32 ( self . value )
358+ }
359+ }
360+
361+ #[ repr( C ) ]
362+ #[ derive( Copy , Clone ) ]
363+ pub struct NumpyInt64 {
364+ pub ob_refcnt : Py_ssize_t ,
365+ pub ob_type : * mut PyTypeObject ,
366+ pub value : i64 ,
367+ }
368+
369+ impl < ' p > Serialize for NumpyInt64 {
370+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
371+ where
372+ S : Serializer ,
373+ {
374+ serializer. serialize_i64 ( self . value )
375+ }
376+ }
377+
378+ #[ repr( C ) ]
379+ #[ derive( Copy , Clone ) ]
380+ pub struct NumpyUint32 {
381+ pub ob_refcnt : Py_ssize_t ,
382+ pub ob_type : * mut PyTypeObject ,
383+ pub value : u32 ,
384+ }
385+
386+ impl < ' p > Serialize for NumpyUint32 {
387+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
388+ where
389+ S : Serializer ,
390+ {
391+ serializer. serialize_u32 ( self . value )
392+ }
393+ }
394+
395+ #[ repr( C ) ]
396+ #[ derive( Copy , Clone ) ]
397+ pub struct NumpyUint64 {
398+ pub ob_refcnt : Py_ssize_t ,
399+ pub ob_type : * mut PyTypeObject ,
400+ pub value : u64 ,
401+ }
402+
403+ impl < ' p > Serialize for NumpyUint64 {
404+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
405+ where
406+ S : Serializer ,
407+ {
408+ serializer. serialize_u64 ( self . value )
409+ }
410+ }
411+
412+ #[ repr( C ) ]
413+ #[ derive( Copy , Clone ) ]
414+ pub struct NumpyFloat32 {
415+ pub ob_refcnt : Py_ssize_t ,
416+ pub ob_type : * mut PyTypeObject ,
417+ pub value : f32 ,
418+ }
419+
420+ impl < ' p > Serialize for NumpyFloat32 {
421+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
422+ where
423+ S : Serializer ,
424+ {
425+ serializer. serialize_f32 ( self . value )
426+ }
427+ }
428+
429+ #[ repr( C ) ]
430+ #[ derive( Copy , Clone ) ]
431+ pub struct NumpyFloat64 {
432+ pub ob_refcnt : Py_ssize_t ,
433+ pub ob_type : * mut PyTypeObject ,
434+ pub value : f64 ,
435+ }
436+
437+ impl < ' p > Serialize for NumpyFloat64 {
438+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
439+ where
440+ S : Serializer ,
441+ {
442+ serializer. serialize_f64 ( self . value )
443+ }
444+ }
445+
446+ pub fn is_numpy_scalar ( ob_type : * mut PyTypeObject ) -> bool {
447+ let available_types;
448+ unsafe {
449+ match NUMPY_TYPES . deref_mut ( ) {
450+ Some ( v) => available_types = v,
451+ _ => return false ,
452+ }
453+ }
454+
455+ let numpy_scalars = [
456+ available_types. float32 ,
457+ available_types. float64 ,
458+ available_types. int32 ,
459+ available_types. int64 ,
460+ available_types. uint32 ,
461+ available_types. uint64 ,
462+ ] ;
463+ numpy_scalars. contains ( & ob_type)
464+ }
465+
466+ pub fn is_numpy_array ( ob_type : * mut PyTypeObject ) -> bool {
467+ let available_types;
468+ unsafe {
469+ match NUMPY_TYPES . deref_mut ( ) {
470+ Some ( v) => available_types = v,
471+ _ => return false ,
472+ }
473+ }
474+ available_types. array == ob_type
475+ }
476+
477+ // pub fn serialize_numpy_scalar<S>(obj: *mut pyo3::ffi::PyObject, serializer: S) -> Result<S::Ok, S::Error>
478+ // where
479+ // S: Serializer,
480+ // {
481+ // let ob_type = ob_type!(obj);
482+ // let numpy = match ob_type {
483+
484+ // }
485+ // }
486+
487+ pub enum NumpyObjects {
488+ Float32 ( NumpyFloat32 ) ,
489+ Float64 ( NumpyFloat64 ) ,
490+ Int32 ( NumpyInt32 ) ,
491+ Int64 ( NumpyInt64 ) ,
492+ Uint32 ( NumpyUint32 ) ,
493+ Uint64 ( NumpyUint64 ) ,
494+ }
495+
496+ pub fn pyobj_to_numpy_obj ( obj : * mut pyo3:: ffi:: PyObject ) -> Result < NumpyObjects , NumpyError > {
497+ let available_types;
498+ unsafe {
499+ match NUMPY_TYPES . deref_mut ( ) {
500+ Some ( v) => available_types = v,
501+ _ => return Err ( NumpyError :: NotAvailable ) ,
502+ }
503+ }
504+
505+ let ob_type = ob_type ! ( obj) ;
506+
507+ unsafe {
508+ if ob_type == available_types. float32 {
509+ return Ok ( NumpyObjects :: Float32 ( * ( obj as * mut NumpyFloat32 ) ) ) ;
510+ } else if ob_type == available_types. float64 {
511+ return Ok ( NumpyObjects :: Float64 ( * ( obj as * mut NumpyFloat64 ) ) ) ;
512+ } else if ob_type == available_types. int32 {
513+ return Ok ( NumpyObjects :: Int32 ( * ( obj as * mut NumpyInt32 ) ) ) ;
514+ } else if ob_type == available_types. int64 {
515+ return Ok ( NumpyObjects :: Int64 ( * ( obj as * mut NumpyInt64 ) ) ) ;
516+ } else if ob_type == available_types. uint32 {
517+ return Ok ( NumpyObjects :: Uint32 ( * ( obj as * mut NumpyUint32 ) ) ) ;
518+ } else if ob_type == available_types. uint64 {
519+ return Ok ( NumpyObjects :: Uint64 ( * ( obj as * mut NumpyUint64 ) ) ) ;
520+ } else {
521+ return Err ( NumpyError :: InvalidType ) ;
522+ }
523+ }
524+ }
0 commit comments