Skip to content

Commit

Permalink
Merge pull request #202 from st-tech/superset_matching
Browse files Browse the repository at this point in the history
add n_comb options
  • Loading branch information
wildsnowman authored Sep 10, 2022
2 parents 4de9f4a + 21daad9 commit 39101d9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
3 changes: 2 additions & 1 deletion benchmarks/set_matching_pytorch/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_test_loader(

iqon_outfits = IQONOutfits(root=root, split=split)

test_examples = iqon_outfits.get_fitb_data(label_dir_name)
test_examples = iqon_outfits.get_fitb_data(label_dir_name, n_comb=args.n_comb)
feature_dir = iqon_outfits.feature_dir
dataset = FINBsDataset(
test_examples,
Expand Down Expand Up @@ -113,6 +113,7 @@ def main(args):
parser.add_argument("--valid_year", type=int)
parser.add_argument("--split", type=int, choices=[0, 1, 2])
parser.add_argument("--model_dir", "-d", type=str)
parser.add_argument("--n_comb", type=int, default=1)
args = parser.parse_args()

main(args)
16 changes: 13 additions & 3 deletions benchmarks/set_matching_pytorch/train_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_train_val_loader(
valid_year: Union[str, int],
split: int,
batch_size: int,
n_comb: int,
root: str = C.ROOT,
num_workers: Optional[int] = None,
) -> Tuple[Any, Any]:
Expand All @@ -33,8 +34,12 @@ def get_train_val_loader(

train, valid = iqon_outfits.get_trainval_data(label_dir_name)
feature_dir = iqon_outfits.feature_dir
train_dataset = MultisetSplitDataset(train, feature_dir, n_sets=1, n_drops=None)
valid_dataset = MultisetSplitDataset(valid, feature_dir, n_sets=1, n_drops=None)
train_dataset = MultisetSplitDataset(
train, feature_dir, n_comb=n_comb, n_drops=None
)
valid_dataset = MultisetSplitDataset(
valid, feature_dir, n_comb=n_comb, n_drops=None
)
return (
get_loader(train_dataset, batch_size, num_workers=num_workers, is_train=True),
get_loader(valid_dataset, batch_size, num_workers=num_workers, is_train=False),
Expand Down Expand Up @@ -70,7 +75,11 @@ def main(args):

# dataset
train_loader, valid_loader = get_train_val_loader(
args.train_year, args.valid_year, args.split, args.batchsize
train_year=args.train_year,
valid_year=args.valid_year,
split=args.split,
batch_size=args.batchsize,
n_comb=args.n_comb,
)

# logger
Expand Down Expand Up @@ -222,6 +231,7 @@ def eval_process(engine, batch):
parser.add_argument("--valid_year", type=int)
parser.add_argument("--split", type=int, choices=[0, 1, 2])
parser.add_argument("--weight_path", "-w", type=str, default=None)
parser.add_argument("--n_comb", type=int, default=1)

args = parser.parse_args()

Expand Down
14 changes: 7 additions & 7 deletions shift15m/datasets/outfitfeature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ def __init__(
self,
sets: List,
root: pathlib.Path,
n_sets: int,
n_comb: int,
n_drops: Optional[int] = None,
max_elementnum_per_set: Optional[int] = 8,
):
self.sets = sets
self.feat_dir = root
self.n_sets = n_sets
self.n_comb = n_comb
self.n_drops = n_drops
if n_drops is None:
n_drops = max_elementnum_per_set // 2
setX_size = (max_elementnum_per_set - n_drops) * n_sets
setY_size = n_drops * n_sets
setX_size = (max_elementnum_per_set - n_drops) * n_comb
setY_size = n_drops * n_comb
self.transform_x = FeatureListTransform(
max_set_size=setX_size, apply_shuffle=True, apply_padding=True
)
Expand All @@ -59,9 +59,9 @@ def __len__(self):
return len(self.sets)

def __getitem__(self, i):
if self.n_sets > 1: # you can conduct "superset matching" by using n_sets > 1
if self.n_comb > 1: # you can conduct "superset matching" by using n_comb > 1
indices = np.delete(np.arange(len(self.sets)), i)
indices = np.random.choice(indices, self.n_sets - 1, replace=False)
indices = np.random.choice(indices, self.n_comb - 1, replace=False)
indices = [i] + list(indices)
else:
indices = [i]
Expand Down Expand Up @@ -306,7 +306,7 @@ def _make_test_examples(
n_cands: int = 8,
seed: int = 0,
):
print("Make test dataset.")
print("Making test dataset.")
np.random.seed(seed)

test_sets = json.load(open(path / "test.json"))
Expand Down

0 comments on commit 39101d9

Please sign in to comment.