Skip to content

Commit

Permalink
Update the function names for wren utils given breaking changes in Av…
Browse files Browse the repository at this point in the history
…iary (#182)

* fea: rename wren.utils functions

* breaking: rename styled_html_tag to html_tag and add kwarg title: str = ""

* make bs4 + cssutils call-time rather than import-time deps of df_to_svg

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
CompRhys and janosh authored Jul 22, 2024
1 parent 19fe1c9 commit d60f3bd
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# %%
import plotly.express as px
from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib
from aviary.wren.utils import (
count_wyckoff_positions,
get_protostructure_label_from_spglib,
)
from matminer.datasets import load_dataset
from tqdm import tqdm

Expand Down Expand Up @@ -40,7 +43,7 @@
]

df_diel[Key.wyckoff] = [
get_aflow_label_from_spglib(struct)
get_protostructure_label_from_spglib(struct)
for struct in tqdm(df_diel[Key.structure], desc="Getting Wyckoff strings")
]
df_diel[Key.n_wyckoff] = df_diel.wyckoff.map(count_wyckoff_positions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import numpy as np
import plotly.express as px
from aviary.wren.utils import count_wyckoff_positions, get_aflow_label_from_spglib
from aviary.wren.utils import (
count_wyckoff_positions,
get_protostructure_label_from_spglib,
)
from matminer.datasets import load_dataset
from pymatgen.core import Structure
from tqdm import tqdm
Expand Down Expand Up @@ -47,7 +50,7 @@
]

df_grvh[Key.wyckoff] = [
get_aflow_label_from_spglib(struct)
get_protostructure_label_from_spglib(struct)
for struct in tqdm(
df_grvh[Key.structure], desc="Getting matbench_log_gvrh Wyckoff strings"
)
Expand Down
4 changes: 2 additions & 2 deletions examples/matbench_perovskites_eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import plotly.io as pio\n",
"\n",
"# see https://github.com/CompRhys/aviary/blob/main/aviary/wren/utils.py\n",
"from aviary.wren.utils import get_aflow_label_from_spglib\n",
"from aviary.wren.utils import get_protostructure_label_from_spglib\n",
"from matminer.datasets import load_dataset\n",
"from tqdm import tqdm\n",
"\n",
Expand Down Expand Up @@ -222,7 +222,7 @@
"# takes ~6h (when running uninterrupted)\n",
"for idx, struct in tqdm(df_perov[Key.structure].items(), total=len(df_perov)):\n",
" if pd.isna(df_perov.aflow_wyckoff[idx]):\n",
" df_perov.loc[idx, \"aflow_wyckoff\"] = get_aflow_label_from_spglib(\n",
" df_perov.loc[idx, \"aflow_wyckoff\"] = get_protostructure_label_from_spglib(\n",
" struct,\n",
" \"/path/to/aflow\", # defaults to aflow\n",
" )"
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
set_plotly_template,
)
from pymatviz.uncertainty import error_decay_with_uncert, qq_gaussian
from pymatviz.utils import PKG_DIR, ROOT, styled_html_tag
from pymatviz.utils import PKG_DIR, ROOT, html_tag
from pymatviz.xrd import plot_xrd_pattern


Expand Down
31 changes: 15 additions & 16 deletions pymatviz/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum, unique
from typing import TYPE_CHECKING

from pymatviz.utils import styled_html_tag
from pymatviz.utils import html_tag


if TYPE_CHECKING:
Expand Down Expand Up @@ -103,21 +103,20 @@ def __reduce_ex__(self, proto: object) -> tuple[type, tuple[str]]:
return str, (self.value,)


small_font = "font-size: 0.9em; font-weight: lighter;"
eV_per_atom = styled_html_tag("(eV/atom)", style=small_font) # noqa: N816
eV = styled_html_tag("(eV)", style=small_font) # noqa: N816
eV_per_angstrom = styled_html_tag("(eV/Å)", style=small_font) # noqa: N816
eV_per_kelvin = styled_html_tag("(eV/K)", style=small_font) # noqa: N816
angstrom = styled_html_tag("(Å)", style=small_font)
angstrom_per_atom = styled_html_tag("(Å/atom)", style=small_font)
cubic_angstrom = styled_html_tag("(Å<sup>3</sup>)", style=small_font)
gram_per_cm3 = styled_html_tag("(g/cm³)", style=small_font)
kelvin = styled_html_tag("(K)", style=small_font)
pascal = styled_html_tag("(Pa)", style=small_font)
giga_pascal = styled_html_tag("(GPa)", style=small_font)
joule = styled_html_tag("(J)", style=small_font)
joule_per_mol = styled_html_tag("(J/mol)", style=small_font)
joule_per_m2 = styled_html_tag("(J/m²)", style=small_font)
eV_per_atom = html_tag("(eV/atom)", style="small") # noqa: N816
eV = html_tag("(eV)", style="small") # noqa: N816
eV_per_angstrom = html_tag("(eV/Å)", style="small") # noqa: N816
eV_per_kelvin = html_tag("(eV/K)", style="small") # noqa: N816
angstrom = html_tag("(Å)", style="small")
angstrom_per_atom = html_tag("(Å/atom)", style="small")
cubic_angstrom = html_tag("(Å<sup>3</sup>)", style="small")
gram_per_cm3 = html_tag("(g/cm³)", style="small")
kelvin = html_tag("(K)", style="small")
pascal = html_tag("(Pa)", style="small")
giga_pascal = html_tag("(GPa)", style="small")
joule = html_tag("(J)", style="small")
joule_per_mol = html_tag("(J/mol)", style="small")
joule_per_m2 = html_tag("(J/m²)", style="small")


@unique
Expand Down
14 changes: 7 additions & 7 deletions pymatviz/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Final, Literal

import cssutils
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from bs4 import BeautifulSoup
from matplotlib import lines as mlines
from matplotlib import patches as mpatches
from matplotlib.backends.backend_agg import RendererAgg
Expand Down Expand Up @@ -467,16 +465,19 @@ def df_to_svg(
Raises:
subprocess.CalledProcessError: If SVG compression fails.
"""
import bs4
import cssutils

# TODO find a way to not have to hardcode these values
fig_width, fig_height, dpi = 20, 4, 72 # Using dpi=72 as a standard value

def parse_html(html: str) -> tuple[list[list[list[str | bool | int]]], int]:
html = html.replace("<br>", "\n")
soup = BeautifulSoup(html, features="lxml")
soup = bs4.BeautifulSoup(html, features="lxml")
style = soup.find("style")
sheet = cssutils.parseString(style.text) if style else []

def get_style_prop(element: BeautifulSoup, prop_name: str) -> str | None:
def get_style_prop(element: bs4.element.Tag, prop_name: str) -> str | None:
style = element.get("style", "").lower()
if prop_name in style:
return style.split(f"{prop_name}:")[1].split(";")[0].strip()
Expand Down Expand Up @@ -559,7 +560,7 @@ def print_table(
col_widths: list[float],
row_heights: list[float],
) -> Figure:
row_colors = ["#f5f5f5", "#ffffff"]
row_colors = ("#f5f5f5", "#ffffff")
padding = font_size / (fig_width * dpi) * 0.5
total_width = sum(col_widths)
fig_height = fig.get_figheight()
Expand All @@ -586,7 +587,7 @@ def print_table(
for xd, val in zip(col_widths, row):
text, weight, ha, bg_color, fg_color = val[:5]

if bg_color != "#ffffff":
if bg_color != row_colors[1]:
rect_bg = mpatches.Rectangle(
(x_i, y_i),
width=xd,
Expand Down Expand Up @@ -627,7 +628,6 @@ def print_table(

return fig

# Main logic
html = obj.to_html() if isinstance(obj, Styler) else obj.to_html(notebook=True)
text_fig = Figure()
renderer = RendererAgg(fig_width, fig_height, dpi)
Expand Down
26 changes: 21 additions & 5 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def get_cbar_label_formatter(
return FormatStrFormatter(f"%{cbar_label_fmt}")


def styled_html_tag(text: str, tag: str = "span", style: str = "") -> str:
def html_tag(text: str, tag: str = "span", style: str = "", title: str = "") -> str:
"""Wrap text in a span with custom style.
Style defaults to decreased font size and weight e.g. to display units
Expand All @@ -427,11 +427,27 @@ def styled_html_tag(text: str, tag: str = "span", style: str = "") -> str:
Args:
text (str): Text to wrap in span.
tag (str, optional): HTML tag name. Defaults to "span".
style (str, optional): CSS style string. Defaults to
"font-size: 0.8em; font-weight: lighter;".
style (str, optional): CSS style string. Defaults to "". Special keys:
"small": font-size: 0.8em; font-weight: lighter;
"bold": font-weight: bold;
"italic": font-style: italic;
"underline": text-decoration: underline;
title (str | None, optional): Title attribute which displays additional
information in a tooltip. Defaults to "".
Returns:
str: HTML string with tag-wrapped text.
"""
style = style or "font-size: 0.8em; font-weight: lighter;"
return f"<{tag} {style=}>{text}</{tag}>"
style = {
"small": "font-size: 0.8em; font-weight: lighter;",
"bold": "font-weight: bold;",
"italic": "font-style: italic;",
"underline": "text-decoration: underline;",
}.get(style, style)
attr_str = f" {title=}" if title else ""
if style:
attr_str += f" {style=}"
return f"<{tag}{attr_str}>{text}</{tag}>"


def validate_fig(func: Callable[P, R]) -> Callable[P, R]:
Expand Down
23 changes: 14 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
get_cbar_label_formatter,
get_fig_xy_range,
get_font_color,
html_tag,
luminance,
patch_dict,
pick_bw_for_contrast,
pretty_label,
si_fmt,
si_fmt_int,
styled_html_tag,
validate_fig,
)
from tests.conftest import y_pred, y_true
Expand Down Expand Up @@ -395,18 +395,23 @@ def test_values_show_mode(


@pytest.mark.parametrize(
"text, tag, style",
"text, tag, title, style",
[
("foo", "span", ""),
("bar", "small", "color: red;"),
("baz", "div", "font-size: 0.8em;"),
("", "strong", "font-size: 0.8em; font-weight: lighter;"),
("foo", "span", "", ""),
("bar", "small", "some title", "color: red;"),
("baz", "div", "long title " * 10, "font-size: 0.8em;"),
("", "strong", " ", "font-size: 0.8em; font-weight: lighter;"),
("", "strong", " ", "small"),
],
)
def test_styled_html_tag(text: str, tag: str, style: str) -> None:
style = style or "font-size: 0.8em; font-weight: lighter;"
def test_html_tag(text: str, tag: str, title: str, style: str) -> None:
orig_style = style
style = {"small": "font-size: 0.8em; font-weight: lighter;"}.get(style, style)
attrs = f" {title=} " if title else ""
attrs += f"{style=}" if style else ""
assert (
styled_html_tag(text, tag=tag, style=style) == f"<{tag} {style=}>{text}</{tag}>"
html_tag(text, tag=tag, title=title, style=orig_style)
== f"<{tag}{attrs}>{text}</{tag}>"
)


Expand Down

0 comments on commit d60f3bd

Please sign in to comment.