Skip to content

Commit

Permalink
add support for strides for offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Mar 6, 2024
1 parent 797a904 commit 2674200
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 30 deletions.
3 changes: 2 additions & 1 deletion mwatershed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def agglom(
offsets: list[list[int]],
seeds: Optional[np.ndarray] = None,
edges: Optional[list[tuple[bool, int, int]]] = None,
strides: Optional[list[list[int]]] = None
):
return agglom_rs(affinities, offsets, seeds, edges)
return agglom_rs(affinities, offsets, seeds, edges, strides)


__all__ = ["agglom", "cluster"]
118 changes: 89 additions & 29 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub fn get_edges<const D: usize>(
affinities: &Array<f64, IxDyn>,
offsets: Vec<Vec<isize>>,
seeds: &Array<usize, IxDyn>,
strides: Option<Vec<Vec<usize>>>,
) -> (Vec<AgglomEdge>, HashSet<usize>) {
// let (_, array_shape) = get_dims::<D>(seeds.dim(), 0);
let offsets: Vec<[isize; D]> = offsets
Expand All @@ -38,36 +39,46 @@ pub fn get_edges<const D: usize>(

let mut to_filter: HashSet<usize> = HashSet::from_iter(seeds.iter().copied());

let strides = strides.unwrap_or_else(|| {
(0..D)
.map(|_| (0..D).map(|_| 1).collect())
.collect::<Vec<Vec<usize>>>()
});

offsets
.iter()
.zip(strides.iter())
.enumerate()
.for_each(|(offset_index, offset)| {
.for_each(|(offset_index, (offset, stride))| {
let all_offset_affs = affinities.index_axis(Axis(0), offset_index);
let offset_affs = all_offset_affs.slice_each_axis(|ax| {
Slice::from(
std::cmp::max(0, -offset[ax.axis.index()])
..std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
),
Slice::new(
std::cmp::max(0, -offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
)
});
let u_seeds = seeds.slice_each_axis(|ax| {
Slice::from(
std::cmp::max(0, -offset[ax.axis.index()])
..std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
),
Slice::new(
std::cmp::max(0, -offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) - offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
)
});
let v_seeds = seeds.slice_each_axis(|ax| {
Slice::from(
std::cmp::max(0, offset[ax.axis.index()])
..std::cmp::min(
ax.len as isize,
(ax.len as isize) + offset[ax.axis.index()],
),
Slice::new(
std::cmp::max(0, offset[ax.axis.index()]),
Some(std::cmp::min(
ax.len as isize,
(ax.len as isize) + offset[ax.axis.index()],
)),
stride[ax.axis.index()].try_into().unwrap(),
)
});
offset_affs.indexed_iter().for_each(|(index, aff)| {
Expand Down Expand Up @@ -104,6 +115,7 @@ pub fn agglomerate<const D: usize>(
offsets: Vec<Vec<isize>>,
mut edges: Vec<AgglomEdge>,
mut seeds: Array<usize, IxDyn>,
strides: Option<Vec<Vec<usize>>>,
) -> Array<usize, IxDyn> {
// relabel to consecutive ids
let mut lookup = HashMap::new();
Expand All @@ -130,7 +142,8 @@ pub fn agglomerate<const D: usize>(
});

// main algorithm
let (sorted_edges, mut filtered_background) = get_edges::<D>(affinities, offsets, &seeds);
let (sorted_edges, mut filtered_background) =
get_edges::<D>(affinities, offsets, &seeds, strides);
edges.extend(sorted_edges);
lookup.values().for_each(|node_id| {
filtered_background.remove(node_id);
Expand Down Expand Up @@ -211,6 +224,7 @@ fn agglom_rs<'py>(
offsets: Vec<Vec<isize>>,
seeds: Option<&PyArrayDyn<usize>>,
edges: Option<Vec<(bool, usize, usize)>>,
strides: Option<Vec<Vec<usize>>>,
) -> PyResult<&'py PyArrayDyn<usize>> {
let affinities = unsafe { affinities.as_array() }.to_owned();
let seeds = match seeds {
Expand All @@ -224,12 +238,12 @@ fn agglom_rs<'py>(
.map(|(pos, u, v)| AgglomEdge(pos, u, v))
.collect();
let result = match dim {
1 => agglomerate::<1>(&affinities, offsets, edges, seeds),
2 => agglomerate::<2>(&affinities, offsets, edges, seeds),
3 => agglomerate::<3>(&affinities, offsets, edges, seeds),
4 => agglomerate::<4>(&affinities, offsets, edges, seeds),
5 => agglomerate::<5>(&affinities, offsets, edges, seeds),
6 => agglomerate::<6>(&affinities, offsets, edges, seeds),
1 => agglomerate::<1>(&affinities, offsets, edges, seeds, strides),
2 => agglomerate::<2>(&affinities, offsets, edges, seeds, strides),
3 => agglomerate::<3>(&affinities, offsets, edges, seeds, strides),
4 => agglomerate::<4>(&affinities, offsets, edges, seeds, strides),
5 => agglomerate::<5>(&affinities, offsets, edges, seeds, strides),
6 => agglomerate::<6>(&affinities, offsets, edges, seeds, strides),
_ => panic!["Only 1-6 dimensional arrays supported"],
};
Ok(result.into_pyarray(_py))
Expand All @@ -256,6 +270,8 @@ fn mwatershed(_py: Python, m: &PyModule) -> PyResult<()> {

#[cfg(test)]
mod tests {
use std::vec;

use super::*;
use itertools::Itertools;
use ndarray::array;
Expand Down Expand Up @@ -292,7 +308,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, 1], vec![1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let ids = components
.clone()
.into_iter()
Expand All @@ -304,6 +320,50 @@ mod tests {
assert!(!ids.contains(&0), "{:?}", components);
assert!(ids.len() == 4, "{:?}", components);
}
/// Seeds
/// 1 2 0
/// 4 0 0
/// 0 0 0
///
/// Affs
/// offset [0, 1]
/// 0 1 0
/// 0 1 0
/// 0 1 0
///
/// offset [1, 0]
/// 0 0 0
/// 1 1 1
/// 0 0 0
///
/// Expected Components
/// 1 2 2
/// 4 0 x
/// 4 x x
///
#[test]
fn test_agglom_with_strides() {
let affinities = array![
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]],
[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
]
.into_dyn()
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, 1], vec![1, 0]];
let strides = vec![vec![2, 1], vec![1, 2]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides));
let ids = components
.clone()
.into_iter()
.unique()
.collect::<Vec<usize>>();
for id in [1, 2, 4].iter() {
assert!(ids.contains(id), "{:?}", components);
}
assert!(ids.contains(&0), "{:?}", components);
assert!(ids.len() == 5, "{:?}", components);
}

/// Seeds
/// 1 2 0
Expand Down Expand Up @@ -336,7 +396,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, -1], vec![-1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let ids = components
.clone()
.into_iter()
Expand Down Expand Up @@ -380,7 +440,7 @@ mod tests {
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, -1], vec![-1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds);
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None);
let ids = components
.clone()
.into_iter()
Expand Down
29 changes: 29 additions & 0 deletions tests/test_agglom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np

import mwatershed


def test_agglom_2d_with_strides():
offsets = [(0, 1), (1, 0)]
strides = [(1, 2), (2, 1)]
affinities = (
np.array(
[[[0, 1, 0], [0, 1, 0], [0, 1, 0]], [[0, 0, 0], [1, 1, 1], [0, 0, 0]]],
dtype=float,
)
- 0.5
)
# 9 nodes. connecting edges:
# 2-3, 5-6, 8-9, 4-7, 5-8, 6-9
# components: [(1,),(2,3),(4,7),(5,6,8,9)]

components = mwatershed.agglom(affinities, offsets, strides=strides)

_, counts = np.unique(components, return_counts=True)
counts = sorted(counts)
assert len(counts) == 4

assert counts[0] == 1
assert counts[1] == 2
assert counts[2] == 2
assert counts[3] == 4

0 comments on commit 2674200

Please sign in to comment.