Skip to content

Commit

Permalink
Merge pull request #200 from st-tech/199-a-label-creation-bug-on-set-…
Browse files Browse the repository at this point in the history
…matching-with-multiple-splits

use multiple split settings on creating the data
  • Loading branch information
wildsnowman authored Sep 7, 2022
2 parents ced68c8 + d632dfd commit 4de9f4a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmarks/set_matching_pytorch/train_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_train_val_loader(
) -> Tuple[Any, Any]:
label_dir_name = f"{train_year}-{valid_year}-split{split}"

iqon_outfits = IQONOutfits(root=root, split=split)
iqon_outfits = IQONOutfits(root=root)

train, valid = iqon_outfits.get_trainval_data(label_dir_name)
feature_dir = iqon_outfits.feature_dir
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/set_matching_pytorch/train_we.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_train_val_loader(
) -> Tuple[Any, Any]:
label_dir_name = f"{train_year}-{valid_year}-split{split}"

iqon_outfits = IQONOutfits(root=root, split=split)
iqon_outfits = IQONOutfits(root=root)

train, valid = iqon_outfits.get_trainval_data(label_dir_name)
feature_dir = iqon_outfits.feature_dir
Expand Down
9 changes: 5 additions & 4 deletions shift15m/datasets/outfitfeature.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
self,
root: str = C.ROOT,
split: int = 0,
) -> None:
) -> None: # not used
self.root = pathlib.Path(root)
self.root.mkdir(parents=True, exist_ok=True)
if not (self.root / "iqon_outfits.json").exists():
Expand All @@ -191,8 +191,11 @@ def __init__(

self._label_dir = self.root / "set_matching/labels"
if not self._label_dir.exists():
print("Making train/val dataset.")
self._label_dir.mkdir(parents=True, exist_ok=True)
self._make_trainval_dataset(seed=split)
splits = [0, 1, 2]
for _s in splits:
self._make_trainval_dataset(seed=_s)

self._feature_dir = self.root / "features"
if not self._feature_dir.exists():
Expand Down Expand Up @@ -231,8 +234,6 @@ def _make_trainval_dataset(
min_like_num: int = 50,
seed: int = 0,
):
print("Make train/val dataset.")

np.random.seed(seed)
num_train, num_val, num_test = 30816, 3851, 3851 # max size

Expand Down

0 comments on commit 4de9f4a

Please sign in to comment.