puco21 commited on
Commit
20d417a
·
verified ·
1 Parent(s): 80399e9

Upload 3 files

Browse files
Files changed (3) hide show
  1. hongik/board_ai.py +332 -0
  2. hongik/engine_ai.py +571 -0
  3. hongik/hongik_ai.py +358 -0
hongik/board_ai.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implements the world of Go, defining the game board, its rules
2
+ # (moves, liberties, scoring, etc.), and determining the final winner.
3
+ #
4
+ # Author: Gemini 2.5 Pro, Gemini 2.5 Flash
5
+
6
+ import numpy as np
7
+ from collections import deque
8
+ import random
9
+
10
+ class IllegalMoveError(ValueError):
11
+ """Exception class raised for an illegal move."""
12
+ pass
13
+
14
+ class Board:
15
+ """
16
+ Represents the Go board and enforces game rules.
17
+ """
18
+ EMPTY, BLACK, WHITE, WALL = 0, 1, 2, 3
19
+ PASS_LOC = 0
20
+
21
+ def __init__(self, size):
22
+ """Initializes the board, sets up walls, and history."""
23
+ if isinstance(size, tuple):
24
+ self.x_size, self.y_size = size
25
+ else:
26
+ self.x_size = self.y_size = size
27
+
28
+ self.arrsize = (self.x_size + 1) * (self.y_size + 2) + 1
29
+ self.dy = self.x_size + 1
30
+ self.adj = [-self.dy, -1, 1, self.dy]
31
+
32
+ self.board = np.zeros(shape=(self.arrsize), dtype=np.int8)
33
+ self.pla = Board.BLACK
34
+ self.prisoners = {Board.BLACK: 0, Board.WHITE: 0}
35
+ self.ko_points = set()
36
+ self.consecutive_passes = 0
37
+ self.turns = 0
38
+ self.ko_recapture_counts = {}
39
+ self.position_history = set()
40
+ self.position_history.add(self.board.tobytes())
41
+
42
+ for i in range(-1, self.x_size + 1):
43
+ self.board[self.loc(i, -1)] = Board.WALL
44
+ self.board[self.loc(i, self.y_size)] = Board.WALL
45
+ for i in range(-1, self.y_size + 1):
46
+ self.board[self.loc(-1, i)] = Board.WALL
47
+ self.board[self.loc(self.x_size, i)] = Board.WALL
48
+
49
+ def copy(self):
50
+ """Creates a deep copy of the current board state."""
51
+ new_board = Board((self.x_size, self.y_size))
52
+ new_board.board = np.copy(self.board)
53
+ new_board.pla = self.pla
54
+ new_board.prisoners = self.prisoners.copy()
55
+ new_board.ko_points = self.ko_points.copy()
56
+ new_board.consecutive_passes = self.consecutive_passes
57
+ new_board.turns = self.turns
58
+ new_board.ko_recapture_counts = self.ko_recapture_counts.copy()
59
+ new_board.position_history = self.position_history.copy()
60
+ return new_board
61
+
62
+ @staticmethod
63
+ def get_opp(player):
64
+ """Gets the opponent of the given player."""
65
+ return 3 - player
66
+
67
+ def loc(self, x, y):
68
+ """Converts (x, y) coordinates to a 1D array location."""
69
+ return (x + 1) + self.dy * (y + 1)
70
+
71
+ def loc_to_coord(self, loc):
72
+ """Converts a 1D array location back to (x, y) coordinates."""
73
+ return (loc % self.dy) - 1, (loc // self.dy) - 1
74
+
75
+ def is_on_board(self, loc):
76
+ """Checks if a location is within the board boundaries (not a wall)."""
77
+ return self.board[loc] != Board.WALL
78
+
79
+ def _get_group_info(self, loc):
80
+ """Scans and returns the stones and liberties of a group at a specific location."""
81
+ if not self.is_on_board(loc) or self.board[loc] == self.EMPTY:
82
+ return None, None
83
+
84
+ player = self.board[loc]
85
+ group_stones, liberties = set(), set()
86
+ q, visited = deque([loc]), {loc}
87
+
88
+ while q:
89
+ current_loc = q.popleft()
90
+ group_stones.add(current_loc)
91
+ for dloc in self.adj:
92
+ adj_loc = current_loc + dloc
93
+ if self.is_on_board(adj_loc):
94
+ adj_stone = self.board[adj_loc]
95
+ if adj_stone == self.EMPTY:
96
+ liberties.add(adj_loc)
97
+ elif adj_stone == player and adj_loc not in visited:
98
+ visited.add(adj_loc)
99
+ q.append(adj_loc)
100
+ return group_stones, liberties
101
+
102
+ def would_be_legal(self, player, loc):
103
+ """Checks if a move would be legal without actually playing it."""
104
+ if loc == self.PASS_LOC: return True
105
+ if not self.is_on_board(loc) or self.board[loc] != self.EMPTY or loc in self.ko_points:
106
+ return False
107
+
108
+ temp_board = self.copy()
109
+ temp_board.board[loc] = player
110
+
111
+ opponent = self.get_opp(player)
112
+ captured_any = False
113
+ captured_stones = set()
114
+
115
+ for dloc in temp_board.adj:
116
+ adj_loc = loc + dloc
117
+ if temp_board.board[adj_loc] == opponent:
118
+ group, libs = temp_board._get_group_info(adj_loc)
119
+ if not libs:
120
+ captured_any = True
121
+ captured_stones.update(group)
122
+
123
+ if captured_any:
124
+ for captured_loc in captured_stones:
125
+ temp_board.board[captured_loc] = self.EMPTY
126
+
127
+ next_board_hash = temp_board.board.tobytes()
128
+ if next_board_hash in self.position_history:
129
+ return False
130
+
131
+ if captured_any:
132
+ return True
133
+
134
+ _, my_libs = temp_board._get_group_info(loc)
135
+ return bool(my_libs)
136
+
137
+ def get_features(self):
138
+ """Generates a feature tensor for the neural network input."""
139
+ features = np.zeros((self.y_size, self.x_size, 3), dtype=np.float32)
140
+
141
+ current_player = self.pla
142
+ opponent_player = self.get_opp(self.pla)
143
+
144
+ for y in range(self.y_size):
145
+ for x in range(self.x_size):
146
+ loc = self.loc(x, y)
147
+ stone = self.board[loc]
148
+ if stone == current_player:
149
+ features[y, x, 0] = 1
150
+ elif stone == opponent_player:
151
+ features[y, x, 1] = 1
152
+
153
+ if self.pla == self.WHITE:
154
+ features[:, :, 2] = 1.0
155
+
156
+ return features
157
+
158
+ def is_game_over(self):
159
+ """Checks if the game is over (due to consecutive passes)."""
160
+ return self.consecutive_passes >= 2
161
+
162
+ def play(self, player, loc):
163
+ """Plays a move on the board, captures stones, and updates the game state."""
164
+ if not self.would_be_legal(player, loc):
165
+ raise IllegalMoveError("This move is against the rules.")
166
+
167
+ self.ko_points.clear()
168
+
169
+ if loc == self.PASS_LOC:
170
+ self.consecutive_passes += 1
171
+ else:
172
+ self.consecutive_passes = 0
173
+ self.board[loc] = player
174
+ opponent = self.get_opp(player)
175
+
176
+ captured_stones = set()
177
+
178
+ for dloc in self.adj:
179
+ adj_loc = loc + dloc
180
+ if self.board[adj_loc] == opponent:
181
+ group, libs = self._get_group_info(adj_loc)
182
+ if not libs:
183
+ captured_stones.update(group)
184
+
185
+ if captured_stones:
186
+ self.prisoners[player] += len(captured_stones)
187
+ for captured_loc in captured_stones:
188
+ self.board[captured_loc] = self.EMPTY
189
+
190
+ my_group, my_libs = self._get_group_info(loc)
191
+
192
+ if len(captured_stones) == 1 and len(my_group) == 1 and len(my_libs) == 1:
193
+ ko_loc = captured_stones.pop()
194
+ self.ko_points.add(ko_loc)
195
+
196
+ board_hash = self.board.tobytes()
197
+ self.position_history.add(board_hash)
198
+ self.pla = self.get_opp(player)
199
+ self.turns += 1
200
+
201
+ def _is_group_alive_statically(self, group_stones: set, board_state: np.ndarray) -> bool:
202
+ """Statically analyzes if a group is alive by checking for two eyes."""
203
+ if not group_stones: return False
204
+ owner_player = board_state[next(iter(group_stones))]
205
+ eye_locations = set()
206
+ for stone_loc in group_stones:
207
+ for dloc in self.adj:
208
+ adj_loc = stone_loc + dloc
209
+ if board_state[adj_loc] == self.EMPTY: eye_locations.add(adj_loc)
210
+ real_eye_count, visited_eye_locs = 0, set()
211
+ for eye_loc in eye_locations:
212
+ if eye_loc in visited_eye_locs: continue
213
+ eye_region, q, is_real_eye = set(), deque([eye_loc]), True
214
+ visited_eye_locs.add(eye_loc); eye_region.add(eye_loc)
215
+ while q:
216
+ current_loc = q.popleft()
217
+ for dloc in self.adj:
218
+ adj_loc = current_loc + dloc
219
+ if self.is_on_board(adj_loc):
220
+ if board_state[adj_loc] == self.get_opp(owner_player):
221
+ is_real_eye = False; break
222
+ elif board_state[adj_loc] == self.EMPTY and adj_loc not in visited_eye_locs:
223
+ visited_eye_locs.add(adj_loc); eye_region.add(adj_loc); q.append(adj_loc)
224
+ if not is_real_eye: break
225
+ if is_real_eye:
226
+ eye_size = len(eye_region)
227
+ if eye_size >= 6: real_eye_count += 2
228
+ else: real_eye_count += 1
229
+ if real_eye_count >= 2: return True
230
+ return real_eye_count >= 2
231
+
232
+ def _is_group_alive_by_rollout(self, group_stones_initial: set) -> bool:
233
+ """Determines if a group is alive via Monte Carlo rollouts for ambiguous cases."""
234
+ NUM_ROLLOUTS = 20
235
+ MAX_ROLLOUT_DEPTH = self.x_size * self.y_size // 2
236
+
237
+ owner_player = self.board[next(iter(group_stones_initial))]
238
+ attacker = self.get_opp(owner_player)
239
+ deaths = 0
240
+
241
+ for _ in range(NUM_ROLLOUTS):
242
+ rollout_board = self.copy()
243
+ rollout_board.pla = attacker
244
+
245
+ for _ in range(MAX_ROLLOUT_DEPTH):
246
+ first_stone_loc = next(iter(group_stones_initial))
247
+ if rollout_board.board[first_stone_loc] != owner_player:
248
+ deaths += 1
249
+ break
250
+
251
+ legal_moves = [loc for loc in range(1, self.arrsize) if rollout_board.board[loc] == self.EMPTY]
252
+ random.shuffle(legal_moves)
253
+
254
+ move_made = False
255
+ for move in legal_moves:
256
+ if rollout_board.would_be_legal(rollout_board.pla, move):
257
+ rollout_board.play(rollout_board.pla, move)
258
+ move_made = True
259
+ break
260
+
261
+ if not move_made:
262
+ rollout_board.play(rollout_board.pla, self.PASS_LOC)
263
+
264
+ if rollout_board.is_game_over():
265
+ break
266
+
267
+ death_rate = deaths / NUM_ROLLOUTS
268
+ print(f"[Life/Death Log] Group survival probability: {1-death_rate:.0%}")
269
+ return death_rate < 0.5
270
+
271
+ def get_winner(self, komi=6.5):
272
+ """Calculates the final score and determines the winner, handling life and death."""
273
+ temp_board_state = np.copy(self.board)
274
+ total_captives = self.prisoners.copy()
275
+
276
+ all_groups = self._find_all_groups(temp_board_state)
277
+ for player, groups in all_groups.items():
278
+ for group_stones in groups:
279
+ is_alive = self._is_group_alive_statically(group_stones, temp_board_state)
280
+
281
+ if not is_alive:
282
+ is_alive = self._is_group_alive_by_rollout(group_stones)
283
+
284
+ if not is_alive:
285
+ total_captives[self.get_opp(player)] += len(group_stones)
286
+ for stone_loc in group_stones:
287
+ temp_board_state[stone_loc] = self.EMPTY
288
+
289
+ final_board_with_territory = self._calculate_territory(temp_board_state)
290
+ black_territory = np.sum((final_board_with_territory == self.BLACK) & (temp_board_state == self.EMPTY))
291
+ white_territory = np.sum((final_board_with_territory == self.WHITE) & (temp_board_state == self.EMPTY))
292
+
293
+ black_score = black_territory + total_captives.get(self.BLACK, 0)
294
+ white_score = white_territory + total_captives.get(self.WHITE, 0) + komi
295
+ winner = self.BLACK if black_score > white_score else self.WHITE
296
+
297
+ return winner, black_score, white_score, total_captives
298
+
299
+ def _find_all_groups(self, board_state: np.ndarray) -> dict:
300
+ """Finds all stone groups on the board for a given board state."""
301
+ visited, all_groups = set(), {self.BLACK: [], self.WHITE: []}
302
+ for loc in range(self.arrsize):
303
+ if board_state[loc] in [self.BLACK, self.WHITE] and loc not in visited:
304
+ player, group_stones, q = board_state[loc], set(), deque([loc])
305
+ visited.add(loc); group_stones.add(loc)
306
+ while q:
307
+ current_loc = q.popleft()
308
+ for dloc in self.adj:
309
+ adj_loc = current_loc + dloc
310
+ if board_state[adj_loc] == player and adj_loc not in visited:
311
+ visited.add(adj_loc); group_stones.add(adj_loc); q.append(adj_loc)
312
+ all_groups[player].append(group_stones)
313
+ return all_groups
314
+
315
+ def _calculate_territory(self, board_state: np.ndarray) -> np.ndarray:
316
+ """Calculates the territory for each player on a given board state."""
317
+ territory_map, visited = np.copy(board_state), set()
318
+ for loc in range(self.arrsize):
319
+ if territory_map[loc] == self.EMPTY and loc not in visited:
320
+ region_points, border_colors, q = set(), set(), deque([loc])
321
+ visited.add(loc); region_points.add(loc)
322
+ while q:
323
+ current_loc = q.popleft()
324
+ for dloc in self.adj:
325
+ adj_loc = current_loc + dloc
326
+ if board_state[adj_loc] in [self.BLACK, self.WHITE]: border_colors.add(board_state[adj_loc])
327
+ elif board_state[adj_loc] == self.EMPTY and adj_loc not in visited:
328
+ visited.add(adj_loc); region_points.add(adj_loc); q.append(adj_loc)
329
+ if len(border_colors) == 1:
330
+ owner = border_colors.pop()
331
+ for point in region_points: territory_map[point] = owner
332
+ return territory_map
hongik/engine_ai.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The engine file that serves as the AI's brain, responsible for training
2
+ # the model through reinforcement learning and performing move analysis.
3
+ #
4
+ # Author: Gemini 2.5 Pro, Gemini 2.5 Flash
5
+
6
+ import os
7
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
8
+
9
+ import threading
10
+ import tensorflow as tf
11
+ tf.get_logger().setLevel('ERROR')
12
+
13
+ import numpy as np
14
+ import time
15
+ import json
16
+ import random
17
+ import traceback
18
+ from collections import deque
19
+ import pickle
20
+ import csv
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from hongik.hongik_ai import HongikAIPlayer,CNNTransformerHybrid
24
+ from katrain.core.sgf_parser import Move
25
+ from katrain.core.constants import *
26
+ from kivy.clock import Clock
27
+ from hongik.board_ai import Board, IllegalMoveError
28
+ from katrain.core.game_node import GameNode
29
+ from katrain.gui.theme import Theme
30
+
31
+ class BaseEngine:
32
+ """Base class for KaTrain engines."""
33
+ def __init__(self, katrain, config):
34
+ self.katrain, self.config = katrain, config
35
+ def on_error(self, message, code=None, allow_popup=True):
36
+ print(f"ERROR: {message}", OUTPUT_ERROR)
37
+ if allow_popup and hasattr(self.katrain, "engine_recovery_popup"):
38
+ Clock.schedule_once(lambda dt: self.katrain("engine_recovery_popup", message, code))
39
+
40
+ class HongikAIEngine(BaseEngine):
41
+ """
42
+ Main AI engine that manages the model, self-play, training, and analysis.
43
+ It orchestrates the entire reinforcement learning loop.
44
+ """
45
+ BOARD_SIZE, NUM_LAYERS, D_MODEL, NUM_HEADS, D_FF = 19, 7, 256, 8, 1024
46
+ SAVE_WEIGHTS_EVERY_STEPS, EVALUATION_EVERY_STEPS = 5, 20
47
+ REPLAY_BUFFER_SIZE, TRAINING_BATCH_SIZE = 200000, 32
48
+ CHECKPOINT_EVERY_GAMES = 10
49
+
50
+ RULES = {
51
+ "tromp-taylor": {"name": "Tromp-Taylor", "komi": 7.5, "scoring": "area"},
52
+ "korean": {"name": "korean", "komi": 6.5, "scoring": "territory"},
53
+ "chinese": {"name": "Chinese", "komi": 7.5, "scoring": "area"}
54
+ }
55
+
56
+ @staticmethod
57
+ def get_rules(ruleset: str):
58
+ """Returns the ruleset details for a given ruleset name."""
59
+ if not ruleset or ruleset.lower() not in HongikAIEngine.RULES:
60
+ ruleset = "korean"
61
+ return HongikAIEngine.RULES[ruleset.lower()]
62
+
63
+ def __init__(self, katrain, config):
64
+ """
65
+ Initializes the Hongik AI Engine. This involves setting up paths, loading
66
+ the neural network model and replay buffer, and preparing for training.
67
+ """
68
+ super().__init__(katrain, config)
69
+ print("Initializing Hongik AI Integrated Engine...", OUTPUT_DEBUG)
70
+
71
+ from appdirs import user_data_dir
72
+ APP_NAME = "HongikAI"
73
+ APP_AUTHOR = "NamyongPark"
74
+
75
+ self.BASE_PATH = user_data_dir(APP_NAME, APP_AUTHOR)
76
+ print(f"Data will be stored in: {self.BASE_PATH}")
77
+
78
+ self.REPLAY_BUFFER_PATH = os.path.join(self.BASE_PATH, "replay_buffer.pkl")
79
+ self.WEIGHTS_FILE_PATH = os.path.join(self.BASE_PATH, "hongik_ai_memory.weights.h5")
80
+ self.BEST_WEIGHTS_FILE_PATH = os.path.join(self.BASE_PATH, "hongik_ai_best.weights.h5")
81
+ self.CHECKPOINT_BUFFER_PATH = os.path.join(self.BASE_PATH, "replay_buffer_checkpoint.pkl")
82
+ self.CHECKPOINT_WEIGHTS_PATH = os.path.join(self.BASE_PATH, "hongik_ai_checkpoint.weights.h5")
83
+ self.TRAINING_LOG_PATH = os.path.join(self.BASE_PATH, "training_log.csv")
84
+
85
+ os.makedirs(self.BASE_PATH, exist_ok=True)
86
+
87
+ REPO_ID = "puco21/HongikAI"
88
+ files_to_download = [
89
+ "replay_buffer.pkl",
90
+ "hongik_ai_memory.weights.h5",
91
+ "hongik_ai_best.weights.h5"
92
+ ]
93
+
94
+ print("Checking for AI data files...")
95
+ for filename in files_to_download:
96
+ local_path = os.path.join(self.BASE_PATH, filename)
97
+ if not os.path.exists(local_path):
98
+ print(f"Downloading {filename} from Hugging Face Hub...")
99
+ try:
100
+ hf_hub_download(
101
+ repo_id=REPO_ID,
102
+ filename=filename,
103
+ local_dir=self.BASE_PATH,
104
+ local_dir_use_symlinks=False
105
+ )
106
+ print(f"'{filename}' download complete.")
107
+ except Exception as e:
108
+ print(f"Failed to download {filename}: {e}")
109
+
110
+ self.replay_buffer = deque(maxlen=self.REPLAY_BUFFER_SIZE)
111
+ self.load_replay_buffer(self.REPLAY_BUFFER_PATH)
112
+
113
+ self.hongik_model = CNNTransformerHybrid(self.NUM_LAYERS, self.D_MODEL, self.NUM_HEADS, self.D_FF, self.BOARD_SIZE)
114
+ _ = self.hongik_model(np.zeros((1, self.BOARD_SIZE, self.BOARD_SIZE, 3), dtype=np.float32))
115
+
116
+ load_path = self.CHECKPOINT_WEIGHTS_PATH if os.path.exists(self.CHECKPOINT_WEIGHTS_PATH) else (self.WEIGHTS_FILE_PATH if os.path.exists(self.WEIGHTS_FILE_PATH) else self.BEST_WEIGHTS_FILE_PATH)
117
+ if os.path.exists(load_path):
118
+ try:
119
+ self.hongik_model.load_weights(load_path)
120
+ print(f"Successfully loaded weights: {load_path}")
121
+ except Exception as e:
122
+ print(f"Failed to load weights (starting new training): {e}")
123
+
124
+ try:
125
+ max_visits = int(config.get("max_visits",150))
126
+ except (ValueError, TypeError):
127
+ print(f"Warning: Invalid max_visits value in config. Using default (150).")
128
+ max_visits = 150
129
+ self.hongik_player = HongikAIPlayer(self.hongik_model, n_simulations=max_visits)
130
+
131
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1.0)
132
+ self.policy_loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
133
+ self.value_loss_fn = tf.keras.losses.MeanSquaredError()
134
+ self.training_step_counter, self.game_history, self.self_play_active = 0, [], False
135
+
136
+ class MockProcess: poll = lambda self: None
137
+ self.katago_process = self.hongik_process = MockProcess()
138
+ self.sound_index = False
139
+ print("Hongik AI Engine ready!", OUTPUT_DEBUG)
140
+
141
+ def save_replay_buffer(self, path):
142
+ """Saves the current replay buffer to a specified file path using pickle."""
143
+ try:
144
+ with open(path, 'wb') as f:
145
+ pickle.dump(self.replay_buffer, f)
146
+ print(f"Successfully saved experience data ({len(self.replay_buffer)} items) to '{path}'.")
147
+ except Exception as e:
148
+ print(f"Error saving experience data: {e}")
149
+
150
+ def load_replay_buffer(self, path):
151
+ """Loads a replay buffer from a file, prioritizing a checkpoint file if it exists."""
152
+ load_path = self.CHECKPOINT_BUFFER_PATH if os.path.exists(self.CHECKPOINT_BUFFER_PATH) else path
153
+ if os.path.exists(load_path):
154
+ try:
155
+ with open(load_path, 'rb') as f:
156
+ self.replay_buffer = pickle.load(f)
157
+
158
+ if self.replay_buffer.maxlen != self.REPLAY_BUFFER_SIZE:
159
+ new_buffer = deque(maxlen=self.REPLAY_BUFFER_SIZE)
160
+ new_buffer.extend(self.replay_buffer)
161
+ self.replay_buffer = new_buffer
162
+ print(f"Successfully loaded experience data ({len(self.replay_buffer)} items) from '{load_path}'.")
163
+ except Exception as e:
164
+ print(f"Error loading experience data: {e}")
165
+
166
+ def _checkpoint_save(self):
167
+ """Saves a training checkpoint, including both the replay buffer and model weights."""
168
+ print(f"\n[{time.strftime('%H:%M:%S')}] Saving checkpoint...")
169
+ self.save_replay_buffer(self.CHECKPOINT_BUFFER_PATH)
170
+ self.hongik_model.save_weights(self.CHECKPOINT_WEIGHTS_PATH)
171
+ print("Checkpoint saved.")
172
+
173
+ def _log_training_progress(self, details: dict):
174
+ """Logs the progress of the training process to a CSV file for later analysis."""
175
+ try:
176
+ file_exists = os.path.isfile(self.TRAINING_LOG_PATH)
177
+ with open(self.TRAINING_LOG_PATH, 'a', newline='', encoding='utf-8') as f:
178
+ writer = csv.DictWriter(f, fieldnames=details.keys())
179
+ if not file_exists:
180
+ writer.writeheader()
181
+ writer.writerow(details)
182
+ except Exception as e:
183
+ print(f"Error logging training progress: {e}")
184
+
185
+ def _node_to_board(self, node: GameNode) -> Board:
186
+ """Converts a KaTrain GameNode object to the internal Board representation used by the engine."""
187
+ board_snapshot = Board(node.board_size)
188
+ root_node = node.root
189
+ for player, prop_name in [(Board.BLACK, 'AB'), (Board.WHITE, 'AW')]:
190
+ setup_stones = root_node.properties.get(prop_name, [])
191
+ if setup_stones:
192
+ for coords in setup_stones:
193
+ loc = board_snapshot.loc(coords[0], coords[1])
194
+ if board_snapshot.board[loc] == Board.EMPTY:
195
+ board_snapshot.play(player, loc)
196
+
197
+ current_player = Board.BLACK
198
+ for scene_node in node.nodes_from_root[1:]:
199
+ move = scene_node.move
200
+ if move:
201
+ loc = Board.PASS_LOC if move.is_pass else board_snapshot.loc(move.coords[0], move.coords[1])
202
+ board_snapshot.play(current_player, loc)
203
+ current_player = board_snapshot.pla
204
+
205
+ board_snapshot.pla = Board.BLACK if node.next_player == 'B' else Board.WHITE
206
+ return board_snapshot
207
+
208
+ def request_analysis(self, analysis_node: GameNode, callback: callable, **kwargs):
209
+ """
210
+ Requests an analysis of a specific board position. The analysis is run
211
+ in a separate thread to avoid blocking the GUI.
212
+ """
213
+ if not self.katrain.game: return
214
+ game_id = self.katrain.game.game_id
215
+ board = self._node_to_board(analysis_node)
216
+ threading.Thread(target=self._run_analysis, args=(game_id, board, analysis_node, callback), daemon=True).start()
217
+
218
+ def _run_analysis(self, game_id, board, analysis_node, callback):
219
+ """
220
+ The target function for the analysis thread. It runs MCTS and sends the
221
+ formatted results back to the GUI via the provided callback.
222
+ """
223
+ try:
224
+ policy_logits, _ = self.hongik_player.model(np.expand_dims(board.get_features(), 0), training=False)
225
+ policy = tf.nn.softmax(policy_logits[0]).numpy()
226
+
227
+ _, root_node = self.hongik_player.get_best_move(board)
228
+ analysis_result = self._format_analysis_results("analysis", root_node, board, policy)
229
+
230
+ analysis_node.analysis = analysis_result
231
+
232
+ def guarded_callback(dt):
233
+ if self.katrain.game and self.katrain.game.game_id == game_id:
234
+ callback(analysis_result, False)
235
+ Clock.schedule_once(guarded_callback)
236
+ except Exception as e:
237
+ print(f"Error during AI analysis execution: {e}")
238
+ traceback.print_exc()
239
+
240
+ def _format_analysis_results(self, query_id, root_node, board, policy=None): # <-- policy=None 인자 추가
241
+ """
242
+ MCTS 분석 데이터를 KaTrain GUI가 이해할 수 있는 딕셔너리 형식으로 변환합니다.
243
+ """
244
+ move_infos, moves_dict = [], {}
245
+
246
+ if root_node and root_node.children:
247
+ sorted_children = sorted(root_node.children.items(), key=lambda i: i[1].n_visits, reverse=True)
248
+
249
+ best_move_q = sorted_children[0][1].q_value if sorted_children else 0
250
+
251
+ for i, (action, child) in enumerate(sorted_children):
252
+ coords = board.loc_to_coord(self.hongik_player._action_to_loc(action, board))
253
+ move_gtp = Move(coords=coords).gtp()
254
+
255
+ current_player_winrate = (child.q_value + 1) / 2
256
+ display_winrate = 1.0 - current_player_winrate if board.pla == Board.WHITE else current_player_winrate
257
+ display_score = -child.q_value * 20 if board.pla == Board.WHITE else child.q_value * 20
258
+
259
+ points_lost = (best_move_q - child.q_value) * 20
260
+
261
+ move_data = {
262
+ "move": move_gtp,
263
+ "visits": child.n_visits,
264
+ "winrate": display_winrate,
265
+ "scoreLead": display_score,
266
+ "pointsLost": points_lost,
267
+ "pv": [move_gtp],
268
+ "order": i
269
+ }
270
+
271
+ move_infos.append(move_data)
272
+ moves_dict[move_gtp] = move_data
273
+
274
+ current_player_winrate = (root_node.q_value + 1) / 2 if root_node else 0.5
275
+ display_winrate = 1.0 - current_player_winrate if board.pla == Board.WHITE else current_player_winrate
276
+ display_score = -root_node.q_value * 20 if (root_node and board.pla == Board.WHITE) else (root_node.q_value * 20 if root_node else 0.0)
277
+
278
+ root_info = {"winrate": display_winrate, "scoreLead": display_score, "visits": root_node.n_visits if root_node else 0}
279
+ return {"id": query_id, "moveInfos": move_infos, "moves": moves_dict, "root": root_info, "rootInfo": root_info, "policy": policy.tolist() if policy is not None else None, "completed": True}
280
+
281
+ def start_self_play_loop(self):
282
+ """Starts the main self-play loop, which continuously plays games to generate training data."""
283
+ print(f"\n===========================================\n[{time.strftime('%H:%M:%S')}] Starting new self-play game.\n===========================================")
284
+ self.stop_self_play_loop()
285
+ self.self_play_active = True
286
+ self.game_history = []
287
+ Clock.schedule_once(self._self_play_turn, 0.3)
288
+
289
+ def request_score(self, game_node, callback):
290
+ """Requests a score calculation for the current game node, run in a separate thread."""
291
+ threading.Thread(target=lambda: callback(self.get_score(game_node)), daemon=True).start()
292
+
293
+ def stop_self_play_loop(self):
294
+ """Stops the active self-play loop."""
295
+ if not self.self_play_active: return
296
+ self.self_play_active = False
297
+ Clock.unschedule(self._self_play_turn)
298
+
299
+ def _self_play_turn(self, dt=None):
300
+ """
301
+ Executes a single turn of a self-play game. It gets the best move from the
302
+ AI, plays it on the board, and stores the state for later training.
303
+ """
304
+ if not self.self_play_active: return
305
+ game = self.katrain.game
306
+ try:
307
+ current_node = game.current_node
308
+ board_snapshot = self._node_to_board(current_node)
309
+ if game.end_result or board_snapshot.is_game_over():
310
+ self._process_game_result(game)
311
+ return
312
+ move_loc, root_node = self.hongik_player.get_best_move(board_snapshot, is_self_play=True)
313
+ coords = None if move_loc == Board.PASS_LOC else board_snapshot.loc_to_coord(move_loc)
314
+ move_obj = Move(player='B' if board_snapshot.pla == Board.BLACK else 'W', coords=coords)
315
+ game.play(move_obj)
316
+ if not move_obj.is_pass and self.sound_index:
317
+ self.katrain.play_sound()
318
+ black_player_type = self.katrain.players_info['B'].player_type
319
+ white_player_type = self.katrain.players_info['W'].player_type
320
+
321
+ if black_player_type == PLAYER_AI and white_player_type == PLAYER_AI:
322
+ if self.katrain.game.current_node.next_player == 'B':
323
+ self.katrain.controls.players['B'].active = True
324
+ self.katrain.controls.players['W'].active = False
325
+ else:
326
+ self.katrain.controls.players['B'].active = False
327
+ self.katrain.controls.players['W'].active = True
328
+
329
+ policy = np.zeros(self.BOARD_SIZE**2 + 1, dtype=np.float32)
330
+ if root_node and root_node.children:
331
+ total_visits = sum(c.n_visits for c in root_node.children.values())
332
+ if total_visits > 0:
333
+ for action, child in root_node.children.items(): policy[action] = child.n_visits / total_visits
334
+
335
+ blacks_win_rate = 0.5
336
+ if root_node:
337
+ player_q_value = root_node.q_value
338
+ player_win_rate = (player_q_value + 1) / 2
339
+ blacks_win_rate = player_win_rate if board_snapshot.pla == Board.BLACK else (1 - player_win_rate)
340
+
341
+ self.game_history.append([board_snapshot.get_features(), policy, board_snapshot.pla, blacks_win_rate])
342
+ self.katrain.update_gui(game.current_node)
343
+ self.sound_index = True
344
+ Clock.schedule_once(self._self_play_turn, 0.3)
345
+ except Exception as e:
346
+ print(f"Critical error during self-play: {e}"); traceback.print_exc(); self.stop_self_play_loop()
347
+
348
+ def _process_game_result(self, game: 'Game'):
349
+ """
350
+ Processes the result of a finished game. It requests a final score and
351
+ then triggers the callback to handle training data generation.
352
+ """
353
+ try:
354
+ self.katrain.controls.set_status("Scoring...", STATUS_INFO)
355
+ self.katrain.board_gui.game_over_message = "Scoring..."
356
+
357
+ self.katrain.board_gui.game_is_over = True
358
+ self.request_score(game.current_node, self._on_score_calculated)
359
+ except Exception as e:
360
+ print(f"Error requesting score calculation: {e}")
361
+ self.katrain._do_start_hongik_selfplay()
362
+
363
+ def _on_score_calculated(self, score_details):
364
+ """
365
+ Callback function that handles the game result after scoring. It assigns rewards,
366
+ augments the data, adds it to the replay buffer, and triggers a training step.
367
+ """
368
+ try:
369
+ if not score_details:
370
+ print("Game ended but no result. Starting next game.")
371
+ return
372
+
373
+ game_num = self.training_step_counter + 1
374
+ winner_text = "Black" if score_details['winner'] == 'B' else "White"
375
+ b_score, w_score, diff = score_details['black_score'], score_details['white_score'], score_details['score']
376
+ final_message = f"{winner_text} wins by {abs(diff):.1f} points"
377
+ self.katrain.board_gui.game_over_message = final_message
378
+ print(f"\n==========================================\n[{time.strftime('%H:%M:%S')}] Game #{game_num} Finished\n-----------------------------------------\n Winner: {winner_text}\n Margin: {abs(diff):.1f} points\n Details: Black {b_score:.1f} vs White {w_score:.1f}\n--------------------------------------------")
379
+
380
+ winner = Board.BLACK if score_details['winner'] == 'B' else Board.WHITE
381
+
382
+ REVERSAL_THRESHOLD = 0.2
383
+ WIN_REWARD = 1.0
384
+ LOSS_REWARD = -1.0
385
+ BRILLIANT_MOVE_BONUS = 0.5
386
+ CONSOLATION_REWARD = 0.5
387
+
388
+ for i, (features, policy, player_turn, blacks_win_rate_after) in enumerate(self.game_history):
389
+ blacks_win_rate_before = self.game_history[i-1][3] if i > 0 else 0.5
390
+ if player_turn == Board.BLACK:
391
+ win_rate_swing = blacks_win_rate_after - blacks_win_rate_before
392
+ else: # player_turn == Board.WHITE
393
+ white_win_rate_before = 1 - blacks_win_rate_before
394
+ white_win_rate_after = 1 - blacks_win_rate_after
395
+ win_rate_swing = white_win_rate_after - white_win_rate_before
396
+
397
+ is_brilliant_move = win_rate_swing > REVERSAL_THRESHOLD
398
+
399
+ if player_turn == winner:
400
+ reward = WIN_REWARD
401
+ if is_brilliant_move:
402
+ reward += BRILLIANT_MOVE_BONUS
403
+ else:
404
+ reward = LOSS_REWARD
405
+ if is_brilliant_move:
406
+ reward = CONSOLATION_REWARD
407
+
408
+ for j in range(8):
409
+ aug_features = self._augment_data(features, j, 'features')
410
+ aug_policy = self._augment_data(policy, j, 'policy')
411
+ self.replay_buffer.append([aug_features, aug_policy, reward])
412
+
413
+ self.training_step_counter += 1
414
+ loss = self._train_model() if len(self.replay_buffer) >= self.TRAINING_BATCH_SIZE else None
415
+ if loss is not None:
416
+ print(f" Training complete! (Final loss: {loss:.4f})\n=======================================", OUTPUT_DEBUG)
417
+
418
+ log_data = {
419
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
420
+ 'game_num': game_num,
421
+ 'winner': score_details['winner'],
422
+ 'score_diff': diff,
423
+ 'total_moves': len(self.game_history),
424
+ 'loss': f"{loss:.4f}" if loss else "N/A"
425
+ }
426
+ self._log_training_progress(log_data)
427
+
428
+ if self.training_step_counter % self.SAVE_WEIGHTS_EVERY_STEPS == 0:
429
+ self.hongik_model.save_weights(self.WEIGHTS_FILE_PATH)
430
+
431
+ if self.training_step_counter % self.CHECKPOINT_EVERY_GAMES == 0:
432
+ self._checkpoint_save()
433
+
434
+ if self.training_step_counter % self.EVALUATION_EVERY_STEPS == 0:
435
+ self._evaluate_model()
436
+
437
+ except Exception as e:
438
+ print(f"Error during post-game processing: {e}")
439
+ traceback.print_exc()
440
+ finally:
441
+ self.katrain._do_start_hongik_selfplay()
442
+
443
+ def _train_model(self):
444
+ """
445
+ Performs one training step. It samples a minibatch from the replay buffer
446
+ and uses it to update the neural network's weights.
447
+ """
448
+ if len(self.replay_buffer) < self.TRAINING_BATCH_SIZE: return None
449
+ total_loss, TRAIN_ITERATIONS = 0, 5
450
+ for _ in range(TRAIN_ITERATIONS):
451
+ minibatch = random.sample(self.replay_buffer, self.TRAINING_BATCH_SIZE)
452
+ features, policies, values = (np.array(e) for e in zip(*minibatch))
453
+ with tf.GradientTape() as tape:
454
+ pred_p, pred_v = self.hongik_model(features, training=True)
455
+ value_loss = self.value_loss_fn(values[:, None], pred_v)
456
+ policy_loss = self.policy_loss_fn(policies, pred_p)
457
+ loss = policy_loss + value_loss
458
+ self.optimizer.apply_gradients(zip(tape.gradient(loss, self.hongik_model.trainable_variables), self.hongik_model.trainable_variables))
459
+ total_loss += loss.numpy()
460
+ return total_loss / TRAIN_ITERATIONS
461
+
462
+ def _augment_data(self, data, index, data_type):
463
+ """
464
+ Augments the training data by applying 8 symmetries (rotations and flips)
465
+ to the board features and policy target.
466
+ """
467
+ if data_type == 'features':
468
+ augmented = data
469
+ if index & 1: augmented = np.fliplr(augmented)
470
+ if index & 2: augmented = np.flipud(augmented)
471
+ if index & 4: augmented = np.rot90(augmented, 1)
472
+ return augmented
473
+ elif data_type == 'policy':
474
+ policy_board = data[:-1].reshape(self.BOARD_SIZE, self.BOARD_SIZE)
475
+ augmented_board = policy_board
476
+ if index & 1: augmented_board = np.fliplr(augmented_board)
477
+ if index & 2: augmented_board = np.flipud(augmented_board)
478
+ if index & 4: augmented_board = np.rot90(augmented_board, 1)
479
+ return np.append(augmented_board.flatten(), data[-1])
480
+ return data
481
+
482
+ def get_score(self, game_node):
483
+ """Calculates the final score of a game using the board's internal scoring method."""
484
+ try:
485
+ board = self._node_to_board(game_node)
486
+ winner, black_score, white_score, _ = board.get_winner(self.katrain.game.komi)
487
+ score_diff = black_score - white_score
488
+ return {"winner": "B" if winner == Board.BLACK else "W", "score": score_diff, "black_score": black_score, "white_score": white_score}
489
+ except Exception as e:
490
+ print(f"Error during internal score calculation: {e}"); traceback.print_exc(); return None
491
+
492
+ def _game_turn(self):
493
+ """
494
+ Handles the AI's turn in a game against a human or another AI. It runs
495
+ in a separate thread to avoid blocking the GUI.
496
+ """
497
+ if self.self_play_active or self.katrain.game.end_result: return
498
+ next_player_info = self.katrain.players_info[self.katrain.game.current_node.next_player]
499
+ if next_player_info.player_type == PLAYER_AI:
500
+ def ai_move_thread():
501
+ try:
502
+ board_snapshot = self._node_to_board(self.katrain.game.current_node)
503
+ move_loc, _ = self.hongik_player.get_best_move(board_snapshot, is_self_play=False)
504
+ coords = None if move_loc == Board.PASS_LOC else board_snapshot.loc_to_coord(move_loc)
505
+ Clock.schedule_once(lambda dt: self.katrain._do_play(coords))
506
+ except Exception as e:
507
+ print(f"\n--- Critical error during AI thinking (in thread) ---\n{traceback.format_exc()}\n---------------------------------------\n")
508
+ threading.Thread(target=ai_move_thread, daemon=True).start()
509
+
510
+ def _evaluate_model(self):
511
+ """
512
+ Periodically evaluates the currently training model against the best-known
513
+ 'champion' model to measure progress and update the best weights if the
514
+ challenger is stronger.
515
+ """
516
+ print("\n--- [Championship Match Start] ---")
517
+ challenger_player = self.hongik_player
518
+ best_weights_path = self.BEST_WEIGHTS_FILE_PATH
519
+ if not os.path.exists(best_weights_path):
520
+ print("[Championship Match] Crowning the first champion!")
521
+ self.hongik_model.save_weights(best_weights_path)
522
+ return
523
+ champion_model = CNNTransformerHybrid(self.NUM_LAYERS, self.D_MODEL, self.NUM_HEADS, self.D_FF, self.BOARD_SIZE)
524
+ _ = champion_model(np.zeros((1, self.BOARD_SIZE, self.BOARD_SIZE, 3), dtype=np.float32))
525
+ champion_model.load_weights(best_weights_path)
526
+ champion_player = HongikAIPlayer(champion_model, int(self.config.get("max_visits", 150)))
527
+ EVAL_GAMES, challenger_wins = 5, 0
528
+ for i in range(EVAL_GAMES):
529
+ print(f"\n[Championship Match] Game {i+1} starting...")
530
+ board = Board(self.BOARD_SIZE)
531
+ players = {Board.BLACK: challenger_player, Board.WHITE: champion_player} if i % 2 == 0 else {Board.BLACK: champion_player, Board.WHITE: challenger_player}
532
+ while not board.is_game_over():
533
+ current_player_obj = players[board.pla]
534
+ move_loc, _ = current_player_obj.get_best_move(board)
535
+ board.play(board.pla, move_loc)
536
+ winner, _, _, _ = board.get_winner()
537
+ if (winner == Board.BLACK and i % 2 == 0) or (winner == Board.WHITE and i % 2 != 0):
538
+ challenger_wins += 1; print(f"[Championship Match] Game {i+1}: Challenger wins!")
539
+ else:
540
+ print(f"[Championship Match] Game {i+1}: Champion wins!")
541
+ print(f"\n--- [Championship Match End] ---\nFinal Score: Challenger {challenger_wins} wins / Champion {EVAL_GAMES - challenger_wins} wins")
542
+ if challenger_wins > EVAL_GAMES / 2:
543
+ print("A new champion is born! Updating 'best' weights.")
544
+ self.hongik_model.save_weights(best_weights_path)
545
+ else:
546
+ print("The champion defends the title. Keeping existing weights.")
547
+
548
+ def on_new_game(self):
549
+ """Called when a new game starts."""
550
+ pass
551
+ def start(self):
552
+ """Starts the engine."""
553
+ self.katrain.game_controls.set_player_selection()
554
+ def shutdown(self, finish=False):
555
+ """Shuts down the engine, saving progress and cleaning up checkpoint files."""
556
+ self.stop_self_play_loop()
557
+ self.save_replay_buffer(self.REPLAY_BUFFER_PATH)
558
+ try:
559
+ if os.path.exists(self.CHECKPOINT_BUFFER_PATH): os.remove(self.CHECKPOINT_BUFFER_PATH)
560
+ if os.path.exists(self.CHECKPOINT_WEIGHTS_PATH): os.remove(self.CHECKPOINT_WEIGHTS_PATH)
561
+ except OSError as e:
562
+ print(f"Error deleting checkpoint files: {e}")
563
+ def stop_pondering(self):
564
+ """Stops pondering."""
565
+ pass
566
+ def queries_remaining(self):
567
+ """Returns the number of remaining queries."""
568
+ return 0
569
+ def is_idle(self):
570
+ """Checks if the engine is idle (i.e., not in a self-play loop)."""
571
+ return not self.self_play_active
hongik/hongik_ai.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implements the AI's 'brain', combining a CNN and Transformer for intuition,
2
+ # and Monte Carlo Tree Search (MCTS) for rational deliberation.
3
+ #
4
+ # Author: 박남영,Gemini 2.5 Pro, Gemini 2.5 Flash
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.keras import layers, Model
8
+ import numpy as np
9
+
10
+ from hongik.board_ai import Board, IllegalMoveError
11
+
12
+
13
+ # ===================================================================
14
+ # 트랜스포머 부품들
15
+ # 이 부분은 우리가 이전에 함께 만들었던 트랜스포머의 핵심 부품들입니다.
16
+ # 아빠의 설계 그대로 완벽하기에, 엄마는 손대지 않았어요.
17
+ # ===================================================================
18
+ def scaled_dot_product_attention(q, k, v, mask=None):
19
+ """
20
+ Calculates the attention scores, which is the core of the attention mechanism.
21
+ It determines how much focus to place on other parts of the input sequence.
22
+ """
23
+ matmul_qk = tf.matmul(q, k, transpose_b=True)
24
+ d_k = tf.cast(tf.shape(k)[-1], tf.float32)
25
+ scaled_attention_logits = matmul_qk / tf.math.sqrt(d_k)
26
+ if mask is not None:
27
+ scaled_attention_logits += (mask * -1e9)
28
+ attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
29
+ output = tf.matmul(attention_weights, v)
30
+ return output, attention_weights
31
+
32
+ class MultiHeadAttention(layers.Layer):
33
+ """
34
+ Implements the Multi-Head Attention mechanism. This allows the model to jointly attend
35
+ to information from different representation subspaces at different positions,
36
+ which is more powerful than single-head attention.
37
+ """
38
+ def __init__(self, d_model, num_heads):
39
+ super(MultiHeadAttention, self).__init__()
40
+ self.num_heads = num_heads
41
+ self.d_model = d_model
42
+ assert d_model % self.num_heads == 0
43
+ self.depth = d_model // self.num_heads
44
+ self.wq = layers.Dense(d_model)
45
+ self.wk = layers.Dense(d_model)
46
+ self.wv = layers.Dense(d_model)
47
+ self.dense = layers.Dense(d_model)
48
+
49
+ def split_heads(self, x, batch_size):
50
+ """Splits the last dimension into (num_heads, depth)."""
51
+ seq_len = tf.shape(x)[1]
52
+ x = tf.reshape(x, (batch_size, seq_len, self.num_heads, self.depth))
53
+ return tf.transpose(x, perm=[0, 2, 1, 3])
54
+
55
+ def call(self, v, k, q, mask=None):
56
+ """Processes the input tensors through the multi-head attention mechanism."""
57
+ batch_size = tf.shape(q)[0]
58
+ q = self.wq(q); k = self.wk(k); v = self.wv(v)
59
+ q = self.split_heads(q, batch_size)
60
+ k = self.split_heads(k, batch_size)
61
+ v = self.split_heads(v, batch_size)
62
+ scaled_attention, _ = scaled_dot_product_attention(q, k, v, mask)
63
+ scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
64
+ concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
65
+ output = self.dense(concat_attention)
66
+ return output
67
+
68
+ class PositionWiseFeedForwardNetwork(layers.Layer):
69
+ """
70
+ Implements the Position-wise Feed-Forward Network. This is applied to each
71
+ position separately and identically. It consists of two linear transformations
72
+ with a ReLU activation in between.
73
+ """
74
+ def __init__(self, d_model, d_ff):
75
+ super(PositionWiseFeedForwardNetwork, self).__init__()
76
+ self.dense_1 = layers.Dense(d_ff, activation='relu')
77
+ self.dense_2 = layers.Dense(d_model)
78
+ def call(self, inputs):
79
+ return self.dense_2(self.dense_1(inputs))
80
+
81
+ class EncoderLayer(layers.Layer):
82
+ """
83
+ Represents one layer of the Transformer encoder. It consists of a multi-head
84
+ attention mechanism followed by a position-wise feed-forward network.
85
+ Includes dropout and layer normalization.
86
+ """
87
+ def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
88
+ super(EncoderLayer, self).__init__()
89
+ self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
90
+ self.ffn = PositionWiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff)
91
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
92
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
93
+ self.dropout1 = layers.Dropout(dropout_rate)
94
+ self.dropout2 = layers.Dropout(dropout_rate)
95
+
96
+ def call(self, inputs, training, padding_mask=None):
97
+ attn_output = self.mha(inputs, inputs, inputs, padding_mask)
98
+ attn_output = self.dropout1(attn_output, training=training)
99
+ out1 = self.layernorm1(inputs + attn_output)
100
+
101
+ ffn_output = self.ffn(out1)
102
+ ffn_output = self.dropout2(ffn_output, training=training)
103
+
104
+ out2 = self.layernorm2(out1 + ffn_output)
105
+
106
+ return out2
107
+
108
+ def get_positional_encoding(max_seq_len, d_model):
109
+ """
110
+ Generates positional encodings. Since the model contains no recurrence or
111
+ convolution, this is used to inject information about the relative or
112
+ absolute position of the tokens in the sequence.
113
+ """
114
+ angle_rads = (np.arange(max_seq_len)[:, np.newaxis] /
115
+ np.power(10000, (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / np.float32(d_model)))
116
+ angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
117
+ angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
118
+ pos_encoding = angle_rads[np.newaxis, ...]
119
+ return tf.cast(pos_encoding, dtype=tf.float32)
120
+
121
+ # ===================================================================
122
+ # 3. CNN + 트랜스포머 '직관' 엔진
123
+ # ===================================================================
124
+ class CNNTransformerHybrid(Model):
125
+ """
126
+ The 'Intuition' engine, combining a 'Scout' (CNN) and a 'Commander' (Transformer).
127
+ This version implements a lightweight head architecture using Squeeze-and-Excitation
128
+ and Convolutional Heads for parameter efficiency and performance.
129
+ """
130
+ def __init__(self, num_transformer_layers, d_model, num_heads, d_ff,
131
+ board_size=19, cnn_filters=128, dropout_rate=0.1):
132
+ super(CNNTransformerHybrid, self).__init__()
133
+ self.board_size = board_size
134
+ self.d_model = d_model
135
+
136
+ self.cnn_conv1 = layers.Conv2D(cnn_filters, 3, padding='same', activation='relu')
137
+ self.cnn_bn1 = layers.BatchNormalization()
138
+ self.cnn_conv2 = layers.Conv2D(d_model, 1, padding='same', activation='relu')
139
+ self.cnn_bn2 = layers.BatchNormalization()
140
+ self.reshape_to_seq = layers.Reshape((board_size * board_size, d_model))
141
+ self.positional_encoding = get_positional_encoding(board_size * board_size, d_model)
142
+ self.dropout = layers.Dropout(dropout_rate)
143
+ self.transformer_encoder = [EncoderLayer(d_model, num_heads, d_ff, dropout_rate) for _ in range(num_transformer_layers)]
144
+ self.reshape_to_2d = layers.Reshape((board_size, board_size, d_model))
145
+
146
+ self.se_gap = layers.GlobalAveragePooling2D()
147
+ self.se_reshape = layers.Reshape((1, 1, d_model))
148
+ self.se_dense_1 = layers.Dense(d_model // 16, activation='relu', kernel_initializer='he_normal', use_bias=False)
149
+ self.se_dense_2 = layers.Dense(d_model, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)
150
+ self.se_multiply = layers.Multiply()
151
+
152
+ self.policy_conv = layers.Conv2D(filters=2, kernel_size=1, padding='same', activation='relu')
153
+ self.policy_bn = layers.BatchNormalization()
154
+ self.policy_flatten = layers.Flatten()
155
+ self.policy_dense = layers.Dense(board_size * board_size + 1, name='policy_head')
156
+
157
+ self.value_conv = layers.Conv2D(filters=1, kernel_size=1, padding='same', activation='relu')
158
+ self.value_bn = layers.BatchNormalization()
159
+ self.value_flatten = layers.Flatten()
160
+ self.value_dense1 = layers.Dense(256, activation='relu')
161
+ self.value_dense2 = layers.Dense(1, activation='tanh', name='value_head')
162
+
163
+ @tf.function(jit_compile=False)
164
+ def call(self, inputs, training=False):
165
+ x = self.cnn_conv1(inputs)
166
+ x = self.cnn_bn1(x, training=training)
167
+ x = self.cnn_conv2(x)
168
+ cnn_output = self.cnn_bn2(x, training=training)
169
+
170
+ x = self.reshape_to_seq(cnn_output)
171
+ seq_len = tf.shape(x)[1]
172
+ x += self.positional_encoding[:, :seq_len, :]
173
+ x = self.dropout(x, training=training)
174
+
175
+ for i in range(len(self.transformer_encoder)):
176
+ x = self.transformer_encoder[i](x, training=training, padding_mask=None)
177
+
178
+ transformer_output = self.reshape_to_2d(x)
179
+
180
+ se = self.se_gap(transformer_output)
181
+ se = self.se_reshape(se)
182
+ se = self.se_dense_1(se)
183
+ se = self.se_dense_2(se)
184
+ se_output = self.se_multiply([transformer_output, se])
185
+
186
+ ph = self.policy_conv(se_output)
187
+ ph = self.policy_bn(ph, training=training)
188
+ ph = self.policy_flatten(ph)
189
+ policy_logits = self.policy_dense(ph)
190
+
191
+ vh = self.value_conv(se_output)
192
+ vh = self.value_bn(vh, training=training)
193
+ vh = self.value_flatten(vh)
194
+ vh = self.value_dense1(vh)
195
+ value = self.value_dense2(vh)
196
+ return policy_logits, value
197
+
198
+ # ===================================================================
199
+ # 4. MCTS '이성' 엔진
200
+ # ===================================================================
201
+ class MCTSNode:
202
+ """
203
+ Represents a single node in the Monte Carlo Tree Search. Each node stores
204
+ statistics like visit count (n_visits), total action value (q_value), and
205
+ prior probability (p_sa).
206
+ """
207
+ def __init__(self, parent=None, prior_p=1.0):
208
+ self.parent, self.children, self.n_visits, self.q_value, self.p_sa = parent, {}, 0, 0, prior_p
209
+ self.C_PUCT_BASE, self.C_PUCT_INIT = 19652, 1.25
210
+
211
+ def select(self, root_n_visits):
212
+ """
213
+ Selects the child node with the highest Upper Confidence Bound (UCB) score.
214
+ This balances exploration and exploitation during the search.
215
+ """
216
+ dynamic_c_puct = np.log((1 + root_n_visits + self.C_PUCT_BASE) / self.C_PUCT_BASE) + self.C_PUCT_INIT
217
+ return max(self.children.items(),
218
+ key=lambda item: item[1].q_value + dynamic_c_puct * item[1].p_sa * np.sqrt(self.n_visits) / (1 + item[1].n_visits))
219
+
220
+ def expand(self, action_probs):
221
+ """
222
+ Expands a leaf node by creating new child nodes for all legal moves,
223
+ initializing their statistics from the prior probabilities given by the
224
+ neural network.
225
+ """
226
+ for action, prob in enumerate(action_probs):
227
+ if prob > 0 and action not in self.children: self.children[action] = MCTSNode(parent=self, prior_p=prob)
228
+
229
+ def update(self, leaf_value):
230
+ """
231
+ Updates the statistics of the node and its ancestors by backpropagating
232
+ the value obtained from the leaf node of a simulation.
233
+ """
234
+ if self.parent: self.parent.update(-leaf_value)
235
+ self.n_visits += 1; self.q_value += (leaf_value - self.q_value) / self.n_visits
236
+
237
+ def is_leaf(self):
238
+ """Checks if the node is a leaf node (i.e., has no children)."""
239
+ return len(self.children) == 0
240
+
241
+ # ===================================================================
242
+ # HongikAIPlayer 클래스
243
+ # ===================================================================
244
+ class HongikAIPlayer:
245
+ """
246
+ The 'Supreme Commander' that makes the final decision. It uses the neural
247
+ network's 'intuition' to guide the 'rational' search of the MCTS,
248
+ ultimately selecting the best move.
249
+ """
250
+ def __init__(self, cnn_transformer_model, n_simulations=100):
251
+ self.model, self.n_simulations, self.board_size = cnn_transformer_model, n_simulations, cnn_transformer_model.board_size
252
+
253
+ def _action_to_loc(self, action, board):
254
+ """Converts a policy network action index to a board location."""
255
+ return board.loc(action % self.board_size, action // self.board_size) if action < self.board_size**2 else Board.PASS_LOC
256
+
257
+ def get_best_move(self, board_state: Board, is_self_play=False):
258
+ """
259
+ Determines the best move for the current board state by running MCTS simulations.
260
+ It integrates the neural network's policy and value predictions to guide the search.
261
+ """
262
+ features = board_state.get_features()
263
+ policy_logits, value = self.model(np.expand_dims(features, 0), training=False)
264
+ intuition_probs = tf.nn.softmax(policy_logits[0]).numpy()
265
+
266
+ def is_filling_eye(loc, board):
267
+ if board.board[loc] != Board.EMPTY: return False
268
+ neighbor_colors = {board.board[loc + dloc] for dloc in board.adj if board.board[loc + dloc] != Board.WALL}
269
+ return len(neighbor_colors) == 1 and board.pla in neighbor_colors
270
+
271
+ for action, prob in enumerate(intuition_probs):
272
+ if prob > 0.001:
273
+ move_loc = self._action_to_loc(action, board_state)
274
+ if move_loc != Board.PASS_LOC and is_filling_eye(move_loc, board_state): intuition_probs[action] = 0
275
+
276
+ pass_action = self.board_size**2
277
+ pass_prob = intuition_probs[pass_action]
278
+ intuition_probs[pass_action] = 0
279
+
280
+ if board_state.turns < 100: pass_prob = 0
281
+
282
+ for action, prob in enumerate(intuition_probs):
283
+ if prob > 0 and not board_state.would_be_legal(board_state.pla, self._action_to_loc(action, board_state)): intuition_probs[action] = 0
284
+
285
+ total_prob = np.sum(intuition_probs)
286
+ if total_prob <= 1e-6: return self._action_to_loc(pass_action, board_state), MCTSNode()
287
+ intuition_probs /= total_prob
288
+
289
+ root = MCTSNode(); root.expand(intuition_probs)
290
+ for _ in range(self.n_simulations):
291
+ node, search_board = root, board_state.copy()
292
+ while not node.is_leaf():
293
+ action, node = node.select(root.n_visits)
294
+ move_loc = self._action_to_loc(action, search_board)
295
+ if not search_board.would_be_legal(search_board.pla, move_loc):
296
+ node = None; break
297
+
298
+ try:
299
+ search_board.play(search_board.pla, move_loc)
300
+ except IllegalMoveError:
301
+ parent_node = node.parent
302
+ if parent_node and action in parent_node.children:
303
+ del parent_node.children[action]
304
+
305
+ node = None
306
+ break
307
+ if node is not None:
308
+ leaf_features = search_board.get_features()
309
+ _, leaf_value_tensor = self.model(np.expand_dims(leaf_features, 0), training=False)
310
+ leaf_value = leaf_value_tensor.numpy()[0][0]
311
+ node.update(leaf_value)
312
+
313
+ if not root.children: return self._action_to_loc(pass_action, board_state), root
314
+
315
+ PASS_THRESHOLD = -0.99
316
+ best_action_node = max(root.children.values(), key=lambda n: n.n_visits)
317
+ if best_action_node.q_value < PASS_THRESHOLD and pass_prob > 0:
318
+ return self._action_to_loc(pass_action, board_state), root
319
+
320
+ if board_state.turns < 30:
321
+ if is_self_play:
322
+ if not root.children:
323
+ return self._action_to_loc(pass_action, board_state), root
324
+
325
+ child_actions = np.array(sorted(root.children.keys()))
326
+ visit_counts = np.array([root.children[action].n_visits for action in child_actions], dtype=np.float32)
327
+
328
+ temperature = 1.0
329
+ visit_counts_temp = visit_counts**(1/temperature)
330
+ if np.sum(visit_counts_temp) == 0:
331
+ probs = np.ones(len(child_actions)) / len(child_actions)
332
+ else:
333
+ probs = visit_counts_temp / np.sum(visit_counts_temp)
334
+
335
+ action = np.random.choice(child_actions, p=probs)
336
+ return self._action_to_loc(action, board_state), root
337
+
338
+ if not root.children:
339
+ return self._action_to_loc(pass_action, board_state), root
340
+
341
+ visit_counts = np.zeros_like(intuition_probs)
342
+ for action, node in root.children.items():
343
+ visit_counts[action] = node.n_visits
344
+
345
+ total_visits = np.sum(visit_counts)
346
+ reason_probs = visit_counts / total_visits if total_visits > 0 else intuition_probs
347
+
348
+ final_probs = (0.7 * intuition_probs) + (0.3 * reason_probs)
349
+
350
+ final_probs[pass_action] = -1
351
+
352
+ sorted_actions = np.argsort(final_probs)[::-1]
353
+ for action in sorted_actions:
354
+ move_loc = self._action_to_loc(action, board_state)
355
+ if board_state.would_be_legal(board_state.pla, move_loc):
356
+ return move_loc, root
357
+
358
+ return self._action_to_loc(pass_action, board_state), root