Skip to content

Commit

Permalink
Change viz scripts to call functions in fsate; add consistent plot st…
Browse files Browse the repository at this point in the history
…yle theme (#33)

* Create custom matplotlib stylesheet for RaFTS plots

* Flip axes on scatter; change perf to pred for clarity

* Change perf to pred for clarity

* Read in mplstyle file directly from fs_algo

* incorporate plotting functions into fs_perf_viz.py

* Use functions for creating file output paths

* Change perf_map to pred_map

---------

Co-authored-by: glitt13 <[email protected]>
  • Loading branch information
bolotinl and glitt13 authored Dec 19, 2024
1 parent ac3067a commit c6cf6cb
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 55 deletions.
19 changes: 19 additions & 0 deletions pkg/fs_algo/fs_algo/RaFTS_theme.mplstyle
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Style theme for RaFTS data visualizations

axes.labelsize : 12
lines.linewidth : 2
xtick.labelsize : 11
ytick.labelsize : 11
legend.fontsize : 11
font.family : Arial

# viridis color codes: https://waldyrious.net/viridis-palette-generator/
# viridis with a slightly lighter purple:
axes.prop_cycle: cycler('color', ['7e3b8a', '21918c', 'fde725', '3b528b', '5ec962'])

# Other odd options -------
# viridis:
# axes.prop_cycle: cycler('color', ['440154', '21918c', 'fde725', '3b528b', '5ec962'])

# viridis plasma:
# axes.prop_cycle: cycler('color', ['f89540', 'cc4778', '7e03a8', '0d0887', 'f0f921'])
31 changes: 16 additions & 15 deletions pkg/fs_algo/fs_algo/fs_algo_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,11 +1547,11 @@ def plot_pred_vs_obs_regr(y_pred, y_obs, ds:str, metr:str):
# Adapted from plot in bolotinl's fs_perf_viz.py

# Plot the observed vs. predicted module performance
plt.scatter(x=y_obs,y=y_pred, c='teal')
plt.scatter(x=y_obs,y=y_pred)
plt.axline((0, 0), (1, 1), color='black', linestyle='--')
plt.ylabel('Predicted {}'.format(metr))
plt.xlabel('Actual {}'.format(metr))
plt.title('Observed vs. Predicted Performance: {}'.format(ds))
plt.title('Observed vs. RaFTS Predicted Performance: {}'.format(ds))
fig = plt.gcf()
return fig

Expand All @@ -1567,15 +1567,15 @@ def plot_pred_vs_obs_wrap(y_pred, y_obs, dir_out_viz_base:str|os.PathLike,
plt.clf()
plt.close()

#%% Performance map visualization, adapted from plot in bolotinl's fs_perf_viz.py
def std_map_perf_path(dir_out_viz_base:str|os.PathLike, ds:str,
#%% Prediction map visualization, adapted from plot in bolotinl's fs_perf_viz.py
def std_map_pred_path(dir_out_viz_base:str|os.PathLike, ds:str,
metr:str,algo_str:str,
split_type:str='') -> pathlib.PosixPath:

# Generate a filepath of the feature_importance plot:
path_perf_map_plot = Path(f"{dir_out_viz_base}/{ds}/performance_map_{ds}_{metr}_{algo_str}_{split_type}.png")
path_perf_map_plot.parent.mkdir(parents=True,exist_ok=True)
return path_perf_map_plot
path_pred_map_plot = Path(f"{dir_out_viz_base}/{ds}/prediction_map_{ds}_{metr}_{algo_str}_{split_type}.png")
path_pred_map_plot.parent.mkdir(parents=True,exist_ok=True)
return path_pred_map_plot

def gen_conus_basemap(dir_out_basemap, # This should be the data_visualizations directory
url = 'https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_500k.zip',
Expand Down Expand Up @@ -1607,14 +1607,15 @@ def gen_conus_basemap(dir_out_basemap, # This should be the data_visualizations
# geo_df['performance'] = data['prediction'].values
# geo_df.crs = ("EPSG:4326")

def plot_map_perf(geo_df, states,title,metr,colname_data='performance'):
def plot_map_pred(geo_df, states,title,metr,colname_data='performance'):
fig, ax = plt.subplots(1, 1, figsize=(20, 24))
base = states.boundary.plot(ax=ax,color="#555555", linewidth=1)
# Points
geo_df.plot(column=colname_data, ax=ax, markersize=150, cmap='viridis', legend=False, zorder=2) # delete zorder to plot points behind states boundaries
# States
states.boundary.plot(ax=ax, color="#555555", linewidth=1, zorder=1) # Plot states boundary again with lower zorder


# TODO: need to customize the colorbar min and max based on the metric
## cbar = plt.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0,vmax = 1), cmap='viridis')
cbar = plt.cm.ScalarMappable(cmap='viridis')
ax.tick_params(axis='x', labelsize= 24)
Expand All @@ -1630,26 +1631,26 @@ def plot_map_perf(geo_df, states,title,metr,colname_data='performance'):
fig = plt.gcf()
return fig

def plot_map_perf_wrap(test_gdf,dir_out_viz_base, ds,
def plot_map_pred_wrap(test_gdf,dir_out_viz_base, ds,
metr,algo_str,
split_type='test',
colname_data='performance'):

path_perf_map_plot = std_map_perf_path(dir_out_viz_base,ds,metr,algo_str,split_type)
dir_out_basemap = path_perf_map_plot.parent.parent
path_pred_map_plot = std_map_pred_path(dir_out_viz_base,ds,metr,algo_str,split_type)
dir_out_basemap = path_pred_map_plot.parent.parent
states = gen_conus_basemap(dir_out_basemap = dir_out_basemap)

# Ensure the gdf matches the 4326 epsg used for states:
test_gdf = test_gdf.to_crs(4326)

# Generate the map
plot_title = f"Predicted Performance: {metr} - {ds}"
plot_perf_map = plot_map_perf(geo_df=test_gdf, states=states,title=plot_title,
plot_pred_map = plot_map_pred(geo_df=test_gdf, states=states,title=plot_title,
metr=metr,colname_data=colname_data)

# Save the plot as a .png file
plot_perf_map.savefig(path_perf_map_plot, dpi=300, bbox_inches='tight')
print(f"Wrote performance map to \n{path_perf_map_plot}")
plot_pred_map.savefig(path_pred_map_plot, dpi=300, bbox_inches='tight')
print(f"Wrote performance map to \n{path_pred_map_plot}")
plt.clf()
plt.close()

Expand Down
58 changes: 19 additions & 39 deletions pkg/fs_algo/fs_algo/fs_perf_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import xarray as xr
import urllib.request
import zipfile
import pkg_resources


if __name__ == "__main__":
Expand Down Expand Up @@ -101,6 +102,11 @@

# Location for accessing existing outputs and saving plots
dir_out = fsate.fs_save_algo_dir_struct(dir_base).get('dir_out')
dir_out_viz_base = Path(dir_out/Path("data_visualizations"))

# Enforce style
style_path = pkg_resources.resource_filename('fs_algo', 'RaFTS_theme.mplstyle')
plt.style.use(style_path)

# Loop through all datasets
for ds in datasets:
Expand All @@ -120,20 +126,8 @@
# data.to_csv(f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_data.csv')

# Does the user want a scatter plot comparing the observed module performance and the predicted module performance by RaFTS?
if 'perf_map' in true_keys:
url = 'https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_500k.zip'
zip_filename = f'{dir_out}/data_visualizations/cb_2018_us_state_500k.zip'
filename = f'{dir_out}/data_visualizations/cb_2018_us_state_500k.shp'

if not Path(zip_filename).exists():
print('Downloading shapefile...')
urllib.request.urlretrieve(url, zip_filename)
if not Path(filename).exists():
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
zip_ref.extractall(f'{dir_out}/data_visualizations')

states = gpd.read_file(filename)
states = states.to_crs("EPSG:4326")
if 'pred_map' in true_keys:
states = fsate.gen_conus_basemap(f'{dir_out}/data_visualizations/')

# Plot performance on map
lat = data['Y']
Expand All @@ -143,27 +137,14 @@
geo_df['performance'] = data['prediction'].values
geo_df.crs = ("EPSG:4326")

fig, ax = plt.subplots(1, 1, figsize=(20, 24))
base = states.boundary.plot(ax=ax,color="#555555", linewidth=1)
# Points
geo_df.plot(column="performance", ax=ax, markersize=150, cmap='viridis', legend=False, zorder=2) # delete zorder to plot points behind states boundaries
# States
states.boundary.plot(ax=ax, color="#555555", linewidth=1, zorder=1) # Plot states boundary again with lower zorder

cbar = plt.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=-0.41,vmax = 1), cmap='viridis')
ax.tick_params(axis='x', labelsize= 24)
ax.tick_params(axis='y', labelsize= 24)
plt.xlabel('Latitude',fontsize = 26)
plt.ylabel('Longitude',fontsize = 26)
cbar_ax = plt.colorbar(cbar, ax=ax,fraction=0.02, pad=0.04)
cbar_ax.set_label(label=metric,size=24)
cbar_ax.ax.tick_params(labelsize=24) # Set colorbar tick labels size
plt.title("Predicted Performance: {}".format(ds), fontsize = 28)
ax.set_xlim(-126, -66)
ax.set_ylim(24, 50)
fsate.plot_map_pred(geo_df=geo_df, states=states,
title=f'RaFTS Predicted Performance Map: {ds}',
metr=metric, colname_data='performance')

# Save the plot as a .png file
output_path = f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_performance_map.png'
output_path = fsate.std_map_pred_path(dir_out_viz_base=dir_out_viz_base,
ds=ds, metr=metric, algo_str=algo,
split_type='prediction')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
Expand Down Expand Up @@ -196,14 +177,13 @@
data = pd.merge(data, obs, how = 'inner', on = 'identifier')

# Plot the observed vs. predicted module performance
plt.scatter(data['prediction'], data[metric], c='teal')
plt.axline((0, 0), (1, 1), color='black', linestyle='--')
plt.xlabel('Predicted {}'.format(metric))
plt.ylabel('Actual {}'.format(metric))
plt.title('Observed vs. Predicted Performance: {}'.format(ds))
fsate.plot_pred_vs_obs_regr(y_pred=data['prediction'], y_obs=data[metric],
ds = ds, metr=metric)

# Save the plot as a .png file
output_path = f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_obs_vs_sim_scatter.png'
output_path = fsate.std_regr_pred_obs_path(dir_out_viz_base=dir_out_viz_base,
ds=ds, metr=metric, algo_str=algo,
split_type='prediction')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval_ingest/xssa/xssa_viz_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ metrics: # All option could pull from the pred config file
- 'NSE'
plot_types:
- obs_vs_sim_scatter: True # NOTE: These plots can only be created if observed (actual) model performance values are available
- perf_map: True
- pred_map: True

0 comments on commit c6cf6cb

Please sign in to comment.