Skip to content

Commit

Permalink
Add support for directed graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
funkey committed Aug 21, 2024
1 parent d60fca1 commit 72b7cc8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 43 deletions.
6 changes: 3 additions & 3 deletions spatial_graph/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __new__(
wrapper_template.node_dtype = node_dtype
wrapper_template.node_attr_dtypes = node_attr_dtypes
wrapper_template.edge_attr_dtypes = edge_attr_dtypes
wrapper_template.directed = directed

wrapper = witty.compile_module(
str(wrapper_template),
Expand All @@ -42,9 +43,8 @@ def __new__(
language="c++",
quiet=True,
)
Graph = wrapper.DirectedGraph if directed else wrapper.UndirectedGraph
GraphType = type(cls.__name__, (cls, Graph), {})
return Graph.__new__(GraphType)
GraphType = type(cls.__name__, (cls, wrapper.Graph), {})
return wrapper.Graph.__new__(GraphType)

def __init__(self, node_dtype, node_attr_dtypes, edge_attr_dtypes, directed=False):
super().__init__()
Expand Down
113 changes: 76 additions & 37 deletions spatial_graph/graph/wrapper_template.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@ cdef extern from *:
#include "src/graph_lite.h"
// partial template instantiation of graph_lite::Graph as
// UndirectedGraphTmpl
// GraphTmpl
template<typename NodeType, typename NodeData, typename EdgeData>
class UndirectedGraphTmpl : public graph_lite::Graph<
class GraphTmpl : public graph_lite::Graph<
NodeType,
NodeData,
EdgeData,
%if $directed
graph_lite::EdgeDirection::DIRECTED,
%else
graph_lite::EdgeDirection::UNDIRECTED,
%end if
graph_lite::MultiEdge::DISALLOWED,
graph_lite::SelfLoop::DISALLOWED,
graph_lite::Map::UNORDERED_MAP,
graph_lite::Container::VEC
> {};
"""

cdef cppclass UndirectedGraphTmpl[NodeType, NodeData, EdgeData]:
cdef cppclass GraphTmpl[NodeType, NodeData, EdgeData]:

cppclass Iterator:
NodeType operator*()
Expand All @@ -49,12 +53,24 @@ cdef extern from *:

EdgeData& edge_prop[T](T& u, T& v)

%if $directed
pair[NeighborsIterator, NeighborsIterator] out_neighbors(Iterator& node)
pair[NeighborsIterator, NeighborsIterator] out_neighbors(NodeType& node)
pair[NeighborsIterator, NeighborsIterator] in_neighbors(Iterator& node)
pair[NeighborsIterator, NeighborsIterator] in_neighbors(NodeType& node)
%else
pair[NeighborsIterator, NeighborsIterator] neighbors(Iterator& node)
pair[NeighborsIterator, NeighborsIterator] neighbors(NodeType& node)
%end if

int remove_nodes(NodeType& node)

%if $directed
int count_in_neighbors(NodeType& node)
int count_out_neighbors(NodeType& node)
%else
int count_neighbors(NodeType& node)
%end if

size_t size() const

Expand Down Expand Up @@ -144,14 +160,14 @@ cdef class ${class_name}View:

%end for

ctypedef UndirectedGraphTmpl[NodeType, NodeData, EdgeData] UndirectedGraphType
ctypedef UndirectedGraphType.Iterator NodeIterator
ctypedef UndirectedGraphType.NeighborsIterator NeighborsIterator
ctypedef GraphTmpl[NodeType, NodeData, EdgeData] GraphType
ctypedef GraphType.Iterator NodeIterator
ctypedef GraphType.NeighborsIterator NeighborsIterator


cdef class UndirectedGraph:
cdef class Graph:

cdef UndirectedGraphType _graph
cdef GraphType _graph

%for kind, Kind, dtypes in [
("node", "Node", $node_attr_dtypes),
Expand Down Expand Up @@ -260,62 +276,71 @@ cdef class UndirectedGraph:
yield deref(it)
inc(it)

def edges(self, node=None, bint data=False):
%if $directed
%set $prefixes=["in_", "out_"]
%else
%set $prefixes=[""]
%end if
%for prefix in $prefixes
def ${prefix}edges(self, node=None, bint data=False):

if node is not None:
yield from self._neighbors(<NodeType>node, data)
yield from self._${prefix}neighbors(<NodeType>node, data)
return

# iterate over all edges by iterating over all nodes u and their
# neighbors v with u < v
cdef NodeIterator node_it = self._graph.begin()
cdef NodeIterator node_end = self._graph.end()
cdef pair[NeighborsIterator, NeighborsIterator] view
cdef NodeType u, v
cdef EdgeDataView edge_data = EdgeDataView()

while node_it != node_end:
view = self._graph.neighbors(node_it)
view = self._graph.${prefix}neighbors(node_it)
u = deref(node_it)
it = view.first
end = view.second
while it != end:
v = deref(it).first
if u < v:
if data:
edge_data.set_ptr(&deref(it).second.prop())
yield (u, v), edge_data
else:
yield (u, v)
%if not $directed
# avoid double-reporting undirected edges by returning only
# edges where u < v
if u >= v:
inc(it)
continue
%end if
if data:
edge_data.set_ptr(&deref(it).second.prop())
yield (u, v), edge_data
else:
yield (u, v)
inc(it)
inc(node_it)

# same as above, but for fast access to edges incident to an array of nodes
def edges_by_nodes(self, NodeType[::1] nodes):
# NOTE: this will double-report edges between "nodes"
def ${prefix}edges_by_nodes(self, NodeType[::1] nodes):

# iterate over all edges by iterating over all nodes u and their
# neighbors v with u < v
cdef pair[NeighborsIterator, NeighborsIterator] view
cdef NodeType u, v
cdef Py_ssize_t i = 0

num_edges = self._num_edges(nodes)
num_edges = self._num_${prefix}edges(nodes)
data = np.empty(shape=(num_edges, 2), dtype="$node_dtype.base")
cdef NodeType[:, ::1] edges = data

for u in nodes:
view = self._graph.neighbors(u)
view = self._graph.${prefix}neighbors(u)
it = view.first
end = view.second
while it != end:
v = deref(it).first
if u < v:
edges[i, 0] = u
edges[i, 1] = v
i += 1
edges[i, 0] = u
edges[i, 1] = v
i += 1
inc(it)

return data[:i]
%end for

# generator access to node and edge data

Expand Down Expand Up @@ -442,11 +467,17 @@ cdef class UndirectedGraph:

if us is None:

# iterate over all edges by iterating over all nodes u and their
# neighbors v with u < v

while node_it != node_end:
%if $directed
# iterate over all edges by iterating over all nodes u and
# their out neighbors
edges_view = self._graph.out_neighbors(node_it)
%else
# iterate over all edges by iterating over all nodes u and
# their neighbors v with u < v
edges_view = self._graph.neighbors(node_it)
%end if
u = deref(node_it)
it = edges_view.first
end = edges_view.second
Expand Down Expand Up @@ -507,19 +538,26 @@ cdef class UndirectedGraph:

# read-only graph properties

def count_neighbors(self, NodeType[:] nodes):
%if $directed
%set $prefixes=["in_", "out_"]
%else
%set $prefixes=[""]
%end if
%for prefix in $prefixes
def count_${prefix}neighbors(self, NodeType[:] nodes):
num_nodes = len(nodes)
cdef int[:] counts = view.array(
shape=(num_nodes,),
itemsize=sizeof(int),
format="i")
for i in range(num_nodes):
counts[i] = self._graph.count_neighbors(nodes[i])
counts[i] = self._graph.count_${prefix}neighbors(nodes[i])
return counts

def _neighbors(self, NodeType node, bint data):
def _${prefix}neighbors(self, NodeType node, bint data):

cdef pair[NeighborsIterator, NeighborsIterator] view = self._graph.neighbors(node)
cdef pair[NeighborsIterator, NeighborsIterator] view = \
self._graph.${prefix}neighbors(node)
cdef NeighborsIterator it = view.first
cdef NeighborsIterator end = view.second
cdef EdgeDataView edge_data = EdgeDataView()
Expand All @@ -534,11 +572,12 @@ cdef class UndirectedGraph:
yield deref(it).first
inc(it)

def _num_${prefix}edges(self, NodeType[::1] nodes):
return np.sum(self.count_${prefix}neighbors(nodes))
%end for

def __len__(self):
return self._graph.size()

def num_edges(self):
return self._graph.num_edges()

def _num_edges(self, NodeType[::1] nodes):
return np.sum(self.count_neighbors(nodes))
49 changes: 46 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,55 @@
@pytest.mark.parametrize("edge_attr_dtypes", edge_attr_dtypes)
@pytest.mark.parametrize("directed", [True, False])
def test_construction(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed):
# TODO (directed graphs not yet wrapped)
if directed:
return
graph = sg.Graph(node_dtype, node_attr_dtypes, edge_attr_dtypes, directed)


@pytest.mark.parametrize("directed", [True, False])
def test_operations(directed):
graph = sg.Graph(
"uint64", {"score": "float"}, {"score": "float"}, directed=directed
)

nodes = [1, 2, 3, 4, 5]
graph.add_nodes(
np.array(nodes, dtype="uint64"),
score=np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype="float32"),
)
for u in nodes:
for v in nodes:
if v == u:
continue
graph.add_edge(np.array([u, v], dtype="uint64"), score=u * 100 + v)

if directed:
assert graph.num_edges() == len(nodes) ** 2 - len(nodes)

for node in nodes:
in_neighbors = graph.count_in_neighbors(np.array([node], dtype="uint64"))
out_neighbors = graph.count_in_neighbors(np.array([node], dtype="uint64"))
assert len(in_neighbors) == 1
assert len(out_neighbors) == 1
assert in_neighbors[0] == len(nodes) - 1
assert out_neighbors[0] == len(nodes) - 1

for edge, attrs in graph.out_edges(data=True):
assert attrs.score == edge[0] * 100 + edge[1]

for edge, attrs in graph.in_edges(data=True):
assert attrs.score == edge[1] * 100 + edge[0]

else:
assert graph.num_edges() == (len(nodes) ** 2 - len(nodes)) / 2

for node in nodes:
neighbors = graph.count_neighbors(np.array([node], dtype="uint64"))
assert len(neighbors) == 1
assert neighbors[0] == len(nodes) - 1

for edge, attrs in graph.edges(data=True):
assert attrs.score == edge[0] * 100 + edge[1]


def test_attribute_modification():
graph = sg.Graph(
"uint64",
Expand Down

0 comments on commit 72b7cc8

Please sign in to comment.