-
Notifications
You must be signed in to change notification settings - Fork 3
/
visualize.py
36 lines (31 loc) · 1.2 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import io
import PIL
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from torchvision.transforms import ToTensor
plt.switch_backend('agg')
cmap = ListedColormap(['grey', 'white', 'red', 'blue', 'green', 'white'])
def draw_stack(density, atom_type=None, atom_coord=None, dim=-1):
"""
Draw a 2D density map along specific axis.
:param density: density data, tensor of shape (batch_size, nx, ny, nz)
:param atom_type: atom types, tensor of shape (batch_size, n_atom)
:param atom_coord: atom coordinates, tensor of shape (batch_size, n_atom, 3)
:param dim: axis along which to sum
:return: an image tensor
"""
plt.figure(figsize=(3, 3))
plt.imshow(density.sum(dim).detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
if atom_type is not None:
idx = [i for i in range(3) if i != dim % 3]
coord = atom_coord.detach().cpu().numpy()
color = cmap(atom_type.detach().cpu().numpy())
plt.scatter(coord[:, idx[1]], coord[:, idx[0]], c=color, alpha=0.8)
buf = io.BytesIO()
plt.savefig(buf, format='jpg')
buf.seek(0)
image = PIL.Image.open(buf)
image = ToTensor()(image)
plt.close()
return image