Skip to content

Commit

Permalink
✨ ensure all symbols have similar final importance
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasParistech committed Jun 2, 2024
1 parent 1f33975 commit 0a75bf3
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions dobble/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@

RNG = np.random.default_rng(42)

SCALE_TARGETS_LIST = [
[0.5, 0.8, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0],
[0.6, 0.8, 0.9, 0.9, 5.0, 5.0, 5.0, 5.0],
[0.7, 0.7, 0.7, 2, 2., 7., 7., 7.],
[0.8, 0.8, 0.8, 0.8, 5., 5., 7., 7.]
SCALE_TARGETS_LIST: List[List[float]] = [
[0.8, 0.8, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0],
[0.8, 0.8, 0.9, 0.9, 5.0, 5.0, 5.0, 5.0],
[0.8, 0.8, 0.8, 2, 2., 5., 5., 5.]
]
# Account for the added margin around the image
SCALE_TARGETS_LIST = [[np.sqrt(2)*s for s in scales]
for scales in SCALE_TARGETS_LIST]

DISK_OCCUPANCY_TARGET = 0.8
DISK_OCCUPANCY_TARGET = 0.85

TRANSLATION_MAX_STEP = 60
ANGLE_MAX_STEP = 90
Expand Down Expand Up @@ -142,7 +141,8 @@ class Card:
enough to estimate the overlap.
"""

def __init__(self, masks: List[np.ndarray]) -> None:
def __init__(self, masks: List[np.ndarray],
scale_targets: List[float]) -> None:
"""Init from a list of 8 low-resolution symbol masks"""
assert len(masks) == 8

Expand All @@ -154,10 +154,6 @@ def __init__(self, masks: List[np.ndarray]) -> None:
RNG.shuffle(list_xy)
list_xy = list_xy[:8]

idx = RNG.integers(0, len(SCALE_TARGETS_LIST))
scale_targets = SCALE_TARGETS_LIST[idx]
RNG.shuffle(scale_targets)

total_target_area = sum(scale_target*np.count_nonzero(mask)
for mask, scale_target in zip(masks, scale_targets))
disk_area = np.pi * self.size_pix**2 / 4
Expand Down Expand Up @@ -222,6 +218,32 @@ def imshow(self, wait_key: int = 0):
cv2.waitKey(wait_key)


def allocate_scale_targets(cards: List[List[int]]) -> List[List[float]]:
"""Allocate scale targets while ensuring that each symbol appears at least once with a large scale."""
scales_per_symbol = [[] for _ in range(57)]

def compute_score(scales: List[float]) -> float:
if len(scales) == 0:
return -np.inf
return np.max(scales) * np.mean(scales)

all_scale_targets: List[List[float]] = []
for symbols in cards:

idx = RNG.integers(0, len(SCALE_TARGETS_LIST))
scale_targets = np.array(sorted(SCALE_TARGETS_LIST[idx], reverse=True))

scores = [compute_score(scales_per_symbol[s]) for s in symbols]
scale_targets[np.argsort(scores)] = scale_targets

for s, scale in zip(symbols, scale_targets):
scales_per_symbol[s].append(scale)

all_scale_targets.append(scale_targets.tolist())

return all_scale_targets


def main(masks_folder: str,
symbols_folder: str,
out_cards_folder: str,
Expand All @@ -247,14 +269,17 @@ def main(masks_folder: str,

new_folder(out_cards_folder)

for card_idx, symbols in enumerate(tqdm(cards, "Cards")):
scale_targets_per_card = allocate_scale_targets(cards)
assert_len(scale_targets_per_card, 57)
for card_idx, (symbols, scale_targets) in enumerate(tqdm(zip(cards, scale_targets_per_card),
desc="Cards", total=57)):
card_path = os.path.join(out_cards_folder, f"card_{card_idx}.png")

masks = [cv2.imread(os.path.join(masks_folder, names[symbol_idx]),
cv2.IMREAD_GRAYSCALE)
for symbol_idx in symbols]

card = Card(masks)
card = Card(masks, scale_targets)
for _ in range(n_iter):
card.next(DEBUG)

Expand Down

0 comments on commit 0a75bf3

Please sign in to comment.