Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fix on color handling for showProjection #1070

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 52 additions & 42 deletions prody/dynamics/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
and keyword arguments are passed to the Matplotlib functions."""

from collections import defaultdict

from numbers import Number
import numpy as np

from prody import LOGGER, SETTINGS, PY3K
from prody.utilities import showFigure, addEnds, showMatrix
from prody.utilities import showFigure, addEnds, showMatrix, isListLike
from prody.atomic import AtomGroup, Selection, Atomic, sliceAtoms, sliceAtomicData

from .nma import NMA
Expand Down Expand Up @@ -218,13 +219,20 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
Default is to use ensemble.getData('size')
:type weights: int, list, :class:`~numpy.ndarray`

:keyword color: a color name or a list of color names or values,
:keyword color: a color name or value or a list of length ensemble.numConfs() or projection.shape[0] of these,
or a dictionary with these with keys corresponding to labels provided by keyword label
default is ``'blue'``
Color values can have 1 element to be mapped with cmap or 3 as RGB or 4 as RGBA.
See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def
:type color: str, list

:keyword label: label or a list of labels
:type label: str, list

:keyword use_labels: whether to use labels for coloring subsets.
These can also be taken from an LDA or LRA model.
:type use_labels: bool

:keyword marker: a marker or a list of markers, default is ``'o'``
:type marker: str, list

Expand Down Expand Up @@ -278,31 +286,17 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
if labels is None and use_labels and modes is not None:
if isinstance(modes, (LDA, LRA)):
labels = modes._labels.tolist()
LOGGER.info('using labels from LDA modes')
LOGGER.info('using labels from {0} modes'.format(type(modes)))
elif isinstance(modes.getModel(), (LDA, LRA)):
labels = modes.getModel()._labels.tolist()
LOGGER.info('using labels from LDA model')
LOGGER.info('using labels from {0} modes'.format(type(modes.getModel())))

if labels is not None and len(labels) != num:
raise ValueError('label should have the same length as ensemble')

c = kwargs.pop('c', 'b')
colors = kwargs.pop('color', c)
colors_dict = {}
if isinstance(colors, np.ndarray):
colors = tuple(colors)
if isinstance(colors, (str, tuple)) or colors is None:
colors = [colors] * num
elif isinstance(colors, list):
if len(colors) != num:
raise ValueError('length of color must be {0}'.format(num))
elif isinstance(colors, dict):
if labels is None:
raise TypeError('color must be a string or a list unless labels are provided')
colors_dict = colors
colors = [colors_dict[label] for label in labels]
else:
raise TypeError('color must be a string or a list or a dict if labels are provided')
colors, colors_dict = checkColors(colors, num, labels, allowNumbers=True)

if labels is not None and len(colors_dict) == 0:
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
Expand Down Expand Up @@ -507,7 +501,8 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):
:keyword scalar: scalar factor for projection onto selected mode
:type scalar: float

:keyword color: a color name or a list of color name, default is ``'blue'``
:keyword color: a color spec or a list of color specs, default is ``'blue'``
See https://matplotlib.org/stable/users/explain/colors/colors.html#colors-def
:type color: str, list

:keyword label: label or a list of labels
Expand Down Expand Up @@ -556,13 +551,6 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):
raise TypeError('marker must be a string or a list')

colors = kwargs.pop('color', 'blue')
if isinstance(colors, str) or colors is None:
colors = [colors] * num
elif isinstance(colors, list):
if len(colors) != num:
raise ValueError('length of color must be {0}'.format(num))
else:
raise TypeError('color must be a string or a list')

labels = kwargs.pop('label', None)
if isinstance(labels, str) or labels is None:
Expand All @@ -575,21 +563,7 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):

kwargs['ls'] = kwargs.pop('linestyle', None) or kwargs.pop('ls', 'None')

colors_dict = {}
if isinstance(colors, np.ndarray):
colors = tuple(colors)
if isinstance(colors, (str, tuple)) or colors is None:
colors = [colors] * num
elif isinstance(colors, list):
if len(colors) != num:
raise ValueError('length of color must be {0}'.format(num))
elif isinstance(colors, dict):
if labels is None:
raise TypeError('color must be a string or a list unless labels are provided')
colors_dict = colors
colors = [colors_dict[label] for label in labels]
else:
raise TypeError('color must be a string or a list or a dict if labels are provided')
colors, colors_dict = checkColors(colors, num, labels)

if labels is not None and len(colors_dict) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this section can go away and checkColors doesn't need to return color_dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a case where we need colors_dict on line 317 where we make a line graph. We'll have to find a way to adjust that.

cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
Expand Down Expand Up @@ -2381,3 +2355,39 @@ def showTree_networkx(tree, node_size=20, node_color='red', node_shape='o',
showFigure()

return mpl.gca()


def checkColors(colors, num, labels, allowNumbers=False):
"""Check colors and process them if needed"""

from matplotlib.colors import is_color_like

colors_dict = {}

if is_color_like(colors) or colors is None or isinstance(colors, Number):
Copy link
Collaborator

@SHZ66 SHZ66 Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should colors is None be part of this? Wouldn't a list of None's fail the check down below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it would fail so I added it down there too, because we do want to allow it

colors = [colors] * num
elif isListLike(colors):
colors = list(colors)
elif isinstance(colors, dict):
if labels is None:
raise TypeError('color cannot be a dict unless labels are provided')
colors_dict = colors
colors = [colors_dict[label] for label in labels]

if isinstance(colors, list):
if len(colors) != num:
raise ValueError('colors should have the length of the set to be colored or satisfy matplotlib color rules')

for color in colors:
if not is_color_like(color) and color is not None:
if not allowNumbers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if allowNumbers, you may need to convert the number to the color in cycle like this:

cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color = cycle_colors[color % len(cycle_colors)]

Copy link
Collaborator

@SHZ66 SHZ66 Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the entire loop may look like this:

if isinstance(colors, list):
      if len(colors) != num:
          raise ValueError('colors should have the length of the set to be colored or satisfy matplotlib color rules')

      for i, color in enumerate(colors):
          if not is_color_like(color):
              if allowNumbers:
                  cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
                  colors[i] = cycle_colors[color % len(cycle_colors)]
              else:
                  raise ValueError('each element of colors should satisfy matplotlib color rules')
       ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a conversion of the numbers with the color cycle somewhere else, but yes, we could maybe move it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we use matplotlib.colors.Normalize instead to link up with the cmap. This is on lines 356 and 428

This could somehow be incorporated too though

raise ValueError('each element of colors should satisfy matplotlib color rules')
elif not isinstance(color, Number):
raise ValueError('each element of colors should be a number or satisfy matplotlib color rules')

if not isinstance(color, type(colors[0])):
raise TypeError('each element of colors should have the same type')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? Couldn't the matplotlib function handle colors defined in different ways in the same list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I guess it probably could

else:
raise TypeError('colors should be a color spec or convertible to a list of color specs')

return colors, colors_dict
Loading