-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAINT: refactor the CLI using click (#175)
- Loading branch information
1 parent
c22e3ee
commit 0fc6a08
Showing
7 changed files
with
184 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import ast | ||
|
||
import click | ||
import pandas as pd | ||
|
||
from ...externals.tabulate import tabulate | ||
|
||
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | ||
|
||
|
||
class PythonLiteralOption(click.Option): | ||
|
||
def type_cast_value(self, ctx, value): | ||
try: | ||
return ast.literal_eval(value) | ||
except: # noqa: E722 | ||
raise click.BadParameter(value) | ||
|
||
|
||
def _load_score_submission(submission_path, metric, step): | ||
"""Load the score for a single submission.""" | ||
training_output_path = os.path.join(submission_path, 'training_output') | ||
folds_path = [ | ||
os.path.join(training_output_path, fold_name) | ||
for fold_name in os.listdir(training_output_path) | ||
if (os.path.isdir(os.path.join(training_output_path, fold_name)) and | ||
'fold_' in fold_name) | ||
] | ||
data = {} | ||
for fold_id, path in enumerate(folds_path): | ||
score_path = os.path.join(path, 'scores.csv') | ||
if not os.path.exists(score_path): | ||
return | ||
scores = pd.read_csv(score_path, index_col=0) | ||
scores.columns.name = 'score' | ||
data[fold_id] = scores | ||
df = pd.concat(data, names=['fold']) | ||
metric = metric if metric else slice(None) | ||
step = step if step else slice(None) | ||
return df.loc[(slice(None), step), metric] | ||
|
||
|
||
@click.group(context_settings=CONTEXT_SETTINGS) | ||
def main(): | ||
"""Command-line to show information about local submissions.""" | ||
pass | ||
|
||
|
||
@main.command() | ||
@click.option("--ramp-kit-dir", default='.', show_default=True, | ||
help='Root directory of the ramp-kit to retrieved the train ' | ||
'submission.') | ||
@click.option("--metric", cls=PythonLiteralOption, default="[]", | ||
show_default=True, | ||
help='A list of the metric to report') | ||
@click.option("--step", cls=PythonLiteralOption, default="[]", | ||
show_default=True, | ||
help='A list of the processing to report. Choices are ' | ||
'{"train" , "valid", "test"}') | ||
@click.option("--sort-by", cls=PythonLiteralOption, default="[]", | ||
show_default=True, | ||
help='Give the metric, step, and stat to use for sorting.') | ||
@click.option("--ascending/--descending", default=True, show_default=True, | ||
help='Sort in ascending or descending order.') | ||
@click.option("--precision", default=2, show_default=True, | ||
help='The precision for the different metrics reported.') | ||
def leaderboard(ramp_kit_dir, metric, step, sort_by, ascending, precision): | ||
"""Display the leaderboard for all the local submissions.""" | ||
path_submissions = os.path.join(ramp_kit_dir, 'submissions') | ||
all_submissions = { | ||
sub: os.path.join(path_submissions, sub) | ||
for sub in os.listdir(path_submissions) | ||
if os.path.isdir(os.path.join(path_submissions, sub)) | ||
} | ||
data = {} | ||
for sub_name, sub_path in all_submissions.items(): | ||
scores = _load_score_submission(sub_path, metric, step) | ||
if scores is None: | ||
continue | ||
data[sub_name] = scores | ||
df = pd.concat(data, names=['submission']) | ||
df = df.unstack(level=['step']) | ||
df = pd.concat([df.groupby('submission').mean(), | ||
df.groupby('submission').std()], | ||
keys=['mean', 'std'], axis=1, names=['stat']) | ||
df = df.round(precision).reorder_levels([1, 2, 0], axis=1) | ||
step = ['train', 'valid', 'test'] if not step else step | ||
df = (df.sort_index(axis=1, level=0) | ||
.reindex(labels=step, level='step', axis=1)) | ||
|
||
if sort_by: | ||
df = df.sort_values(tuple(sort_by), ascending=ascending, axis=0) | ||
|
||
headers = (["\n".join(df.columns.names)] + | ||
["\n".join(col_names) for col_names in df.columns.get_values()]) | ||
click.echo(tabulate(df, headers=headers, tablefmt='grid')) | ||
|
||
|
||
def start(): | ||
main() | ||
|
||
|
||
if __name__ == '__main__': | ||
start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
|
||
import click | ||
|
||
from ..testing import assert_notebook | ||
from ..testing import assert_submission | ||
|
||
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | ||
|
||
|
||
@click.command(context_settings=CONTEXT_SETTINGS) | ||
@click.option('--submission', default='starting_kit', show_default=True, | ||
help='The kit to test. It should be located in the ' | ||
'"submissions" folder of the starting kit. If "ALL", all ' | ||
'submissions in the directory will be tested.') | ||
@click.option('--ramp-kit-dir', default='.', show_default=True, | ||
help='Root directory of the ramp-kit to test.') | ||
@click.option('--ramp-data-dir', default='.', show_default=True, | ||
help='Directory containing the data. This directory should ' | ||
'contain a "data" folder.') | ||
@click.option('--ramp-submission-dir', default='submissions', | ||
show_default=True, | ||
help='Directory where the submissions are stored. It is the ' | ||
'directory (typically called "submissions" in the ramp-kit) ' | ||
'that contains the individual submission subdirectories.') | ||
@click.option('--notebook', is_flag=True, show_default=True, | ||
help='Whether or not to test the notebook.') | ||
@click.option('--quick-test', is_flag=True, | ||
help='Specify this flag to test the submission on a small ' | ||
'subset of the data.') | ||
@click.option('--pickle', is_flag=True, | ||
help='Specify this flag to pickle the submission after ' | ||
'training.') | ||
@click.option('--save-output', is_flag=True, | ||
help='Specify this flag to save predictions, scores, eventual ' | ||
'error trace, and state after training.') | ||
@click.option('--retrain', is_flag=True, | ||
help='Specify this flag to retrain the submission on the full ' | ||
'training set after the CV loop.') | ||
def main(submission, ramp_kit_dir, ramp_data_dir, ramp_submission_dir, | ||
notebook, quick_test, pickle, save_output, retrain): | ||
"""Test a submission and/or a notebook before to submit on RAMP studio.""" | ||
if quick_test: | ||
os.environ['RAMP_TEST_MODE'] = '1' | ||
|
||
if submission == "ALL": | ||
ramp_submission_dir = os.path.join(ramp_kit_dir, 'submissions') | ||
submission = [ | ||
directory | ||
for directory in os.listdir(ramp_submission_dir) | ||
if os.path.isdir(os.path.join(ramp_submission_dir, directory)) | ||
] | ||
else: | ||
submission = [submission] | ||
|
||
for sub in submission: | ||
assert_submission(ramp_kit_dir=ramp_kit_dir, | ||
ramp_data_dir=ramp_data_dir, | ||
ramp_submission_dir=ramp_submission_dir, | ||
submission=sub, | ||
is_pickle=pickle, | ||
save_output=save_output, | ||
retrain=retrain) | ||
|
||
if notebook: | ||
assert_notebook(ramp_kit_dir=ramp_kit_dir) | ||
|
||
|
||
def start(): | ||
main() | ||
|
||
|
||
if __name__ == '__main__': | ||
start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters