From 5fcc7ae00434056829b884e8e7be4fdb85614605 Mon Sep 17 00:00:00 2001 From: Thanachai Soontornwutikul Date: Mon, 12 Oct 2015 20:34:49 +0900 Subject: [PATCH 1/3] Allow specifying cross-validation subsets of the dataset to be monitored. --- pylearn2/cross_validation/__init__.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pylearn2/cross_validation/__init__.py b/pylearn2/cross_validation/__init__.py index e5dbc60113..bc2957ae76 100644 --- a/pylearn2/cross_validation/__init__.py +++ b/pylearn2/cross_validation/__init__.py @@ -1,4 +1,4 @@ -""" +""" Cross validation module. Each fold of cross validation is a separate experiment, so we create a @@ -34,6 +34,8 @@ class TrainCV(object): Training model. If list, training model for each fold. algorithm : TrainingAlgorithm Training algorithm. + algorithm_monitoring_datasets : list or None + Subsets of the dataset to be monitored. Leave as None to monitor all subsets. save_path : str or None Output filename for trained models. Also used (with modification) for individual models if save_folds is True. @@ -48,7 +50,7 @@ class TrainCV(object): cv_extensions : list or None TrainCVExtension objects for the parent TrainCV object. """ - def __init__(self, dataset_iterator, model, algorithm=None, + def __init__(self, dataset_iterator, model, algorithm=None, algorithm_monitoring_datasets=None, save_path=None, save_freq=0, extensions=None, allow_overwrite=True, save_folds=False, cv_extensions=None): self.dataset_iterator = dataset_iterator @@ -76,7 +78,11 @@ def __init__(self, dataset_iterator, model, algorithm=None, # setup monitoring datasets this_algorithm = deepcopy(algorithm) - this_algorithm._set_monitoring_dataset(datasets) + if algorithm_monitoring_datasets is None: + monitoring_datasets = datasets + else: + monitoring_datasets = {k:v for (k,v) in datasets.items() if k in algorithm_monitoring_datasets} + this_algorithm._set_monitoring_dataset(monitoring_datasets) # extensions this_extensions = deepcopy(extensions) From 249019006c2b7ed1d6c9855a7f13593ec26a5af7 Mon Sep 17 00:00:00 2001 From: Thanachai Soontornwutikul Date: Mon, 12 Oct 2015 23:31:36 +0900 Subject: [PATCH 2/3] Modified pylearn2/cross_validation/__init__.py to ensure compatibility and conformation to formatting rules. --- pylearn2/cross_validation/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pylearn2/cross_validation/__init__.py b/pylearn2/cross_validation/__init__.py index bc2957ae76..daea4a8a2b 100644 --- a/pylearn2/cross_validation/__init__.py +++ b/pylearn2/cross_validation/__init__.py @@ -35,7 +35,8 @@ class TrainCV(object): algorithm : TrainingAlgorithm Training algorithm. algorithm_monitoring_datasets : list or None - Subsets of the dataset to be monitored. Leave as None to monitor all subsets. + Subsets of the dataset to be monitored. + Leave as None to monitor all subsets. save_path : str or None Output filename for trained models. Also used (with modification) for individual models if save_folds is True. @@ -50,7 +51,8 @@ class TrainCV(object): cv_extensions : list or None TrainCVExtension objects for the parent TrainCV object. """ - def __init__(self, dataset_iterator, model, algorithm=None, algorithm_monitoring_datasets=None, + def __init__(self, dataset_iterator, model, algorithm=None, + algorithm_monitoring_datasets=None, save_path=None, save_freq=0, extensions=None, allow_overwrite=True, save_folds=False, cv_extensions=None): self.dataset_iterator = dataset_iterator @@ -81,7 +83,9 @@ def __init__(self, dataset_iterator, model, algorithm=None, algorithm_monitoring if algorithm_monitoring_datasets is None: monitoring_datasets = datasets else: - monitoring_datasets = {k:v for (k,v) in datasets.items() if k in algorithm_monitoring_datasets} + monitoring_datasets = dict( + (k, v) for (k, v) in datasets.iteritems() + if k in algorithm_monitoring_datasets) this_algorithm._set_monitoring_dataset(monitoring_datasets) # extensions From 7e8231e055e0f07a3dd3302efcedcea72c4d7848 Mon Sep 17 00:00:00 2001 From: Thanachai Soontornwutikul Date: Tue, 13 Oct 2015 00:14:51 +0900 Subject: [PATCH 3/3] Removed inadvertently introduced UTF-8 BOM and whitespaces. --- pylearn2/cross_validation/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pylearn2/cross_validation/__init__.py b/pylearn2/cross_validation/__init__.py index daea4a8a2b..fcfbbbe451 100644 --- a/pylearn2/cross_validation/__init__.py +++ b/pylearn2/cross_validation/__init__.py @@ -1,4 +1,4 @@ -""" +""" Cross validation module. Each fold of cross validation is a separate experiment, so we create a @@ -51,7 +51,7 @@ class TrainCV(object): cv_extensions : list or None TrainCVExtension objects for the parent TrainCV object. """ - def __init__(self, dataset_iterator, model, algorithm=None, + def __init__(self, dataset_iterator, model, algorithm=None, algorithm_monitoring_datasets=None, save_path=None, save_freq=0, extensions=None, allow_overwrite=True, save_folds=False, cv_extensions=None): @@ -84,7 +84,7 @@ def __init__(self, dataset_iterator, model, algorithm=None, monitoring_datasets = datasets else: monitoring_datasets = dict( - (k, v) for (k, v) in datasets.iteritems() + (k, v) for (k, v) in datasets.iteritems() if k in algorithm_monitoring_datasets) this_algorithm._set_monitoring_dataset(monitoring_datasets)