From 79a0378ca21849565d935831319d871684722231 Mon Sep 17 00:00:00 2001 From: Pedro Arruda Date: Mon, 31 Oct 2022 20:25:14 -0300 Subject: [PATCH 1/2] updating dependencies on this crate (will need new release) --- Cargo.toml | 8 +++--- src/from_numpy.rs | 62 ++++++++++++++--------------------------------- src/to_numpy.rs | 5 ++-- tests/errors.rs | 24 +++++++++--------- tests/to_numpy.rs | 13 +++++----- 5 files changed, 43 insertions(+), 69 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b867200..b9476b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,10 @@ keywords = ["python", "numpy", "nalgebra", "matrix", "conversion"] categories = ["science"] [dependencies] -nalgebra = "0.24.1" -numpy = "0.11" -pyo3 = "0.11" +nalgebra = "0.30" +numpy = "0.17" +pyo3 = "0.17" [dev-dependencies] -inline-python = "0.6.0" +inline-python = "0.10.0" assert2 = "0.3.4" diff --git a/src/from_numpy.rs b/src/from_numpy.rs index c992827..aec3e63 100644 --- a/src/from_numpy.rs +++ b/src/from_numpy.rs @@ -39,7 +39,7 @@ pub struct WrongObjectTypeError { pub struct IncompatibleArrayError { pub expected_shape: Shape, pub actual_shape: Vec, - pub expected_dtype: numpy::DataType, + pub expected_dtype: String, pub actual_dtype: String, } @@ -58,7 +58,7 @@ pub struct UnalignedArrayError; /// The user must ensure that the data is not modified through other pointers or references. #[allow(clippy::needless_lifetimes)] pub unsafe fn matrix_slice_from_numpy<'a, N, R, C>( - _py: pyo3::Python, + py: pyo3::Python, input: &'a PyAny, ) -> Result, Error> where @@ -66,7 +66,7 @@ where R: nalgebra::Dim, C: nalgebra::Dim, { - matrix_slice_from_numpy_ptr(input.as_ptr()) + matrix_slice_from_numpy_ptr(py, input.as_ptr()) } /// Create a mutable nalgebra view from a numpy array. @@ -80,7 +80,7 @@ where /// The user must ensure that no other Rust references to the same data exist. #[allow(clippy::needless_lifetimes)] pub unsafe fn matrix_slice_mut_from_numpy<'a, N, R, C>( - _py: pyo3::Python, + py: pyo3::Python, input: &'a PyAny, ) -> Result, Error> where @@ -88,7 +88,7 @@ where R: nalgebra::Dim, C: nalgebra::Dim, { - matrix_slice_mut_from_numpy_ptr(input.as_ptr()) + matrix_slice_mut_from_numpy_ptr(py, input.as_ptr()) } /// Create an owning nalgebra matrix from a numpy array. @@ -98,7 +98,7 @@ where /// The array dtype must match the output type exactly. /// If desired, you can convert the array to the desired type in Python /// using [`numpy.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html). -pub fn matrix_from_numpy(py: pyo3::Python, input: &PyAny) -> Result, Error> +pub fn matrix_from_numpy(py: pyo3::Python, input: &PyAny) -> Result, Error> where N: nalgebra::Scalar + numpy::Element, R: nalgebra::Dim, @@ -111,6 +111,7 @@ where /// Same as [`matrix_slice_from_numpy`], but takes a raw [`PyObject`](pyo3::ffi::PyObject) pointer. #[allow(clippy::missing_safety_doc)] pub unsafe fn matrix_slice_from_numpy_ptr<'a, N, R, C>( + py: pyo3::Python, array: *mut pyo3::ffi::PyObject, ) -> Result, Error> where @@ -118,8 +119,8 @@ where R: nalgebra::Dim, C: nalgebra::Dim, { - let array = cast_to_py_array(array)?; - let shape = check_array_compatible::(array)?; + let array = cast_to_py_array(py, array)?; + let shape = check_array_compatible::(py, array)?; check_array_alignment(array)?; let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::()); @@ -132,6 +133,7 @@ where /// Same as [`matrix_slice_mut_from_numpy`], but takes a raw [`PyObject`](pyo3::ffi::PyObject) pointer. #[allow(clippy::missing_safety_doc)] pub unsafe fn matrix_slice_mut_from_numpy_ptr<'a, N, R, C>( + py: pyo3::Python, array: *mut pyo3::ffi::PyObject, ) -> Result, Error> where @@ -139,8 +141,8 @@ where R: nalgebra::Dim, C: nalgebra::Dim, { - let array = cast_to_py_array(array)?; - let shape = check_array_compatible::(array)?; + let array = cast_to_py_array(py, array)?; + let shape = check_array_compatible::(py, array)?; check_array_alignment(array)?; let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::()); @@ -151,8 +153,8 @@ where } /// Check if an object is numpy array and cast the pointer. -unsafe fn cast_to_py_array(object: *mut pyo3::ffi::PyObject) -> Result<*mut PyArrayObject, WrongObjectTypeError> { - if npyffi::array::PyArray_Check(object) == 1 { +unsafe fn cast_to_py_array(py: pyo3::Python, object: *mut pyo3::ffi::PyObject) -> Result<*mut PyArrayObject, WrongObjectTypeError> { + if npyffi::array::PyArray_Check(py, object) == 1 { Ok(&mut *(object as *mut npyffi::objects::PyArrayObject)) } else { Err(WrongObjectTypeError { @@ -162,7 +164,7 @@ unsafe fn cast_to_py_array(object: *mut pyo3::ffi::PyObject) -> Result<*mut PyAr } /// Check if a numpy array is compatible and return the runtime shape. -unsafe fn check_array_compatible(array: *mut PyArrayObject) -> Result<(R, C), IncompatibleArrayError> +unsafe fn check_array_compatible(py: pyo3::Python, array: *mut PyArrayObject) -> Result<(R, C), IncompatibleArrayError> where N: numpy::Element, R: nalgebra::Dim, @@ -177,7 +179,7 @@ where IncompatibleArrayError { expected_shape, actual_shape: shape(array), - expected_dtype: N::DATA_TYPE, + expected_dtype: N::get_dtype(py).to_string(), actual_dtype: data_type_string(array), } }; @@ -201,7 +203,7 @@ where } // Check the data type of the input array. - if npyffi::array::PY_ARRAY_API.PyArray_EquivTypenums((*(*array).descr).type_num, N::ffi_dtype() as u32 as i32) != 1 { + if npyffi::array::PY_ARRAY_API.PyArray_EquivTypenums(py, (*(*array).descr).type_num, N::get_dtype(py).num()) != 1 { return Err(make_error()); } @@ -309,10 +311,7 @@ impl std::fmt::Display for IncompatibleArrayError { write!( f, "incompatible array: expected ndarray(shape={}, dtype='{}'), found ndarray(shape={:?}, dtype={:?})", - self.expected_shape, - FormatDataType(&self.expected_dtype), - self.actual_shape, - self.actual_dtype, + self.expected_shape, &self.expected_dtype, self.actual_shape, self.actual_dtype, ) } } @@ -323,31 +322,6 @@ impl std::fmt::Display for UnalignedArrayError { } } -/// Helper to format [`numpy::DataType`] more consistently. -struct FormatDataType<'a>(&'a numpy::DataType); - -impl std::fmt::Display for FormatDataType<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let Self(dtype) = self; - match dtype { - numpy::DataType::Bool => write!(f, "bool"), - numpy::DataType::Complex32 => write!(f, "complex32"), - numpy::DataType::Complex64 => write!(f, "complex64"), - numpy::DataType::Float32 => write!(f, "float32"), - numpy::DataType::Float64 => write!(f, "float64"), - numpy::DataType::Int8 => write!(f, "int8"), - numpy::DataType::Int16 => write!(f, "int16"), - numpy::DataType::Int32 => write!(f, "int32"), - numpy::DataType::Int64 => write!(f, "int64"), - numpy::DataType::Object => write!(f, "object"), - numpy::DataType::Uint8 => write!(f, "uint8"), - numpy::DataType::Uint16 => write!(f, "uint16"), - numpy::DataType::Uint32 => write!(f, "uint32"), - numpy::DataType::Uint64 => write!(f, "uint64"), - } - } -} - impl std::error::Error for Error {} impl std::error::Error for WrongObjectTypeError {} impl std::error::Error for IncompatibleArrayError {} diff --git a/src/to_numpy.rs b/src/to_numpy.rs index d3be3a5..af160ae 100644 --- a/src/to_numpy.rs +++ b/src/to_numpy.rs @@ -13,11 +13,12 @@ where C: nalgebra::Dim, S: nalgebra::storage::Storage, { - let array = PyArray::new(py, (matrix.nrows(), matrix.ncols()), false); + // TODO: safety!? + let array = unsafe { PyArray::new(py, (matrix.nrows(), matrix.ncols()), false) }; for r in 0..matrix.nrows() { for c in 0..matrix.ncols() { unsafe { - *array.uget_mut((r, c)) = matrix[(r, c)].inlined_clone(); + *array.uget_mut((r, c)) = matrix[(r, c)].clone(); } } } diff --git a/tests/errors.rs b/tests/errors.rs index 7226358..34bbd3c 100644 --- a/tests/errors.rs +++ b/tests/errors.rs @@ -5,21 +5,21 @@ use nalgebra_numpy::{matrix_from_numpy, Error}; #[test] fn wrong_type() { - let gil = pyo3::Python::acquire_gil(); - let py = gil.python(); - let context = Context::new_with_gil(py); + pyo3::Python::with_gil(|py| { + let context = Context::new_with_gil(py); - context.run(python! { - float = 3.4 - int = 8 - list = [1.0, 2.0, 3.0] - }); + context.run(python! { + float = 3.4 + int = 8 + list = [1.0, 2.0, 3.0] + }); - let get_global = |name| context.globals(py).get_item(name).unwrap(); + let get_global = |name| context.globals(py).get_item(name).unwrap(); - assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("float"))); - assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("int"))); - assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("list"))); + assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("float"))); + assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("int"))); + assert!(let Err(Error::WrongObjectType(_)) = matrix_from_numpy::(py, get_global("list"))); + }); } #[test] diff --git a/tests/to_numpy.rs b/tests/to_numpy.rs index a8b611a..885e944 100644 --- a/tests/to_numpy.rs +++ b/tests/to_numpy.rs @@ -5,9 +5,8 @@ use nalgebra_numpy::matrix_to_numpy; #[test] #[rustfmt::skip] fn fixed_size() { - let gil = pyo3::Python::acquire_gil(); - - let matrix = matrix_to_numpy(gil.python(), &Matrix3::::new( + pyo3::Python::with_gil(|py| { + let matrix = matrix_to_numpy(py, &Matrix3::::new( 0, 1, 2, 3, 4, 5, 6, 7, 8, @@ -21,14 +20,14 @@ fn fixed_size() { [6, 7, 8], ]) } +}) } #[test] #[rustfmt::skip] fn dynamic_size() { - let gil = pyo3::Python::acquire_gil(); - - let matrix = matrix_to_numpy(gil.python(), &DMatrix::::from_row_slice(3, 4, &[ + pyo3::Python::with_gil(|py| -> () { + let matrix = matrix_to_numpy(py, &DMatrix::::from_row_slice(3, 4, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, @@ -41,5 +40,5 @@ fn dynamic_size() { [4, 5, 6, 7], [8, 9, 10, 11], ]) - } + }}); } From b95bad1ac3b7e6d38615537b4966e09b2bd4a83a Mon Sep 17 00:00:00 2001 From: Pedro Arruda Date: Mon, 31 Oct 2022 23:26:56 -0300 Subject: [PATCH 2/2] more recent version of nalgebra --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b9476b6..1957c83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ keywords = ["python", "numpy", "nalgebra", "matrix", "conversion"] categories = ["science"] [dependencies] -nalgebra = "0.30" +nalgebra = "0.31" numpy = "0.17" pyo3 = "0.17"