-
Notifications
You must be signed in to change notification settings - Fork 0
/
hexagon.py
217 lines (187 loc) · 8.77 KB
/
hexagon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from __future__ import annotations
import sys
from typing import Dict, List, Tuple, Set
class Hex:
NEIGHBOURS: List[Tuple[int]] = [
(1, 0, -1), # right
(0, 1, -1), # bottom right
(-1, 1, 0), # bottom left
(-1, 0, 1), # left
(0, -1, 1), # top left
(1, -1, 0), # top right
]
def __init__(self, q: int, r: int, s: int):
assert(q + r + s == 0)
self.q = q
self.r = r
self.s = s
def __eq__(self, other: Hex):
"""Checks if 2 hexagons in a map are the same."""
return self.q == other.q and self.r == other.r and self.s == other.s
def __key(self):
return self.q, self.r, self.s
def __hash__(self):
return hash(self.__key())
@classmethod
def hex_add(cls, a: Hex, b: Hex):
"""Adds coordinates of 2 hexagons."""
return Hex(a.q + b.q, a.r + b.r, a.s + b.s)
def neighbour(self, direction: int):
"""Calculates the next hexagon in the specified direction.
Args:
direction: Direction to step in, must be between 0 and 5
"""
return self.hex_add(self, Hex(*self.NEIGHBOURS[direction]))
class HexMap:
"""Represents a hexagonal map of hexagons."""
def __init__(self, radius: int):
self._map: Dict[Hex, int] = {}
self.radius: int = radius
self.diameter: int = 2 * radius + 1
# Create a hexagon shape from hexagons
for q in range(-radius, radius + 1):
r1 = max(-radius, -q - radius)
r2 = min(radius, -q + radius)
for r in range(r1, r2 + 1):
self._map[Hex(q, r, -q - r)] = 0
# Hash maps used for quickly checking whether the map satisfies the condition
self.rows: List[Set[int]] = [set() for _ in range(self.diameter)]
self.diags1: List[Set[int]] = [set() for _ in range(self.diameter)]
self.diags2: List[Set[int]] = [set() for _ in range(self.diameter)]
def get_direction(self, start: Hex, direction: int):
"""Creates an iterator over a row/diagonal in the given direction."""
# Rollback to the start of the row
while start in self._map:
start = start.neighbour(direction + 3)
start = start.neighbour(direction)
# Yield the row
while start in self._map:
yield start
start = start.neighbour(direction)
def print(self):
"""Prints the text representation of the map."""
min_row_len = current_row_len = self.diameter - self.radius
# start in the top left corner
starting_hexagon = Hex(0, -self.radius, self.radius)
shift = 1
while current_row_len > min_row_len - 1:
starting_spaces = self.diameter - current_row_len
print(starting_spaces * ' ', end='')
for hexagon in self.get_direction(starting_hexagon, 0):
print(self._map[hexagon], end=' ')
print()
if current_row_len == self.diameter:
# We've reached the middle row, the rows are shrinking
shift = -1
current_row_len += shift
if shift == 1:
starting_hexagon = starting_hexagon.neighbour(2)
else:
starting_hexagon = starting_hexagon.neighbour(1)
print()
def check_condition(self, choices: List[int]):
"""Checks whether the current state satisfies the task condition."""
for sets in (self.rows, self.diags1, self.diags2):
for value_set in sets:
if len(value_set) != len(choices):
return False
return True
def set_value(self, current_node: Hex, new_value: int, row_index: int, diag1_index: int, diag2_index: int):
"""Sets the value of the current_node to new_value.
Returns whether the value was set successfully (returns False if
the number has already been used in the current row/diagonal).
"""
if new_value != 0:
if new_value in self.rows[row_index] or new_value in self.diags1[diag1_index] or new_value in \
self.diags2[diag2_index]:
return False
self.rows[row_index].add(new_value)
self.diags1[diag1_index].add(new_value)
self.diags2[diag2_index].add(new_value)
self._map[current_node] = new_value
return True
def reset_value(self, current_node: Hex, row_index: int, diag1_index: int, diag2_index: int):
"""Resets a node to its initial state."""
current_value = self._map[current_node]
self._map[current_node] = 0
if current_value != 0:
self.rows[row_index].remove(current_value)
self.diags1[diag1_index].remove(current_value)
self.diags2[diag2_index].remove(current_value)
def is_invalid_row_or_diag(self, direction: int, choices: List[int], row_index: int,
diag1_index: int, diag2_index: int):
"""Checks if the recently finished row or diagonal is valid."""
invalid = False
if direction % 3 == 0:
# We just finished a row, check the rows
if len(choices) != len(self.rows[row_index]):
invalid = True
elif direction % 3 == 1:
# We just finished top right - bottom left diag
if len(choices) != len(self.diags2[diag2_index]):
invalid = True
elif direction % 3 == 2:
# We just finished top left - bottom right diag
if len(choices) != len(self.diags1[diag1_index]):
invalid = True
return invalid
def _walk(self, current_node: Hex, direction: int, remaining_steps: int, level: int,
choices: List[int], choices_index: int):
row_index = current_node.r + self.radius
diag1_index = current_node.s + self.radius
diag2_index = current_node.q + self.radius
# Base case for recursion
if current_node == Hex(0, 0, 0):
for new_value in (0, choices[choices_index]):
if not self.set_value(current_node, new_value, row_index, diag1_index, diag2_index):
continue
if self.check_condition(choices):
self.print()
self.reset_value(current_node, row_index, diag1_index, diag2_index)
return
# We can either leave the current cell empty (0) or fill it with the next number
for new_value in (0, choices[choices_index]):
new_choices_index = choices_index
new_level = level
new_direction = direction
new_remaining_steps = remaining_steps
next_node = current_node
if new_value != 0:
new_choices_index = (choices_index + 1) % len(choices)
if not self.set_value(current_node, new_value, row_index, diag1_index, diag2_index):
continue
if remaining_steps == 0:
# One direction finished, lets try to cut down the search tree
if self.is_invalid_row_or_diag(direction, choices, row_index, diag1_index, diag2_index):
self.reset_value(current_node, row_index, diag1_index, diag2_index)
continue
# Change direction
new_direction = (direction + 1) % 6
new_remaining_steps = self.radius - level - (1 if new_direction == 5 else 0)
# Last element before center, there is no steps to go to in top right direction, just skip it.
if new_direction == 5 and new_remaining_steps == 0:
new_direction = 0
new_remaining_steps = 1
if new_direction == 0:
new_level += 1
if new_remaining_steps != 0:
next_node = current_node.neighbour(new_direction)
new_remaining_steps -= 1
self._walk(next_node, new_direction, new_remaining_steps, new_level, choices, new_choices_index)
self.reset_value(current_node, row_index, diag1_index, diag2_index)
def solve_problem(self, solve_for: int):
"""Solves the hexagon problem trying to fit `solve_for` numbers into the hexagon."""
current_node = Hex(0, -self.radius, self.radius)
self._walk(current_node, 0, self.radius, 0, [x + 1 for x in range(solve_for)], 0)
if __name__ == '__main__':
if len(sys.argv) != 3:
print("Invalid number of arguments!\nUsage: python3 hexagon.py <grid radius> <number of choices>")
sys.exit(1)
try:
radius = int(sys.argv[1], 10)
choices = int(sys.argv[2], 10)
except ValueError:
print("Arguments must be numbers!\nUsage: python3 hexagon.py <grid radius> <number of choices>")
sys.exit(1)
hex_map = HexMap(radius)
hex_map.solve_problem(choices)