Skip to content

Commit

Permalink
added flat colors to plot_network
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM committed Nov 23, 2023
1 parent d65f7c1 commit b6b430e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
40 changes: 24 additions & 16 deletions decoupler/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,17 +1768,22 @@ def get_obs_act_net(act, obs, net, n_sources, n_targets, by_abs):
def add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap):

mpl = check_if_matplotlib(return_mpl=True)
s_cmap = mpl.colormaps.get_cmap(s_cmap)
t_cmap = mpl.colormaps.get_cmap(t_cmap)

color = []
for i, k in enumerate(g.vs['label']):
if g.vs['type'][i]:
color.append(t_cmap(t_norm(obs[k].values[0])))
else:
color.append(s_cmap(s_norm(act[k].values[0])))

cmaps = mpl.colormaps.keys()
if (s_cmap in cmaps) and (t_cmap in cmaps):
s_cmap = mpl.colormaps.get_cmap(s_cmap)
t_cmap = mpl.colormaps.get_cmap(t_cmap)
color = []
for i, k in enumerate(g.vs['label']):
if g.vs['type'][i]:
color.append(t_cmap(t_norm(obs[k].values[0])))
else:
color.append(s_cmap(s_norm(act[k].values[0])))
is_cmap = True
else:
color = [s_cmap if typ == 0. else t_cmap for typ in g.vs['type']]
is_cmap = False
g.vs['color'] = color
return is_cmap


def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_size=0.5, label_size=5, s_cmap='RdBu_r',
Expand Down Expand Up @@ -1806,9 +1811,9 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz
label_size : int
Size of the labels in the plot.
s_cmap : str
Colormap to use to color sources.
Color or colormap to use to color sources.
t_cmap : str
Colormap to use to color targets.
Color or colormap to use to color targets.
vcenter : bool
Whether to center colors around 0.
c_pos_w : str
Expand Down Expand Up @@ -1854,7 +1859,7 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz
# Get graph
g = get_g(fact, fobs, fnet)
g.es['color'] = [c_pos_w if w > 0 else c_neg_w for w in g.es['weight']]
add_colors(g, fact, fobs, s_norm, t_norm, s_cmap, t_cmap)
is_cmap = add_colors(g, fact, fobs, s_norm, t_norm, s_cmap, t_cmap)

# Build figure
fig = plt.figure(figsize=figsize, dpi=dpi, tight_layout=True)
Expand All @@ -1869,9 +1874,12 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz
vertex_size=node_size,
vertex_label_size=label_size
)

fig.colorbar(mpl.cm.ScalarMappable(norm=s_norm, cmap=s_cmap), cax=ax2, orientation="horizontal", label=s_label)
fig.colorbar(mpl.cm.ScalarMappable(norm=t_norm, cmap=t_cmap), cax=ax3, orientation="horizontal", label=t_label)
if is_cmap:
fig.colorbar(mpl.cm.ScalarMappable(norm=s_norm, cmap=s_cmap), cax=ax2, orientation="horizontal", label=s_label)
fig.colorbar(mpl.cm.ScalarMappable(norm=t_norm, cmap=t_cmap), cax=ax3, orientation="horizontal", label=t_label)
else:
ax2.axis("off")
ax3.axis("off")

save_plot(fig, None, save)

Expand Down
10 changes: 8 additions & 2 deletions decoupler/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,15 @@ def test_add_colors():
s_norm = get_norm(act, vcenter=False)
t_norm = get_norm(obs, vcenter=False)
s_cmap, t_cmap = 'RdBu_r', 'viridis'
r = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap)
assert r is None
is_cmap = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap)
assert is_cmap
assert (g.vs['color'][0][0] > 0.82) & (g.vs['color'][0][1] < 0.38) & (g.vs['color'][0][2] < 0.31)
assert (g.vs['color'][-1][0] > 0.98) & (g.vs['color'][-1][1] > 0.89) & (g.vs['color'][-1][2] < 0.15)
s_cmap, t_cmap = 'red', 'blue'
is_cmap = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap)
assert not is_cmap
assert g.vs['color'][0] == 'red'
assert g.vs['color'][-1] == 'blue'


def test_plot_network():
Expand All @@ -484,3 +489,4 @@ def test_plot_network():
['N3', 'N4'],
], columns=['source', 'target'])
plot_network(obs, act, net, figsize=(3, 3), node_size=0.25)
plot_network(obs, act, net, figsize=(3, 3), node_size=0.25, s_cmap='red', t_cmap='blue')

0 comments on commit b6b430e

Please sign in to comment.