generated from NOAA-OWP/owp-open-source-project-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prediction algorithms, hyperparameterization, evaluation (#22)
* feat: create script to generate algo prediction data for testing * feat: generating predictions from trained algos under dev * feat: add processing of xssa locations, randomly selecting a subset to use for algo prediction * feat: develop algo prediction's config ingest, and determine paths to prediction locations and trained algos * feat: create metric prediction and write results to file * feat: build unit test for build_cfig_path() * feat: build unit test for build_cfig_path() * feat: add unit testsfor std_pred_path and _read_pred_comid; test coverage now at 92% * feat: add oob = True as default for RandomForestRegressor * feat: add hyperparameterization capability using grid search and associated unit tests * feat: add unit testing for train_eval() * chore: change algo config for testing out hyperparameterization * chore: add UserWarning category specification to warnings.warn * fix: algo config assignment accidentally only looked at first line of params * fix: make sure that hyperparameter key:value pairings contained inside dict, not list * fix: adjust unit test's algo_config formats to represent the issue of a dict of a list, which the list_to_dict() function then converts * fix: _check_attributes_exist now appropriately reports missing attributes and comids * fix: ensure algo and pipeline keys contain algo and pipeline object types in the grid search case
- Loading branch information
Showing
10 changed files
with
753 additions
and
116 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
import yaml | ||
import joblib | ||
import fs_algo.fs_algo_train_eval as fsate | ||
import pandas as pd | ||
from pathlib import Path | ||
import ast | ||
import warnings | ||
import os | ||
import numpy as np | ||
|
||
# TODO create a function that's flexible/converts user formatted checks (a la fsds_proc) | ||
|
||
|
||
# Predict values and evaluate predictions | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description = 'process the prediction config file') | ||
parser.add_argument('path_pred_config', type=str, help='Path to the YAML configuration file specific for prediction.') | ||
# NOTE pred_config should contain the path for path_algo_config | ||
args = parser.parse_args() | ||
|
||
home_dir = Path.home() | ||
path_pred_config = Path(args.path_pred_config) #Path(f'{home_dir}/git/formulation-selector/scripts/eval_ingest/xssa/xssa_pred_config.yaml') | ||
with open(path_pred_config, 'r') as file: | ||
pred_cfg = yaml.safe_load(file) | ||
|
||
#%% READ CONTENTS FROM THE ATTRIBUTE CONFIG | ||
path_attr_config = fsate.build_cfig_path(path_pred_config,pred_cfg.get('name_attr_config',None)) | ||
attr_cfig = fsate.AttrConfigAndVars(path_attr_config) | ||
attr_cfig._read_attr_config() | ||
|
||
dir_base = attr_cfig.attrs_cfg_dict.get('dir_base') | ||
dir_std_base = attr_cfig.attrs_cfg_dict.get('dir_std_base') | ||
dir_db_attrs = attr_cfig.attrs_cfg_dict.get('dir_db_attrs') | ||
datasets = attr_cfig.attrs_cfg_dict.get('datasets') # Identify datasets of interest | ||
attrs_sel = attr_cfig.attrs_cfg_dict.get('attrs_sel', None) | ||
|
||
#%% ESTABLISH ALGORITHM FILE I/O | ||
dir_out = fsate.fs_save_algo_dir_struct(dir_base).get('dir_out') | ||
dir_out_alg_base = fsate.fs_save_algo_dir_struct(dir_base).get('dir_out_alg_base') | ||
#%% PREDICTION FILE'S COMIDS | ||
path_pred_locs = pred_cfg.get('pred_file_in').format(**attr_cfig.attrs_cfg_dict) | ||
comid_pred_col = pred_cfg.get('pred_file_comid_colname') | ||
|
||
comids_pred = fsate._read_pred_comid(path_pred_locs, comid_pred_col ) | ||
|
||
#%% prediction config | ||
# TODO create pred config | ||
# path_pred_config = Path(args.path_pred_config) | ||
resp_vars = pred_cfg.get('algo_response_vars') | ||
algos = pred_cfg.get('algo_type') | ||
|
||
|
||
#%% Read in predictor variable data (aka basin attributes) | ||
# Read the predictor variable data (basin attributes) generated by fsds.attr.hydfab | ||
df_attr = fsate.fs_read_attr_comid(dir_db_attrs, comids_pred, attrs_sel = attrs_sel, | ||
_s3 = None,storage_options=None) | ||
# Convert into wide format for model training | ||
df_attr_wide = df_attr.pivot(index='featureID', columns = 'attribute', values = 'value') | ||
#%% Run prediction | ||
for ds in datasets: | ||
dir_out_alg_ds = Path(dir_out_alg_base/Path(ds)) | ||
print(f"PREDICTING algorithm for {ds}") | ||
for metric in resp_vars: | ||
for algo in algos: | ||
path_algo = fsate.std_algo_path(dir_out_alg_ds, algo=algo, metric=metric, dataset_id=ds) | ||
if not Path(path_algo).exists(): | ||
raise FileNotFoundError(f"The following algorithm path does not exist: \n{path_algo}") | ||
|
||
|
||
# Read in the algorithm's pipeline | ||
pipe = joblib.load(path_algo) | ||
feat_names = list(pipe.feature_names_in_) | ||
df_attr_sub = df_attr_wide[feat_names] | ||
|
||
# Perform prediction | ||
resp_pred = pipe.predict(df_attr_sub) | ||
|
||
# compile prediction results: | ||
df_pred =pd.DataFrame({'comid':comids_pred, | ||
'prediction':resp_pred, | ||
'metric':metric, | ||
'dataset':ds, | ||
'algo':algo, | ||
'name_algo':Path(path_algo).name}) | ||
|
||
path_pred_out = fsate.std_pred_path(dir_out,algo=algo,metric=metric,dataset_id=ds) | ||
# Write prediction results | ||
df_pred.to_parquet(path_pred_out) | ||
print(f" Completed {algo} prediction of {metric}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.