From e132cf4265ff07b76b3d53e040f28f86a140e683 Mon Sep 17 00:00:00 2001 From: wildsnowman Date: Sat, 10 Sep 2022 18:01:59 +0900 Subject: [PATCH 1/2] add n_comb options --- benchmarks/set_matching_pytorch/test.py | 3 ++- benchmarks/set_matching_pytorch/train_sm.py | 12 +++++++++--- shift15m/datasets/outfitfeature.py | 14 +++++++------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/benchmarks/set_matching_pytorch/test.py b/benchmarks/set_matching_pytorch/test.py index 9224f42..fb1f7d3 100644 --- a/benchmarks/set_matching_pytorch/test.py +++ b/benchmarks/set_matching_pytorch/test.py @@ -31,7 +31,7 @@ def get_test_loader( iqon_outfits = IQONOutfits(root=root, split=split) - test_examples = iqon_outfits.get_fitb_data(label_dir_name) + test_examples = iqon_outfits.get_fitb_data(label_dir_name, n_comb=args.n_comb) feature_dir = iqon_outfits.feature_dir dataset = FINBsDataset( test_examples, @@ -113,6 +113,7 @@ def main(args): parser.add_argument("--valid_year", type=int) parser.add_argument("--split", type=int, choices=[0, 1, 2]) parser.add_argument("--model_dir", "-d", type=str) + parser.add_argument("--n_comb", type=int, default=1) args = parser.parse_args() main(args) diff --git a/benchmarks/set_matching_pytorch/train_sm.py b/benchmarks/set_matching_pytorch/train_sm.py index 8613534..1db5249 100644 --- a/benchmarks/set_matching_pytorch/train_sm.py +++ b/benchmarks/set_matching_pytorch/train_sm.py @@ -24,6 +24,7 @@ def get_train_val_loader( valid_year: Union[str, int], split: int, batch_size: int, + n_comb: int, root: str = C.ROOT, num_workers: Optional[int] = None, ) -> Tuple[Any, Any]: @@ -33,8 +34,8 @@ def get_train_val_loader( train, valid = iqon_outfits.get_trainval_data(label_dir_name) feature_dir = iqon_outfits.feature_dir - train_dataset = MultisetSplitDataset(train, feature_dir, n_sets=1, n_drops=None) - valid_dataset = MultisetSplitDataset(valid, feature_dir, n_sets=1, n_drops=None) + train_dataset = MultisetSplitDataset(train, feature_dir, n_comb=n_comb, n_drops=None) + valid_dataset = MultisetSplitDataset(valid, feature_dir, n_comb=n_comb, n_drops=None) return ( get_loader(train_dataset, batch_size, num_workers=num_workers, is_train=True), get_loader(valid_dataset, batch_size, num_workers=num_workers, is_train=False), @@ -70,7 +71,11 @@ def main(args): # dataset train_loader, valid_loader = get_train_val_loader( - args.train_year, args.valid_year, args.split, args.batchsize + train_year=args.train_year, + valid_year=args.valid_year, + split=args.split, + batch_size=args.batchsize, + n_comb=args.n_comb, ) # logger @@ -222,6 +227,7 @@ def eval_process(engine, batch): parser.add_argument("--valid_year", type=int) parser.add_argument("--split", type=int, choices=[0, 1, 2]) parser.add_argument("--weight_path", "-w", type=str, default=None) + parser.add_argument("--n_comb", type=int, default=1) args = parser.parse_args() diff --git a/shift15m/datasets/outfitfeature.py b/shift15m/datasets/outfitfeature.py index af3ab5d..180133d 100644 --- a/shift15m/datasets/outfitfeature.py +++ b/shift15m/datasets/outfitfeature.py @@ -36,18 +36,18 @@ def __init__( self, sets: List, root: pathlib.Path, - n_sets: int, + n_comb: int, n_drops: Optional[int] = None, max_elementnum_per_set: Optional[int] = 8, ): self.sets = sets self.feat_dir = root - self.n_sets = n_sets + self.n_comb = n_comb self.n_drops = n_drops if n_drops is None: n_drops = max_elementnum_per_set // 2 - setX_size = (max_elementnum_per_set - n_drops) * n_sets - setY_size = n_drops * n_sets + setX_size = (max_elementnum_per_set - n_drops) * n_comb + setY_size = n_drops * n_comb self.transform_x = FeatureListTransform( max_set_size=setX_size, apply_shuffle=True, apply_padding=True ) @@ -59,9 +59,9 @@ def __len__(self): return len(self.sets) def __getitem__(self, i): - if self.n_sets > 1: # you can conduct "superset matching" by using n_sets > 1 + if self.n_comb > 1: # you can conduct "superset matching" by using n_comb > 1 indices = np.delete(np.arange(len(self.sets)), i) - indices = np.random.choice(indices, self.n_sets - 1, replace=False) + indices = np.random.choice(indices, self.n_comb - 1, replace=False) indices = [i] + list(indices) else: indices = [i] @@ -306,7 +306,7 @@ def _make_test_examples( n_cands: int = 8, seed: int = 0, ): - print("Make test dataset.") + print("Making test dataset.") np.random.seed(seed) test_sets = json.load(open(path / "test.json")) From 21daad98a6c378004b1897d81699b2b4ff40ff23 Mon Sep 17 00:00:00 2001 From: wildsnowman Date: Sat, 10 Sep 2022 18:12:53 +0900 Subject: [PATCH 2/2] black --- benchmarks/set_matching_pytorch/train_sm.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/benchmarks/set_matching_pytorch/train_sm.py b/benchmarks/set_matching_pytorch/train_sm.py index 1db5249..6fdc427 100644 --- a/benchmarks/set_matching_pytorch/train_sm.py +++ b/benchmarks/set_matching_pytorch/train_sm.py @@ -34,8 +34,12 @@ def get_train_val_loader( train, valid = iqon_outfits.get_trainval_data(label_dir_name) feature_dir = iqon_outfits.feature_dir - train_dataset = MultisetSplitDataset(train, feature_dir, n_comb=n_comb, n_drops=None) - valid_dataset = MultisetSplitDataset(valid, feature_dir, n_comb=n_comb, n_drops=None) + train_dataset = MultisetSplitDataset( + train, feature_dir, n_comb=n_comb, n_drops=None + ) + valid_dataset = MultisetSplitDataset( + valid, feature_dir, n_comb=n_comb, n_drops=None + ) return ( get_loader(train_dataset, batch_size, num_workers=num_workers, is_train=True), get_loader(valid_dataset, batch_size, num_workers=num_workers, is_train=False), @@ -71,10 +75,10 @@ def main(args): # dataset train_loader, valid_loader = get_train_val_loader( - train_year=args.train_year, - valid_year=args.valid_year, - split=args.split, - batch_size=args.batchsize, + train_year=args.train_year, + valid_year=args.valid_year, + split=args.split, + batch_size=args.batchsize, n_comb=args.n_comb, )