Skip to content

Commit

Permalink
fix "COM812", # trailing comma missing
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Sep 2, 2024
1 parent 476560b commit 5ab3390
Show file tree
Hide file tree
Showing 67 changed files with 2,927 additions and 2,550 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@

# %%
fig = ptable_heatmap(
count_elements(df_boltz[Key.formula]), log=True, return_type="figure"
count_elements(df_boltz[Key.formula]),
log=True,
return_type="figure",
)
fig.suptitle("Elements in BoltzTraP MP dataset")
pmv.save_fig(fig, "boltztrap_mp-ptable-heatmap.pdf")
Expand All @@ -65,6 +67,9 @@

# %%
df_boltz.sort_values("pf_n", ascending=False).head(1000).hist(
bins=50, log=True, layout=[2, 3], figsize=[18, 8]
bins=50,
log=True,
layout=[2, 3],
figsize=[18, 8],
)
plt.suptitle("BoltzTraP MP")
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
df_diel[Key.n_wyckoff] = df_diel.wyckoff.map(count_wyckoff_positions)

df_diel[Key.crystal_system] = df_diel[Key.spg_num].map(
pmv.utils.crystal_sys_from_spg_num
pmv.utils.crystal_sys_from_spg_num,
)

df_diel[Key.volume] = [x.volume for x in df_diel[Key.structure]]
Expand All @@ -49,7 +49,9 @@

# %%
fig = pmv.ptable_heatmap(
pmv.count_elements(df_diel[Key.formula]), log=True, return_type="figure"
pmv.count_elements(df_diel[Key.formula]),
log=True,
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Matbench dielectric dataset")
pmv.save_fig(fig, "dielectric-ptable-heatmap.pdf")
Expand Down Expand Up @@ -88,7 +90,8 @@

x_ticks = {} # custom x axis tick labels
for cry_sys, df_group in sorted(
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
df_diel.groupby(Key.crystal_system),
key=lambda x: pmv.crystal_sys_order.index(x[0]),
):
x_ticks[cry_sys] = (
f"<b>{cry_sys}</b><br>"
Expand All @@ -101,7 +104,8 @@
fig.layout.margin = dict(b=10, l=10, r=10, t=50)
fig.layout.showlegend = False
fig.layout.xaxis = reusable_x_axis = dict(
tickvals=list(range(len(pmv.crystal_sys_order))), ticktext=list(x_ticks.values())
tickvals=list(range(len(pmv.crystal_sys_order))),
ticktext=list(x_ticks.values()),
)


Expand Down Expand Up @@ -130,7 +134,8 @@ def rgb_color(val: float, max_val: float) -> str:

x_ticks = {}
for cry_sys, df_group in sorted(
df_diel.groupby(Key.crystal_system), key=lambda x: pmv.crystal_sys_order.index(x[0])
df_diel.groupby(Key.crystal_system),
key=lambda x: pmv.crystal_sys_order.index(x[0]),
):
n_wyckoff = df_group[Key.n_wyckoff].mean()
clr = rgb_color(n_wyckoff, 14)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@
df_grvh[[Key.spg_symbol, Key.spg_num]] = [
struct.get_space_group_info()
for struct in tqdm(
df_grvh[Key.structure], desc="Getting matbench_log_gvrh spacegroups"
df_grvh[Key.structure],
desc="Getting matbench_log_gvrh spacegroups",
)
]
df_grvh[Key.crystal_system] = df_grvh[Key.spg_num].map(
pmv.utils.crystal_sys_from_spg_num
pmv.utils.crystal_sys_from_spg_num,
)


df_grvh[Key.wyckoff] = [
get_protostructure_label_from_spglib(struct)
for struct in tqdm(
df_grvh[Key.structure], desc="Getting matbench_log_gvrh Wyckoff strings"
df_grvh[Key.structure],
desc="Getting matbench_log_gvrh Wyckoff strings",
)
]
df_grvh[Key.n_wyckoff] = df_grvh.wyckoff.map(count_wyckoff_positions)
Expand Down Expand Up @@ -133,7 +135,9 @@ def has_isolated_atom(crystal: Structure, radius: float = 5) -> bool:
df_grvh[Key.formula] = df_grvh[Key.structure].map(lambda struct: struct.formula)

fig = ptable_heatmap(
count_elements(df_grvh[Key.formula]), log=True, return_type="figure"
count_elements(df_grvh[Key.formula]),
log=True,
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Matbench bulk/shear modulus datasets")
pmv.save_fig(fig, "log_gvrh-ptable-heatmap.pdf")
Expand Down Expand Up @@ -172,7 +176,8 @@ def rgb_color(val: float, max_val: float) -> str:

x_ticks = {}
for cry_sys, df_group in sorted(
df_grvh.groupby(Key.crystal_system), key=lambda x: crystal_sys_order.index(x[0])
df_grvh.groupby(Key.crystal_system),
key=lambda x: crystal_sys_order.index(x[0]),
):
n_wyckoff_top = df_group[Key.n_wyckoff].mean()
clr = rgb_color(n_wyckoff_top, 14)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
df_perov[Key.formula] = df_perov[Key.structure].map(lambda cryst: cryst.formula)

df_perov[Key.crystal_system] = df_perov[Key.spg_num].map(
pmv.utils.crystal_sys_from_spg_num
pmv.utils.crystal_sys_from_spg_num,
)


Expand All @@ -49,7 +49,9 @@

# %%
fig = ptable_heatmap(
count_elements(df_perov[Key.formula]), log=True, return_type="figure"
count_elements(df_perov[Key.formula]),
log=True,
return_type="figure",
)
fig.suptitle("Elements in Matbench Perovskites dataset")
pmv.save_fig(fig, "perovskites-ptable-heatmap.pdf")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
df_phonon[Key.volume] = df_phonon[Key.structure].map(lambda cryst: cryst.volume)

fig = ptable_heatmap(
count_elements(df_phonon[Key.formula]), log=True, return_type="figure"
count_elements(df_phonon[Key.formula]),
log=True,
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Matbench phonons dataset")
pmv.save_fig(fig, "phonons-ptable-heatmap.pdf")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

# %%
fig = ptable_heatmap(
count_elements(df_steels[Key.composition]), log=True, return_type="figure"
count_elements(df_steels[Key.composition]),
log=True,
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Matbench steels dataset")
pmv.save_fig(fig, "steels-ptable-heatmap.pdf")
13 changes: 9 additions & 4 deletions examples/dataset_exploration/matpes/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
total_force_col = "Σ|force<sub>i</sub>| (eV/Å)"
df_pbe[total_force_col] = df_pbe[Key.forces].map(lambda arr: np.abs(arr).sum(axis=1))
df_r2scan[total_force_col] = df_r2scan[Key.forces].map(
lambda arr: np.abs(arr).sum(axis=1)
lambda arr: np.abs(arr).sum(axis=1),
)

fig = go.Figure()
Expand Down Expand Up @@ -149,7 +149,8 @@
}

fig = pmv.ptable_heatmap_splits(
per_elem_cohesive_energy, cbar_title=f"{col_name.label} (eV)"
per_elem_cohesive_energy,
cbar_title=f"{col_name.label} (eV)",
)


Expand Down Expand Up @@ -201,7 +202,9 @@

# %% spacegroup histogram
fig = pmv.spacegroup_bar(
df_r2scan[Key.spg_num], title="r2SCAN spacegroup histogram", log=True
df_r2scan[Key.spg_num],
title="r2SCAN spacegroup histogram",
log=True,
)
fig.show()
pmv.save_fig(fig, "r2scan-spacegroup-hist.pdf")
Expand Down Expand Up @@ -230,7 +233,9 @@
for site, force in zip(struct, forces, strict=True)
}
for struct, forces in zip(
df_r2scan[Key.structure], df_r2scan[Key.forces], strict=True
df_r2scan[Key.structure],
df_r2scan[Key.forces],
strict=True,
)
).mean()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

# %%
df_carrier = pd.concat(
[df_carrier, pd.json_normalize(df_carrier.data)], axis="columns"
[df_carrier, pd.json_normalize(df_carrier.data)],
axis="columns",
).drop(columns=["data", "is_public", "project"])

df_carrier = df_carrier.set_index("identifier")
Expand Down Expand Up @@ -99,5 +100,6 @@

# %%
df_carrier.to_json(
"cleaned_ricci_boltztrap_mp_tabular.json.gz", default_handler=lambda x: x.as_dict()
"cleaned_ricci_boltztrap_mp_tabular.json.gz",
default_handler=lambda x: x.as_dict(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@

# %%
fig = ptable_heatmap(
count_elements(df_carrier.pretty_formula.dropna()), log=True, return_type="figure"
count_elements(df_carrier.pretty_formula.dropna()),
log=True,
return_type="figure",
)
fig.suptitle("Elemental prevalence in the Ricci Carrier Transport dataset")
pmv.save_fig(fig, "carrier-transport-ptable-heatmap.pdf")
Expand All @@ -55,7 +57,7 @@
# %%
ax = df_carrier[["S.p [µV/K]", "S.n [µV/K]"]].hist(bins=50, log=True, figsize=[18, 8])
plt.suptitle(
"Ricci Carrier Transport dataset histograms for n- and p-type Seebeck coefficients"
"Ricci Carrier Transport dataset histograms for n- and p-type Seebeck coefficients",
)
pmv.save_fig(ax, "carrier-transport-seebeck-n+p.pdf")

Expand Down
6 changes: 4 additions & 2 deletions examples/dataset_exploration/wbm/explore_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

# %% download wbm-summary.csv (12 MB)
df_wbm = pd.read_csv("https://figshare.com/ndownloader/files/44225498").set_index(
Key.mat_id, drop=False
Key.mat_id,
drop=False,
)

df_wbm["batch_idx"] = df_wbm.index.str.split("-").str[2].astype(int)
Expand Down Expand Up @@ -69,7 +70,8 @@
fig.layout.margin = dict(b=10, l=10, r=10, t=50)
fig.layout.showlegend = False
fig.layout.xaxis = dict(
tickvals=list(range(len(crystal_sys_order))), ticktext=list(x_ticks.values())
tickvals=list(range(len(crystal_sys_order))),
ticktext=list(x_ticks.values()),
)
fig.update_traces(hoverinfo="skip", hovertemplate=None)

Expand Down
13 changes: 10 additions & 3 deletions examples/diatomics/mace_pair_repulsion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def timer(label: str = "") -> Generator[None, None, None]:


def generate_diatomics(
symbol0: str, symbol1: str, distances: list[float] | np.ndarray
symbol0: str,
symbol1: str,
distances: list[float] | np.ndarray,
) -> list[Atoms]:
"""Build diatomic molecules in vacuum for given distances.
Expand All @@ -56,7 +58,10 @@ def generate_diatomics(


def calc_one_pair(
z0: int, z1: int, calc: MACECalculator, distances: list[float] | np.ndarray
z0: int,
z1: int,
calc: MACECalculator,
distances: list[float] | np.ndarray,
) -> list[float]:
"""Calculate potential energy for a pair of elements at given distances.
Expand All @@ -72,7 +77,9 @@ def calc_one_pair(
return [
calc.get_potential_energy(at)
for at in generate_diatomics(
chemical_symbols[z0], chemical_symbols[z1], distances
chemical_symbols[z0],
chemical_symbols[z1],
distances,
)
]

Expand Down
16 changes: 11 additions & 5 deletions examples/diatomics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

# %%
def plot_on_ax(
ax: plt.Axes, distances: np.ndarray, energy: np.ndarray, formula: str
ax: plt.Axes,
distances: np.ndarray,
energy: np.ndarray,
formula: str,
) -> None:
"""Plot pair repulsion curve on a given axes.
Expand Down Expand Up @@ -54,7 +57,7 @@ def plot_homo_nuclear(model_size: str) -> None:
n_rows, n_columns, size_factor = 10, 18, 3

fig = plt.figure(
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor)
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor),
)
gs = plt.GridSpec(figure=fig, nrows=n_rows, ncols=n_columns)

Expand Down Expand Up @@ -93,12 +96,12 @@ def plot_hetero_nuclear(model_size: str) -> None:
[
int(fn.name.split("-")[2])
for fn in Path("simulations/").glob(f"results-{model_size}-*-X.json")
]
],
)
with PdfPages(f"{model_size}-hetero-nuclear.pdf") as pdf:
for z_main in z_calculated:
fig = plt.figure(
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor)
figsize=(0.75 * n_columns * size_factor, 0.7 * n_rows * size_factor),
)
gs = plt.GridSpec(figure=fig, nrows=n_rows, ncols=n_columns)
plot_element_heteronuclear(fig, gs, model_size, z_main)
Expand All @@ -113,7 +116,10 @@ def plot_hetero_nuclear(model_size: str) -> None:


def plot_element_heteronuclear(
fig: plt.Figure, gs: plt.GridSpec, model_size: str, atomic_number: int
fig: plt.Figure,
gs: plt.GridSpec,
model_size: str,
atomic_number: int,
) -> None:
"""Plot heteronuclear pair repulsion curves for a specific element.
Expand Down
6 changes: 5 additions & 1 deletion examples/make_assets/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

# %% Histogram Plots
ax = pmv.elements_hist(
df_expt_gap[Key.composition], keep_top=15, v_offset=200, rotation=0, fontsize=12
df_expt_gap[Key.composition],
keep_top=15,
v_offset=200,
rotation=0,
fontsize=12,
)
pmv.io.save_and_compress_svg(ax, "elements-hist")

Expand Down
4 changes: 3 additions & 1 deletion examples/make_assets/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@

fig = pmv.phonon_bands_and_dos(ph_bands, ph_doses)
fig.layout.title = dict(
text=f"Phonon Bands and DOS of {formula} ({mp_id})", x=0.5, y=0.98
text=f"Phonon Bands and DOS of {formula} ({mp_id})",
x=0.5,
y=0.98,
)
fig.layout.margin = dict(l=0, r=0, b=0, t=40)
pmv.io.save_and_compress_svg(fig, f"phonon-bands-and-dos-{mp_id}")
5 changes: 4 additions & 1 deletion examples/make_assets/ptable/ptable_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@

# %%
fig = pmv.ptable_heatmap_ratio(
df_expt_gap[Key.composition], df_steels[Key.composition], log=True, value_fmt=".4g"
df_expt_gap[Key.composition],
df_steels[Key.composition],
log=True,
value_fmt=".4g",
)
title = "Element ratios in Matbench Experimental Band Gap vs Matbench Steel"
fig.suptitle(title, y=0.96, fontsize=16, fontweight="bold")
Expand Down
4 changes: 3 additions & 1 deletion examples/make_assets/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

# %%
fig = pmv.ptable_heatmap_plotly(
df_expt_gap[Key.composition], log=True, colorscale="viridis"
df_expt_gap[Key.composition],
log=True,
colorscale="viridis",
)
title = "Elements in Matbench Experimental Bandgap (log scale)"
fig.layout.title = dict(text=f"<b>{title}</b>", x=0.45, y=0.94, font_size=20)
Expand Down
Loading

0 comments on commit 5ab3390

Please sign in to comment.