diff --git a/decoupler/plotting.py b/decoupler/plotting.py index b49a358..bf09b9a 100644 --- a/decoupler/plotting.py +++ b/decoupler/plotting.py @@ -1768,17 +1768,22 @@ def get_obs_act_net(act, obs, net, n_sources, n_targets, by_abs): def add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap): mpl = check_if_matplotlib(return_mpl=True) - s_cmap = mpl.colormaps.get_cmap(s_cmap) - t_cmap = mpl.colormaps.get_cmap(t_cmap) - - color = [] - for i, k in enumerate(g.vs['label']): - if g.vs['type'][i]: - color.append(t_cmap(t_norm(obs[k].values[0]))) - else: - color.append(s_cmap(s_norm(act[k].values[0]))) - + cmaps = mpl.colormaps.keys() + if (s_cmap in cmaps) and (t_cmap in cmaps): + s_cmap = mpl.colormaps.get_cmap(s_cmap) + t_cmap = mpl.colormaps.get_cmap(t_cmap) + color = [] + for i, k in enumerate(g.vs['label']): + if g.vs['type'][i]: + color.append(t_cmap(t_norm(obs[k].values[0]))) + else: + color.append(s_cmap(s_norm(act[k].values[0]))) + is_cmap = True + else: + color = [s_cmap if typ == 0. else t_cmap for typ in g.vs['type']] + is_cmap = False g.vs['color'] = color + return is_cmap def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_size=0.5, label_size=5, s_cmap='RdBu_r', @@ -1806,9 +1811,9 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz label_size : int Size of the labels in the plot. s_cmap : str - Colormap to use to color sources. + Color or colormap to use to color sources. t_cmap : str - Colormap to use to color targets. + Color or colormap to use to color targets. vcenter : bool Whether to center colors around 0. c_pos_w : str @@ -1854,7 +1859,7 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz # Get graph g = get_g(fact, fobs, fnet) g.es['color'] = [c_pos_w if w > 0 else c_neg_w for w in g.es['weight']] - add_colors(g, fact, fobs, s_norm, t_norm, s_cmap, t_cmap) + is_cmap = add_colors(g, fact, fobs, s_norm, t_norm, s_cmap, t_cmap) # Build figure fig = plt.figure(figsize=figsize, dpi=dpi, tight_layout=True) @@ -1869,9 +1874,12 @@ def plot_network(obs, act, net, n_sources=5, n_targets=10, by_abs=True, node_siz vertex_size=node_size, vertex_label_size=label_size ) - - fig.colorbar(mpl.cm.ScalarMappable(norm=s_norm, cmap=s_cmap), cax=ax2, orientation="horizontal", label=s_label) - fig.colorbar(mpl.cm.ScalarMappable(norm=t_norm, cmap=t_cmap), cax=ax3, orientation="horizontal", label=t_label) + if is_cmap: + fig.colorbar(mpl.cm.ScalarMappable(norm=s_norm, cmap=s_cmap), cax=ax2, orientation="horizontal", label=s_label) + fig.colorbar(mpl.cm.ScalarMappable(norm=t_norm, cmap=t_cmap), cax=ax3, orientation="horizontal", label=t_label) + else: + ax2.axis("off") + ax3.axis("off") save_plot(fig, None, save) diff --git a/decoupler/tests/test_plotting.py b/decoupler/tests/test_plotting.py index e995155..f76b5e9 100644 --- a/decoupler/tests/test_plotting.py +++ b/decoupler/tests/test_plotting.py @@ -467,10 +467,15 @@ def test_add_colors(): s_norm = get_norm(act, vcenter=False) t_norm = get_norm(obs, vcenter=False) s_cmap, t_cmap = 'RdBu_r', 'viridis' - r = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap) - assert r is None + is_cmap = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap) + assert is_cmap assert (g.vs['color'][0][0] > 0.82) & (g.vs['color'][0][1] < 0.38) & (g.vs['color'][0][2] < 0.31) assert (g.vs['color'][-1][0] > 0.98) & (g.vs['color'][-1][1] > 0.89) & (g.vs['color'][-1][2] < 0.15) + s_cmap, t_cmap = 'red', 'blue' + is_cmap = add_colors(g, act, obs, s_norm, t_norm, s_cmap, t_cmap) + assert not is_cmap + assert g.vs['color'][0] == 'red' + assert g.vs['color'][-1] == 'blue' def test_plot_network(): @@ -484,3 +489,4 @@ def test_plot_network(): ['N3', 'N4'], ], columns=['source', 'target']) plot_network(obs, act, net, figsize=(3, 3), node_size=0.25) + plot_network(obs, act, net, figsize=(3, 3), node_size=0.25, s_cmap='red', t_cmap='blue')