Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to graph class (px2ang and max_neighbors are added) #66

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
30 changes: 22 additions & 8 deletions atomai/nets/fcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Unet(nn.Module):
Builds a fully convolutional Unet-like neural network model

Args:
n_channels:
Number of channels in the input image
nb_classes:
Number of classes in the ground truth
nb_filters:
Expand Down Expand Up @@ -48,6 +50,7 @@ class Unet(nn.Module):
(to maintain symmetry between encoder and decoder)
"""
def __init__(self,
n_channels: int = 1,
nb_classes: int = 1,
nb_filters: int = 16,
dropout: bool = False,
Expand All @@ -64,7 +67,7 @@ def __init__(self,
padding_values = dilation_values.copy()
dropout_vals = [.1, .2, .1] if dropout else [0, 0, 0]
self.c1 = ConvBlock(
2, nbl[0], 1, nb_filters,
2, nbl[0], n_channels, nb_filters,
batch_norm=batch_norm
)
self.c2 = ConvBlock(
Expand Down Expand Up @@ -148,6 +151,8 @@ class dilnet(nn.Module):
by utilizing a combination of regular and dilated convolutions

Args:
n_channels:
Number of channels in the input image
nb_classes:
Number of classes in the ground truth
nb_filters:
Expand All @@ -167,6 +172,7 @@ class dilnet(nn.Module):
"""

def __init__(self,
n_channels: int = 1,
nb_classes: int = 1,
nb_filters: int = 25,
dropout: bool = False,
Expand All @@ -184,7 +190,7 @@ def __init__(self,
padding_values_2 = dilation_values_2.copy()
dropout_vals = [.3, .3] if dropout else [0, 0]
self.c1 = ConvBlock(
2, nbl[0], 1, nb_filters,
2, nbl[0], n_channels, nb_filters,
batch_norm=batch_norm
)
self.at1 = DilatedBlock(
Expand Down Expand Up @@ -231,6 +237,8 @@ class ResHedNet(nn.Module):
Holistically nested edge detector with residual connections in each block

Args:
n_channels:
Number of channels in the input layer
nb_classes:
Number of classes in the ground truth
nb_filters:
Expand All @@ -247,6 +255,7 @@ class ResHedNet(nn.Module):

"""
def __init__(self,
n_channels: int = 1,
nb_classes: int = 1,
nb_filters: int = 64,
upsampling_mode: str = "bilinear",
Expand All @@ -257,7 +266,7 @@ def __init__(self,
super(ResHedNet, self).__init__()
nbl = kwargs.get("layers", [3, 4, 5])
self.upsample = upsampling_mode
self.net1 = ResModule(2, nbl[0], 1, nb_filters, True)
self.net1 = ResModule(2, nbl[0], n_channels, nb_filters, True)
self.net2 = nn.Sequential(
nn.MaxPool2d(2, 2),
ResModule(2, nbl[1], nb_filters, 2*nb_filters, True)
Expand Down Expand Up @@ -302,6 +311,8 @@ class SegResNet(nn.Module):
with residual blocks for semantic segmentation

Args:
n_channels:
Number of channels in the input image
nb_classes:
Number of classes in the ground truth
nb_filters:
Expand All @@ -321,6 +332,7 @@ class SegResNet(nn.Module):

'''
def __init__(self,
n_channels: int = 1,
nb_classes: int = 1,
nb_filters: int = 32,
batch_norm: bool = True,
Expand All @@ -333,7 +345,7 @@ def __init__(self,
super(SegResNet, self).__init__()
nbl = kwargs.get("layers", [2, 2, 2])
self.c1 = ConvBlock(
2, 1, 1, nb_filters, batch_norm=batch_norm
2, 1, n_channels, nb_filters, batch_norm=batch_norm
)
self.c2 = ResModule(
2, nbl[0], nb_filters, nb_filters*2, batch_norm=batch_norm
Expand Down Expand Up @@ -386,12 +398,14 @@ def init_fcnn_model(model: Union[Type[nn.Module], str],
meta_state_dict = {
'model_type': 'Seg', model: 'custom', 'nb_classes': nb_classes}
return model, meta_state_dict
n_channels = kwargs.get('n_channels', 1)
batch_norm = kwargs.get('batch_norm', True)
dropout = kwargs.get('dropout', False)
upsampling = kwargs.get('upsampling', "bilinear")
meta_state_dict = {
'model_type': 'seg',
'model': model,
'n_channels': n_channels,
'nb_classes': nb_classes,
'batch_norm': batch_norm,
'dropout': dropout,
Expand All @@ -402,7 +416,7 @@ def init_fcnn_model(model: Union[Type[nn.Module], str],
nb_filters = kwargs.get('nb_filters', 16)
layers = kwargs.get("layers", [1, 2, 2, 3])
net = Unet(
nb_classes, nb_filters, dropout,
n_channels, nb_classes, nb_filters, dropout,
batch_norm, upsampling, with_dilation,
layers=layers
)
Expand All @@ -411,22 +425,22 @@ def init_fcnn_model(model: Union[Type[nn.Module], str],
nb_filters = kwargs.get('nb_filters', 25)
layers = kwargs.get("layers", [1, 3, 3, 1])
net = dilnet(
nb_classes, nb_filters,
n_channels, nb_classes, nb_filters,
dropout, batch_norm, upsampling,
layers=layers
)
elif isinstance(model, str) and model == 'SegResNet':
nb_filters = kwargs.get('nb_filters', 32)
layers = kwargs.get("layers", [2, 2, 2])
net = SegResNet(
nb_classes, nb_filters,
n_channels, nb_classes, nb_filters,
batch_norm, upsampling, layers=layers
)
elif isinstance(model, str) and model == 'ResHedNet':
nb_filters = kwargs.get('nb_filters', 64)
layers = kwargs.get("layers", [3, 4, 5])
net = ResHedNet(
nb_classes, nb_filters,
n_channels, nb_classes, nb_filters,
upsampling, layers=layers
)
else:
Expand Down
68 changes: 50 additions & 18 deletions atomai/utils/graphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class Graph:
"""

def __init__(self, coordinates: np.ndarray,
map_dict: Dict) -> None:
map_dict: Dict,
px2ang: float = 1) -> None:
"""
Initializes a graph object
"""
Expand All @@ -76,6 +77,8 @@ def __init__(self, coordinates: np.ndarray,
v = Node(i, coords[:-1].tolist(), map_dict[coords[-1]])
self.vertices.append(v)
self.coordinates = coordinates
self.coordinates_ang = deepcopy(coordinates)
self.coordinates_ang[:, :-1] = self.coordinates[:, :-1] * px2ang
self.map_dict = map_dict
self.size = len(coordinates)
self.rings = []
Expand All @@ -87,6 +90,10 @@ def find_neighbors(self, **kwargs: float):
Identifies neighbors of each graph node

Args:
**max_neighbors(int):
This is the maximum number of neighbors each node can have,
ususally used to form the graph with only nearest neighbors
Default is -1 which means it will find all the neighbors
**expand (float):
coefficient determining the maximum allowable expansion of
atomic bonds when constructing a graph. For example, the two
Expand All @@ -97,34 +104,59 @@ def find_neighbors(self, **kwargs: float):
del v.neighbors[:]
Rij = get_interatomic_r
e = kwargs.get("expand", 1.2)
tree = spatial.cKDTree(self.coordinates[:, :3])
uval = np.unique(self.coordinates[:, -1])
max_neighbors = kwargs.get("max_neighbors", -1)
tree = spatial.cKDTree(self.coordinates_ang[:, :3])
uval = np.unique(self.coordinates_ang[:, -1])
if len(uval) == 1:
rmax = Rij([self.map_dict[uval[0]], self.map_dict[uval[0]]], e)
neighbors = tree.query_ball_point(self.coordinates[:, :3], r=rmax)
if max_neighbors == -1:
neighbors = tree.query_ball_point(self.coordinates_ang[:, :3], r=rmax)
else:
_, neighbors = tree.query(self.coordinates_ang[:, :3], k=max_neighbors+1, distance_upper_bound = rmax)
for v, nn in zip(self.vertices, neighbors):
for n in nn:
if self.vertices[n] != v:
v.neighbors.append(self.vertices[n])
v.neighborscopy.append(self.vertices[n])
if not n >= len(self.vertices):
if self.vertices[n] != v:
v.neighbors.append(self.vertices[n])
v.neighborscopy.append(self.vertices[n])

else:
uval = [self.map_dict[u] for u in uval]
apairs = [(p[0], p[1]) for p in itertools.product(uval, repeat=2)]
rij = [Rij([a[0], a[1]], e) for a in apairs]
rmax = np.max(rij)
rij = dict(zip(apairs, rij))
for v, coords in zip(self.vertices, self.coordinates):
for v, coords in zip(self.vertices, self.coordinates_ang):
atom1 = self.map_dict[coords[-1]]
nn = tree.query_ball_point(coords[:3], r=rmax)
for n, coords2 in zip(nn, self.coordinates[nn]):
if self.vertices[n] != v:
atom2 = self.map_dict[coords2[-1]]
eucldist = np.linalg.norm(
coords[:3] - coords2[:3])
if eucldist <= rij[(atom1, atom2)]:
v.neighbors.append(self.vertices[n])
v.neighborscopy.append(self.vertices[n])

if max_neighbors == -1:
nn = tree.query_ball_point(coords[:3], r=rmax)
else:
_, nn = tree.query(coords[:3], k=max_neighbors+1, distance_upper_bound = rmax)

for n in nn:
if not n >= len(self.vertices):
coords2 = self.coordinates_ang[n]
if self.vertices[n] != v:
atom2 = self.map_dict[coords2[-1]]
eucldist = np.linalg.norm(
coords[:3] - coords2[:3])
if eucldist <= rij[(atom1, atom2)]:
v.neighbors.append(self.vertices[n])
v.neighborscopy.append(self.vertices[n])

#Making the graph symmetric when max_neighbors is used
for v in self.vertices:
id = v.id
rem_ids = []
for nn in v.neighbors:
nn_neighbors_list = [nn.neighbors[l].id for l in range(len(nn.neighbors))]
if id not in nn_neighbors_list:
rem_ids.append(nn)

for rem_id in rem_ids:
v.neighbors.remove(rem_id)


def find_rings(self,
v: Type[Node],
rings: List[List[Type[Node]]] = [],
Expand Down