Skip to content

Commit

Permalink
adjust the _adjust_axes_labels to be more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
MAfarrag committed Aug 28, 2024
1 parent 1ba0f90 commit 085404a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 53 deletions.
104 changes: 54 additions & 50 deletions statista/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,46 @@ def _get_ax_fig(
return fig, ax

@staticmethod
def _adjust_axes_labels(
ax: Axes, tick_labels: List[str], title: str, xlabel: str, ylabel: str
):
def _adjust_axes_labels(ax: Axes, tick_labels: List[str] = None, **kwargs):
"""Adjust the labels of the axes."""
ax.set_xticklabels(tick_labels)
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if tick_labels is not None:
ax.set_xticklabels(tick_labels)

ax.set_title(
kwargs.get("title"),
fontsize=kwargs.get("title_fontsize", 18),
fontweight="bold",
)
ax.set_xlabel(
kwargs.get("xlabel"),
fontsize=kwargs.get("xlabel_fontsize", 14),
)
ax.set_ylabel(
kwargs.get("ylabel"),
fontsize=kwargs.get("ylabel_fontsize", 14),
)

ax.grid(
kwargs.get("grid", True),
axis=kwargs.get("grid_axis", "both"),
linestyle=kwargs.get("grid_line_style", "-."),
linewidth=kwargs.get("grid_line_width", 0.3),
)

# Customize ticks and their labels
ax.tick_params(
axis="both", which="major", labelsize=kwargs.get("tick_fontsize", 12)
)

# Add a legend if needed
if "legend" in kwargs:
ax.legend(
[kwargs.get("legend")], fontsize=kwargs.get("legend_fontsize", 12)
)

# Adjust layout for better spacing
plt.tight_layout()

return ax

def box_plot(
Expand Down Expand Up @@ -251,6 +283,8 @@ def box_plot(
:align: center
"""
fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax"))
kwargs.pop("fig", None)
kwargs.pop("ax", None)
color = kwargs.get("color", None)
data = [self[col].dropna() for col in self.columns]
ax.boxplot(
Expand All @@ -268,12 +302,9 @@ def box_plot(
ax = self._adjust_axes_labels(
ax,
self.columns,
kwargs.get("title"),
kwargs.get("xlabel"),
kwargs.get("ylabel"),
**kwargs,
)

ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)
plt.show()
return fig, ax

Expand Down Expand Up @@ -402,7 +433,8 @@ def violin(
:align: center
"""
fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax"))

kwargs.pop("fig", None)
kwargs.pop("ax", None)
# positions where violins are plotted (1, 3, 5, ...)ing labels
positions = np.arange(1, len(self.columns) * (spacing + 1) + 1, spacing + 1)

Expand All @@ -425,12 +457,9 @@ def violin(
ax = self._adjust_axes_labels(
ax,
self.columns,
kwargs.get("title"),
kwargs.get("xlabel"),
kwargs.get("ylabel"),
**kwargs,
)

ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)
plt.show()
return fig, ax

Expand Down Expand Up @@ -503,6 +532,8 @@ def raincloud(
>>> fig, ax = ts_2d.raincloud(mean=True, grid=True)
"""
fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax"))
kwargs.pop("fig", None)
kwargs.pop("ax", None)
if order is None:
order = ["violin", "scatter", "box"]

Expand Down Expand Up @@ -574,13 +605,9 @@ def raincloud(
ax = self._adjust_axes_labels(
ax,
self.columns,
kwargs.get("title"),
kwargs.get("xlabel"),
kwargs.get("ylabel"),
**kwargs,
)

ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)

# Add grid lines for better readability
# ax.yaxis.grid(True)

Expand Down Expand Up @@ -637,6 +664,8 @@ def histogram(self, bins=10, **kwargs) -> Tuple[Figure, Axes]:
# plt.style.use('ggplot')

fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax"))
kwargs.pop("fig", None)
kwargs.pop("ax", None)
color = kwargs.get("color") if "color" in kwargs else VIOLIN_PROP
ax.hist(
self.values,
Expand All @@ -646,36 +675,11 @@ def histogram(self, bins=10, **kwargs) -> Tuple[Figure, Axes]:
alpha=color.get("alpha"),
)

# Set title and labels with larger font sizes
ax.set_title(
kwargs.get("title", "Histogram"),
fontsize=kwargs.get("title_fontsize", 18),
fontweight="bold",
)
ax.set_xlabel(
kwargs.get("xlabel", "X-axis Label"),
fontsize=kwargs.get("label_fontsize", 14),
)
ax.set_ylabel(
kwargs.get("ylabel", "Y-axis Label"),
fontsize=kwargs.get("label_fontsize", 14),
)

ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)

# Customize ticks and their labels
ax.tick_params(
axis="both", which="major", labelsize=kwargs.get("tick_fontsize", 12)
ax = self._adjust_axes_labels(
ax,
self.columns,
**kwargs,
)

# Add a legend if needed
if "legend" in kwargs:
ax.legend(
[kwargs.get("legend")], fontsize=kwargs.get("legend_fontsize", 12)
)

# Adjust layout for better spacing
plt.tight_layout()

plt.show()
return fig, ax
6 changes: 3 additions & 3 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def test_default(self):
# Test with default parameters
ts = TimeSeries(np.array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4]))
fig, ax = ts.histogram()
assert ax.get_title() == "Histogram"
assert ax.get_xlabel() == "X-axis Label"
assert ax.get_ylabel() == "Y-axis Label"
assert ax.get_title() == ""
assert ax.get_xlabel() == ""
assert ax.get_ylabel() == ""
plt.close()

def test_custom_labels(self):
Expand Down

0 comments on commit 085404a

Please sign in to comment.