Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change viz scripts to call functions in fsate; add consistent plot style theme #33

Merged
merged 8 commits into from
Dec 19, 2024
19 changes: 19 additions & 0 deletions pkg/fs_algo/fs_algo/RaFTS_theme.mplstyle
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Style theme for RaFTS data visualizations

axes.labelsize : 12
lines.linewidth : 2
xtick.labelsize : 11
ytick.labelsize : 11
legend.fontsize : 11
font.family : Arial

# viridis color codes: https://waldyrious.net/viridis-palette-generator/
# viridis with a slightly lighter purple:
axes.prop_cycle: cycler('color', ['7e3b8a', '21918c', 'fde725', '3b528b', '5ec962'])

# Other odd options -------
# viridis:
# axes.prop_cycle: cycler('color', ['440154', '21918c', 'fde725', '3b528b', '5ec962'])

# viridis plasma:
# axes.prop_cycle: cycler('color', ['f89540', 'cc4778', '7e03a8', '0d0887', 'f0f921'])
31 changes: 16 additions & 15 deletions pkg/fs_algo/fs_algo/fs_algo_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,11 +1547,11 @@ def plot_pred_vs_obs_regr(y_pred, y_obs, ds:str, metr:str):
# Adapted from plot in bolotinl's fs_perf_viz.py

# Plot the observed vs. predicted module performance
plt.scatter(x=y_obs,y=y_pred, c='teal')
plt.scatter(x=y_obs,y=y_pred)
plt.axline((0, 0), (1, 1), color='black', linestyle='--')
plt.ylabel('Predicted {}'.format(metr))
plt.xlabel('Actual {}'.format(metr))
plt.title('Observed vs. Predicted Performance: {}'.format(ds))
plt.title('Observed vs. RaFTS Predicted Performance: {}'.format(ds))
fig = plt.gcf()
return fig

Expand All @@ -1567,15 +1567,15 @@ def plot_pred_vs_obs_wrap(y_pred, y_obs, dir_out_viz_base:str|os.PathLike,
plt.clf()
plt.close()

#%% Performance map visualization, adapted from plot in bolotinl's fs_perf_viz.py
def std_map_perf_path(dir_out_viz_base:str|os.PathLike, ds:str,
#%% Prediction map visualization, adapted from plot in bolotinl's fs_perf_viz.py
def std_map_pred_path(dir_out_viz_base:str|os.PathLike, ds:str,
metr:str,algo_str:str,
split_type:str='') -> pathlib.PosixPath:

# Generate a filepath of the feature_importance plot:
path_perf_map_plot = Path(f"{dir_out_viz_base}/{ds}/performance_map_{ds}_{metr}_{algo_str}_{split_type}.png")
path_perf_map_plot.parent.mkdir(parents=True,exist_ok=True)
return path_perf_map_plot
path_pred_map_plot = Path(f"{dir_out_viz_base}/{ds}/prediction_map_{ds}_{metr}_{algo_str}_{split_type}.png")
path_pred_map_plot.parent.mkdir(parents=True,exist_ok=True)
return path_pred_map_plot

def gen_conus_basemap(dir_out_basemap, # This should be the data_visualizations directory
url = 'https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_500k.zip',
Expand Down Expand Up @@ -1607,14 +1607,15 @@ def gen_conus_basemap(dir_out_basemap, # This should be the data_visualizations
# geo_df['performance'] = data['prediction'].values
# geo_df.crs = ("EPSG:4326")

def plot_map_perf(geo_df, states,title,metr,colname_data='performance'):
def plot_map_pred(geo_df, states,title,metr,colname_data='performance'):
fig, ax = plt.subplots(1, 1, figsize=(20, 24))
base = states.boundary.plot(ax=ax,color="#555555", linewidth=1)
# Points
geo_df.plot(column=colname_data, ax=ax, markersize=150, cmap='viridis', legend=False, zorder=2) # delete zorder to plot points behind states boundaries
# States
states.boundary.plot(ax=ax, color="#555555", linewidth=1, zorder=1) # Plot states boundary again with lower zorder


# TODO: need to customize the colorbar min and max based on the metric
## cbar = plt.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0,vmax = 1), cmap='viridis')
cbar = plt.cm.ScalarMappable(cmap='viridis')
ax.tick_params(axis='x', labelsize= 24)
Expand All @@ -1630,26 +1631,26 @@ def plot_map_perf(geo_df, states,title,metr,colname_data='performance'):
fig = plt.gcf()
return fig

def plot_map_perf_wrap(test_gdf,dir_out_viz_base, ds,
def plot_map_pred_wrap(test_gdf,dir_out_viz_base, ds,
metr,algo_str,
split_type='test',
colname_data='performance'):

path_perf_map_plot = std_map_perf_path(dir_out_viz_base,ds,metr,algo_str,split_type)
dir_out_basemap = path_perf_map_plot.parent.parent
path_pred_map_plot = std_map_pred_path(dir_out_viz_base,ds,metr,algo_str,split_type)
dir_out_basemap = path_pred_map_plot.parent.parent
states = gen_conus_basemap(dir_out_basemap = dir_out_basemap)

# Ensure the gdf matches the 4326 epsg used for states:
test_gdf = test_gdf.to_crs(4326)

# Generate the map
plot_title = f"Predicted Performance: {metr} - {ds}"
plot_perf_map = plot_map_perf(geo_df=test_gdf, states=states,title=plot_title,
plot_pred_map = plot_map_pred(geo_df=test_gdf, states=states,title=plot_title,
metr=metr,colname_data=colname_data)

# Save the plot as a .png file
plot_perf_map.savefig(path_perf_map_plot, dpi=300, bbox_inches='tight')
print(f"Wrote performance map to \n{path_perf_map_plot}")
plot_pred_map.savefig(path_pred_map_plot, dpi=300, bbox_inches='tight')
print(f"Wrote performance map to \n{path_pred_map_plot}")
plt.clf()
plt.close()

Expand Down
58 changes: 19 additions & 39 deletions pkg/fs_algo/fs_algo/fs_perf_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import xarray as xr
import urllib.request
import zipfile
import pkg_resources


if __name__ == "__main__":
Expand Down Expand Up @@ -101,6 +102,11 @@

# Location for accessing existing outputs and saving plots
dir_out = fsate.fs_save_algo_dir_struct(dir_base).get('dir_out')
dir_out_viz_base = Path(dir_out/Path("data_visualizations"))

# Enforce style
style_path = pkg_resources.resource_filename('fs_algo', 'RaFTS_theme.mplstyle')
plt.style.use(style_path)

# Loop through all datasets
for ds in datasets:
Expand All @@ -120,20 +126,8 @@
# data.to_csv(f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_data.csv')

# Does the user want a scatter plot comparing the observed module performance and the predicted module performance by RaFTS?
if 'perf_map' in true_keys:
url = 'https://www2.census.gov/geo/tiger/GENZ2018/shp/cb_2018_us_state_500k.zip'
zip_filename = f'{dir_out}/data_visualizations/cb_2018_us_state_500k.zip'
filename = f'{dir_out}/data_visualizations/cb_2018_us_state_500k.shp'

if not Path(zip_filename).exists():
print('Downloading shapefile...')
urllib.request.urlretrieve(url, zip_filename)
if not Path(filename).exists():
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
zip_ref.extractall(f'{dir_out}/data_visualizations')

states = gpd.read_file(filename)
states = states.to_crs("EPSG:4326")
if 'pred_map' in true_keys:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bolotinl this is the spot where pred_map should be changed to perf_map

states = fsate.gen_conus_basemap(f'{dir_out}/data_visualizations/')

# Plot performance on map
lat = data['Y']
Expand All @@ -143,27 +137,14 @@
geo_df['performance'] = data['prediction'].values
geo_df.crs = ("EPSG:4326")

fig, ax = plt.subplots(1, 1, figsize=(20, 24))
base = states.boundary.plot(ax=ax,color="#555555", linewidth=1)
# Points
geo_df.plot(column="performance", ax=ax, markersize=150, cmap='viridis', legend=False, zorder=2) # delete zorder to plot points behind states boundaries
# States
states.boundary.plot(ax=ax, color="#555555", linewidth=1, zorder=1) # Plot states boundary again with lower zorder

cbar = plt.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=-0.41,vmax = 1), cmap='viridis')
ax.tick_params(axis='x', labelsize= 24)
ax.tick_params(axis='y', labelsize= 24)
plt.xlabel('Latitude',fontsize = 26)
plt.ylabel('Longitude',fontsize = 26)
cbar_ax = plt.colorbar(cbar, ax=ax,fraction=0.02, pad=0.04)
cbar_ax.set_label(label=metric,size=24)
cbar_ax.ax.tick_params(labelsize=24) # Set colorbar tick labels size
plt.title("Predicted Performance: {}".format(ds), fontsize = 28)
ax.set_xlim(-126, -66)
ax.set_ylim(24, 50)
fsate.plot_map_pred(geo_df=geo_df, states=states,
title=f'RaFTS Predicted Performance Map: {ds}',
metr=metric, colname_data='performance')

# Save the plot as a .png file
output_path = f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_performance_map.png'
output_path = fsate.std_map_pred_path(dir_out_viz_base=dir_out_viz_base,
ds=ds, metr=metric, algo_str=algo,
split_type='prediction')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
Expand Down Expand Up @@ -196,14 +177,13 @@
data = pd.merge(data, obs, how = 'inner', on = 'identifier')

# Plot the observed vs. predicted module performance
plt.scatter(data['prediction'], data[metric], c='teal')
plt.axline((0, 0), (1, 1), color='black', linestyle='--')
plt.xlabel('Predicted {}'.format(metric))
plt.ylabel('Actual {}'.format(metric))
plt.title('Observed vs. Predicted Performance: {}'.format(ds))
fsate.plot_pred_vs_obs_regr(y_pred=data['prediction'], y_obs=data[metric],
ds = ds, metr=metric)

# Save the plot as a .png file
output_path = f'{dir_out}/data_visualizations/{ds}_{algo}_{metric}_obs_vs_sim_scatter.png'
output_path = fsate.std_regr_pred_obs_path(dir_out_viz_base=dir_out_viz_base,
ds=ds, metr=metric, algo_str=algo,
split_type='prediction')
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval_ingest/xssa/xssa_viz_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ metrics: # All option could pull from the pred config file
- 'NSE'
plot_types:
- obs_vs_sim_scatter: True # NOTE: These plots can only be created if observed (actual) model performance values are available
- perf_map: True
- pred_map: True