Skip to content

Commit

Permalink
Update pyo3 and rust-numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
dustalov committed Dec 1, 2024
1 parent ca1dba6 commit 0a2c853
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 27 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ crate-type = ["cdylib"]

[dependencies]
approx = "^0.5.1"
ndarray = "^0.16.1"
ndarray = "^0.16.1" # numpy supports only >= 0.15, < 0.17
num-traits = "^0.2.19"
pyo3 = { version = "^0.22.3", features = ["extension-module", "abi3-py38"], optional = true }
numpy = { version = "^0.22.0", optional = true }
pyo3 = { version = "^0.23.2", features = ["extension-module", "abi3-py38"], optional = true }
numpy = { version = "^0.23.0", optional = true }

[features]
python = ["dep:pyo3", "dep:numpy"]
5 changes: 4 additions & 1 deletion src/bradley_terry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ pub fn newman(

v = one_nan_to_num(v_new, tolerance);

let broadcast_scores_t = scores.clone().into_shape_with_order((1, scores.len())).unwrap();
let broadcast_scores_t = scores
.clone()
.into_shape_with_order((1, scores.len()))
.unwrap();
let sqrt_scores_outer =
(&broadcast_scores_t * &broadcast_scores_t.t()).mapv_into(f64::sqrt);
let sum_scores = &broadcast_scores_t + &broadcast_scores_t.t();
Expand Down
36 changes: 13 additions & 23 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ unsafe impl Element for Winner {
Clone::clone(self)
}

fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
numpy::dtype_bound::<u8>(py)
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
numpy::dtype::<u8>(py)
}
}

create_exception!(evalica, LengthMismatchError, PyValueError);

#[pyfunction]
fn matrices_pyo3<'py>(
py: Python<'py>,
Expand All @@ -63,8 +62,8 @@ fn matrices_pyo3<'py>(
total,
) {
Ok((wins, ties)) => Ok((
wins.into_pyarray_bound(py).unbind(),
ties.into_pyarray_bound(py).unbind(),
wins.into_pyarray(py).unbind(),
ties.into_pyarray(py).unbind(),
)),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
Expand All @@ -77,7 +76,7 @@ fn pairwise_scores_pyo3<'py>(
) -> PyResult<Py<PyArray2<f64>>> {
let pairwise = pairwise_scores(&scores.as_array());

Ok(pairwise.into_pyarray_bound(py).unbind())
Ok(pairwise.into_pyarray(py).unbind())
}

#[pyfunction]
Expand All @@ -100,7 +99,7 @@ fn counting_pyo3<'py>(
win_weight,
tie_weight,
) {
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand All @@ -125,7 +124,7 @@ fn average_win_rate_pyo3<'py>(
win_weight,
tie_weight,
) {
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand Down Expand Up @@ -160,9 +159,7 @@ fn bradley_terry_pyo3<'py>(
);

match bradley_terry(&matrix.view(), tolerance, limit) {
Ok((scores, iterations)) => {
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
}
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).into(), iterations)),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand Down Expand Up @@ -206,7 +203,7 @@ fn newman_pyo3<'py>(
limit,
) {
Ok((scores, v, iterations)) => {
Ok((scores.into_pyarray_bound(py).unbind(), v, iterations))
Ok((scores.into_pyarray(py).unbind(), v, iterations))
}
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
Expand Down Expand Up @@ -243,7 +240,7 @@ fn elo_pyo3<'py>(
win_weight,
tie_weight,
) {
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand Down Expand Up @@ -278,9 +275,7 @@ fn eigen_pyo3<'py>(
);

match eigen(&matrix.view(), tolerance, limit) {
Ok((scores, iterations)) => {
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
}
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand Down Expand Up @@ -319,9 +314,7 @@ fn pagerank_pyo3<'py>(
);

match pagerank(&matrix.view(), damping, tolerance, limit) {
Ok((scores, iterations)) => {
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
}
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)),
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
}
}
Expand All @@ -332,10 +325,7 @@ fn pagerank_pyo3<'py>(
#[pymodule]
fn evalica(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add(
"LengthMismatchError",
py.get_type_bound::<LengthMismatchError>(),
)?;
m.add("LengthMismatchError", py.get_type::<LengthMismatchError>())?;
m.add_function(wrap_pyfunction!(matrices_pyo3, m)?)?;
m.add_function(wrap_pyfunction!(pairwise_scores_pyo3, m)?)?;
m.add_function(wrap_pyfunction!(counting_pyo3, m)?)?;
Expand Down

0 comments on commit 0a2c853

Please sign in to comment.