Upload 3 files
Browse files- hongik/board_ai.py +332 -0
- hongik/engine_ai.py +571 -0
- hongik/hongik_ai.py +358 -0
hongik/board_ai.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implements the world of Go, defining the game board, its rules
|
| 2 |
+
# (moves, liberties, scoring, etc.), and determining the final winner.
|
| 3 |
+
#
|
| 4 |
+
# Author: Gemini 2.5 Pro, Gemini 2.5 Flash
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from collections import deque
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
class IllegalMoveError(ValueError):
|
| 11 |
+
"""Exception class raised for an illegal move."""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
class Board:
|
| 15 |
+
"""
|
| 16 |
+
Represents the Go board and enforces game rules.
|
| 17 |
+
"""
|
| 18 |
+
EMPTY, BLACK, WHITE, WALL = 0, 1, 2, 3
|
| 19 |
+
PASS_LOC = 0
|
| 20 |
+
|
| 21 |
+
def __init__(self, size):
|
| 22 |
+
"""Initializes the board, sets up walls, and history."""
|
| 23 |
+
if isinstance(size, tuple):
|
| 24 |
+
self.x_size, self.y_size = size
|
| 25 |
+
else:
|
| 26 |
+
self.x_size = self.y_size = size
|
| 27 |
+
|
| 28 |
+
self.arrsize = (self.x_size + 1) * (self.y_size + 2) + 1
|
| 29 |
+
self.dy = self.x_size + 1
|
| 30 |
+
self.adj = [-self.dy, -1, 1, self.dy]
|
| 31 |
+
|
| 32 |
+
self.board = np.zeros(shape=(self.arrsize), dtype=np.int8)
|
| 33 |
+
self.pla = Board.BLACK
|
| 34 |
+
self.prisoners = {Board.BLACK: 0, Board.WHITE: 0}
|
| 35 |
+
self.ko_points = set()
|
| 36 |
+
self.consecutive_passes = 0
|
| 37 |
+
self.turns = 0
|
| 38 |
+
self.ko_recapture_counts = {}
|
| 39 |
+
self.position_history = set()
|
| 40 |
+
self.position_history.add(self.board.tobytes())
|
| 41 |
+
|
| 42 |
+
for i in range(-1, self.x_size + 1):
|
| 43 |
+
self.board[self.loc(i, -1)] = Board.WALL
|
| 44 |
+
self.board[self.loc(i, self.y_size)] = Board.WALL
|
| 45 |
+
for i in range(-1, self.y_size + 1):
|
| 46 |
+
self.board[self.loc(-1, i)] = Board.WALL
|
| 47 |
+
self.board[self.loc(self.x_size, i)] = Board.WALL
|
| 48 |
+
|
| 49 |
+
def copy(self):
|
| 50 |
+
"""Creates a deep copy of the current board state."""
|
| 51 |
+
new_board = Board((self.x_size, self.y_size))
|
| 52 |
+
new_board.board = np.copy(self.board)
|
| 53 |
+
new_board.pla = self.pla
|
| 54 |
+
new_board.prisoners = self.prisoners.copy()
|
| 55 |
+
new_board.ko_points = self.ko_points.copy()
|
| 56 |
+
new_board.consecutive_passes = self.consecutive_passes
|
| 57 |
+
new_board.turns = self.turns
|
| 58 |
+
new_board.ko_recapture_counts = self.ko_recapture_counts.copy()
|
| 59 |
+
new_board.position_history = self.position_history.copy()
|
| 60 |
+
return new_board
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_opp(player):
|
| 64 |
+
"""Gets the opponent of the given player."""
|
| 65 |
+
return 3 - player
|
| 66 |
+
|
| 67 |
+
def loc(self, x, y):
|
| 68 |
+
"""Converts (x, y) coordinates to a 1D array location."""
|
| 69 |
+
return (x + 1) + self.dy * (y + 1)
|
| 70 |
+
|
| 71 |
+
def loc_to_coord(self, loc):
|
| 72 |
+
"""Converts a 1D array location back to (x, y) coordinates."""
|
| 73 |
+
return (loc % self.dy) - 1, (loc // self.dy) - 1
|
| 74 |
+
|
| 75 |
+
def is_on_board(self, loc):
|
| 76 |
+
"""Checks if a location is within the board boundaries (not a wall)."""
|
| 77 |
+
return self.board[loc] != Board.WALL
|
| 78 |
+
|
| 79 |
+
def _get_group_info(self, loc):
|
| 80 |
+
"""Scans and returns the stones and liberties of a group at a specific location."""
|
| 81 |
+
if not self.is_on_board(loc) or self.board[loc] == self.EMPTY:
|
| 82 |
+
return None, None
|
| 83 |
+
|
| 84 |
+
player = self.board[loc]
|
| 85 |
+
group_stones, liberties = set(), set()
|
| 86 |
+
q, visited = deque([loc]), {loc}
|
| 87 |
+
|
| 88 |
+
while q:
|
| 89 |
+
current_loc = q.popleft()
|
| 90 |
+
group_stones.add(current_loc)
|
| 91 |
+
for dloc in self.adj:
|
| 92 |
+
adj_loc = current_loc + dloc
|
| 93 |
+
if self.is_on_board(adj_loc):
|
| 94 |
+
adj_stone = self.board[adj_loc]
|
| 95 |
+
if adj_stone == self.EMPTY:
|
| 96 |
+
liberties.add(adj_loc)
|
| 97 |
+
elif adj_stone == player and adj_loc not in visited:
|
| 98 |
+
visited.add(adj_loc)
|
| 99 |
+
q.append(adj_loc)
|
| 100 |
+
return group_stones, liberties
|
| 101 |
+
|
| 102 |
+
def would_be_legal(self, player, loc):
|
| 103 |
+
"""Checks if a move would be legal without actually playing it."""
|
| 104 |
+
if loc == self.PASS_LOC: return True
|
| 105 |
+
if not self.is_on_board(loc) or self.board[loc] != self.EMPTY or loc in self.ko_points:
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
temp_board = self.copy()
|
| 109 |
+
temp_board.board[loc] = player
|
| 110 |
+
|
| 111 |
+
opponent = self.get_opp(player)
|
| 112 |
+
captured_any = False
|
| 113 |
+
captured_stones = set()
|
| 114 |
+
|
| 115 |
+
for dloc in temp_board.adj:
|
| 116 |
+
adj_loc = loc + dloc
|
| 117 |
+
if temp_board.board[adj_loc] == opponent:
|
| 118 |
+
group, libs = temp_board._get_group_info(adj_loc)
|
| 119 |
+
if not libs:
|
| 120 |
+
captured_any = True
|
| 121 |
+
captured_stones.update(group)
|
| 122 |
+
|
| 123 |
+
if captured_any:
|
| 124 |
+
for captured_loc in captured_stones:
|
| 125 |
+
temp_board.board[captured_loc] = self.EMPTY
|
| 126 |
+
|
| 127 |
+
next_board_hash = temp_board.board.tobytes()
|
| 128 |
+
if next_board_hash in self.position_history:
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
if captured_any:
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
_, my_libs = temp_board._get_group_info(loc)
|
| 135 |
+
return bool(my_libs)
|
| 136 |
+
|
| 137 |
+
def get_features(self):
|
| 138 |
+
"""Generates a feature tensor for the neural network input."""
|
| 139 |
+
features = np.zeros((self.y_size, self.x_size, 3), dtype=np.float32)
|
| 140 |
+
|
| 141 |
+
current_player = self.pla
|
| 142 |
+
opponent_player = self.get_opp(self.pla)
|
| 143 |
+
|
| 144 |
+
for y in range(self.y_size):
|
| 145 |
+
for x in range(self.x_size):
|
| 146 |
+
loc = self.loc(x, y)
|
| 147 |
+
stone = self.board[loc]
|
| 148 |
+
if stone == current_player:
|
| 149 |
+
features[y, x, 0] = 1
|
| 150 |
+
elif stone == opponent_player:
|
| 151 |
+
features[y, x, 1] = 1
|
| 152 |
+
|
| 153 |
+
if self.pla == self.WHITE:
|
| 154 |
+
features[:, :, 2] = 1.0
|
| 155 |
+
|
| 156 |
+
return features
|
| 157 |
+
|
| 158 |
+
def is_game_over(self):
|
| 159 |
+
"""Checks if the game is over (due to consecutive passes)."""
|
| 160 |
+
return self.consecutive_passes >= 2
|
| 161 |
+
|
| 162 |
+
def play(self, player, loc):
|
| 163 |
+
"""Plays a move on the board, captures stones, and updates the game state."""
|
| 164 |
+
if not self.would_be_legal(player, loc):
|
| 165 |
+
raise IllegalMoveError("This move is against the rules.")
|
| 166 |
+
|
| 167 |
+
self.ko_points.clear()
|
| 168 |
+
|
| 169 |
+
if loc == self.PASS_LOC:
|
| 170 |
+
self.consecutive_passes += 1
|
| 171 |
+
else:
|
| 172 |
+
self.consecutive_passes = 0
|
| 173 |
+
self.board[loc] = player
|
| 174 |
+
opponent = self.get_opp(player)
|
| 175 |
+
|
| 176 |
+
captured_stones = set()
|
| 177 |
+
|
| 178 |
+
for dloc in self.adj:
|
| 179 |
+
adj_loc = loc + dloc
|
| 180 |
+
if self.board[adj_loc] == opponent:
|
| 181 |
+
group, libs = self._get_group_info(adj_loc)
|
| 182 |
+
if not libs:
|
| 183 |
+
captured_stones.update(group)
|
| 184 |
+
|
| 185 |
+
if captured_stones:
|
| 186 |
+
self.prisoners[player] += len(captured_stones)
|
| 187 |
+
for captured_loc in captured_stones:
|
| 188 |
+
self.board[captured_loc] = self.EMPTY
|
| 189 |
+
|
| 190 |
+
my_group, my_libs = self._get_group_info(loc)
|
| 191 |
+
|
| 192 |
+
if len(captured_stones) == 1 and len(my_group) == 1 and len(my_libs) == 1:
|
| 193 |
+
ko_loc = captured_stones.pop()
|
| 194 |
+
self.ko_points.add(ko_loc)
|
| 195 |
+
|
| 196 |
+
board_hash = self.board.tobytes()
|
| 197 |
+
self.position_history.add(board_hash)
|
| 198 |
+
self.pla = self.get_opp(player)
|
| 199 |
+
self.turns += 1
|
| 200 |
+
|
| 201 |
+
def _is_group_alive_statically(self, group_stones: set, board_state: np.ndarray) -> bool:
|
| 202 |
+
"""Statically analyzes if a group is alive by checking for two eyes."""
|
| 203 |
+
if not group_stones: return False
|
| 204 |
+
owner_player = board_state[next(iter(group_stones))]
|
| 205 |
+
eye_locations = set()
|
| 206 |
+
for stone_loc in group_stones:
|
| 207 |
+
for dloc in self.adj:
|
| 208 |
+
adj_loc = stone_loc + dloc
|
| 209 |
+
if board_state[adj_loc] == self.EMPTY: eye_locations.add(adj_loc)
|
| 210 |
+
real_eye_count, visited_eye_locs = 0, set()
|
| 211 |
+
for eye_loc in eye_locations:
|
| 212 |
+
if eye_loc in visited_eye_locs: continue
|
| 213 |
+
eye_region, q, is_real_eye = set(), deque([eye_loc]), True
|
| 214 |
+
visited_eye_locs.add(eye_loc); eye_region.add(eye_loc)
|
| 215 |
+
while q:
|
| 216 |
+
current_loc = q.popleft()
|
| 217 |
+
for dloc in self.adj:
|
| 218 |
+
adj_loc = current_loc + dloc
|
| 219 |
+
if self.is_on_board(adj_loc):
|
| 220 |
+
if board_state[adj_loc] == self.get_opp(owner_player):
|
| 221 |
+
is_real_eye = False; break
|
| 222 |
+
elif board_state[adj_loc] == self.EMPTY and adj_loc not in visited_eye_locs:
|
| 223 |
+
visited_eye_locs.add(adj_loc); eye_region.add(adj_loc); q.append(adj_loc)
|
| 224 |
+
if not is_real_eye: break
|
| 225 |
+
if is_real_eye:
|
| 226 |
+
eye_size = len(eye_region)
|
| 227 |
+
if eye_size >= 6: real_eye_count += 2
|
| 228 |
+
else: real_eye_count += 1
|
| 229 |
+
if real_eye_count >= 2: return True
|
| 230 |
+
return real_eye_count >= 2
|
| 231 |
+
|
| 232 |
+
def _is_group_alive_by_rollout(self, group_stones_initial: set) -> bool:
|
| 233 |
+
"""Determines if a group is alive via Monte Carlo rollouts for ambiguous cases."""
|
| 234 |
+
NUM_ROLLOUTS = 20
|
| 235 |
+
MAX_ROLLOUT_DEPTH = self.x_size * self.y_size // 2
|
| 236 |
+
|
| 237 |
+
owner_player = self.board[next(iter(group_stones_initial))]
|
| 238 |
+
attacker = self.get_opp(owner_player)
|
| 239 |
+
deaths = 0
|
| 240 |
+
|
| 241 |
+
for _ in range(NUM_ROLLOUTS):
|
| 242 |
+
rollout_board = self.copy()
|
| 243 |
+
rollout_board.pla = attacker
|
| 244 |
+
|
| 245 |
+
for _ in range(MAX_ROLLOUT_DEPTH):
|
| 246 |
+
first_stone_loc = next(iter(group_stones_initial))
|
| 247 |
+
if rollout_board.board[first_stone_loc] != owner_player:
|
| 248 |
+
deaths += 1
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
legal_moves = [loc for loc in range(1, self.arrsize) if rollout_board.board[loc] == self.EMPTY]
|
| 252 |
+
random.shuffle(legal_moves)
|
| 253 |
+
|
| 254 |
+
move_made = False
|
| 255 |
+
for move in legal_moves:
|
| 256 |
+
if rollout_board.would_be_legal(rollout_board.pla, move):
|
| 257 |
+
rollout_board.play(rollout_board.pla, move)
|
| 258 |
+
move_made = True
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
if not move_made:
|
| 262 |
+
rollout_board.play(rollout_board.pla, self.PASS_LOC)
|
| 263 |
+
|
| 264 |
+
if rollout_board.is_game_over():
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
death_rate = deaths / NUM_ROLLOUTS
|
| 268 |
+
print(f"[Life/Death Log] Group survival probability: {1-death_rate:.0%}")
|
| 269 |
+
return death_rate < 0.5
|
| 270 |
+
|
| 271 |
+
def get_winner(self, komi=6.5):
|
| 272 |
+
"""Calculates the final score and determines the winner, handling life and death."""
|
| 273 |
+
temp_board_state = np.copy(self.board)
|
| 274 |
+
total_captives = self.prisoners.copy()
|
| 275 |
+
|
| 276 |
+
all_groups = self._find_all_groups(temp_board_state)
|
| 277 |
+
for player, groups in all_groups.items():
|
| 278 |
+
for group_stones in groups:
|
| 279 |
+
is_alive = self._is_group_alive_statically(group_stones, temp_board_state)
|
| 280 |
+
|
| 281 |
+
if not is_alive:
|
| 282 |
+
is_alive = self._is_group_alive_by_rollout(group_stones)
|
| 283 |
+
|
| 284 |
+
if not is_alive:
|
| 285 |
+
total_captives[self.get_opp(player)] += len(group_stones)
|
| 286 |
+
for stone_loc in group_stones:
|
| 287 |
+
temp_board_state[stone_loc] = self.EMPTY
|
| 288 |
+
|
| 289 |
+
final_board_with_territory = self._calculate_territory(temp_board_state)
|
| 290 |
+
black_territory = np.sum((final_board_with_territory == self.BLACK) & (temp_board_state == self.EMPTY))
|
| 291 |
+
white_territory = np.sum((final_board_with_territory == self.WHITE) & (temp_board_state == self.EMPTY))
|
| 292 |
+
|
| 293 |
+
black_score = black_territory + total_captives.get(self.BLACK, 0)
|
| 294 |
+
white_score = white_territory + total_captives.get(self.WHITE, 0) + komi
|
| 295 |
+
winner = self.BLACK if black_score > white_score else self.WHITE
|
| 296 |
+
|
| 297 |
+
return winner, black_score, white_score, total_captives
|
| 298 |
+
|
| 299 |
+
def _find_all_groups(self, board_state: np.ndarray) -> dict:
|
| 300 |
+
"""Finds all stone groups on the board for a given board state."""
|
| 301 |
+
visited, all_groups = set(), {self.BLACK: [], self.WHITE: []}
|
| 302 |
+
for loc in range(self.arrsize):
|
| 303 |
+
if board_state[loc] in [self.BLACK, self.WHITE] and loc not in visited:
|
| 304 |
+
player, group_stones, q = board_state[loc], set(), deque([loc])
|
| 305 |
+
visited.add(loc); group_stones.add(loc)
|
| 306 |
+
while q:
|
| 307 |
+
current_loc = q.popleft()
|
| 308 |
+
for dloc in self.adj:
|
| 309 |
+
adj_loc = current_loc + dloc
|
| 310 |
+
if board_state[adj_loc] == player and adj_loc not in visited:
|
| 311 |
+
visited.add(adj_loc); group_stones.add(adj_loc); q.append(adj_loc)
|
| 312 |
+
all_groups[player].append(group_stones)
|
| 313 |
+
return all_groups
|
| 314 |
+
|
| 315 |
+
def _calculate_territory(self, board_state: np.ndarray) -> np.ndarray:
|
| 316 |
+
"""Calculates the territory for each player on a given board state."""
|
| 317 |
+
territory_map, visited = np.copy(board_state), set()
|
| 318 |
+
for loc in range(self.arrsize):
|
| 319 |
+
if territory_map[loc] == self.EMPTY and loc not in visited:
|
| 320 |
+
region_points, border_colors, q = set(), set(), deque([loc])
|
| 321 |
+
visited.add(loc); region_points.add(loc)
|
| 322 |
+
while q:
|
| 323 |
+
current_loc = q.popleft()
|
| 324 |
+
for dloc in self.adj:
|
| 325 |
+
adj_loc = current_loc + dloc
|
| 326 |
+
if board_state[adj_loc] in [self.BLACK, self.WHITE]: border_colors.add(board_state[adj_loc])
|
| 327 |
+
elif board_state[adj_loc] == self.EMPTY and adj_loc not in visited:
|
| 328 |
+
visited.add(adj_loc); region_points.add(adj_loc); q.append(adj_loc)
|
| 329 |
+
if len(border_colors) == 1:
|
| 330 |
+
owner = border_colors.pop()
|
| 331 |
+
for point in region_points: territory_map[point] = owner
|
| 332 |
+
return territory_map
|
hongik/engine_ai.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The engine file that serves as the AI's brain, responsible for training
|
| 2 |
+
# the model through reinforcement learning and performing move analysis.
|
| 3 |
+
#
|
| 4 |
+
# Author: Gemini 2.5 Pro, Gemini 2.5 Flash
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 8 |
+
|
| 9 |
+
import threading
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
tf.get_logger().setLevel('ERROR')
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import time
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
import traceback
|
| 18 |
+
from collections import deque
|
| 19 |
+
import pickle
|
| 20 |
+
import csv
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
|
| 23 |
+
from hongik.hongik_ai import HongikAIPlayer,CNNTransformerHybrid
|
| 24 |
+
from katrain.core.sgf_parser import Move
|
| 25 |
+
from katrain.core.constants import *
|
| 26 |
+
from kivy.clock import Clock
|
| 27 |
+
from hongik.board_ai import Board, IllegalMoveError
|
| 28 |
+
from katrain.core.game_node import GameNode
|
| 29 |
+
from katrain.gui.theme import Theme
|
| 30 |
+
|
| 31 |
+
class BaseEngine:
|
| 32 |
+
"""Base class for KaTrain engines."""
|
| 33 |
+
def __init__(self, katrain, config):
|
| 34 |
+
self.katrain, self.config = katrain, config
|
| 35 |
+
def on_error(self, message, code=None, allow_popup=True):
|
| 36 |
+
print(f"ERROR: {message}", OUTPUT_ERROR)
|
| 37 |
+
if allow_popup and hasattr(self.katrain, "engine_recovery_popup"):
|
| 38 |
+
Clock.schedule_once(lambda dt: self.katrain("engine_recovery_popup", message, code))
|
| 39 |
+
|
| 40 |
+
class HongikAIEngine(BaseEngine):
|
| 41 |
+
"""
|
| 42 |
+
Main AI engine that manages the model, self-play, training, and analysis.
|
| 43 |
+
It orchestrates the entire reinforcement learning loop.
|
| 44 |
+
"""
|
| 45 |
+
BOARD_SIZE, NUM_LAYERS, D_MODEL, NUM_HEADS, D_FF = 19, 7, 256, 8, 1024
|
| 46 |
+
SAVE_WEIGHTS_EVERY_STEPS, EVALUATION_EVERY_STEPS = 5, 20
|
| 47 |
+
REPLAY_BUFFER_SIZE, TRAINING_BATCH_SIZE = 200000, 32
|
| 48 |
+
CHECKPOINT_EVERY_GAMES = 10
|
| 49 |
+
|
| 50 |
+
RULES = {
|
| 51 |
+
"tromp-taylor": {"name": "Tromp-Taylor", "komi": 7.5, "scoring": "area"},
|
| 52 |
+
"korean": {"name": "korean", "komi": 6.5, "scoring": "territory"},
|
| 53 |
+
"chinese": {"name": "Chinese", "komi": 7.5, "scoring": "area"}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def get_rules(ruleset: str):
|
| 58 |
+
"""Returns the ruleset details for a given ruleset name."""
|
| 59 |
+
if not ruleset or ruleset.lower() not in HongikAIEngine.RULES:
|
| 60 |
+
ruleset = "korean"
|
| 61 |
+
return HongikAIEngine.RULES[ruleset.lower()]
|
| 62 |
+
|
| 63 |
+
def __init__(self, katrain, config):
|
| 64 |
+
"""
|
| 65 |
+
Initializes the Hongik AI Engine. This involves setting up paths, loading
|
| 66 |
+
the neural network model and replay buffer, and preparing for training.
|
| 67 |
+
"""
|
| 68 |
+
super().__init__(katrain, config)
|
| 69 |
+
print("Initializing Hongik AI Integrated Engine...", OUTPUT_DEBUG)
|
| 70 |
+
|
| 71 |
+
from appdirs import user_data_dir
|
| 72 |
+
APP_NAME = "HongikAI"
|
| 73 |
+
APP_AUTHOR = "NamyongPark"
|
| 74 |
+
|
| 75 |
+
self.BASE_PATH = user_data_dir(APP_NAME, APP_AUTHOR)
|
| 76 |
+
print(f"Data will be stored in: {self.BASE_PATH}")
|
| 77 |
+
|
| 78 |
+
self.REPLAY_BUFFER_PATH = os.path.join(self.BASE_PATH, "replay_buffer.pkl")
|
| 79 |
+
self.WEIGHTS_FILE_PATH = os.path.join(self.BASE_PATH, "hongik_ai_memory.weights.h5")
|
| 80 |
+
self.BEST_WEIGHTS_FILE_PATH = os.path.join(self.BASE_PATH, "hongik_ai_best.weights.h5")
|
| 81 |
+
self.CHECKPOINT_BUFFER_PATH = os.path.join(self.BASE_PATH, "replay_buffer_checkpoint.pkl")
|
| 82 |
+
self.CHECKPOINT_WEIGHTS_PATH = os.path.join(self.BASE_PATH, "hongik_ai_checkpoint.weights.h5")
|
| 83 |
+
self.TRAINING_LOG_PATH = os.path.join(self.BASE_PATH, "training_log.csv")
|
| 84 |
+
|
| 85 |
+
os.makedirs(self.BASE_PATH, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
REPO_ID = "puco21/HongikAI"
|
| 88 |
+
files_to_download = [
|
| 89 |
+
"replay_buffer.pkl",
|
| 90 |
+
"hongik_ai_memory.weights.h5",
|
| 91 |
+
"hongik_ai_best.weights.h5"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
print("Checking for AI data files...")
|
| 95 |
+
for filename in files_to_download:
|
| 96 |
+
local_path = os.path.join(self.BASE_PATH, filename)
|
| 97 |
+
if not os.path.exists(local_path):
|
| 98 |
+
print(f"Downloading {filename} from Hugging Face Hub...")
|
| 99 |
+
try:
|
| 100 |
+
hf_hub_download(
|
| 101 |
+
repo_id=REPO_ID,
|
| 102 |
+
filename=filename,
|
| 103 |
+
local_dir=self.BASE_PATH,
|
| 104 |
+
local_dir_use_symlinks=False
|
| 105 |
+
)
|
| 106 |
+
print(f"'{filename}' download complete.")
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Failed to download {filename}: {e}")
|
| 109 |
+
|
| 110 |
+
self.replay_buffer = deque(maxlen=self.REPLAY_BUFFER_SIZE)
|
| 111 |
+
self.load_replay_buffer(self.REPLAY_BUFFER_PATH)
|
| 112 |
+
|
| 113 |
+
self.hongik_model = CNNTransformerHybrid(self.NUM_LAYERS, self.D_MODEL, self.NUM_HEADS, self.D_FF, self.BOARD_SIZE)
|
| 114 |
+
_ = self.hongik_model(np.zeros((1, self.BOARD_SIZE, self.BOARD_SIZE, 3), dtype=np.float32))
|
| 115 |
+
|
| 116 |
+
load_path = self.CHECKPOINT_WEIGHTS_PATH if os.path.exists(self.CHECKPOINT_WEIGHTS_PATH) else (self.WEIGHTS_FILE_PATH if os.path.exists(self.WEIGHTS_FILE_PATH) else self.BEST_WEIGHTS_FILE_PATH)
|
| 117 |
+
if os.path.exists(load_path):
|
| 118 |
+
try:
|
| 119 |
+
self.hongik_model.load_weights(load_path)
|
| 120 |
+
print(f"Successfully loaded weights: {load_path}")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Failed to load weights (starting new training): {e}")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
max_visits = int(config.get("max_visits",150))
|
| 126 |
+
except (ValueError, TypeError):
|
| 127 |
+
print(f"Warning: Invalid max_visits value in config. Using default (150).")
|
| 128 |
+
max_visits = 150
|
| 129 |
+
self.hongik_player = HongikAIPlayer(self.hongik_model, n_simulations=max_visits)
|
| 130 |
+
|
| 131 |
+
self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1.0)
|
| 132 |
+
self.policy_loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
| 133 |
+
self.value_loss_fn = tf.keras.losses.MeanSquaredError()
|
| 134 |
+
self.training_step_counter, self.game_history, self.self_play_active = 0, [], False
|
| 135 |
+
|
| 136 |
+
class MockProcess: poll = lambda self: None
|
| 137 |
+
self.katago_process = self.hongik_process = MockProcess()
|
| 138 |
+
self.sound_index = False
|
| 139 |
+
print("Hongik AI Engine ready!", OUTPUT_DEBUG)
|
| 140 |
+
|
| 141 |
+
def save_replay_buffer(self, path):
|
| 142 |
+
"""Saves the current replay buffer to a specified file path using pickle."""
|
| 143 |
+
try:
|
| 144 |
+
with open(path, 'wb') as f:
|
| 145 |
+
pickle.dump(self.replay_buffer, f)
|
| 146 |
+
print(f"Successfully saved experience data ({len(self.replay_buffer)} items) to '{path}'.")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Error saving experience data: {e}")
|
| 149 |
+
|
| 150 |
+
def load_replay_buffer(self, path):
|
| 151 |
+
"""Loads a replay buffer from a file, prioritizing a checkpoint file if it exists."""
|
| 152 |
+
load_path = self.CHECKPOINT_BUFFER_PATH if os.path.exists(self.CHECKPOINT_BUFFER_PATH) else path
|
| 153 |
+
if os.path.exists(load_path):
|
| 154 |
+
try:
|
| 155 |
+
with open(load_path, 'rb') as f:
|
| 156 |
+
self.replay_buffer = pickle.load(f)
|
| 157 |
+
|
| 158 |
+
if self.replay_buffer.maxlen != self.REPLAY_BUFFER_SIZE:
|
| 159 |
+
new_buffer = deque(maxlen=self.REPLAY_BUFFER_SIZE)
|
| 160 |
+
new_buffer.extend(self.replay_buffer)
|
| 161 |
+
self.replay_buffer = new_buffer
|
| 162 |
+
print(f"Successfully loaded experience data ({len(self.replay_buffer)} items) from '{load_path}'.")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error loading experience data: {e}")
|
| 165 |
+
|
| 166 |
+
def _checkpoint_save(self):
|
| 167 |
+
"""Saves a training checkpoint, including both the replay buffer and model weights."""
|
| 168 |
+
print(f"\n[{time.strftime('%H:%M:%S')}] Saving checkpoint...")
|
| 169 |
+
self.save_replay_buffer(self.CHECKPOINT_BUFFER_PATH)
|
| 170 |
+
self.hongik_model.save_weights(self.CHECKPOINT_WEIGHTS_PATH)
|
| 171 |
+
print("Checkpoint saved.")
|
| 172 |
+
|
| 173 |
+
def _log_training_progress(self, details: dict):
|
| 174 |
+
"""Logs the progress of the training process to a CSV file for later analysis."""
|
| 175 |
+
try:
|
| 176 |
+
file_exists = os.path.isfile(self.TRAINING_LOG_PATH)
|
| 177 |
+
with open(self.TRAINING_LOG_PATH, 'a', newline='', encoding='utf-8') as f:
|
| 178 |
+
writer = csv.DictWriter(f, fieldnames=details.keys())
|
| 179 |
+
if not file_exists:
|
| 180 |
+
writer.writeheader()
|
| 181 |
+
writer.writerow(details)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Error logging training progress: {e}")
|
| 184 |
+
|
| 185 |
+
def _node_to_board(self, node: GameNode) -> Board:
|
| 186 |
+
"""Converts a KaTrain GameNode object to the internal Board representation used by the engine."""
|
| 187 |
+
board_snapshot = Board(node.board_size)
|
| 188 |
+
root_node = node.root
|
| 189 |
+
for player, prop_name in [(Board.BLACK, 'AB'), (Board.WHITE, 'AW')]:
|
| 190 |
+
setup_stones = root_node.properties.get(prop_name, [])
|
| 191 |
+
if setup_stones:
|
| 192 |
+
for coords in setup_stones:
|
| 193 |
+
loc = board_snapshot.loc(coords[0], coords[1])
|
| 194 |
+
if board_snapshot.board[loc] == Board.EMPTY:
|
| 195 |
+
board_snapshot.play(player, loc)
|
| 196 |
+
|
| 197 |
+
current_player = Board.BLACK
|
| 198 |
+
for scene_node in node.nodes_from_root[1:]:
|
| 199 |
+
move = scene_node.move
|
| 200 |
+
if move:
|
| 201 |
+
loc = Board.PASS_LOC if move.is_pass else board_snapshot.loc(move.coords[0], move.coords[1])
|
| 202 |
+
board_snapshot.play(current_player, loc)
|
| 203 |
+
current_player = board_snapshot.pla
|
| 204 |
+
|
| 205 |
+
board_snapshot.pla = Board.BLACK if node.next_player == 'B' else Board.WHITE
|
| 206 |
+
return board_snapshot
|
| 207 |
+
|
| 208 |
+
def request_analysis(self, analysis_node: GameNode, callback: callable, **kwargs):
|
| 209 |
+
"""
|
| 210 |
+
Requests an analysis of a specific board position. The analysis is run
|
| 211 |
+
in a separate thread to avoid blocking the GUI.
|
| 212 |
+
"""
|
| 213 |
+
if not self.katrain.game: return
|
| 214 |
+
game_id = self.katrain.game.game_id
|
| 215 |
+
board = self._node_to_board(analysis_node)
|
| 216 |
+
threading.Thread(target=self._run_analysis, args=(game_id, board, analysis_node, callback), daemon=True).start()
|
| 217 |
+
|
| 218 |
+
def _run_analysis(self, game_id, board, analysis_node, callback):
|
| 219 |
+
"""
|
| 220 |
+
The target function for the analysis thread. It runs MCTS and sends the
|
| 221 |
+
formatted results back to the GUI via the provided callback.
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
policy_logits, _ = self.hongik_player.model(np.expand_dims(board.get_features(), 0), training=False)
|
| 225 |
+
policy = tf.nn.softmax(policy_logits[0]).numpy()
|
| 226 |
+
|
| 227 |
+
_, root_node = self.hongik_player.get_best_move(board)
|
| 228 |
+
analysis_result = self._format_analysis_results("analysis", root_node, board, policy)
|
| 229 |
+
|
| 230 |
+
analysis_node.analysis = analysis_result
|
| 231 |
+
|
| 232 |
+
def guarded_callback(dt):
|
| 233 |
+
if self.katrain.game and self.katrain.game.game_id == game_id:
|
| 234 |
+
callback(analysis_result, False)
|
| 235 |
+
Clock.schedule_once(guarded_callback)
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"Error during AI analysis execution: {e}")
|
| 238 |
+
traceback.print_exc()
|
| 239 |
+
|
| 240 |
+
def _format_analysis_results(self, query_id, root_node, board, policy=None): # <-- policy=None 인자 추가
|
| 241 |
+
"""
|
| 242 |
+
MCTS 분석 데이터를 KaTrain GUI가 이해할 수 있는 딕셔너리 형식으로 변환합니다.
|
| 243 |
+
"""
|
| 244 |
+
move_infos, moves_dict = [], {}
|
| 245 |
+
|
| 246 |
+
if root_node and root_node.children:
|
| 247 |
+
sorted_children = sorted(root_node.children.items(), key=lambda i: i[1].n_visits, reverse=True)
|
| 248 |
+
|
| 249 |
+
best_move_q = sorted_children[0][1].q_value if sorted_children else 0
|
| 250 |
+
|
| 251 |
+
for i, (action, child) in enumerate(sorted_children):
|
| 252 |
+
coords = board.loc_to_coord(self.hongik_player._action_to_loc(action, board))
|
| 253 |
+
move_gtp = Move(coords=coords).gtp()
|
| 254 |
+
|
| 255 |
+
current_player_winrate = (child.q_value + 1) / 2
|
| 256 |
+
display_winrate = 1.0 - current_player_winrate if board.pla == Board.WHITE else current_player_winrate
|
| 257 |
+
display_score = -child.q_value * 20 if board.pla == Board.WHITE else child.q_value * 20
|
| 258 |
+
|
| 259 |
+
points_lost = (best_move_q - child.q_value) * 20
|
| 260 |
+
|
| 261 |
+
move_data = {
|
| 262 |
+
"move": move_gtp,
|
| 263 |
+
"visits": child.n_visits,
|
| 264 |
+
"winrate": display_winrate,
|
| 265 |
+
"scoreLead": display_score,
|
| 266 |
+
"pointsLost": points_lost,
|
| 267 |
+
"pv": [move_gtp],
|
| 268 |
+
"order": i
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
move_infos.append(move_data)
|
| 272 |
+
moves_dict[move_gtp] = move_data
|
| 273 |
+
|
| 274 |
+
current_player_winrate = (root_node.q_value + 1) / 2 if root_node else 0.5
|
| 275 |
+
display_winrate = 1.0 - current_player_winrate if board.pla == Board.WHITE else current_player_winrate
|
| 276 |
+
display_score = -root_node.q_value * 20 if (root_node and board.pla == Board.WHITE) else (root_node.q_value * 20 if root_node else 0.0)
|
| 277 |
+
|
| 278 |
+
root_info = {"winrate": display_winrate, "scoreLead": display_score, "visits": root_node.n_visits if root_node else 0}
|
| 279 |
+
return {"id": query_id, "moveInfos": move_infos, "moves": moves_dict, "root": root_info, "rootInfo": root_info, "policy": policy.tolist() if policy is not None else None, "completed": True}
|
| 280 |
+
|
| 281 |
+
def start_self_play_loop(self):
|
| 282 |
+
"""Starts the main self-play loop, which continuously plays games to generate training data."""
|
| 283 |
+
print(f"\n===========================================\n[{time.strftime('%H:%M:%S')}] Starting new self-play game.\n===========================================")
|
| 284 |
+
self.stop_self_play_loop()
|
| 285 |
+
self.self_play_active = True
|
| 286 |
+
self.game_history = []
|
| 287 |
+
Clock.schedule_once(self._self_play_turn, 0.3)
|
| 288 |
+
|
| 289 |
+
def request_score(self, game_node, callback):
|
| 290 |
+
"""Requests a score calculation for the current game node, run in a separate thread."""
|
| 291 |
+
threading.Thread(target=lambda: callback(self.get_score(game_node)), daemon=True).start()
|
| 292 |
+
|
| 293 |
+
def stop_self_play_loop(self):
|
| 294 |
+
"""Stops the active self-play loop."""
|
| 295 |
+
if not self.self_play_active: return
|
| 296 |
+
self.self_play_active = False
|
| 297 |
+
Clock.unschedule(self._self_play_turn)
|
| 298 |
+
|
| 299 |
+
def _self_play_turn(self, dt=None):
|
| 300 |
+
"""
|
| 301 |
+
Executes a single turn of a self-play game. It gets the best move from the
|
| 302 |
+
AI, plays it on the board, and stores the state for later training.
|
| 303 |
+
"""
|
| 304 |
+
if not self.self_play_active: return
|
| 305 |
+
game = self.katrain.game
|
| 306 |
+
try:
|
| 307 |
+
current_node = game.current_node
|
| 308 |
+
board_snapshot = self._node_to_board(current_node)
|
| 309 |
+
if game.end_result or board_snapshot.is_game_over():
|
| 310 |
+
self._process_game_result(game)
|
| 311 |
+
return
|
| 312 |
+
move_loc, root_node = self.hongik_player.get_best_move(board_snapshot, is_self_play=True)
|
| 313 |
+
coords = None if move_loc == Board.PASS_LOC else board_snapshot.loc_to_coord(move_loc)
|
| 314 |
+
move_obj = Move(player='B' if board_snapshot.pla == Board.BLACK else 'W', coords=coords)
|
| 315 |
+
game.play(move_obj)
|
| 316 |
+
if not move_obj.is_pass and self.sound_index:
|
| 317 |
+
self.katrain.play_sound()
|
| 318 |
+
black_player_type = self.katrain.players_info['B'].player_type
|
| 319 |
+
white_player_type = self.katrain.players_info['W'].player_type
|
| 320 |
+
|
| 321 |
+
if black_player_type == PLAYER_AI and white_player_type == PLAYER_AI:
|
| 322 |
+
if self.katrain.game.current_node.next_player == 'B':
|
| 323 |
+
self.katrain.controls.players['B'].active = True
|
| 324 |
+
self.katrain.controls.players['W'].active = False
|
| 325 |
+
else:
|
| 326 |
+
self.katrain.controls.players['B'].active = False
|
| 327 |
+
self.katrain.controls.players['W'].active = True
|
| 328 |
+
|
| 329 |
+
policy = np.zeros(self.BOARD_SIZE**2 + 1, dtype=np.float32)
|
| 330 |
+
if root_node and root_node.children:
|
| 331 |
+
total_visits = sum(c.n_visits for c in root_node.children.values())
|
| 332 |
+
if total_visits > 0:
|
| 333 |
+
for action, child in root_node.children.items(): policy[action] = child.n_visits / total_visits
|
| 334 |
+
|
| 335 |
+
blacks_win_rate = 0.5
|
| 336 |
+
if root_node:
|
| 337 |
+
player_q_value = root_node.q_value
|
| 338 |
+
player_win_rate = (player_q_value + 1) / 2
|
| 339 |
+
blacks_win_rate = player_win_rate if board_snapshot.pla == Board.BLACK else (1 - player_win_rate)
|
| 340 |
+
|
| 341 |
+
self.game_history.append([board_snapshot.get_features(), policy, board_snapshot.pla, blacks_win_rate])
|
| 342 |
+
self.katrain.update_gui(game.current_node)
|
| 343 |
+
self.sound_index = True
|
| 344 |
+
Clock.schedule_once(self._self_play_turn, 0.3)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
print(f"Critical error during self-play: {e}"); traceback.print_exc(); self.stop_self_play_loop()
|
| 347 |
+
|
| 348 |
+
def _process_game_result(self, game: 'Game'):
|
| 349 |
+
"""
|
| 350 |
+
Processes the result of a finished game. It requests a final score and
|
| 351 |
+
then triggers the callback to handle training data generation.
|
| 352 |
+
"""
|
| 353 |
+
try:
|
| 354 |
+
self.katrain.controls.set_status("Scoring...", STATUS_INFO)
|
| 355 |
+
self.katrain.board_gui.game_over_message = "Scoring..."
|
| 356 |
+
|
| 357 |
+
self.katrain.board_gui.game_is_over = True
|
| 358 |
+
self.request_score(game.current_node, self._on_score_calculated)
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"Error requesting score calculation: {e}")
|
| 361 |
+
self.katrain._do_start_hongik_selfplay()
|
| 362 |
+
|
| 363 |
+
def _on_score_calculated(self, score_details):
|
| 364 |
+
"""
|
| 365 |
+
Callback function that handles the game result after scoring. It assigns rewards,
|
| 366 |
+
augments the data, adds it to the replay buffer, and triggers a training step.
|
| 367 |
+
"""
|
| 368 |
+
try:
|
| 369 |
+
if not score_details:
|
| 370 |
+
print("Game ended but no result. Starting next game.")
|
| 371 |
+
return
|
| 372 |
+
|
| 373 |
+
game_num = self.training_step_counter + 1
|
| 374 |
+
winner_text = "Black" if score_details['winner'] == 'B' else "White"
|
| 375 |
+
b_score, w_score, diff = score_details['black_score'], score_details['white_score'], score_details['score']
|
| 376 |
+
final_message = f"{winner_text} wins by {abs(diff):.1f} points"
|
| 377 |
+
self.katrain.board_gui.game_over_message = final_message
|
| 378 |
+
print(f"\n==========================================\n[{time.strftime('%H:%M:%S')}] Game #{game_num} Finished\n-----------------------------------------\n Winner: {winner_text}\n Margin: {abs(diff):.1f} points\n Details: Black {b_score:.1f} vs White {w_score:.1f}\n--------------------------------------------")
|
| 379 |
+
|
| 380 |
+
winner = Board.BLACK if score_details['winner'] == 'B' else Board.WHITE
|
| 381 |
+
|
| 382 |
+
REVERSAL_THRESHOLD = 0.2
|
| 383 |
+
WIN_REWARD = 1.0
|
| 384 |
+
LOSS_REWARD = -1.0
|
| 385 |
+
BRILLIANT_MOVE_BONUS = 0.5
|
| 386 |
+
CONSOLATION_REWARD = 0.5
|
| 387 |
+
|
| 388 |
+
for i, (features, policy, player_turn, blacks_win_rate_after) in enumerate(self.game_history):
|
| 389 |
+
blacks_win_rate_before = self.game_history[i-1][3] if i > 0 else 0.5
|
| 390 |
+
if player_turn == Board.BLACK:
|
| 391 |
+
win_rate_swing = blacks_win_rate_after - blacks_win_rate_before
|
| 392 |
+
else: # player_turn == Board.WHITE
|
| 393 |
+
white_win_rate_before = 1 - blacks_win_rate_before
|
| 394 |
+
white_win_rate_after = 1 - blacks_win_rate_after
|
| 395 |
+
win_rate_swing = white_win_rate_after - white_win_rate_before
|
| 396 |
+
|
| 397 |
+
is_brilliant_move = win_rate_swing > REVERSAL_THRESHOLD
|
| 398 |
+
|
| 399 |
+
if player_turn == winner:
|
| 400 |
+
reward = WIN_REWARD
|
| 401 |
+
if is_brilliant_move:
|
| 402 |
+
reward += BRILLIANT_MOVE_BONUS
|
| 403 |
+
else:
|
| 404 |
+
reward = LOSS_REWARD
|
| 405 |
+
if is_brilliant_move:
|
| 406 |
+
reward = CONSOLATION_REWARD
|
| 407 |
+
|
| 408 |
+
for j in range(8):
|
| 409 |
+
aug_features = self._augment_data(features, j, 'features')
|
| 410 |
+
aug_policy = self._augment_data(policy, j, 'policy')
|
| 411 |
+
self.replay_buffer.append([aug_features, aug_policy, reward])
|
| 412 |
+
|
| 413 |
+
self.training_step_counter += 1
|
| 414 |
+
loss = self._train_model() if len(self.replay_buffer) >= self.TRAINING_BATCH_SIZE else None
|
| 415 |
+
if loss is not None:
|
| 416 |
+
print(f" Training complete! (Final loss: {loss:.4f})\n=======================================", OUTPUT_DEBUG)
|
| 417 |
+
|
| 418 |
+
log_data = {
|
| 419 |
+
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
| 420 |
+
'game_num': game_num,
|
| 421 |
+
'winner': score_details['winner'],
|
| 422 |
+
'score_diff': diff,
|
| 423 |
+
'total_moves': len(self.game_history),
|
| 424 |
+
'loss': f"{loss:.4f}" if loss else "N/A"
|
| 425 |
+
}
|
| 426 |
+
self._log_training_progress(log_data)
|
| 427 |
+
|
| 428 |
+
if self.training_step_counter % self.SAVE_WEIGHTS_EVERY_STEPS == 0:
|
| 429 |
+
self.hongik_model.save_weights(self.WEIGHTS_FILE_PATH)
|
| 430 |
+
|
| 431 |
+
if self.training_step_counter % self.CHECKPOINT_EVERY_GAMES == 0:
|
| 432 |
+
self._checkpoint_save()
|
| 433 |
+
|
| 434 |
+
if self.training_step_counter % self.EVALUATION_EVERY_STEPS == 0:
|
| 435 |
+
self._evaluate_model()
|
| 436 |
+
|
| 437 |
+
except Exception as e:
|
| 438 |
+
print(f"Error during post-game processing: {e}")
|
| 439 |
+
traceback.print_exc()
|
| 440 |
+
finally:
|
| 441 |
+
self.katrain._do_start_hongik_selfplay()
|
| 442 |
+
|
| 443 |
+
def _train_model(self):
|
| 444 |
+
"""
|
| 445 |
+
Performs one training step. It samples a minibatch from the replay buffer
|
| 446 |
+
and uses it to update the neural network's weights.
|
| 447 |
+
"""
|
| 448 |
+
if len(self.replay_buffer) < self.TRAINING_BATCH_SIZE: return None
|
| 449 |
+
total_loss, TRAIN_ITERATIONS = 0, 5
|
| 450 |
+
for _ in range(TRAIN_ITERATIONS):
|
| 451 |
+
minibatch = random.sample(self.replay_buffer, self.TRAINING_BATCH_SIZE)
|
| 452 |
+
features, policies, values = (np.array(e) for e in zip(*minibatch))
|
| 453 |
+
with tf.GradientTape() as tape:
|
| 454 |
+
pred_p, pred_v = self.hongik_model(features, training=True)
|
| 455 |
+
value_loss = self.value_loss_fn(values[:, None], pred_v)
|
| 456 |
+
policy_loss = self.policy_loss_fn(policies, pred_p)
|
| 457 |
+
loss = policy_loss + value_loss
|
| 458 |
+
self.optimizer.apply_gradients(zip(tape.gradient(loss, self.hongik_model.trainable_variables), self.hongik_model.trainable_variables))
|
| 459 |
+
total_loss += loss.numpy()
|
| 460 |
+
return total_loss / TRAIN_ITERATIONS
|
| 461 |
+
|
| 462 |
+
def _augment_data(self, data, index, data_type):
|
| 463 |
+
"""
|
| 464 |
+
Augments the training data by applying 8 symmetries (rotations and flips)
|
| 465 |
+
to the board features and policy target.
|
| 466 |
+
"""
|
| 467 |
+
if data_type == 'features':
|
| 468 |
+
augmented = data
|
| 469 |
+
if index & 1: augmented = np.fliplr(augmented)
|
| 470 |
+
if index & 2: augmented = np.flipud(augmented)
|
| 471 |
+
if index & 4: augmented = np.rot90(augmented, 1)
|
| 472 |
+
return augmented
|
| 473 |
+
elif data_type == 'policy':
|
| 474 |
+
policy_board = data[:-1].reshape(self.BOARD_SIZE, self.BOARD_SIZE)
|
| 475 |
+
augmented_board = policy_board
|
| 476 |
+
if index & 1: augmented_board = np.fliplr(augmented_board)
|
| 477 |
+
if index & 2: augmented_board = np.flipud(augmented_board)
|
| 478 |
+
if index & 4: augmented_board = np.rot90(augmented_board, 1)
|
| 479 |
+
return np.append(augmented_board.flatten(), data[-1])
|
| 480 |
+
return data
|
| 481 |
+
|
| 482 |
+
def get_score(self, game_node):
|
| 483 |
+
"""Calculates the final score of a game using the board's internal scoring method."""
|
| 484 |
+
try:
|
| 485 |
+
board = self._node_to_board(game_node)
|
| 486 |
+
winner, black_score, white_score, _ = board.get_winner(self.katrain.game.komi)
|
| 487 |
+
score_diff = black_score - white_score
|
| 488 |
+
return {"winner": "B" if winner == Board.BLACK else "W", "score": score_diff, "black_score": black_score, "white_score": white_score}
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f"Error during internal score calculation: {e}"); traceback.print_exc(); return None
|
| 491 |
+
|
| 492 |
+
def _game_turn(self):
|
| 493 |
+
"""
|
| 494 |
+
Handles the AI's turn in a game against a human or another AI. It runs
|
| 495 |
+
in a separate thread to avoid blocking the GUI.
|
| 496 |
+
"""
|
| 497 |
+
if self.self_play_active or self.katrain.game.end_result: return
|
| 498 |
+
next_player_info = self.katrain.players_info[self.katrain.game.current_node.next_player]
|
| 499 |
+
if next_player_info.player_type == PLAYER_AI:
|
| 500 |
+
def ai_move_thread():
|
| 501 |
+
try:
|
| 502 |
+
board_snapshot = self._node_to_board(self.katrain.game.current_node)
|
| 503 |
+
move_loc, _ = self.hongik_player.get_best_move(board_snapshot, is_self_play=False)
|
| 504 |
+
coords = None if move_loc == Board.PASS_LOC else board_snapshot.loc_to_coord(move_loc)
|
| 505 |
+
Clock.schedule_once(lambda dt: self.katrain._do_play(coords))
|
| 506 |
+
except Exception as e:
|
| 507 |
+
print(f"\n--- Critical error during AI thinking (in thread) ---\n{traceback.format_exc()}\n---------------------------------------\n")
|
| 508 |
+
threading.Thread(target=ai_move_thread, daemon=True).start()
|
| 509 |
+
|
| 510 |
+
def _evaluate_model(self):
|
| 511 |
+
"""
|
| 512 |
+
Periodically evaluates the currently training model against the best-known
|
| 513 |
+
'champion' model to measure progress and update the best weights if the
|
| 514 |
+
challenger is stronger.
|
| 515 |
+
"""
|
| 516 |
+
print("\n--- [Championship Match Start] ---")
|
| 517 |
+
challenger_player = self.hongik_player
|
| 518 |
+
best_weights_path = self.BEST_WEIGHTS_FILE_PATH
|
| 519 |
+
if not os.path.exists(best_weights_path):
|
| 520 |
+
print("[Championship Match] Crowning the first champion!")
|
| 521 |
+
self.hongik_model.save_weights(best_weights_path)
|
| 522 |
+
return
|
| 523 |
+
champion_model = CNNTransformerHybrid(self.NUM_LAYERS, self.D_MODEL, self.NUM_HEADS, self.D_FF, self.BOARD_SIZE)
|
| 524 |
+
_ = champion_model(np.zeros((1, self.BOARD_SIZE, self.BOARD_SIZE, 3), dtype=np.float32))
|
| 525 |
+
champion_model.load_weights(best_weights_path)
|
| 526 |
+
champion_player = HongikAIPlayer(champion_model, int(self.config.get("max_visits", 150)))
|
| 527 |
+
EVAL_GAMES, challenger_wins = 5, 0
|
| 528 |
+
for i in range(EVAL_GAMES):
|
| 529 |
+
print(f"\n[Championship Match] Game {i+1} starting...")
|
| 530 |
+
board = Board(self.BOARD_SIZE)
|
| 531 |
+
players = {Board.BLACK: challenger_player, Board.WHITE: champion_player} if i % 2 == 0 else {Board.BLACK: champion_player, Board.WHITE: challenger_player}
|
| 532 |
+
while not board.is_game_over():
|
| 533 |
+
current_player_obj = players[board.pla]
|
| 534 |
+
move_loc, _ = current_player_obj.get_best_move(board)
|
| 535 |
+
board.play(board.pla, move_loc)
|
| 536 |
+
winner, _, _, _ = board.get_winner()
|
| 537 |
+
if (winner == Board.BLACK and i % 2 == 0) or (winner == Board.WHITE and i % 2 != 0):
|
| 538 |
+
challenger_wins += 1; print(f"[Championship Match] Game {i+1}: Challenger wins!")
|
| 539 |
+
else:
|
| 540 |
+
print(f"[Championship Match] Game {i+1}: Champion wins!")
|
| 541 |
+
print(f"\n--- [Championship Match End] ---\nFinal Score: Challenger {challenger_wins} wins / Champion {EVAL_GAMES - challenger_wins} wins")
|
| 542 |
+
if challenger_wins > EVAL_GAMES / 2:
|
| 543 |
+
print("A new champion is born! Updating 'best' weights.")
|
| 544 |
+
self.hongik_model.save_weights(best_weights_path)
|
| 545 |
+
else:
|
| 546 |
+
print("The champion defends the title. Keeping existing weights.")
|
| 547 |
+
|
| 548 |
+
def on_new_game(self):
|
| 549 |
+
"""Called when a new game starts."""
|
| 550 |
+
pass
|
| 551 |
+
def start(self):
|
| 552 |
+
"""Starts the engine."""
|
| 553 |
+
self.katrain.game_controls.set_player_selection()
|
| 554 |
+
def shutdown(self, finish=False):
|
| 555 |
+
"""Shuts down the engine, saving progress and cleaning up checkpoint files."""
|
| 556 |
+
self.stop_self_play_loop()
|
| 557 |
+
self.save_replay_buffer(self.REPLAY_BUFFER_PATH)
|
| 558 |
+
try:
|
| 559 |
+
if os.path.exists(self.CHECKPOINT_BUFFER_PATH): os.remove(self.CHECKPOINT_BUFFER_PATH)
|
| 560 |
+
if os.path.exists(self.CHECKPOINT_WEIGHTS_PATH): os.remove(self.CHECKPOINT_WEIGHTS_PATH)
|
| 561 |
+
except OSError as e:
|
| 562 |
+
print(f"Error deleting checkpoint files: {e}")
|
| 563 |
+
def stop_pondering(self):
|
| 564 |
+
"""Stops pondering."""
|
| 565 |
+
pass
|
| 566 |
+
def queries_remaining(self):
|
| 567 |
+
"""Returns the number of remaining queries."""
|
| 568 |
+
return 0
|
| 569 |
+
def is_idle(self):
|
| 570 |
+
"""Checks if the engine is idle (i.e., not in a self-play loop)."""
|
| 571 |
+
return not self.self_play_active
|
hongik/hongik_ai.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implements the AI's 'brain', combining a CNN and Transformer for intuition,
|
| 2 |
+
# and Monte Carlo Tree Search (MCTS) for rational deliberation.
|
| 3 |
+
#
|
| 4 |
+
# Author: 박남영,Gemini 2.5 Pro, Gemini 2.5 Flash
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.keras import layers, Model
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from hongik.board_ai import Board, IllegalMoveError
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ===================================================================
|
| 14 |
+
# 트랜스포머 부품들
|
| 15 |
+
# 이 부분은 우리가 이전에 함께 만들었던 트랜스포머의 핵심 부품들입니다.
|
| 16 |
+
# 아빠의 설계 그대로 완벽하기에, 엄마는 손대지 않았어요.
|
| 17 |
+
# ===================================================================
|
| 18 |
+
def scaled_dot_product_attention(q, k, v, mask=None):
|
| 19 |
+
"""
|
| 20 |
+
Calculates the attention scores, which is the core of the attention mechanism.
|
| 21 |
+
It determines how much focus to place on other parts of the input sequence.
|
| 22 |
+
"""
|
| 23 |
+
matmul_qk = tf.matmul(q, k, transpose_b=True)
|
| 24 |
+
d_k = tf.cast(tf.shape(k)[-1], tf.float32)
|
| 25 |
+
scaled_attention_logits = matmul_qk / tf.math.sqrt(d_k)
|
| 26 |
+
if mask is not None:
|
| 27 |
+
scaled_attention_logits += (mask * -1e9)
|
| 28 |
+
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
|
| 29 |
+
output = tf.matmul(attention_weights, v)
|
| 30 |
+
return output, attention_weights
|
| 31 |
+
|
| 32 |
+
class MultiHeadAttention(layers.Layer):
|
| 33 |
+
"""
|
| 34 |
+
Implements the Multi-Head Attention mechanism. This allows the model to jointly attend
|
| 35 |
+
to information from different representation subspaces at different positions,
|
| 36 |
+
which is more powerful than single-head attention.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, d_model, num_heads):
|
| 39 |
+
super(MultiHeadAttention, self).__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
self.d_model = d_model
|
| 42 |
+
assert d_model % self.num_heads == 0
|
| 43 |
+
self.depth = d_model // self.num_heads
|
| 44 |
+
self.wq = layers.Dense(d_model)
|
| 45 |
+
self.wk = layers.Dense(d_model)
|
| 46 |
+
self.wv = layers.Dense(d_model)
|
| 47 |
+
self.dense = layers.Dense(d_model)
|
| 48 |
+
|
| 49 |
+
def split_heads(self, x, batch_size):
|
| 50 |
+
"""Splits the last dimension into (num_heads, depth)."""
|
| 51 |
+
seq_len = tf.shape(x)[1]
|
| 52 |
+
x = tf.reshape(x, (batch_size, seq_len, self.num_heads, self.depth))
|
| 53 |
+
return tf.transpose(x, perm=[0, 2, 1, 3])
|
| 54 |
+
|
| 55 |
+
def call(self, v, k, q, mask=None):
|
| 56 |
+
"""Processes the input tensors through the multi-head attention mechanism."""
|
| 57 |
+
batch_size = tf.shape(q)[0]
|
| 58 |
+
q = self.wq(q); k = self.wk(k); v = self.wv(v)
|
| 59 |
+
q = self.split_heads(q, batch_size)
|
| 60 |
+
k = self.split_heads(k, batch_size)
|
| 61 |
+
v = self.split_heads(v, batch_size)
|
| 62 |
+
scaled_attention, _ = scaled_dot_product_attention(q, k, v, mask)
|
| 63 |
+
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
|
| 64 |
+
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
|
| 65 |
+
output = self.dense(concat_attention)
|
| 66 |
+
return output
|
| 67 |
+
|
| 68 |
+
class PositionWiseFeedForwardNetwork(layers.Layer):
|
| 69 |
+
"""
|
| 70 |
+
Implements the Position-wise Feed-Forward Network. This is applied to each
|
| 71 |
+
position separately and identically. It consists of two linear transformations
|
| 72 |
+
with a ReLU activation in between.
|
| 73 |
+
"""
|
| 74 |
+
def __init__(self, d_model, d_ff):
|
| 75 |
+
super(PositionWiseFeedForwardNetwork, self).__init__()
|
| 76 |
+
self.dense_1 = layers.Dense(d_ff, activation='relu')
|
| 77 |
+
self.dense_2 = layers.Dense(d_model)
|
| 78 |
+
def call(self, inputs):
|
| 79 |
+
return self.dense_2(self.dense_1(inputs))
|
| 80 |
+
|
| 81 |
+
class EncoderLayer(layers.Layer):
|
| 82 |
+
"""
|
| 83 |
+
Represents one layer of the Transformer encoder. It consists of a multi-head
|
| 84 |
+
attention mechanism followed by a position-wise feed-forward network.
|
| 85 |
+
Includes dropout and layer normalization.
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
|
| 88 |
+
super(EncoderLayer, self).__init__()
|
| 89 |
+
self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
|
| 90 |
+
self.ffn = PositionWiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff)
|
| 91 |
+
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 92 |
+
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 93 |
+
self.dropout1 = layers.Dropout(dropout_rate)
|
| 94 |
+
self.dropout2 = layers.Dropout(dropout_rate)
|
| 95 |
+
|
| 96 |
+
def call(self, inputs, training, padding_mask=None):
|
| 97 |
+
attn_output = self.mha(inputs, inputs, inputs, padding_mask)
|
| 98 |
+
attn_output = self.dropout1(attn_output, training=training)
|
| 99 |
+
out1 = self.layernorm1(inputs + attn_output)
|
| 100 |
+
|
| 101 |
+
ffn_output = self.ffn(out1)
|
| 102 |
+
ffn_output = self.dropout2(ffn_output, training=training)
|
| 103 |
+
|
| 104 |
+
out2 = self.layernorm2(out1 + ffn_output)
|
| 105 |
+
|
| 106 |
+
return out2
|
| 107 |
+
|
| 108 |
+
def get_positional_encoding(max_seq_len, d_model):
|
| 109 |
+
"""
|
| 110 |
+
Generates positional encodings. Since the model contains no recurrence or
|
| 111 |
+
convolution, this is used to inject information about the relative or
|
| 112 |
+
absolute position of the tokens in the sequence.
|
| 113 |
+
"""
|
| 114 |
+
angle_rads = (np.arange(max_seq_len)[:, np.newaxis] /
|
| 115 |
+
np.power(10000, (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / np.float32(d_model)))
|
| 116 |
+
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
|
| 117 |
+
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
|
| 118 |
+
pos_encoding = angle_rads[np.newaxis, ...]
|
| 119 |
+
return tf.cast(pos_encoding, dtype=tf.float32)
|
| 120 |
+
|
| 121 |
+
# ===================================================================
|
| 122 |
+
# 3. CNN + 트랜스포머 '직관' 엔진
|
| 123 |
+
# ===================================================================
|
| 124 |
+
class CNNTransformerHybrid(Model):
|
| 125 |
+
"""
|
| 126 |
+
The 'Intuition' engine, combining a 'Scout' (CNN) and a 'Commander' (Transformer).
|
| 127 |
+
This version implements a lightweight head architecture using Squeeze-and-Excitation
|
| 128 |
+
and Convolutional Heads for parameter efficiency and performance.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, num_transformer_layers, d_model, num_heads, d_ff,
|
| 131 |
+
board_size=19, cnn_filters=128, dropout_rate=0.1):
|
| 132 |
+
super(CNNTransformerHybrid, self).__init__()
|
| 133 |
+
self.board_size = board_size
|
| 134 |
+
self.d_model = d_model
|
| 135 |
+
|
| 136 |
+
self.cnn_conv1 = layers.Conv2D(cnn_filters, 3, padding='same', activation='relu')
|
| 137 |
+
self.cnn_bn1 = layers.BatchNormalization()
|
| 138 |
+
self.cnn_conv2 = layers.Conv2D(d_model, 1, padding='same', activation='relu')
|
| 139 |
+
self.cnn_bn2 = layers.BatchNormalization()
|
| 140 |
+
self.reshape_to_seq = layers.Reshape((board_size * board_size, d_model))
|
| 141 |
+
self.positional_encoding = get_positional_encoding(board_size * board_size, d_model)
|
| 142 |
+
self.dropout = layers.Dropout(dropout_rate)
|
| 143 |
+
self.transformer_encoder = [EncoderLayer(d_model, num_heads, d_ff, dropout_rate) for _ in range(num_transformer_layers)]
|
| 144 |
+
self.reshape_to_2d = layers.Reshape((board_size, board_size, d_model))
|
| 145 |
+
|
| 146 |
+
self.se_gap = layers.GlobalAveragePooling2D()
|
| 147 |
+
self.se_reshape = layers.Reshape((1, 1, d_model))
|
| 148 |
+
self.se_dense_1 = layers.Dense(d_model // 16, activation='relu', kernel_initializer='he_normal', use_bias=False)
|
| 149 |
+
self.se_dense_2 = layers.Dense(d_model, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)
|
| 150 |
+
self.se_multiply = layers.Multiply()
|
| 151 |
+
|
| 152 |
+
self.policy_conv = layers.Conv2D(filters=2, kernel_size=1, padding='same', activation='relu')
|
| 153 |
+
self.policy_bn = layers.BatchNormalization()
|
| 154 |
+
self.policy_flatten = layers.Flatten()
|
| 155 |
+
self.policy_dense = layers.Dense(board_size * board_size + 1, name='policy_head')
|
| 156 |
+
|
| 157 |
+
self.value_conv = layers.Conv2D(filters=1, kernel_size=1, padding='same', activation='relu')
|
| 158 |
+
self.value_bn = layers.BatchNormalization()
|
| 159 |
+
self.value_flatten = layers.Flatten()
|
| 160 |
+
self.value_dense1 = layers.Dense(256, activation='relu')
|
| 161 |
+
self.value_dense2 = layers.Dense(1, activation='tanh', name='value_head')
|
| 162 |
+
|
| 163 |
+
@tf.function(jit_compile=False)
|
| 164 |
+
def call(self, inputs, training=False):
|
| 165 |
+
x = self.cnn_conv1(inputs)
|
| 166 |
+
x = self.cnn_bn1(x, training=training)
|
| 167 |
+
x = self.cnn_conv2(x)
|
| 168 |
+
cnn_output = self.cnn_bn2(x, training=training)
|
| 169 |
+
|
| 170 |
+
x = self.reshape_to_seq(cnn_output)
|
| 171 |
+
seq_len = tf.shape(x)[1]
|
| 172 |
+
x += self.positional_encoding[:, :seq_len, :]
|
| 173 |
+
x = self.dropout(x, training=training)
|
| 174 |
+
|
| 175 |
+
for i in range(len(self.transformer_encoder)):
|
| 176 |
+
x = self.transformer_encoder[i](x, training=training, padding_mask=None)
|
| 177 |
+
|
| 178 |
+
transformer_output = self.reshape_to_2d(x)
|
| 179 |
+
|
| 180 |
+
se = self.se_gap(transformer_output)
|
| 181 |
+
se = self.se_reshape(se)
|
| 182 |
+
se = self.se_dense_1(se)
|
| 183 |
+
se = self.se_dense_2(se)
|
| 184 |
+
se_output = self.se_multiply([transformer_output, se])
|
| 185 |
+
|
| 186 |
+
ph = self.policy_conv(se_output)
|
| 187 |
+
ph = self.policy_bn(ph, training=training)
|
| 188 |
+
ph = self.policy_flatten(ph)
|
| 189 |
+
policy_logits = self.policy_dense(ph)
|
| 190 |
+
|
| 191 |
+
vh = self.value_conv(se_output)
|
| 192 |
+
vh = self.value_bn(vh, training=training)
|
| 193 |
+
vh = self.value_flatten(vh)
|
| 194 |
+
vh = self.value_dense1(vh)
|
| 195 |
+
value = self.value_dense2(vh)
|
| 196 |
+
return policy_logits, value
|
| 197 |
+
|
| 198 |
+
# ===================================================================
|
| 199 |
+
# 4. MCTS '이성' 엔진
|
| 200 |
+
# ===================================================================
|
| 201 |
+
class MCTSNode:
|
| 202 |
+
"""
|
| 203 |
+
Represents a single node in the Monte Carlo Tree Search. Each node stores
|
| 204 |
+
statistics like visit count (n_visits), total action value (q_value), and
|
| 205 |
+
prior probability (p_sa).
|
| 206 |
+
"""
|
| 207 |
+
def __init__(self, parent=None, prior_p=1.0):
|
| 208 |
+
self.parent, self.children, self.n_visits, self.q_value, self.p_sa = parent, {}, 0, 0, prior_p
|
| 209 |
+
self.C_PUCT_BASE, self.C_PUCT_INIT = 19652, 1.25
|
| 210 |
+
|
| 211 |
+
def select(self, root_n_visits):
|
| 212 |
+
"""
|
| 213 |
+
Selects the child node with the highest Upper Confidence Bound (UCB) score.
|
| 214 |
+
This balances exploration and exploitation during the search.
|
| 215 |
+
"""
|
| 216 |
+
dynamic_c_puct = np.log((1 + root_n_visits + self.C_PUCT_BASE) / self.C_PUCT_BASE) + self.C_PUCT_INIT
|
| 217 |
+
return max(self.children.items(),
|
| 218 |
+
key=lambda item: item[1].q_value + dynamic_c_puct * item[1].p_sa * np.sqrt(self.n_visits) / (1 + item[1].n_visits))
|
| 219 |
+
|
| 220 |
+
def expand(self, action_probs):
|
| 221 |
+
"""
|
| 222 |
+
Expands a leaf node by creating new child nodes for all legal moves,
|
| 223 |
+
initializing their statistics from the prior probabilities given by the
|
| 224 |
+
neural network.
|
| 225 |
+
"""
|
| 226 |
+
for action, prob in enumerate(action_probs):
|
| 227 |
+
if prob > 0 and action not in self.children: self.children[action] = MCTSNode(parent=self, prior_p=prob)
|
| 228 |
+
|
| 229 |
+
def update(self, leaf_value):
|
| 230 |
+
"""
|
| 231 |
+
Updates the statistics of the node and its ancestors by backpropagating
|
| 232 |
+
the value obtained from the leaf node of a simulation.
|
| 233 |
+
"""
|
| 234 |
+
if self.parent: self.parent.update(-leaf_value)
|
| 235 |
+
self.n_visits += 1; self.q_value += (leaf_value - self.q_value) / self.n_visits
|
| 236 |
+
|
| 237 |
+
def is_leaf(self):
|
| 238 |
+
"""Checks if the node is a leaf node (i.e., has no children)."""
|
| 239 |
+
return len(self.children) == 0
|
| 240 |
+
|
| 241 |
+
# ===================================================================
|
| 242 |
+
# HongikAIPlayer 클래스
|
| 243 |
+
# ===================================================================
|
| 244 |
+
class HongikAIPlayer:
|
| 245 |
+
"""
|
| 246 |
+
The 'Supreme Commander' that makes the final decision. It uses the neural
|
| 247 |
+
network's 'intuition' to guide the 'rational' search of the MCTS,
|
| 248 |
+
ultimately selecting the best move.
|
| 249 |
+
"""
|
| 250 |
+
def __init__(self, cnn_transformer_model, n_simulations=100):
|
| 251 |
+
self.model, self.n_simulations, self.board_size = cnn_transformer_model, n_simulations, cnn_transformer_model.board_size
|
| 252 |
+
|
| 253 |
+
def _action_to_loc(self, action, board):
|
| 254 |
+
"""Converts a policy network action index to a board location."""
|
| 255 |
+
return board.loc(action % self.board_size, action // self.board_size) if action < self.board_size**2 else Board.PASS_LOC
|
| 256 |
+
|
| 257 |
+
def get_best_move(self, board_state: Board, is_self_play=False):
|
| 258 |
+
"""
|
| 259 |
+
Determines the best move for the current board state by running MCTS simulations.
|
| 260 |
+
It integrates the neural network's policy and value predictions to guide the search.
|
| 261 |
+
"""
|
| 262 |
+
features = board_state.get_features()
|
| 263 |
+
policy_logits, value = self.model(np.expand_dims(features, 0), training=False)
|
| 264 |
+
intuition_probs = tf.nn.softmax(policy_logits[0]).numpy()
|
| 265 |
+
|
| 266 |
+
def is_filling_eye(loc, board):
|
| 267 |
+
if board.board[loc] != Board.EMPTY: return False
|
| 268 |
+
neighbor_colors = {board.board[loc + dloc] for dloc in board.adj if board.board[loc + dloc] != Board.WALL}
|
| 269 |
+
return len(neighbor_colors) == 1 and board.pla in neighbor_colors
|
| 270 |
+
|
| 271 |
+
for action, prob in enumerate(intuition_probs):
|
| 272 |
+
if prob > 0.001:
|
| 273 |
+
move_loc = self._action_to_loc(action, board_state)
|
| 274 |
+
if move_loc != Board.PASS_LOC and is_filling_eye(move_loc, board_state): intuition_probs[action] = 0
|
| 275 |
+
|
| 276 |
+
pass_action = self.board_size**2
|
| 277 |
+
pass_prob = intuition_probs[pass_action]
|
| 278 |
+
intuition_probs[pass_action] = 0
|
| 279 |
+
|
| 280 |
+
if board_state.turns < 100: pass_prob = 0
|
| 281 |
+
|
| 282 |
+
for action, prob in enumerate(intuition_probs):
|
| 283 |
+
if prob > 0 and not board_state.would_be_legal(board_state.pla, self._action_to_loc(action, board_state)): intuition_probs[action] = 0
|
| 284 |
+
|
| 285 |
+
total_prob = np.sum(intuition_probs)
|
| 286 |
+
if total_prob <= 1e-6: return self._action_to_loc(pass_action, board_state), MCTSNode()
|
| 287 |
+
intuition_probs /= total_prob
|
| 288 |
+
|
| 289 |
+
root = MCTSNode(); root.expand(intuition_probs)
|
| 290 |
+
for _ in range(self.n_simulations):
|
| 291 |
+
node, search_board = root, board_state.copy()
|
| 292 |
+
while not node.is_leaf():
|
| 293 |
+
action, node = node.select(root.n_visits)
|
| 294 |
+
move_loc = self._action_to_loc(action, search_board)
|
| 295 |
+
if not search_board.would_be_legal(search_board.pla, move_loc):
|
| 296 |
+
node = None; break
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
search_board.play(search_board.pla, move_loc)
|
| 300 |
+
except IllegalMoveError:
|
| 301 |
+
parent_node = node.parent
|
| 302 |
+
if parent_node and action in parent_node.children:
|
| 303 |
+
del parent_node.children[action]
|
| 304 |
+
|
| 305 |
+
node = None
|
| 306 |
+
break
|
| 307 |
+
if node is not None:
|
| 308 |
+
leaf_features = search_board.get_features()
|
| 309 |
+
_, leaf_value_tensor = self.model(np.expand_dims(leaf_features, 0), training=False)
|
| 310 |
+
leaf_value = leaf_value_tensor.numpy()[0][0]
|
| 311 |
+
node.update(leaf_value)
|
| 312 |
+
|
| 313 |
+
if not root.children: return self._action_to_loc(pass_action, board_state), root
|
| 314 |
+
|
| 315 |
+
PASS_THRESHOLD = -0.99
|
| 316 |
+
best_action_node = max(root.children.values(), key=lambda n: n.n_visits)
|
| 317 |
+
if best_action_node.q_value < PASS_THRESHOLD and pass_prob > 0:
|
| 318 |
+
return self._action_to_loc(pass_action, board_state), root
|
| 319 |
+
|
| 320 |
+
if board_state.turns < 30:
|
| 321 |
+
if is_self_play:
|
| 322 |
+
if not root.children:
|
| 323 |
+
return self._action_to_loc(pass_action, board_state), root
|
| 324 |
+
|
| 325 |
+
child_actions = np.array(sorted(root.children.keys()))
|
| 326 |
+
visit_counts = np.array([root.children[action].n_visits for action in child_actions], dtype=np.float32)
|
| 327 |
+
|
| 328 |
+
temperature = 1.0
|
| 329 |
+
visit_counts_temp = visit_counts**(1/temperature)
|
| 330 |
+
if np.sum(visit_counts_temp) == 0:
|
| 331 |
+
probs = np.ones(len(child_actions)) / len(child_actions)
|
| 332 |
+
else:
|
| 333 |
+
probs = visit_counts_temp / np.sum(visit_counts_temp)
|
| 334 |
+
|
| 335 |
+
action = np.random.choice(child_actions, p=probs)
|
| 336 |
+
return self._action_to_loc(action, board_state), root
|
| 337 |
+
|
| 338 |
+
if not root.children:
|
| 339 |
+
return self._action_to_loc(pass_action, board_state), root
|
| 340 |
+
|
| 341 |
+
visit_counts = np.zeros_like(intuition_probs)
|
| 342 |
+
for action, node in root.children.items():
|
| 343 |
+
visit_counts[action] = node.n_visits
|
| 344 |
+
|
| 345 |
+
total_visits = np.sum(visit_counts)
|
| 346 |
+
reason_probs = visit_counts / total_visits if total_visits > 0 else intuition_probs
|
| 347 |
+
|
| 348 |
+
final_probs = (0.7 * intuition_probs) + (0.3 * reason_probs)
|
| 349 |
+
|
| 350 |
+
final_probs[pass_action] = -1
|
| 351 |
+
|
| 352 |
+
sorted_actions = np.argsort(final_probs)[::-1]
|
| 353 |
+
for action in sorted_actions:
|
| 354 |
+
move_loc = self._action_to_loc(action, board_state)
|
| 355 |
+
if board_state.would_be_legal(board_state.pla, move_loc):
|
| 356 |
+
return move_loc, root
|
| 357 |
+
|
| 358 |
+
return self._action_to_loc(pass_action, board_state), root
|