diff --git a/spatial_graph/graph/graph.py b/spatial_graph/graph/graph.py index d53e608..3e0ad6c 100644 --- a/spatial_graph/graph/graph.py +++ b/spatial_graph/graph/graph.py @@ -34,6 +34,7 @@ def __new__( wrapper_template.node_dtype = node_dtype wrapper_template.node_attr_dtypes = node_attr_dtypes wrapper_template.edge_attr_dtypes = edge_attr_dtypes + wrapper_template.directed = directed wrapper = witty.compile_module( str(wrapper_template), @@ -42,9 +43,8 @@ def __new__( language="c++", quiet=True, ) - Graph = wrapper.DirectedGraph if directed else wrapper.UndirectedGraph - GraphType = type(cls.__name__, (cls, Graph), {}) - return Graph.__new__(GraphType) + GraphType = type(cls.__name__, (cls, wrapper.Graph), {}) + return wrapper.Graph.__new__(GraphType) def __init__(self, node_dtype, node_attr_dtypes, edge_attr_dtypes, directed=False): super().__init__() diff --git a/spatial_graph/graph/wrapper_template.pyx b/spatial_graph/graph/wrapper_template.pyx index d403e72..9bc04e3 100644 --- a/spatial_graph/graph/wrapper_template.pyx +++ b/spatial_graph/graph/wrapper_template.pyx @@ -10,13 +10,17 @@ cdef extern from *: #include "src/graph_lite.h" // partial template instantiation of graph_lite::Graph as - // UndirectedGraphTmpl + // GraphTmpl template - class UndirectedGraphTmpl : public graph_lite::Graph< + class GraphTmpl : public graph_lite::Graph< NodeType, NodeData, EdgeData, + %if $directed + graph_lite::EdgeDirection::DIRECTED, + %else graph_lite::EdgeDirection::UNDIRECTED, + %end if graph_lite::MultiEdge::DISALLOWED, graph_lite::SelfLoop::DISALLOWED, graph_lite::Map::UNORDERED_MAP, @@ -24,7 +28,7 @@ cdef extern from *: > {}; """ - cdef cppclass UndirectedGraphTmpl[NodeType, NodeData, EdgeData]: + cdef cppclass GraphTmpl[NodeType, NodeData, EdgeData]: cppclass Iterator: NodeType operator*() @@ -49,12 +53,24 @@ cdef extern from *: EdgeData& edge_prop[T](T& u, T& v) + %if $directed + pair[NeighborsIterator, NeighborsIterator] out_neighbors(Iterator& node) + pair[NeighborsIterator, NeighborsIterator] out_neighbors(NodeType& node) + pair[NeighborsIterator, NeighborsIterator] in_neighbors(Iterator& node) + pair[NeighborsIterator, NeighborsIterator] in_neighbors(NodeType& node) + %else pair[NeighborsIterator, NeighborsIterator] neighbors(Iterator& node) pair[NeighborsIterator, NeighborsIterator] neighbors(NodeType& node) + %end if int remove_nodes(NodeType& node) + %if $directed + int count_in_neighbors(NodeType& node) + int count_out_neighbors(NodeType& node) + %else int count_neighbors(NodeType& node) + %end if size_t size() const @@ -144,14 +160,14 @@ cdef class ${class_name}View: %end for -ctypedef UndirectedGraphTmpl[NodeType, NodeData, EdgeData] UndirectedGraphType -ctypedef UndirectedGraphType.Iterator NodeIterator -ctypedef UndirectedGraphType.NeighborsIterator NeighborsIterator +ctypedef GraphTmpl[NodeType, NodeData, EdgeData] GraphType +ctypedef GraphType.Iterator NodeIterator +ctypedef GraphType.NeighborsIterator NeighborsIterator -cdef class UndirectedGraph: +cdef class Graph: - cdef UndirectedGraphType _graph + cdef GraphType _graph %for kind, Kind, dtypes in [ ("node", "Node", $node_attr_dtypes), @@ -260,14 +276,18 @@ cdef class UndirectedGraph: yield deref(it) inc(it) - def edges(self, node=None, bint data=False): + %if $directed + %set $prefixes=["in_", "out_"] + %else + %set $prefixes=[""] + %end if + %for prefix in $prefixes + def ${prefix}edges(self, node=None, bint data=False): if node is not None: - yield from self._neighbors(node, data) + yield from self._${prefix}neighbors(node, data) return - # iterate over all edges by iterating over all nodes u and their - # neighbors v with u < v cdef NodeIterator node_it = self._graph.begin() cdef NodeIterator node_end = self._graph.end() cdef pair[NeighborsIterator, NeighborsIterator] view @@ -275,47 +295,52 @@ cdef class UndirectedGraph: cdef EdgeDataView edge_data = EdgeDataView() while node_it != node_end: - view = self._graph.neighbors(node_it) + view = self._graph.${prefix}neighbors(node_it) u = deref(node_it) it = view.first end = view.second while it != end: v = deref(it).first - if u < v: - if data: - edge_data.set_ptr(&deref(it).second.prop()) - yield (u, v), edge_data - else: - yield (u, v) + %if not $directed + # avoid double-reporting undirected edges by returning only + # edges where u < v + if u >= v: + inc(it) + continue + %end if + if data: + edge_data.set_ptr(&deref(it).second.prop()) + yield (u, v), edge_data + else: + yield (u, v) inc(it) inc(node_it) # same as above, but for fast access to edges incident to an array of nodes - def edges_by_nodes(self, NodeType[::1] nodes): + # NOTE: this will double-report edges between "nodes" + def ${prefix}edges_by_nodes(self, NodeType[::1] nodes): - # iterate over all edges by iterating over all nodes u and their - # neighbors v with u < v cdef pair[NeighborsIterator, NeighborsIterator] view cdef NodeType u, v cdef Py_ssize_t i = 0 - num_edges = self._num_edges(nodes) + num_edges = self._num_${prefix}edges(nodes) data = np.empty(shape=(num_edges, 2), dtype="$node_dtype.base") cdef NodeType[:, ::1] edges = data for u in nodes: - view = self._graph.neighbors(u) + view = self._graph.${prefix}neighbors(u) it = view.first end = view.second while it != end: v = deref(it).first - if u < v: - edges[i, 0] = u - edges[i, 1] = v - i += 1 + edges[i, 0] = u + edges[i, 1] = v + i += 1 inc(it) return data[:i] + %end for # generator access to node and edge data @@ -442,11 +467,17 @@ cdef class UndirectedGraph: if us is None: - # iterate over all edges by iterating over all nodes u and their - # neighbors v with u < v while node_it != node_end: + %if $directed + # iterate over all edges by iterating over all nodes u and + # their out neighbors + edges_view = self._graph.out_neighbors(node_it) + %else + # iterate over all edges by iterating over all nodes u and + # their neighbors v with u < v edges_view = self._graph.neighbors(node_it) + %end if u = deref(node_it) it = edges_view.first end = edges_view.second @@ -507,19 +538,26 @@ cdef class UndirectedGraph: # read-only graph properties - def count_neighbors(self, NodeType[:] nodes): + %if $directed + %set $prefixes=["in_", "out_"] + %else + %set $prefixes=[""] + %end if + %for prefix in $prefixes + def count_${prefix}neighbors(self, NodeType[:] nodes): num_nodes = len(nodes) cdef int[:] counts = view.array( shape=(num_nodes,), itemsize=sizeof(int), format="i") for i in range(num_nodes): - counts[i] = self._graph.count_neighbors(nodes[i]) + counts[i] = self._graph.count_${prefix}neighbors(nodes[i]) return counts - def _neighbors(self, NodeType node, bint data): + def _${prefix}neighbors(self, NodeType node, bint data): - cdef pair[NeighborsIterator, NeighborsIterator] view = self._graph.neighbors(node) + cdef pair[NeighborsIterator, NeighborsIterator] view = \ + self._graph.${prefix}neighbors(node) cdef NeighborsIterator it = view.first cdef NeighborsIterator end = view.second cdef EdgeDataView edge_data = EdgeDataView() @@ -534,11 +572,12 @@ cdef class UndirectedGraph: yield deref(it).first inc(it) + def _num_${prefix}edges(self, NodeType[::1] nodes): + return np.sum(self.count_${prefix}neighbors(nodes)) + %end for + def __len__(self): return self._graph.size() def num_edges(self): return self._graph.num_edges() - - def _num_edges(self, NodeType[::1] nodes): - return np.sum(self.count_neighbors(nodes)) diff --git a/tests/test_graph.py b/tests/test_graph.py index 2c08622..8a70290 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -20,12 +20,55 @@ @pytest.mark.parametrize("edge_attr_dtypes", edge_attr_dtypes) @pytest.mark.parametrize("directed", [True, False]) def test_construction(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed): - # TODO (directed graphs not yet wrapped) - if directed: - return graph = sg.Graph(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed) +@pytest.mark.parametrize("directed", [True, False]) +def test_operations(directed): + graph = sg.Graph( + "uint64", {"score": "float"}, {"score": "float"}, directed=directed + ) + + nodes = [1, 2, 3, 4, 5] + graph.add_nodes( + np.array(nodes, dtype="uint64"), + score=np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype="float32"), + ) + for u in nodes: + for v in nodes: + if v == u: + continue + graph.add_edge(np.array([u, v], dtype="uint64"), score=u * 100 + v) + + if directed: + assert graph.num_edges() == len(nodes) ** 2 - len(nodes) + + for node in nodes: + in_neighbors = graph.count_in_neighbors(np.array([node], dtype="uint64")) + out_neighbors = graph.count_in_neighbors(np.array([node], dtype="uint64")) + assert len(in_neighbors) == 1 + assert len(out_neighbors) == 1 + assert in_neighbors[0] == len(nodes) - 1 + assert out_neighbors[0] == len(nodes) - 1 + + for edge, attrs in graph.out_edges(data=True): + assert attrs.score == edge[0] * 100 + edge[1] + + for edge, attrs in graph.in_edges(data=True): + assert attrs.score == edge[1] * 100 + edge[0] + + else: + assert graph.num_edges() == (len(nodes) ** 2 - len(nodes)) / 2 + + for node in nodes: + neighbors = graph.count_neighbors(np.array([node], dtype="uint64")) + assert len(neighbors) == 1 + assert neighbors[0] == len(nodes) - 1 + + for edge, attrs in graph.edges(data=True): + assert attrs.score == edge[0] * 100 + edge[1] + + def test_attribute_modification(): graph = sg.Graph( "uint64",