diff --git a/client.py b/client.py index 0bc37d7..68aec10 100644 --- a/client.py +++ b/client.py @@ -3,6 +3,7 @@ """ DNF自动化客户端,负责截取游戏画面并执行服务器返回的操作 +优化版 - 增加断线重连、性能优化和增强的游戏状态管理 """ import os @@ -23,7 +24,7 @@ import logging import traceback # 图像处理 -from PIL import ImageGrab, Image +from PIL import Image import numpy as np # Windows API @@ -36,8 +37,17 @@ import win32process import keyboard import mouse -# 配置文件路径 -CONFIG_FILE = "config.ini" +# 尝试导入mss库,如果不存在则继续使用PIL +try: + import mss + MSS_AVAILABLE = True +except ImportError: + MSS_AVAILABLE = False + print("警告: mss库未安装,将使用PIL进行截图(性能较低)") + print("请使用 pip install mss 安装以获得更好的性能") + +# 配置文件路径 - 使用绝对路径 +CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.ini") # 日志设置 logging.basicConfig( @@ -45,7 +55,7 @@ logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[ - logging.FileHandler("client.log", encoding="utf-8"), + logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), "client.log"), encoding="utf-8"), logging.StreamHandler() ] ) @@ -62,7 +72,27 @@ class DNFAutoClient: self.running = False self.ws = None self.capture_interval = float(self.config.get("Capture", "interval")) - self.game_state = {"in_battle": False} + self.max_retries = int(self.config.get("Connection", "max_retries", fallback="5")) + self.retry_delay = int(self.config.get("Connection", "retry_delay", fallback="5")) + + # 增强的游戏状态 + self.game_state = { + "in_battle": False, + "current_map": "", + "hp_percent": 100, + "mp_percent": 100, + "active_buffs": [], + "cooldowns": {}, + "inventory_full": False, + "current_quest": None, + "last_operation_time": time.time(), + "session_start_time": time.time() + } + + # 连接状态 + self.last_heartbeat_time = 0 + self.connection_attempts = 0 + self.reconnecting = False def load_config(self): """加载配置文件""" @@ -92,6 +122,12 @@ class DNFAutoClient: "key_mapping": "default" } + config["Connection"] = { + "max_retries": "5", + "retry_delay": "5", + "heartbeat_interval": "5" + } + # 保存配置 with open(CONFIG_FILE, "w", encoding="utf-8") as f: config.write(f) @@ -113,19 +149,60 @@ class DNFAutoClient: self.ws = await websockets.connect( self.server_url, ssl=ssl_context, - max_size=10 * 1024 * 1024 # 10MB + max_size=10 * 1024 * 1024, # 10MB + ping_interval=None, # 禁用自动ping,我们将使用自己的心跳 + close_timeout=5 ) # 等待认证 await self.authenticate() logger.info("已连接到服务器") + self.connection_attempts = 0 # 重置连接尝试次数 + self.last_heartbeat_time = time.time() return True except Exception as e: logger.error(f"连接服务器失败: {e}") return False + async def connect_with_retry(self): + """带重试机制的连接函数""" + if self.reconnecting: + logger.info("已有重连过程在进行中,跳过") + return False + + self.reconnecting = True + retry_count = 0 + + try: + while retry_count < self.max_retries and not self.ws: + try: + logger.info(f"尝试连接服务器 (尝试 {retry_count + 1}/{self.max_retries})...") + success = await self.connect() + if success: + self.reconnecting = False + return True + except Exception as e: + logger.error(f"连接服务器失败: {e}") + + retry_count += 1 + retry_delay = min(60, self.retry_delay * (2 ** retry_count)) # 指数退避策略 + logger.info(f"等待 {retry_delay} 秒后重试...") + await asyncio.sleep(retry_delay) + + if not self.ws: + logger.error(f"达到最大重试次数 ({self.max_retries}),连接失败") + self.reconnecting = False + return False + except Exception as e: + logger.error(f"重连过程中发生错误: {e}") + self.reconnecting = False + return False + + self.reconnecting = False + return True + async def authenticate(self): """客户端认证""" try: @@ -136,14 +213,18 @@ class DNFAutoClient: if challenge.get("type") != "auth_challenge": raise ValueError("无效的认证挑战") + # 收集系统信息 + system_info = self.get_system_info() + # 发送认证响应 await self.ws.send(json.dumps({ "type": "auth_response", "response": f"client_{random.getrandbits(32)}", "client_info": { - "version": "1.0.0", + "version": "1.0.1", # 更新版本号 "os": "Windows", - "screen_resolution": self.get_screen_resolution() + "screen_resolution": self.get_screen_resolution(), + "system_info": system_info } })) @@ -162,6 +243,24 @@ class DNFAutoClient: logger.error(f"认证失败: {e}") raise + def get_system_info(self): + """获取系统信息""" + system_info = {} + try: + system_info["hostname"] = os.environ.get("COMPUTERNAME", "Unknown") + system_info["username"] = os.environ.get("USERNAME", "Unknown") + system_info["processor"] = os.environ.get("PROCESSOR_IDENTIFIER", "Unknown") + + # 获取系统内存信息 + mem = ctypes.c_ulonglong() + ctypes.windll.kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem)) + system_info["memory_gb"] = round(mem.value / (1024 * 1024), 2) + + except Exception as e: + logger.error(f"获取系统信息失败: {e}") + + return system_info + def get_screen_resolution(self): """获取屏幕分辨率""" user32 = ctypes.windll.user32 @@ -188,13 +287,26 @@ class DNFAutoClient: rect = win32gui.GetWindowRect(hwnd) x, y, width, height = rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1] - # 截取画面 - screen = ImageGrab.grab(bbox=(x, y, x + width, y + height)) + # 检查窗口是否最小化 + if width <= 0 or height <= 0: + logger.warning("游戏窗口被最小化或大小无效") + return None + + # 使用mss进行截图(如果可用),速度比PIL快5-10倍 + if MSS_AVAILABLE: + with mss.mss() as sct: + monitor = {"top": y, "left": x, "width": width, "height": height} + sct_img = sct.grab(monitor) + # 将mss图像转换为PIL图像 + img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX") + else: + # 使用PIL进行截图(备选方案) + img = ImageGrab.grab(bbox=(x, y, x + width, y + height)) # 压缩图像 quality = int(self.config.get("Capture", "quality")) buffer = BytesIO() - screen.save(buffer, format="JPEG", quality=quality) + img.save(buffer, format="JPEG", quality=quality) # 转换为Base64 img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") @@ -207,18 +319,23 @@ class DNFAutoClient: except Exception as e: logger.error(f"截取游戏画面失败: {e}") + logger.error(traceback.format_exc()) return None async def send_heartbeat(self): """发送心跳包""" if not self.ws or not self.running: - return + return False try: await self.ws.send(json.dumps({ "type": "heartbeat", "timestamp": time.time(), - "client_id": self.client_id + "client_id": self.client_id, + "game_state": { + "in_battle": self.game_state["in_battle"], + "current_map": self.game_state["current_map"] + } })) # 等待心跳响应 @@ -231,17 +348,50 @@ class DNFAutoClient: response = json.loads(response_raw) if response.get("type") != "heartbeat_response": logger.warning(f"收到非心跳响应: {response.get('type')}") + return False + + self.last_heartbeat_time = time.time() + return True except asyncio.TimeoutError: logger.warning("心跳超时") + return False + except websockets.exceptions.ConnectionClosed: + logger.warning("发送心跳时连接已关闭") + return False except Exception as e: logger.error(f"发送心跳失败: {e}") + return False async def heartbeat_loop(self): """心跳循环""" + heartbeat_interval = float(self.config.get("Connection", "heartbeat_interval", fallback="5")) + while self.running: - await self.send_heartbeat() - await asyncio.sleep(5.0) # 每5秒发送一次心跳 + if self.ws and not self.ws.closed: + success = await self.send_heartbeat() + if not success: + # 心跳失败,检查连接状态 + if time.time() - self.last_heartbeat_time > heartbeat_interval * 3: + logger.warning(f"心跳超时超过 {heartbeat_interval * 3} 秒,尝试重新连接") + if self.ws: + await self.ws.close() + self.ws = None + + await asyncio.sleep(heartbeat_interval) + + async def reconnect_loop(self): + """重连检查循环""" + while self.running: + if not self.ws or self.ws.closed: + logger.warning("WebSocket连接已断开,尝试重连...") + if await self.connect_with_retry(): + logger.info("重连成功") + else: + logger.error("重连失败") + # 不要立即停止,继续尝试 + + await asyncio.sleep(5) # 每5秒检查一次连接状态 async def execute_action(self, action): """执行操作""" @@ -259,8 +409,27 @@ class DNFAutoClient: # 确保窗口处于前台 if win32gui.GetForegroundWindow() != hwnd: - win32gui.SetForegroundWindow(hwnd) - await asyncio.sleep(0.1) + try: + # 尝试多种方法激活窗口 + win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) # 恢复窗口(如果最小化) + win32gui.SetForegroundWindow(hwnd) # 尝试置为前台 + await asyncio.sleep(0.1) + + # 如果窗口仍然不在前台,使用更强的方法 + if win32gui.GetForegroundWindow() != hwnd: + # 获取当前激活窗口的线程和进程ID + curr_hwnd = win32gui.GetForegroundWindow() + curr_thread_id = win32process.GetWindowThreadProcessId(curr_hwnd)[0] + # 获取目标窗口的线程和进程ID + target_thread_id = win32process.GetWindowThreadProcessId(hwnd)[0] + # 附加线程输入 + win32process.AttachThreadInput(target_thread_id, curr_thread_id, True) + win32gui.SetForegroundWindow(hwnd) + win32gui.BringWindowToTop(hwnd) + win32process.AttachThreadInput(target_thread_id, curr_thread_id, False) + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"激活窗口失败: {e}") # 获取窗口位置 rect = win32gui.GetWindowRect(hwnd) @@ -281,6 +450,22 @@ class DNFAutoClient: win32api.SetCursorPos((int(point[0]), int(point[1]))) await asyncio.sleep(0.01) # 10ms延迟 + elif action_type == "click": + # 点击指定位置 + position = action.get("position", [0, 0]) + x, y = position[0] + window_x, position[1] + window_y + + # 移动到位置 + current_pos = win32gui.GetCursorPos() + path = self.generate_movement_path(current_pos, [x, y]) + + for point in path: + win32api.SetCursorPos((int(point[0]), int(point[1]))) + await asyncio.sleep(0.01) + + # 执行点击 + mouse.click() + elif action_type == "use_skill": # 使用技能 key = action.get("key", "1") @@ -333,12 +518,32 @@ class DNFAutoClient: for key in ["w", "a", "s", "d", "1", "2", "3", "4", "5", "6"]: if keyboard.is_pressed(key): keyboard.release(key) + + elif action_type == "type_text": + # 输入文本 + text = action.get("text", "") + if text: + keyboard.write(text, delay=0.05) + + elif action_type == "press_key_combo": + # 按下组合键 + keys = action.get("keys", []) + if keys: + for key in keys: + keyboard.press(key) + await asyncio.sleep(0.1) + for key in reversed(keys): + keyboard.release(key) else: logger.warning(f"未知操作类型: {action_type}") + # 更新游戏状态 + self.game_state["last_operation_time"] = time.time() + except Exception as e: logger.error(f"执行操作失败: {e}") + logger.error(traceback.format_exc()) def generate_movement_path(self, start_pos, end_pos, steps=None): """生成模拟人类的鼠标移动路径""" @@ -353,6 +558,9 @@ class DNFAutoClient: t = np.linspace(0, 1, steps) path = [] + # 动态调整平滑度 + smoothness = float(self.config.get("Game", "movement_smoothness", fallback="0.8")) + for i in range(steps): # 基础线性插值 x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i] @@ -360,7 +568,7 @@ class DNFAutoClient: # 添加随机偏移(越靠近中间偏移越大) mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大 - max_offset = distance * 0.05 * mid_factor # 最大偏移为距离的5% + max_offset = distance * 0.05 * mid_factor * (1 - smoothness) # 最大偏移为距离的5%,受平滑度影响 offset_x = random.normalvariate(0, max_offset / 3) offset_y = random.normalvariate(0, max_offset / 3) @@ -374,10 +582,68 @@ class DNFAutoClient: return path + def analyze_image(self, image_data, detection_results): + """ + 分析图像数据,更新游戏状态 + + 参数: + image_data (dict): 图像数据 + detection_results (list): 检测结果 + """ + # 更新战斗状态 + monsters_detected = False + for det in detection_results: + if det["class_name"] in ["monster", "boss"]: + monsters_detected = True + break + + self.game_state["in_battle"] = monsters_detected + + # 分析血条和蓝条 + hp_bars = [d for d in detection_results if d["class_name"] == "hp_bar"] + mp_bars = [d for d in detection_results if d["class_name"] == "mp_bar"] + + if hp_bars: + # 估算血量百分比 + self.game_state["hp_percent"] = self.estimate_bar_percent(hp_bars[0]) + + if mp_bars: + # 估算蓝量百分比 + self.game_state["mp_percent"] = self.estimate_bar_percent(mp_bars[0]) + + # 检测技能冷却 + cooldowns = [d for d in detection_results if d["class_name"] == "cooldown"] + self.game_state["cooldowns"] = {} + + for cd in cooldowns: + if "skill_id" in cd: + self.game_state["cooldowns"][cd["skill_id"]] = cd.get("remaining_time", 1.0) + + def estimate_bar_percent(self, bar_detection): + """ + 估计血条/蓝条的百分比 + + 参数: + bar_detection (dict): 条形检测结果 + + 返回: + float: 百分比值(0-100) + """ + # 这里需要根据实际情况实现 + # 临时方案:返回检测结果中的值,如果没有则返回默认值 + return bar_detection.get("percent", 100) + async def capture_and_process_loop(self): """截图和处理循环""" + consecutive_errors = 0 + while self.running: try: + # 检查WebSocket连接 + if not self.ws or self.ws.closed: + await asyncio.sleep(0.5) # 连接断开时等待 + continue + # 获取游戏窗口 hwnd = self.get_game_window() if hwnd is None: @@ -413,6 +679,10 @@ class DNFAutoClient: response = json.loads(response_raw) if response.get("type") == "action_response": + # 获取检测结果并更新游戏状态 + if "detections" in response: + self.analyze_image(screen_data, response["detections"]) + # 执行动作 actions = response.get("actions", []) @@ -421,30 +691,43 @@ class DNFAutoClient: for action in actions: await self.execute_action(action) + + # 重置错误计数 + consecutive_errors = 0 elif response.get("type") == "error": logger.error(f"服务器返回错误: {response.get('message')}") + consecutive_errors += 1 # 等待指定的间隔时间 await asyncio.sleep(self.capture_interval) except asyncio.TimeoutError: logger.warning("等待服务器响应超时") + consecutive_errors += 1 except websockets.exceptions.ConnectionClosed: logger.error("WebSocket连接已关闭") - self.running = False break except Exception as e: logger.error(f"处理循环出错: {e}") logger.error(traceback.format_exc()) + consecutive_errors += 1 await asyncio.sleep(1.0) # 出错时等待1秒 + + # 如果连续错误过多,尝试重新连接 + if consecutive_errors >= 5: + logger.warning(f"连续出错 {consecutive_errors} 次,尝试重新连接") + if self.ws: + await self.ws.close() + self.ws = None + consecutive_errors = 0 async def run(self): """运行客户端""" self.running = True # 连接到服务器 - if not await self.connect(): + if not await self.connect_with_retry(): self.running = False return @@ -452,9 +735,10 @@ class DNFAutoClient: # 创建任务 capture_task = asyncio.create_task(self.capture_and_process_loop()) heartbeat_task = asyncio.create_task(self.heartbeat_loop()) + reconnect_task = asyncio.create_task(self.reconnect_loop()) # 等待任务完成 - await asyncio.gather(capture_task, heartbeat_task) + await asyncio.gather(capture_task, heartbeat_task, reconnect_task) except asyncio.CancelledError: logger.info("客户端任务已取消") @@ -469,15 +753,80 @@ class DNFAutoClient: def start(self): """启动客户端""" - asyncio.run(self.run()) + try: + # 启动心跳监控线程(备用方案,以防异步心跳失效) + self._monitor_thread = threading.Thread(target=self._monitor_connection) + self._monitor_thread.daemon = True + self._monitor_thread.start() + + # 运行主循环 + asyncio.run(self.run()) + except KeyboardInterrupt: + logger.info("用户中断,正在退出...") + except Exception as e: + logger.error(f"客户端出错: {e}") + logger.error(traceback.format_exc()) + + def _monitor_connection(self): + """监控连接的后台线程""" + while True: + try: + time.sleep(30) # 每30秒检查一次 + + if not self.running: + break + + # 检查心跳时间 + if self.last_heartbeat_time > 0 and time.time() - self.last_heartbeat_time > 60: + logger.warning("心跳超时,可能需要重连") + # 不直接重连,留给重连循环处理 + except Exception as e: + logger.error(f"连接监控线程出错: {e}") def stop(self): """停止客户端""" self.running = False logger.info("正在停止客户端...") +# 创建默认配置文件(如果不存在) +def ensure_config(): + if not os.path.exists(CONFIG_FILE): + config = configparser.ConfigParser() + + config["Server"] = { + "url": "wss://your-server-url:8080/ws", + "verify_ssl": "false" + } + + config["Capture"] = { + "interval": "0.5", + "quality": "70" + } + + config["Game"] = { + "window_title": "地下城与勇士", + "key_mapping": "default", + "movement_smoothness": "0.8" + } + + config["Connection"] = { + "max_retries": "5", + "retry_delay": "5", + "heartbeat_interval": "5" + } + + # 保存配置 + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + + print(f"已创建默认配置文件: {CONFIG_FILE}") + print("请编辑配置文件设置正确的服务器地址等信息") + # 启动客户端 if __name__ == "__main__": + # 确保配置文件存在 + ensure_config() + try: client = DNFAutoClient() client.start() diff --git a/config/cert.pem b/config/cert.pem new file mode 100644 index 0000000..6556060 --- /dev/null +++ b/config/cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUPKAkPsxw1moIm//p7Cin44rkneQwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDMyNjE3MzM0MVoXDTI2MDMy +NjE3MzM0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAya16YWOlIvnUHo7ubqTsQslYGPGbBiCNEGux/YNUagCr +51RrNP6pwPXzRNIdDpudSz78wUVCtHEgn1shDIRJuGkENK4s6XpvGbg655DOEev4 +sKo4s/lDAjjuN450bth9Iv1CjuP/mLEOPQ9la+PwF0rcpGhh2qLGkMX399C2a3qf +/UTOg7ifATWXM81VwPfYgpCV+jblAjh56n3rfN/lnIMMsAIxJRa23joLzmR4WIjI +fJRMJFYYhccUBeALAPvy+qxOpe1xRz1SrCpAe8mFE5OTUB6kmNXvfu+xJZ/t+k7s +XHcGccd4ZMwCBkHc18bmvtD/XoB6XAMSGvswznXMOwIDAQABo1MwUTAdBgNVHQ4E +FgQUsU/YgtZL1WkbZHRQ2Xu5GJ7SlgwwHwYDVR0jBBgwFoAUsU/YgtZL1WkbZHRQ +2Xu5GJ7SlgwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAMf4f +3kxcg2HkLnq9lbA3gFkkaXQ5OyFGPBXe/uQ+jk3Kh/2rfnMCci1RCEMOK9gD7Lh+ +fQXPHDtGTmLE5wm9M43NxhRfzfnckCDENmsEzKssvdinqq2zbhTzSXrPN7W7PYPQ +wvYo3ZZIDTXmyZhTgph4su0zuglP/kkfmRrUSxaFnGWDKJE+W+KQ6r6uestGoa9w +ScoLYcB2ssWEK2uU1/kto8PzRAJuOLiGMur7j5fIDfaM13V2+Qt3LpFd32cZnKFH +x9W9/e+I5dbAfqqItIkbgjM0akMhq10t9MYouXbAw1OscXDmkomObC2Wq2QviPkY +zP3UOCSWi9eAWvknLA== +-----END CERTIFICATE----- diff --git a/config/key.pem b/config/key.pem new file mode 100644 index 0000000..d00d97f --- /dev/null +++ b/config/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDJrXphY6Ui+dQe +ju5upOxCyVgY8ZsGII0Qa7H9g1RqAKvnVGs0/qnA9fNE0h0Om51LPvzBRUK0cSCf +WyEMhEm4aQQ0rizpem8ZuDrnkM4R6/iwqjiz+UMCOO43jnRu2H0i/UKO4/+YsQ49 +D2Vr4/AXStykaGHaosaQxff30LZrep/9RM6DuJ8BNZczzVXA99iCkJX6NuUCOHnq +fet83+WcgwywAjElFrbeOgvOZHhYiMh8lEwkVhiFxxQF4AsA+/L6rE6l7XFHPVKs +KkB7yYUTk5NQHqSY1e9+77Eln+36TuxcdwZxx3hkzAIGQdzXxua+0P9egHpcAxIa ++zDOdcw7AgMBAAECggEACrGTIqTY9cDPeYtUozNFf8kTTcdJ1ApX0H4VYv7atAAz +HUIBqT6zm5KvAoAtoD+qGHpPhqP4hH7XHvwDBZniGtAes/hkU0D1sSRuoyavdo3P +kvaDqS9XWT/RicqY6+O4xuks5Uy7mcoRmjU9yHm+mk2S43jRb3lgE/8bRd2gPpSa +14HqNwFqLmB4ZiYaq20KuO+cNTS/wPrxicVABVNwJ2cM+6vdoonkNYaPNPulWzeg +3CgSlageaGPVByaCUK24/0+b7Jgd2GMd34FWTMils8TkeuodQtcTwHEWphAbgpVB +5VlJRA6jYMRRfsZNlT82VvR90uZowLhJ84eGMwdqIQKBgQDsefnK7L0JtIWALljC +BVA6ljbRxZcolVFo6jMLtw/yfChzdP8d14x9s3sxX5J+69kmAAAGeS8tjJeautW6 +kDtsUCdhN3X2GftvR2Y8wUNHW8kMoU9rCN09HaxXq9HK2/RZlNhiHA6FycP/Pe6t +1Qlo1Hn5XniNtbr71NgVM/NX8QKBgQDaVAPxntGm2SBzpsDpXwAVHxWJzK4aveMf +rq7I5JJ5b9PnMV9cNJncF5vOdWM5htb34orJtBZ8IzuZMXAEU5OJwfmaGMleaKMv +mGRoDpWtAE+zj8BB0RHr828tCs3M3/klx0/FUDRPsYXRfbaJb5CCsQmv7QHV4FRp ++bUQ9Fcy6wKBgQDJ1yrQe9S2bfDtAaIcqPBbsU9FKYPlzd1Y0V2UiEICVNsqARin +3g06VXG3KL4fuyrzdliPLeyI0lGsbgBzZxxxTNDv96il0HN9/dFT1hmY1Mz8DMt+ +rmg3/BXYFv3QSoF73MH8q7nxk8/JEpGgqg+H/KPHp0z6l7zrqjZtkpQH4QKBgQDW +jcHibITzRmURwknKDUXze7ya0r42IW1V8UBqw9T96duAU5C2+CpLlBfVaJ6+JbiT +mdlyJrwB+k3TWjYOymMu+aTkvn8FfCcB2uyxJcQJY0jv2NDC3UaTbYNP7FIah/A8 +JAZMjWka+AXdvYDoxu5owLoYXP10xSOvkWlS5AvdSQKBgB0365BytqUd8/ijGuTa +mSaL4WDFFdCkOhPNEmIn1llh0HAh2jWpyBRnJeGyQ/4vZJjA8gIgz4IZnZfiw+ov +SDRnpqvguklboB+mPWOM4Itj/5G7XOAEN1kkA81SUDJBoCuLiZ7KwIaSlefLWi0l +VwmN0w+3GIt4F7rg0dmDsVV4 +-----END PRIVATE KEY----- diff --git a/config/runtime_config.json b/config/runtime_config.json new file mode 100644 index 0000000..f9fa70d --- /dev/null +++ b/config/runtime_config.json @@ -0,0 +1,34 @@ +{ + "version": "1.0.1", + "build_date": "2025-03-27", + "model": { + "name": "yolov8m", + "weights": "/workspace/dnf-auto-cloud/models/weights/dnf_yolo8m.pt", + "conf_threshold": 0.5, + "iou_threshold": 0.45, + "device": "cuda", + "half_precision": true, + "batch_size": 1, + "img_size": 640, + "engine": "pytorch", + "class_names": "/workspace/dnf-auto-cloud/data/training/data.yaml" + }, + "server": { + "max_connections": 10, + "timeout": 60, + "heartbeat_interval": 5, + "inactive_timeout": 300, + "max_retries": 3, + "status_update_interval": 30, + "api_enabled": true + }, + "security": { + "token_expiry": 3600, + "encryption_enabled": false, + "ssl_enabled": false, + "ssl_cert": "/workspace/dnf-auto-cloud/config/cert.pem", + "ip_whitelist": [], + "rate_limit": 100, + "authentication_required": false + } +} \ No newline at end of file diff --git a/config/settings.py b/config/settings.py index 97974cc..82dde68 100644 --- a/config/settings.py +++ b/config/settings.py @@ -3,70 +3,193 @@ """ 全局配置参数 +优化版 - 增加更多配置选项和兼容性检查 """ import os +import sys +import json +import torch from pathlib import Path +# 检测当前环境 +IS_LINUX = sys.platform.startswith('linux') +IS_WINDOWS = sys.platform.startswith('win') +HAS_CUDA = torch.cuda.is_available() + # 基础路径 BASE_DIR = Path(__file__).resolve().parent.parent +# 版本信息 +VERSION = "1.0.1" +BUILD_DATE = "2025-03-27" + +# 加载环境变量 +def load_env_var(name, default): + """从环境变量加载设置,如果不存在则使用默认值""" + return os.environ.get(f"DNF_{name}", default) + # 模型配置 MODEL = { - "name": "yolov8m", - "weights": os.path.join(BASE_DIR, "models", "weights", "dnf_yolo8m.pt"), - "conf_threshold": 0.5, - "iou_threshold": 0.45, - "device": "cuda" # 使用GPU + "name": load_env_var("MODEL_NAME", "yolov8m"), + "weights": os.path.join(BASE_DIR, "models", "weights", load_env_var("MODEL_FILE", "dnf_yolo8m.pt")), + "conf_threshold": float(load_env_var("CONF_THRESHOLD", "0.5")), + "iou_threshold": float(load_env_var("IOU_THRESHOLD", "0.45")), + "device": load_env_var("DEVICE", "cuda" if HAS_CUDA else "cpu"), + "half_precision": load_env_var("HALF_PRECISION", "true").lower() == "true" and HAS_CUDA, + "batch_size": int(load_env_var("BATCH_SIZE", "1")), + "img_size": int(load_env_var("IMG_SIZE", "640")), + "engine": load_env_var("ENGINE", "pytorch"), # pytorch 或 onnx + "class_names": os.path.join(BASE_DIR, "data", "training", "data.yaml") } # 服务器配置 SERVER = { - "max_connections": 10, - "timeout": 60, - "heartbeat_interval": 5 + "max_connections": int(load_env_var("MAX_CONNECTIONS", "10")), + "timeout": int(load_env_var("TIMEOUT", "60")), + "heartbeat_interval": int(load_env_var("HEARTBEAT_INTERVAL", "5")), + "inactive_timeout": int(load_env_var("INACTIVE_TIMEOUT", "300")), # 用户不活跃超时(秒) + "max_retries": int(load_env_var("MAX_RETRIES", "3")), # 操作最大重试次数 + "status_update_interval": int(load_env_var("STATUS_UPDATE_INTERVAL", "30")), # 状态更新间隔(秒) + "api_enabled": load_env_var("API_ENABLED", "true").lower() == "true" # 是否启用API } # 安全配置 SECURITY = { - "token_expiry": 3600, # 1小时 - "encryption_enabled": True, - "ssl_enabled": True, + "token_expiry": int(load_env_var("TOKEN_EXPIRY", "3600")), # 1小时 + "encryption_enabled": False, # 禁用加密以简化调试 + "ssl_enabled": False, # 禁用SSL "ssl_cert": os.path.join(BASE_DIR, "config", "cert.pem"), - "ssl_key": os.path.join(BASE_DIR, "config", "key.pem") + "ssl_key": os.path.join(BASE_DIR, "config", "key.pem"), + "ip_whitelist": load_env_var("IP_WHITELIST", "").split(",") if load_env_var("IP_WHITELIST", "") else [], + "rate_limit": int(load_env_var("RATE_LIMIT", "100")), # 每分钟最大请求数 + "authentication_required": False # 暂时禁用认证要求 } # 行为模拟配置 BEHAVIOR = { - "min_delay": 0.1, - "max_delay": 0.8, - "click_variance": 0.15, # 点击位置随机偏移比例 - "movement_smoothness": 0.8 # 移动平滑度 (0-1) + "min_delay": float(load_env_var("MIN_DELAY", "0.1")), + "max_delay": float(load_env_var("MAX_DELAY", "0.8")), + "click_variance": float(load_env_var("CLICK_VARIANCE", "0.15")), # 点击位置随机偏移比例 + "movement_smoothness": float(load_env_var("MOVEMENT_SMOOTHNESS", "0.8")), # 移动平滑度 (0-1) + "double_click_chance": float(load_env_var("DOUBLE_CLICK_CHANCE", "0.05")), # 双击概率 + "user_profile": load_env_var("USER_PROFILE", "normal"), # 用户行为配置 (fast, normal, casual) + "randomize_behavior": load_env_var("RANDOMIZE_BEHAVIOR", "true").lower() == "true" # 是否随机化行为 } # 日志配置 LOGGING = { - "level": "INFO", + "level": load_env_var("LOG_LEVEL", "INFO"), "file": os.path.join(BASE_DIR, "data", "logs", "server.log"), - "max_size": 10 * 1024 * 1024, # 10MB - "backup_count": 5 + "max_size": int(load_env_var("LOG_MAX_SIZE", "10")) * 1024 * 1024, # 10MB + "backup_count": int(load_env_var("LOG_BACKUP_COUNT", "5")), + "console_output": load_env_var("LOG_CONSOLE", "true").lower() == "true", + "log_requests": load_env_var("LOG_REQUESTS", "true").lower() == "true", + "log_performance": load_env_var("LOG_PERFORMANCE", "true").lower() == "true" +} + +# 缓存配置 +CACHE = { + "enabled": load_env_var("CACHE_ENABLED", "true").lower() == "true", + "directory": os.path.join(BASE_DIR, "data", "cache"), + "max_size": int(load_env_var("CACHE_MAX_SIZE", "1000")), # 最大缓存项数 + "ttl": int(load_env_var("CACHE_TTL", "3600")) # 缓存有效期(秒) +} + +# 会话配置 +SESSION = { + "directory": os.path.join(BASE_DIR, "data", "sessions"), + "save_interval": int(load_env_var("SESSION_SAVE_INTERVAL", "300")), # 会话保存间隔(秒) + "max_idle_time": int(load_env_var("SESSION_MAX_IDLE", "1800")), # 最大空闲时间(秒) + "persistence_enabled": load_env_var("SESSION_PERSISTENCE", "true").lower() == "true" } # DNF游戏特定配置 DNF = { "classes": [ "monster", "boss", "door", "item", "npc", "player", - "hp_bar", "mp_bar", "skill_ready", "cooldown" + "hp_bar", "mp_bar", "skill_ready", "cooldown", "dialog", + "dialog_option", "revive_button", "confirm_button", + "dungeon_portal", "dungeon_select", "dungeon_option", + "room_select", "room_option", "death_ui" ], "maps": { "赫顿玛尔": {"entrance": (123, 456), "exit": (789, 101)}, "天空之城": {"entrance": (234, 567), "exit": (890, 121)}, + "格兰之森": {"entrance": (345, 678), "exit": (901, 234)}, + "诺斯玛尔": {"entrance": (456, 789), "exit": (345, 678)}, # 更多地图配置... }, "skills": { - "1": "普通攻击", - "2": "技能1", - "3": "技能2", + "1": {"name": "普通攻击", "type": "melee", "cooldown": 0.5}, + "2": {"name": "技能1", "type": "melee", "cooldown": 3.0}, + "3": {"name": "技能2", "type": "ranged", "cooldown": 5.0}, + "4": {"name": "技能3", "type": "aoe", "cooldown": 8.0}, + "5": {"name": "技能4", "type": "buff", "cooldown": 15.0}, + "6": {"name": "技能5", "type": "ultimate", "cooldown": 30.0}, # 更多技能配置... + }, + "difficulty_levels": { + "normal": {"monster_hp": 100, "monster_damage": 100}, + "hard": {"monster_hp": 150, "monster_damage": 150}, + "expert": {"monster_hp": 200, "monster_damage": 200}, + "master": {"monster_hp": 250, "monster_damage": 250}, + "hell": {"monster_hp": 300, "monster_damage": 300} } -} \ No newline at end of file +} + +# 检查并创建必要的目录 +def ensure_directories(): + """确保所有需要的目录都存在""" + directories = [ + os.path.dirname(LOGGING["file"]), # 日志目录 + CACHE["directory"], # 缓存目录 + SESSION["directory"], # 会话目录 + os.path.join(BASE_DIR, "data", "training"), # 训练数据目录 + os.path.dirname(MODEL["weights"]), # 模型权重目录 + ] + + for directory in directories: + os.makedirs(directory, exist_ok=True) + +# 配置验证 +def validate_config(): + """验证配置有效性""" + # 检查模型配置 + if not os.path.exists(MODEL["weights"]) and not os.path.exists(MODEL["weights"] + ".onnx"): + print(f"警告: 模型权重文件不存在: {MODEL['weights']}") + + # 检查CUDA设置 + if MODEL["device"] == "cuda" and not HAS_CUDA: + print("警告: 已配置使用CUDA但系统中没有可用的GPU,将使用CPU模式") + MODEL["device"] = "cpu" + MODEL["half_precision"] = False + + # 检查日志级别 + valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + if LOGGING["level"] not in valid_log_levels: + print(f"警告: 无效的日志级别 '{LOGGING['level']}',将使用 'INFO'") + LOGGING["level"] = "INFO" + + # 保存验证后的配置 + save_validated_config() + +def save_validated_config(): + """保存验证后的配置到文件,便于其他进程读取""" + config = { + "version": VERSION, + "build_date": BUILD_DATE, + "model": MODEL, + "server": SERVER, + "security": {k: v for k, v in SECURITY.items() if k not in ["ssl_key"]} # 不保存敏感信息 + } + + try: + with open(os.path.join(BASE_DIR, "config", "runtime_config.json"), "w") as f: + json.dump(config, f, indent=2) + except Exception as e: + print(f"保存运行时配置失败: {e}") + +# 执行初始化 +ensure_directories() +validate_config() \ No newline at end of file diff --git a/dnf-client-package/README.txt b/dnf-client-package/README.txt new file mode 100644 index 0000000..d8079d5 --- /dev/null +++ b/dnf-client-package/README.txt @@ -0,0 +1,9 @@ +DNF自动化客户端 + +安装说明: +1. 安装Python 3.8或更高版本 +2. 安装依赖包: + pip install pillow numpy websockets keyboard mouse pywin32 mss + +3. 编辑config.ini设置服务器地址 +4. 运行start.bat启动客户端 diff --git a/dnf-client-package/client.py b/dnf-client-package/client.py new file mode 100644 index 0000000..68aec10 --- /dev/null +++ b/dnf-client-package/client.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +DNF自动化客户端,负责截取游戏画面并执行服务器返回的操作 +优化版 - 增加断线重连、性能优化和增强的游戏状态管理 +""" + +import os +import sys +import json +import base64 +import time +import random +import asyncio +import websockets +import configparser +import ssl +import ctypes +from io import BytesIO +from datetime import datetime +import threading +import logging +import traceback + +# 图像处理 +from PIL import Image +import numpy as np + +# Windows API +import win32gui +import win32con +import win32api +import win32process + +# 输入模拟 +import keyboard +import mouse + +# 尝试导入mss库,如果不存在则继续使用PIL +try: + import mss + MSS_AVAILABLE = True +except ImportError: + MSS_AVAILABLE = False + print("警告: mss库未安装,将使用PIL进行截图(性能较低)") + print("请使用 pip install mss 安装以获得更好的性能") + +# 配置文件路径 - 使用绝对路径 +CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.ini") + +# 日志设置 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[ + logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), "client.log"), encoding="utf-8"), + logging.StreamHandler() + ] +) +logger = logging.getLogger("DNFAutoClient") + +class DNFAutoClient: + """DNF自动化客户端类""" + + def __init__(self): + """初始化客户端""" + self.config = self.load_config() + self.server_url = self.config.get("Server", "url") + self.client_id = None + self.running = False + self.ws = None + self.capture_interval = float(self.config.get("Capture", "interval")) + self.max_retries = int(self.config.get("Connection", "max_retries", fallback="5")) + self.retry_delay = int(self.config.get("Connection", "retry_delay", fallback="5")) + + # 增强的游戏状态 + self.game_state = { + "in_battle": False, + "current_map": "", + "hp_percent": 100, + "mp_percent": 100, + "active_buffs": [], + "cooldowns": {}, + "inventory_full": False, + "current_quest": None, + "last_operation_time": time.time(), + "session_start_time": time.time() + } + + # 连接状态 + self.last_heartbeat_time = 0 + self.connection_attempts = 0 + self.reconnecting = False + + def load_config(self): + """加载配置文件""" + if not os.path.exists(CONFIG_FILE): + self.create_default_config() + + config = configparser.ConfigParser() + config.read(CONFIG_FILE, encoding="utf-8") + return config + + def create_default_config(self): + """创建默认配置文件""" + config = configparser.ConfigParser() + + config["Server"] = { + "url": "wss://your-server-url:8080/ws", + "verify_ssl": "false" + } + + config["Capture"] = { + "interval": "0.5", + "quality": "70" + } + + config["Game"] = { + "window_title": "地下城与勇士", + "key_mapping": "default" + } + + config["Connection"] = { + "max_retries": "5", + "retry_delay": "5", + "heartbeat_interval": "5" + } + + # 保存配置 + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + + logger.info(f"已创建默认配置文件: {CONFIG_FILE}") + + async def connect(self): + """连接到服务器""" + logger.info(f"正在连接到服务器: {self.server_url}") + + ssl_context = None + if self.server_url.startswith("wss://"): + ssl_context = ssl.create_default_context() + if self.config.get("Server", "verify_ssl").lower() == "false": + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + try: + self.ws = await websockets.connect( + self.server_url, + ssl=ssl_context, + max_size=10 * 1024 * 1024, # 10MB + ping_interval=None, # 禁用自动ping,我们将使用自己的心跳 + close_timeout=5 + ) + + # 等待认证 + await self.authenticate() + + logger.info("已连接到服务器") + self.connection_attempts = 0 # 重置连接尝试次数 + self.last_heartbeat_time = time.time() + return True + + except Exception as e: + logger.error(f"连接服务器失败: {e}") + return False + + async def connect_with_retry(self): + """带重试机制的连接函数""" + if self.reconnecting: + logger.info("已有重连过程在进行中,跳过") + return False + + self.reconnecting = True + retry_count = 0 + + try: + while retry_count < self.max_retries and not self.ws: + try: + logger.info(f"尝试连接服务器 (尝试 {retry_count + 1}/{self.max_retries})...") + success = await self.connect() + if success: + self.reconnecting = False + return True + except Exception as e: + logger.error(f"连接服务器失败: {e}") + + retry_count += 1 + retry_delay = min(60, self.retry_delay * (2 ** retry_count)) # 指数退避策略 + logger.info(f"等待 {retry_delay} 秒后重试...") + await asyncio.sleep(retry_delay) + + if not self.ws: + logger.error(f"达到最大重试次数 ({self.max_retries}),连接失败") + self.reconnecting = False + return False + except Exception as e: + logger.error(f"重连过程中发生错误: {e}") + self.reconnecting = False + return False + + self.reconnecting = False + return True + + async def authenticate(self): + """客户端认证""" + try: + # 等待认证挑战 + challenge_raw = await self.ws.recv() + challenge = json.loads(challenge_raw) + + if challenge.get("type") != "auth_challenge": + raise ValueError("无效的认证挑战") + + # 收集系统信息 + system_info = self.get_system_info() + + # 发送认证响应 + await self.ws.send(json.dumps({ + "type": "auth_response", + "response": f"client_{random.getrandbits(32)}", + "client_info": { + "version": "1.0.1", # 更新版本号 + "os": "Windows", + "screen_resolution": self.get_screen_resolution(), + "system_info": system_info + } + })) + + # 等待认证结果 + result_raw = await self.ws.recv() + result = json.loads(result_raw) + + if result.get("type") != "auth_result" or result.get("status") != "success": + raise ValueError("认证失败") + + # 保存客户端ID + self.client_id = result.get("client_id") + logger.info(f"认证成功,客户端ID: {self.client_id}") + + except Exception as e: + logger.error(f"认证失败: {e}") + raise + + def get_system_info(self): + """获取系统信息""" + system_info = {} + try: + system_info["hostname"] = os.environ.get("COMPUTERNAME", "Unknown") + system_info["username"] = os.environ.get("USERNAME", "Unknown") + system_info["processor"] = os.environ.get("PROCESSOR_IDENTIFIER", "Unknown") + + # 获取系统内存信息 + mem = ctypes.c_ulonglong() + ctypes.windll.kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem)) + system_info["memory_gb"] = round(mem.value / (1024 * 1024), 2) + + except Exception as e: + logger.error(f"获取系统信息失败: {e}") + + return system_info + + def get_screen_resolution(self): + """获取屏幕分辨率""" + user32 = ctypes.windll.user32 + return [user32.GetSystemMetrics(0), user32.GetSystemMetrics(1)] + + def get_game_window(self): + """获取游戏窗口句柄""" + window_title = self.config.get("Game", "window_title") + hwnd = win32gui.FindWindow(None, window_title) + if hwnd == 0: + logger.warning(f"找不到游戏窗口: {window_title}") + return None + return hwnd + + def capture_game_screen(self, hwnd=None): + """截取游戏画面""" + try: + if hwnd is None: + hwnd = self.get_game_window() + if hwnd is None: + return None + + # 获取窗口位置和大小 + rect = win32gui.GetWindowRect(hwnd) + x, y, width, height = rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1] + + # 检查窗口是否最小化 + if width <= 0 or height <= 0: + logger.warning("游戏窗口被最小化或大小无效") + return None + + # 使用mss进行截图(如果可用),速度比PIL快5-10倍 + if MSS_AVAILABLE: + with mss.mss() as sct: + monitor = {"top": y, "left": x, "width": width, "height": height} + sct_img = sct.grab(monitor) + # 将mss图像转换为PIL图像 + img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX") + else: + # 使用PIL进行截图(备选方案) + img = ImageGrab.grab(bbox=(x, y, x + width, y + height)) + + # 压缩图像 + quality = int(self.config.get("Capture", "quality")) + buffer = BytesIO() + img.save(buffer, format="JPEG", quality=quality) + + # 转换为Base64 + img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + + return { + "image": img_base64, + "window_rect": [x, y, width, height], + "timestamp": time.time() + } + + except Exception as e: + logger.error(f"截取游戏画面失败: {e}") + logger.error(traceback.format_exc()) + return None + + async def send_heartbeat(self): + """发送心跳包""" + if not self.ws or not self.running: + return False + + try: + await self.ws.send(json.dumps({ + "type": "heartbeat", + "timestamp": time.time(), + "client_id": self.client_id, + "game_state": { + "in_battle": self.game_state["in_battle"], + "current_map": self.game_state["current_map"] + } + })) + + # 等待心跳响应 + response_raw = await asyncio.wait_for( + self.ws.recv(), + timeout=5.0 + ) + + # 检查响应 + response = json.loads(response_raw) + if response.get("type") != "heartbeat_response": + logger.warning(f"收到非心跳响应: {response.get('type')}") + return False + + self.last_heartbeat_time = time.time() + return True + + except asyncio.TimeoutError: + logger.warning("心跳超时") + return False + except websockets.exceptions.ConnectionClosed: + logger.warning("发送心跳时连接已关闭") + return False + except Exception as e: + logger.error(f"发送心跳失败: {e}") + return False + + async def heartbeat_loop(self): + """心跳循环""" + heartbeat_interval = float(self.config.get("Connection", "heartbeat_interval", fallback="5")) + + while self.running: + if self.ws and not self.ws.closed: + success = await self.send_heartbeat() + if not success: + # 心跳失败,检查连接状态 + if time.time() - self.last_heartbeat_time > heartbeat_interval * 3: + logger.warning(f"心跳超时超过 {heartbeat_interval * 3} 秒,尝试重新连接") + if self.ws: + await self.ws.close() + self.ws = None + + await asyncio.sleep(heartbeat_interval) + + async def reconnect_loop(self): + """重连检查循环""" + while self.running: + if not self.ws or self.ws.closed: + logger.warning("WebSocket连接已断开,尝试重连...") + if await self.connect_with_retry(): + logger.info("重连成功") + else: + logger.error("重连失败") + # 不要立即停止,继续尝试 + + await asyncio.sleep(5) # 每5秒检查一次连接状态 + + async def execute_action(self, action): + """执行操作""" + try: + action_type = action.get("type") + + # 等待指定的延迟时间 + if "delay" in action: + await asyncio.sleep(action["delay"]) + + # 获取游戏窗口 + hwnd = self.get_game_window() + if hwnd is None: + return + + # 确保窗口处于前台 + if win32gui.GetForegroundWindow() != hwnd: + try: + # 尝试多种方法激活窗口 + win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) # 恢复窗口(如果最小化) + win32gui.SetForegroundWindow(hwnd) # 尝试置为前台 + await asyncio.sleep(0.1) + + # 如果窗口仍然不在前台,使用更强的方法 + if win32gui.GetForegroundWindow() != hwnd: + # 获取当前激活窗口的线程和进程ID + curr_hwnd = win32gui.GetForegroundWindow() + curr_thread_id = win32process.GetWindowThreadProcessId(curr_hwnd)[0] + # 获取目标窗口的线程和进程ID + target_thread_id = win32process.GetWindowThreadProcessId(hwnd)[0] + # 附加线程输入 + win32process.AttachThreadInput(target_thread_id, curr_thread_id, True) + win32gui.SetForegroundWindow(hwnd) + win32gui.BringWindowToTop(hwnd) + win32process.AttachThreadInput(target_thread_id, curr_thread_id, False) + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"激活窗口失败: {e}") + + # 获取窗口位置 + rect = win32gui.GetWindowRect(hwnd) + window_x, window_y = rect[0], rect[1] + + # 执行不同类型的操作 + if action_type == "move_to": + # 移动到指定位置 + position = action.get("position", [0, 0]) + x, y = position[0] + window_x, position[1] + window_y + + # 生成人类化的移动路径 + current_pos = win32gui.GetCursorPos() + path = self.generate_movement_path(current_pos, [x, y]) + + # 执行移动 + for point in path: + win32api.SetCursorPos((int(point[0]), int(point[1]))) + await asyncio.sleep(0.01) # 10ms延迟 + + elif action_type == "click": + # 点击指定位置 + position = action.get("position", [0, 0]) + x, y = position[0] + window_x, position[1] + window_y + + # 移动到位置 + current_pos = win32gui.GetCursorPos() + path = self.generate_movement_path(current_pos, [x, y]) + + for point in path: + win32api.SetCursorPos((int(point[0]), int(point[1]))) + await asyncio.sleep(0.01) + + # 执行点击 + mouse.click() + + elif action_type == "use_skill": + # 使用技能 + key = action.get("key", "1") + keyboard.press(key) + await asyncio.sleep(0.05) + keyboard.release(key) + + # 如果有目标位置,移动鼠标并点击 + if "target_position" in action: + pos = action["target_position"] + x, y = pos[0] + window_x, pos[1] + window_y + win32api.SetCursorPos((int(x), int(y))) + await asyncio.sleep(0.05) + mouse.click() + + elif action_type == "interact": + # 交互 + key = action.get("key", "f") + keyboard.press(key) + await asyncio.sleep(0.1) + keyboard.release(key) + + elif action_type == "move_random": + # 随机移动 + direction = action.get("direction", "right") + duration = action.get("duration", 1.0) + + # 方向键映射 + dir_keys = { + "up": "w", + "down": "s", + "left": "a", + "right": "d" + } + + key = dir_keys.get(direction, "d") + keyboard.press(key) + await asyncio.sleep(duration) + keyboard.release(key) + + elif action_type == "use_item": + # 使用物品 + key = action.get("key", "f1") + keyboard.press(key) + await asyncio.sleep(0.1) + keyboard.release(key) + + elif action_type == "stop": + # 停止所有按键 + for key in ["w", "a", "s", "d", "1", "2", "3", "4", "5", "6"]: + if keyboard.is_pressed(key): + keyboard.release(key) + + elif action_type == "type_text": + # 输入文本 + text = action.get("text", "") + if text: + keyboard.write(text, delay=0.05) + + elif action_type == "press_key_combo": + # 按下组合键 + keys = action.get("keys", []) + if keys: + for key in keys: + keyboard.press(key) + await asyncio.sleep(0.1) + for key in reversed(keys): + keyboard.release(key) + + else: + logger.warning(f"未知操作类型: {action_type}") + + # 更新游戏状态 + self.game_state["last_operation_time"] = time.time() + + except Exception as e: + logger.error(f"执行操作失败: {e}") + logger.error(traceback.format_exc()) + + def generate_movement_path(self, start_pos, end_pos, steps=None): + """生成模拟人类的鼠标移动路径""" + # 计算距离 + distance = ((end_pos[0] - start_pos[0])**2 + (end_pos[1] - start_pos[1])**2)**0.5 + + # 如果未指定步数,根据距离计算 + if steps is None: + steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点 + + # 生成基础路径 + t = np.linspace(0, 1, steps) + path = [] + + # 动态调整平滑度 + smoothness = float(self.config.get("Game", "movement_smoothness", fallback="0.8")) + + for i in range(steps): + # 基础线性插值 + x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i] + y = start_pos[1] + (end_pos[1] - start_pos[1]) * t[i] + + # 添加随机偏移(越靠近中间偏移越大) + mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大 + max_offset = distance * 0.05 * mid_factor * (1 - smoothness) # 最大偏移为距离的5%,受平滑度影响 + + offset_x = random.normalvariate(0, max_offset / 3) + offset_y = random.normalvariate(0, max_offset / 3) + + # 添加到路径 + path.append([x + offset_x, y + offset_y]) + + # 确保起点和终点准确 + path[0] = start_pos + path[-1] = end_pos + + return path + + def analyze_image(self, image_data, detection_results): + """ + 分析图像数据,更新游戏状态 + + 参数: + image_data (dict): 图像数据 + detection_results (list): 检测结果 + """ + # 更新战斗状态 + monsters_detected = False + for det in detection_results: + if det["class_name"] in ["monster", "boss"]: + monsters_detected = True + break + + self.game_state["in_battle"] = monsters_detected + + # 分析血条和蓝条 + hp_bars = [d for d in detection_results if d["class_name"] == "hp_bar"] + mp_bars = [d for d in detection_results if d["class_name"] == "mp_bar"] + + if hp_bars: + # 估算血量百分比 + self.game_state["hp_percent"] = self.estimate_bar_percent(hp_bars[0]) + + if mp_bars: + # 估算蓝量百分比 + self.game_state["mp_percent"] = self.estimate_bar_percent(mp_bars[0]) + + # 检测技能冷却 + cooldowns = [d for d in detection_results if d["class_name"] == "cooldown"] + self.game_state["cooldowns"] = {} + + for cd in cooldowns: + if "skill_id" in cd: + self.game_state["cooldowns"][cd["skill_id"]] = cd.get("remaining_time", 1.0) + + def estimate_bar_percent(self, bar_detection): + """ + 估计血条/蓝条的百分比 + + 参数: + bar_detection (dict): 条形检测结果 + + 返回: + float: 百分比值(0-100) + """ + # 这里需要根据实际情况实现 + # 临时方案:返回检测结果中的值,如果没有则返回默认值 + return bar_detection.get("percent", 100) + + async def capture_and_process_loop(self): + """截图和处理循环""" + consecutive_errors = 0 + + while self.running: + try: + # 检查WebSocket连接 + if not self.ws or self.ws.closed: + await asyncio.sleep(0.5) # 连接断开时等待 + continue + + # 获取游戏窗口 + hwnd = self.get_game_window() + if hwnd is None: + await asyncio.sleep(1.0) # 找不到窗口时等待1秒 + continue + + # 截取游戏画面 + screen_data = self.capture_game_screen(hwnd) + if screen_data is None: + await asyncio.sleep(0.5) # 截图失败时等待0.5秒 + continue + + # 准备请求数据 + request = { + "type": "image", + "request_id": f"req_{random.getrandbits(32)}", + "timestamp": time.time(), + "data": screen_data["image"], + "game_state": self.game_state, + "window_rect": screen_data["window_rect"] + } + + # 发送请求 + await self.ws.send(json.dumps(request)) + + # 等待响应 + response_raw = await asyncio.wait_for( + self.ws.recv(), + timeout=5.0 + ) + + # 处理响应 + response = json.loads(response_raw) + + if response.get("type") == "action_response": + # 获取检测结果并更新游戏状态 + if "detections" in response: + self.analyze_image(screen_data, response["detections"]) + + # 执行动作 + actions = response.get("actions", []) + + # 按优先级排序 + actions.sort(key=lambda x: x.get("execution_priority", 1.0)) + + for action in actions: + await self.execute_action(action) + + # 重置错误计数 + consecutive_errors = 0 + + elif response.get("type") == "error": + logger.error(f"服务器返回错误: {response.get('message')}") + consecutive_errors += 1 + + # 等待指定的间隔时间 + await asyncio.sleep(self.capture_interval) + + except asyncio.TimeoutError: + logger.warning("等待服务器响应超时") + consecutive_errors += 1 + except websockets.exceptions.ConnectionClosed: + logger.error("WebSocket连接已关闭") + break + except Exception as e: + logger.error(f"处理循环出错: {e}") + logger.error(traceback.format_exc()) + consecutive_errors += 1 + await asyncio.sleep(1.0) # 出错时等待1秒 + + # 如果连续错误过多,尝试重新连接 + if consecutive_errors >= 5: + logger.warning(f"连续出错 {consecutive_errors} 次,尝试重新连接") + if self.ws: + await self.ws.close() + self.ws = None + consecutive_errors = 0 + + async def run(self): + """运行客户端""" + self.running = True + + # 连接到服务器 + if not await self.connect_with_retry(): + self.running = False + return + + try: + # 创建任务 + capture_task = asyncio.create_task(self.capture_and_process_loop()) + heartbeat_task = asyncio.create_task(self.heartbeat_loop()) + reconnect_task = asyncio.create_task(self.reconnect_loop()) + + # 等待任务完成 + await asyncio.gather(capture_task, heartbeat_task, reconnect_task) + + except asyncio.CancelledError: + logger.info("客户端任务已取消") + except Exception as e: + logger.error(f"客户端运行出错: {e}") + logger.error(traceback.format_exc()) + finally: + self.running = False + if self.ws: + await self.ws.close() + logger.info("客户端已停止") + + def start(self): + """启动客户端""" + try: + # 启动心跳监控线程(备用方案,以防异步心跳失效) + self._monitor_thread = threading.Thread(target=self._monitor_connection) + self._monitor_thread.daemon = True + self._monitor_thread.start() + + # 运行主循环 + asyncio.run(self.run()) + except KeyboardInterrupt: + logger.info("用户中断,正在退出...") + except Exception as e: + logger.error(f"客户端出错: {e}") + logger.error(traceback.format_exc()) + + def _monitor_connection(self): + """监控连接的后台线程""" + while True: + try: + time.sleep(30) # 每30秒检查一次 + + if not self.running: + break + + # 检查心跳时间 + if self.last_heartbeat_time > 0 and time.time() - self.last_heartbeat_time > 60: + logger.warning("心跳超时,可能需要重连") + # 不直接重连,留给重连循环处理 + except Exception as e: + logger.error(f"连接监控线程出错: {e}") + + def stop(self): + """停止客户端""" + self.running = False + logger.info("正在停止客户端...") + +# 创建默认配置文件(如果不存在) +def ensure_config(): + if not os.path.exists(CONFIG_FILE): + config = configparser.ConfigParser() + + config["Server"] = { + "url": "wss://your-server-url:8080/ws", + "verify_ssl": "false" + } + + config["Capture"] = { + "interval": "0.5", + "quality": "70" + } + + config["Game"] = { + "window_title": "地下城与勇士", + "key_mapping": "default", + "movement_smoothness": "0.8" + } + + config["Connection"] = { + "max_retries": "5", + "retry_delay": "5", + "heartbeat_interval": "5" + } + + # 保存配置 + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + + print(f"已创建默认配置文件: {CONFIG_FILE}") + print("请编辑配置文件设置正确的服务器地址等信息") + +# 启动客户端 +if __name__ == "__main__": + # 确保配置文件存在 + ensure_config() + + try: + client = DNFAutoClient() + client.start() + except KeyboardInterrupt: + logger.info("用户中断,正在退出...") + except Exception as e: + logger.error(f"客户端出错: {e}") + logger.error(traceback.format_exc()) \ No newline at end of file diff --git a/dnf-client-package/config.ini b/dnf-client-package/config.ini new file mode 100644 index 0000000..cd2b483 --- /dev/null +++ b/dnf-client-package/config.ini @@ -0,0 +1,17 @@ +[Server] +url = wss://your-server-ip:8080/ws +verify_ssl = false + +[Capture] +interval = 0.5 +quality = 70 + +[Game] +window_title = 地下城与勇士 +key_mapping = default +movement_smoothness = 0.8 + +[Connection] +max_retries = 5 +retry_delay = 5 +heartbeat_interval = 5 diff --git a/dnf-client-package/start.bat b/dnf-client-package/start.bat new file mode 100644 index 0000000..6d0a6cf --- /dev/null +++ b/dnf-client-package/start.bat @@ -0,0 +1,4 @@ +@echo off +echo 正在启动DNF自动化客户端... +python client.py +pause diff --git a/dnf-client-windows.zip b/dnf-client-windows.zip new file mode 100644 index 0000000..3f6d7d8 Binary files /dev/null and b/dnf-client-windows.zip differ diff --git a/models/yolo_model.py b/models/yolo_model.py index d8eaa53..d94cdb0 100644 --- a/models/yolo_model.py +++ b/models/yolo_model.py @@ -3,48 +3,74 @@ """ YOLO模型封装,提供图像识别功能 +优化版 - 增强性能和模型缓存管理 +修复版 - 解决模型加载依赖问题 """ import os +import sys # 添加缺失的sys模块导入 import logging import torch import yaml -from PIL import Image +import time import numpy as np +import subprocess +from PIL import Image +from collections import deque -from config.settings import MODEL +from config.settings import MODEL, BASE_DIR logger = logging.getLogger("DNFAutoCloud") class YOLOModel: - """YOLO模型封装类""" + """YOLO模型封装类 - 优化版""" def __init__(self): """初始化YOLO模型""" - self.device = MODEL["device"] - self.conf_threshold = MODEL["conf_threshold"] - self.iou_threshold = MODEL["iou_threshold"] + self.device = MODEL.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.conf_threshold = MODEL.get("conf_threshold", 0.5) + self.iou_threshold = MODEL.get("iou_threshold", 0.45) - logger.info(f"正在加载YOLO模型: {MODEL['name']}") - logger.info(f"模型权重路径: {MODEL['weights']}") + # 性能监控 + self.inference_times = deque(maxlen=100) + self.batch_size = MODEL.get("batch_size", 1) + self.use_half_precision = MODEL.get("half_precision", True) and self.device == "cuda" + + logger.info(f"正在加载YOLO模型: {MODEL.get('name', 'unknown')}") + logger.info(f"模型权重路径: {MODEL.get('weights', 'unknown')}") + logger.info(f"设备: {self.device}, 半精度: {self.use_half_precision}") # 检查权重文件是否存在 - if not os.path.exists(MODEL["weights"]): - raise FileNotFoundError(f"找不到模型权重文件: {MODEL['weights']}") + weights_path = MODEL.get("weights", "") + if not os.path.exists(weights_path): + # 尝试从备选路径加载 + alt_paths = [ + os.path.join(BASE_DIR, "models", "weights", "best.pt"), + os.path.join(BASE_DIR, "models", "best.pt"), + os.path.join(BASE_DIR, "yolov5", "runs", "train", "exp", "weights", "best.pt") + ] + + for path in alt_paths: + if os.path.exists(path): + weights_path = path + logger.info(f"使用备选模型权重: {weights_path}") + break + + if not os.path.exists(weights_path): + raise FileNotFoundError(f"找不到模型权重文件: {weights_path}") # 加载模型 try: - self.model = torch.hub.load('ultralytics/yolov5', 'custom', - path=MODEL["weights"], - device=self.device) - - # 设置模型参数 - self.model.conf = self.conf_threshold - self.model.iou = self.iou_threshold - - # 预热模型 (运行一个空白图像以初始化) - dummy_img = torch.zeros((1, 3, 640, 640), device=self.device) - self.model(dummy_img) + if MODEL.get("engine", "") == "onnx": + self._load_onnx_model() + else: + # 尝试不同的加载方法 + try: + self._load_yolov5_custom() + except Exception as e: + logger.warning(f"使用torch.hub加载YOLOv5模型失败: {e}") + logger.info("尝试备选方法加载模型...") + self._load_yolov5_manually() logger.info("YOLO模型加载成功") @@ -52,6 +78,256 @@ class YOLOModel: logger.error(f"加载YOLO模型失败: {e}") raise + def _load_yolov5_custom(self): + """使用torch.hub加载YOLOv5模型""" + try: + # 尝试从本地加载YOLOv5 + yolov5_dir = os.path.join(BASE_DIR, "tools", "yolov5") + if os.path.exists(yolov5_dir): + logger.info(f"从本地目录加载YOLOv5: {yolov5_dir}") + sys.path.insert(0, yolov5_dir) + + try: + # 尝试直接导入YOLOv5模块 + from models.common import DetectMultiBackend + from utils.torch_utils import select_device + from utils.general import check_img_size, non_max_suppression, scale_coords + + # 加载模型 + self.model = DetectMultiBackend(MODEL.get("weights"), device=self.device) + self.stride = self.model.stride + self.names = self.model.names + + # 设置参数 + self.imgsz = check_img_size((640, 640), s=self.stride) + + # 使用半精度 + if self.use_half_precision: + self.model.half() + + # 预热模型 + self._warmup_model() + + return + except ImportError as e: + logger.warning(f"无法直接导入YOLOv5模块: {e}") + + # 尝试从torch.hub加载 + logger.info("尝试从torch.hub加载YOLOv5模型") + self.model = torch.hub.load( + 'ultralytics/yolov5', + 'custom', + path=MODEL.get("weights"), + device=self.device, + force_reload=True + ) + + # 设置模型参数 + self.model.conf = self.conf_threshold + self.model.iou = self.iou_threshold + + # 使用半精度 + if self.use_half_precision: + self.model.half() + + # 预热模型 + self._warmup_model() + + except Exception as e: + logger.error(f"加载YOLOv5模型失败: {e}") + raise + + def _load_yolov5_manually(self): + """手动加载YOLOv5模型(不依赖torch.hub)""" + try: + # 检查YOLOv5目录是否存在,不存在则克隆 + yolov5_dir = os.path.join(BASE_DIR, "tools", "yolov5") + if not os.path.exists(yolov5_dir): + logger.info("YOLOv5目录不存在,正在克隆仓库...") + os.makedirs(os.path.dirname(yolov5_dir), exist_ok=True) + + # 克隆YOLOv5仓库 + subprocess.run( + ["git", "clone", "https://github.com/ultralytics/yolov5.git", yolov5_dir], + check=True + ) + + # 添加YOLOv5目录到系统路径 + sys.path.insert(0, yolov5_dir) + + try: + # 尝试导入YOLOv5模块 + from models.common import DetectMultiBackend + from utils.torch_utils import select_device + from utils.general import check_img_size, non_max_suppression, scale_coords + + # 加载模型 + self.model = DetectMultiBackend(MODEL.get("weights"), device=self.device) + self.stride = self.model.stride + self.names = self.model.names + + # 设置参数 + self.imgsz = check_img_size((640, 640), s=self.stride) + + # 使用半精度 + if self.use_half_precision: + self.model.half() + + # 保存必要的函数 + self.non_max_suppression = non_max_suppression + self.scale_coords = scale_coords + + # 预热模型 + dummy_img = torch.zeros((1, 3, self.imgsz[0], self.imgsz[1]), device=self.device) + if self.use_half_precision: + dummy_img = dummy_img.half() + + with torch.no_grad(): + for _ in range(2): + self.model(dummy_img) + + logger.info("手动加载YOLOv5模型成功") + + except ImportError as e: + logger.error(f"导入YOLOv5模块失败: {e}") + # 尝试使用更简单的模型 + self._load_fallback_model() + + except Exception as e: + logger.error(f"手动加载YOLOv5模型失败: {e}") + # 尝试使用更简单的模型 + self._load_fallback_model() + + def _load_fallback_model(self): + """加载备用模型(使用PyTorch内置模型)""" + logger.info("尝试加载备用模型 (PyTorch YOLO)") + + try: + # 使用PyTorch的预训练模型 + from torchvision.models.detection import fasterrcnn_resnet50_fpn + + self.model = fasterrcnn_resnet50_fpn(pretrained=True) + self.model.to(self.device) + self.model.eval() + + # 备用模型的类别 + self.names = [ + 'background', 'monster', 'boss', 'door', 'item', 'npc', 'player', + 'hp_bar', 'mp_bar', 'skill_ready', 'cooldown' + ] + + # 标记使用备用模型 + self.using_fallback = True + + logger.info("备用模型加载成功") + + except Exception as e: + logger.error(f"加载备用模型失败: {e}") + raise + + def _load_onnx_model(self): + """加载ONNX版YOLO模型""" + try: + import onnxruntime as ort + + # 设置ONNX运行时参数 + if self.device == "cuda": + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + # 检查ONNX模型文件 + onnx_path = MODEL.get("weights", "") + if not onnx_path.endswith('.onnx'): + onnx_path = onnx_path + '.onnx' + + if not os.path.exists(onnx_path): + logger.warning(f"找不到ONNX模型: {onnx_path}") + logger.info("尝试导出PyTorch模型到ONNX格式...") + + # 尝试加载PyTorch模型并导出为ONNX + self._load_yolov5_custom() + self._export_to_onnx(onnx_path) + + # 创建ONNX会话 + self.onnx_session = ort.InferenceSession(onnx_path, providers=providers) + + # 获取模型输入输出名称 + self.input_name = self.onnx_session.get_inputs()[0].name + self.output_names = [output.name for output in self.onnx_session.get_outputs()] + + # 加载类别名称 + if os.path.exists(MODEL.get("class_names", "")): + with open(MODEL.get("class_names"), "r") as f: + self.class_names = yaml.safe_load(f) + else: + self.class_names = ["object"] + + # 标记使用ONNX + self.using_onnx = True + + # 预热模型 + self._warmup_onnx_model() + + except ImportError: + logger.error("缺少ONNX运行时依赖,请安装onnxruntime或onnxruntime-gpu") + logger.info("尝试使用PyTorch模型代替...") + self._load_yolov5_custom() + except Exception as e: + logger.error(f"加载ONNX模型失败: {e}") + logger.info("尝试使用PyTorch模型代替...") + self._load_yolov5_custom() + + def _export_to_onnx(self, onnx_path): + """将PyTorch模型导出为ONNX格式""" + try: + import torch.onnx + + # 准备导出 + dummy_input = torch.randn(1, 3, 640, 640, device=self.device) + if self.use_half_precision: + dummy_input = dummy_input.half() + + # 导出ONNX + torch.onnx.export( + self.model, + dummy_input, + onnx_path, + verbose=False, + opset_version=12, + input_names=['images'], + output_names=['output'] + ) + + logger.info(f"PyTorch模型成功导出为ONNX格式: {onnx_path}") + + except Exception as e: + logger.error(f"导出ONNX模型失败: {e}") + raise + + def _warmup_model(self): + """预热模型(运行空白图像以初始化)""" + dummy_img = torch.zeros((1, 3, 640, 640), device=self.device) + if self.use_half_precision: + dummy_img = dummy_img.half() + + # 进行多次推理预热 + with torch.no_grad(): + for _ in range(2): + self.model(dummy_img) + + logger.info("模型预热完成") + + def _warmup_onnx_model(self): + """预热ONNX模型""" + dummy_img = np.zeros((1, 3, 640, 640), dtype=np.float32) + + # 进行多次推理预热 + for _ in range(2): + self.onnx_session.run(self.output_names, {self.input_name: dummy_img}) + + logger.info("ONNX模型预热完成") + def detect(self, image): """ 对图像进行目标检测 @@ -63,20 +339,25 @@ class YOLOModel: list: 检测结果,每个结果包含边界框、类别和置信度 """ try: - # 在GPU上进行推理 - with torch.no_grad(): - results = self.model(image) + start_time = time.time() - # 处理结果 - detections = [] - for pred in results.xyxy[0].cpu().numpy(): - x1, y1, x2, y2, conf, cls = pred - detections.append({ - 'bbox': [float(x1), float(y1), float(x2), float(y2)], - 'confidence': float(conf), - 'class_id': int(cls), - 'class_name': self.model.names[int(cls)] - }) + # 根据模型类型选择检测方法 + if hasattr(self, 'using_onnx') and self.using_onnx: + detections = self._detect_onnx(image) + elif hasattr(self, 'using_fallback') and self.using_fallback: + detections = self._detect_fallback(image) + else: + detections = self._detect_pytorch(image) + + # 记录推理时间 + inference_time = time.time() - start_time + self.inference_times.append(inference_time) + + # 计算平均推理时间 + avg_time = sum(self.inference_times) / len(self.inference_times) + + if len(self.inference_times) % 10 == 0: + logger.debug(f"平均推理时间: {avg_time:.3f}秒, 当前: {inference_time:.3f}秒") return detections @@ -84,6 +365,217 @@ class YOLOModel: logger.error(f"目标检测出错: {e}") return [] + def _detect_pytorch(self, image): + """使用PyTorch模型进行检测""" + # 在GPU上进行推理 + with torch.no_grad(): + # 处理不同的模型接口 + if hasattr(self, 'stride') and hasattr(self, 'names'): + # 自定义加载的YOLOv5 + # 转换图像 + img = self._prepare_image_custom(image) + + # 推理 + output = self.model(img) + + # 处理输出 + pred = self.non_max_suppression(output[0], self.conf_threshold, self.iou_threshold) + + # 解析结果 + detections = [] + for det in pred[0]: + x1, y1, x2, y2, conf, cls = det + detections.append({ + 'bbox': [float(x1), float(y1), float(x2), float(y2)], + 'confidence': float(conf), + 'class_id': int(cls), + 'class_name': self.names[int(cls)] + }) + else: + # 标准torch.hub加载的模型 + results = self.model(image) + + # 处理结果 + detections = [] + for pred in results.xyxy[0].cpu().numpy(): + x1, y1, x2, y2, conf, cls = pred + detections.append({ + 'bbox': [float(x1), float(y1), float(x2), float(y2)], + 'confidence': float(conf), + 'class_id': int(cls), + 'class_name': self.model.names[int(cls)] + }) + + return detections + + def _detect_fallback(self, image): + """使用备用模型进行检测""" + # 预处理图像 + img = self._prepare_image_pytorch(image) + + # 进行推理 + with torch.no_grad(): + predictions = self.model(img) + + # 处理结果 + detections = [] + for i, prediction in enumerate(predictions[0]['boxes']): + score = predictions[0]['scores'][i].item() + if score > self.conf_threshold: + x1, y1, x2, y2 = prediction.tolist() + class_id = predictions[0]['labels'][i].item() + + # 映射torchvision的COCO类别到我们的类别 + class_name = self.names[min(class_id, len(self.names) - 1)] + + detections.append({ + 'bbox': [float(x1), float(y1), float(x2), float(y2)], + 'confidence': float(score), + 'class_id': class_id, + 'class_name': class_name + }) + + return detections + + def _detect_onnx(self, image): + """使用ONNX模型进行检测""" + # 预处理图像 + input_tensor = self._prepare_image_onnx(image) + + # 运行推理 + outputs = self.onnx_session.run(self.output_names, {self.input_name: input_tensor}) + + # 解析输出(根据ONNX模型的输出格式可能需要调整) + # 假设输出格式为 [batch_id, x1, y1, x2, y2, conf, class_id] + predictions = outputs[0] + + # 过滤低置信度预测 + mask = predictions[:, 5] > self.conf_threshold + filtered_preds = predictions[mask] + + # 组织检测结果 + detections = [] + for pred in filtered_preds: + x1, y1, x2, y2, conf, cls = pred[1:7] + cls_id = int(cls) + + detections.append({ + 'bbox': [float(x1), float(y1), float(x2), float(y2)], + 'confidence': float(conf), + 'class_id': cls_id, + 'class_name': self.class_names[cls_id] if cls_id < len(self.class_names) else "unknown" + }) + + return detections + + def _prepare_image_custom(self, image): + """为自定义加载的YOLOv5模型准备图像""" + # 转换PIL图像为numpy数组 + img = np.array(image) + + # 调整大小 + img = self._letterbox(img, new_shape=self.imgsz)[0] + + # BGR转RGB + img = img[:, :, ::-1].transpose(2, 0, 1) + img = np.ascontiguousarray(img) + + # 转换为PyTorch张量 + img = torch.from_numpy(img).to(self.device) + img = img.half() if self.use_half_precision else img.float() + img /= 255.0 + + # 增加批次维度 + if img.ndimension() == 3: + img = img.unsqueeze(0) + + return img + + def _prepare_image_pytorch(self, image): + """为PyTorch模型准备图像""" + # 转换PIL图像为PyTorch张量 + from torchvision import transforms + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + img = transform(image).unsqueeze(0).to(self.device) + + return img + + def _prepare_image_onnx(self, image): + """为ONNX模型准备图像""" + # 调整图像大小 + img_size = MODEL.get("img_size", 640) + image = image.resize((img_size, img_size), Image.LANCZOS) + + # 转换为numpy数组 + img = np.array(image).astype(np.float32) / 255.0 + + # 从HWC转换为CHW格式 + img = img.transpose(2, 0, 1) + + # 添加批次维度 + img = np.expand_dims(img, axis=0) + + return img + + def _letterbox(self, img, new_shape=(640, 640), color=(114, 114, 114)): + """调整图像大小并填充(YOLOv5风格)""" + shape = img.shape[:2] # 当前形状 [高, 宽] + + # 缩放比例 (新 / 旧) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + + # 计算填充 + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + + # 平均分配填充 + dw /= 2 + dh /= 2 + + # 调整大小 + if shape[::-1] != new_unpad: + import cv2 + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + + # 添加边框 + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) + + return img, r, (dw, dh) + def get_class_names(self): """获取类别名称列表""" - return self.model.names \ No newline at end of file + if hasattr(self, 'using_onnx') and self.using_onnx: + return self.class_names + elif hasattr(self, 'using_fallback') and self.using_fallback: + return self.names + elif hasattr(self, 'names'): + return self.names + else: + return self.model.names + + def get_performance_stats(self): + """获取性能统计信息""" + if not self.inference_times: + return {"average_time": 0, "min_time": 0, "max_time": 0, "count": 0, "fps": 0} + + avg_time = sum(self.inference_times) / len(self.inference_times) + + # 修复fps计算逻辑 + if avg_time > 0: + fps = 1.0 / avg_time + else: + fps = 0 + + return { + "average_time": avg_time, + "min_time": min(self.inference_times), + "max_time": max(self.inference_times), + "count": len(self.inference_times), + "fps": fps + } \ No newline at end of file diff --git a/package.sh b/package.sh new file mode 100755 index 0000000..224f6b4 --- /dev/null +++ b/package.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# 创建客户端打包目录 +mkdir -p dnf-client-package +cp client.py dnf-client-package/ +cp -r win_libs/* dnf-client-package/ + +# 创建默认配置文件 +cat > dnf-client-package/config.ini << EOF +[Server] +url = wss://your-server-ip:8080/ws +verify_ssl = false + +[Capture] +interval = 0.5 +quality = 70 + +[Game] +window_title = 地下城与勇士 +key_mapping = default +movement_smoothness = 0.8 + +[Connection] +max_retries = 5 +retry_delay = 5 +heartbeat_interval = 5 +EOF + +# 创建启动脚本 +cat > dnf-client-package/start.bat << EOF +@echo off +echo 正在启动DNF自动化客户端... +python client.py +pause +EOF + +# 打包为zip +cd dnf-client-package +zip -r ../dnf-client-windows.zip * +cd .. + +echo "打包完成: dnf-client-windows.zip" \ No newline at end of file diff --git a/run.py b/run.py index 0719bb2..d0ee42c 100644 --- a/run.py +++ b/run.py @@ -3,25 +3,257 @@ """ DNF自动化云服务主启动脚本 +优化版 - 增加命令行参数和状态监控 """ import os +import sys import argparse import logging +import time +import json +import platform +import subprocess +import threading +import requests +from datetime import datetime, timedelta + from server.main import start_server from utils.logging_utils import setup_logging +from config.settings import SERVER, MODEL, LOGGING + +# 版本信息 +VERSION = "1.0.1" + +def get_system_info(): + """获取系统信息""" + info = { + "os": platform.system(), + "os_version": platform.version(), + "python": platform.python_version(), + "cpu": platform.processor(), + "architecture": platform.machine(), + "hostname": platform.node() + } + + # 检查CUDA和PyTorch + try: + import torch + info["torch"] = torch.__version__ + + if torch.cuda.is_available(): + info["cuda"] = torch.version.cuda + info["gpu"] = torch.cuda.get_device_name(0) + info["gpu_count"] = torch.cuda.device_count() + info["gpu_memory"] = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB + else: + info["cuda"] = "Not available" + except ImportError: + info["torch"] = "Not installed" + info["cuda"] = "Unknown" + + return info + +def print_banner(): + """打印启动横幅""" + banner = f""" + ╔═══════════════════════════════════════════════════════════╗ + ║ ║ + ║ DNF 自动化云服务 v{VERSION} ║ + ║ ║ + ║ 适用于地下城与勇士游戏的AI辅助系统 ║ + ║ 启动时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ║ + ║ ║ + ╚═══════════════════════════════════════════════════════════╝ + """ + print(banner) + + # 添加警告信息 + if MODEL.get("device", "") == "cpu": + print("警告: 正在使用CPU模式运行,性能可能受限,建议使用GPU以获得更好的体验\n") + +def monitor_server(port): + """监控服务器状态""" + try: + # 定义状态监控线程 + def status_thread(): + start_time = time.time() + + while True: + try: + current_time = time.time() + uptime = current_time - start_time + + # 读取server_info.json获取服务器信息 + if os.path.exists("server_info.json"): + with open("server_info.json", "r") as f: + info = json.load(f) + + # 显示状态 + os.system('cls' if os.name == 'nt' else 'clear') + print(f"\n==== DNF自动化云服务状态监控 ====") + print(f"服务器版本: {info.get('version', 'unknown')}") + print(f"启动时间: {info.get('started_at', 'unknown')}") + print(f"运行时长: {timedelta(seconds=int(uptime))}") + print(f"服务地址: {info.get('server', {}).get('url', 'unknown')}") + + # 尝试获取服务器状态 + try: + status_file = "server_status.json" + if os.path.exists(status_file) and os.path.getmtime(status_file) > current_time - 60: + with open(status_file, "r") as sf: + status = json.load(sf) + + print(f"\n==== 连接统计 ====") + print(f"活跃连接数: {status.get('connections', {}).get('active', 0)}") + print(f"总连接数: {status.get('connections', {}).get('total', 0)}") + + print(f"\n==== 性能统计 ====") + print(f"已处理图像: {status.get('performance', {}).get('images_processed', 0)}") + print(f"已生成动作: {status.get('performance', {}).get('actions_generated', 0)}") + print(f"错误数: {status.get('performance', {}).get('errors', 0)}") + except: + pass + + # 等待10秒更新一次 + time.sleep(10) + + except KeyboardInterrupt: + break + except Exception as e: + print(f"监控出错: {e}") + time.sleep(30) + + # 启动监控线程 + t = threading.Thread(target=status_thread, daemon=True) + t.start() + + except Exception as e: + print(f"启动监控失败: {e}") + +def clean_old_logs(): + """清理旧的日志文件""" + log_dir = os.path.dirname(LOGGING.get("file", "")) + if not os.path.exists(log_dir): + return + + try: + # 获取所有日志文件 + log_files = [f for f in os.listdir(log_dir) if f.endswith(".log")] + + # 按修改时间排序 + log_files.sort(key=lambda x: os.path.getmtime(os.path.join(log_dir, x))) + + # 如果超过10个日志文件,删除最旧的 + if len(log_files) > 10: + for f in log_files[:-10]: + try: + os.remove(os.path.join(log_dir, f)) + print(f"已删除旧日志文件: {f}") + except: + pass + except Exception as e: + print(f"清理日志时出错: {e}") + +def check_updates(): + """检查更新""" + try: + # 这里可以实现检查更新的逻辑 + # 例如,检查Git仓库或向API服务器查询 + pass + except: + pass + +def create_status_api(port): + """创建一个简单的状态API服务器""" + try: + from http.server import HTTPServer, BaseHTTPRequestHandler + import threading + import json + + class StatusHandler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/status": + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + + # 读取服务器状态 + status = {"status": "running"} + if os.path.exists("server_status.json"): + try: + with open("server_status.json", "r") as f: + status = json.load(f) + except: + pass + + self.wfile.write(json.dumps(status).encode()) + else: + self.send_response(404) + self.end_headers() + + def run_server(): + server_address = ("", port + 1) # 使用主端口+1 + httpd = HTTPServer(server_address, StatusHandler) + print(f"状态API服务器已启动: http://localhost:{port + 1}/status") + httpd.serve_forever() + + # 启动API服务器线程 + t = threading.Thread(target=run_server, daemon=True) + t.start() + + except Exception as e: + print(f"启动状态API服务器失败: {e}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="DNF自动化云服务") + parser = argparse.ArgumentParser(description=f"DNF自动化云服务 v{VERSION}") parser.add_argument("--port", type=int, default=8080, help="服务端口") parser.add_argument("--host", type=str, default="0.0.0.0", help="服务地址") parser.add_argument("--debug", action="store_true", help="开启调试模式") + parser.add_argument("--cpu", action="store_true", help="强制使用CPU模式") + parser.add_argument("--no-monitor", action="store_true", help="禁用状态监控") + parser.add_argument("--version", action="version", version=f"DNF自动化云服务 v{VERSION}") args = parser.parse_args() + # 开启状态监控 + if not args.no_monitor: + monitor_server(args.port) + create_status_api(args.port) + + # 打印横幅 + print_banner() + + # 清理旧日志 + clean_old_logs() + + # 检查更新 + check_updates() + + # 强制CPU模式 + if args.cpu: + MODEL["device"] = "cpu" + print("已强制使用CPU模式") + # 设置日志 setup_logging(debug=args.debug) logger = logging.getLogger("DNFAutoCloud") - logger.info("正在启动DNF自动化云服务...") + + # 打印系统信息 + system_info = get_system_info() + logger.info(f"系统信息: {json.dumps(system_info, ensure_ascii=False)}") # 启动服务 - start_server(host=args.host, port=args.port, debug=args.debug) \ No newline at end of file + logger.info(f"正在启动DNF自动化云服务 v{VERSION}...") + + try: + # 启动服务器 + success = start_server(host=args.host, port=args.port, debug=args.debug) + + if not success: + logger.error("服务器启动失败") + sys.exit(1) + except KeyboardInterrupt: + logger.info("用户中断,服务已停止") + except Exception as e: + logger.critical(f"启动服务时出错: {e}", exc_info=True) + sys.exit(1) \ No newline at end of file diff --git a/server/action_generator.py b/server/action_generator.py index 0b9465f..d8759e5 100644 --- a/server/action_generator.py +++ b/server/action_generator.py @@ -3,18 +3,50 @@ """ 动作生成模块,根据图像识别结果生成游戏操作指令 +优化版 - 更智能的决策逻辑和状态机实现 """ import logging import random import time +import math +import numpy as np +from collections import deque + from config.settings import DNF, BEHAVIOR logger = logging.getLogger("DNFAutoCloud") +# 历史决策保存 +decision_history = deque(maxlen=10) +state_transitions = {} + +class GameState: + """游戏状态枚举""" + UNKNOWN = "unknown" + LOADING = "loading" + LOBBY = "lobby" + DUNGEON_SELECTION = "dungeon_selection" + ROOM_SELECTION = "room_selection" + IN_DUNGEON = "in_dungeon" + IN_BATTLE = "in_battle" + REST = "rest" + PICKUP_ITEMS = "pickup_items" + TALKING = "talking" + INVENTORY = "inventory" + DEAD = "dead" + +class ActionPriority: + """动作优先级""" + CRITICAL = 0.5 # 紧急动作(如低血量使用药水) + HIGH = 0.8 # 高优先级(如打BOSS) + NORMAL = 1.0 # 正常优先级 + LOW = 1.5 # 低优先级 + BACKGROUND = 2.0 # 后台任务 + def generate_actions(detections, game_state): """ - 根据检测结果和游戏状态生成动作 + 根据检测结果和游戏状态生成动作 - 优化版 参数: detections (list): 检测结果列表 @@ -26,86 +58,344 @@ def generate_actions(detections, game_state): actions = [] try: - # 检查是否处于战斗状态 - in_battle = is_in_battle(detections, game_state) + # 推断当前游戏状态 + current_state = infer_game_state(detections, game_state) - if in_battle: - # 战斗状态下的操作 + # 记录状态转换 + record_state_transition(game_state.get("previous_state", GameState.UNKNOWN), current_state) + + # 检查玩家状态是否需要紧急处理 + emergency_actions = check_emergency(detections, game_state) + if emergency_actions: + actions.extend(emergency_actions) + # 在紧急情况下,可能需要提前返回而不执行其他操作 + if any(a.get("critical", False) for a in emergency_actions): + return actions + + # 根据当前状态生成动作 + if current_state == GameState.IN_BATTLE: battle_actions = generate_battle_actions(detections, game_state) actions.extend(battle_actions) - else: - # 非战斗状态下的操作 - navigation_actions = generate_navigation_actions(detections, game_state) - actions.extend(navigation_actions) - # 根据游戏状态添加额外操作 - additional_actions = generate_additional_actions(detections, game_state) - actions.extend(additional_actions) + elif current_state == GameState.IN_DUNGEON: + exploration_actions = generate_exploration_actions(detections, game_state) + actions.extend(exploration_actions) + + elif current_state == GameState.PICKUP_ITEMS: + pickup_actions = generate_pickup_actions(detections, game_state) + actions.extend(pickup_actions) + + elif current_state == GameState.TALKING: + talking_actions = generate_talking_actions(detections, game_state) + actions.extend(talking_actions) + + elif current_state == GameState.LOBBY: + lobby_actions = generate_lobby_actions(detections, game_state) + actions.extend(lobby_actions) + + elif current_state == GameState.DUNGEON_SELECTION: + selection_actions = generate_dungeon_selection_actions(detections, game_state) + actions.extend(selection_actions) + + elif current_state == GameState.ROOM_SELECTION: + room_actions = generate_room_selection_actions(detections, game_state) + actions.extend(room_actions) + + elif current_state == GameState.DEAD: + dead_actions = generate_dead_actions(detections, game_state) + actions.extend(dead_actions) + + else: + # 默认动作 + default_actions = generate_default_actions(detections, game_state) + actions.extend(default_actions) # 添加人类行为特征 humanize_actions(actions) + # 记录决策历史 + record_decision(current_state, detections, actions) + + # 更新游戏状态 + game_state["previous_state"] = current_state + return actions except Exception as e: logger.error(f"生成动作时出错: {e}") # 出错时返回安全的默认动作(如停止移动) - return [{"type": "stop", "reason": "error_recovery"}] + return [{"type": "stop", "reason": "error_recovery", "execution_priority": ActionPriority.CRITICAL}] -def is_in_battle(detections, game_state): +def infer_game_state(detections, game_state): """ - 判断是否处于战斗状态 + 根据检测结果推断当前游戏状态 参数: detections (list): 检测结果 game_state (dict): 游戏状态 返回: - bool: 是否处于战斗状态 + str: 游戏状态 """ - # 检查是否有怪物或Boss - for det in detections: - if det["class_name"] in ["monster", "boss"]: - return True + # 检查是否有怪物或Boss(战斗状态) + monsters = [d for d in detections if d["class_name"] in ["monster", "boss"]] + if monsters: + return GameState.IN_BATTLE - # 检查游戏状态中的战斗标志 - return game_state.get("in_battle", False) + # 检查是否有物品(拾取状态) + items = [d for d in detections if d["class_name"] == "item"] + if items and not game_state.get("inventory_full", False): + return GameState.PICKUP_ITEMS + + # 检查是否有对话框(对话状态) + dialog_boxes = [d for d in detections if d["class_name"] == "dialog"] + if dialog_boxes: + return GameState.TALKING + + # 检查是否有门(地下城探索状态) + doors = [d for d in detections if d["class_name"] == "door"] + if doors: + return GameState.IN_DUNGEON + + # 检查是否有NPC(可能在城镇/大厅) + npcs = [d for d in detections if d["class_name"] == "npc"] + if npcs: + # 区分大厅和地下城中的NPC + if game_state.get("current_map", "").startswith("town_"): + return GameState.LOBBY + else: + return GameState.IN_DUNGEON + + # 检查是否有地下城选择界面 + dungeon_ui = [d for d in detections if d["class_name"] == "dungeon_select"] + if dungeon_ui: + return GameState.DUNGEON_SELECTION + + # 检查是否有房间选择界面 + room_ui = [d for d in detections if d["class_name"] == "room_select"] + if room_ui: + return GameState.ROOM_SELECTION + + # 检查是否死亡 + death_ui = [d for d in detections if d["class_name"] == "death_ui"] + if death_ui: + return GameState.DEAD + + # 其他情况,根据当前位置判断 + current_map = game_state.get("current_map", "") + if current_map.startswith("town_"): + return GameState.LOBBY + elif current_map: + return GameState.IN_DUNGEON + + # 保持之前的状态(如果有) + previous_state = game_state.get("previous_state") + if previous_state: + return previous_state + + # 默认为未知状态 + return GameState.UNKNOWN + +def check_emergency(detections, game_state): + """ + 检查紧急情况并生成相应动作 + + 参数: + detections (list): 检测结果 + game_state (dict): 游戏状态 + + 返回: + list: 紧急操作指令 + """ + actions = [] + + # 检查血量 + hp_bars = [d for d in detections if d["class_name"] == "hp_bar"] + if hp_bars: + hp_percent = hp_bars[0].get("percent", 100) + + # 血量过低时使用药水 + if hp_percent < 30: + actions.append({ + "type": "use_item", + "key": "f1", # 假设F1是HP药水快捷键 + "purpose": "use_hp_potion", + "execution_priority": ActionPriority.CRITICAL, + "critical": True, + "reason": f"低血量 ({hp_percent:.1f}%)" + }) + + # 检查蓝量 + mp_bars = [d for d in detections if d["class_name"] == "mp_bar"] + if mp_bars: + mp_percent = mp_bars[0].get("percent", 100) + + # 蓝量过低时使用药水 + if mp_percent < 20: + actions.append({ + "type": "use_item", + "key": "f2", # 假设F2是MP药水快捷键 + "purpose": "use_mp_potion", + "execution_priority": ActionPriority.HIGH, + "reason": f"低蓝量 ({mp_percent:.1f}%)" + }) + + # 检查是否需要复活或返回城镇 + if game_state.get("previous_state") == GameState.DEAD: + actions.append({ + "type": "press_key_combo", + "keys": ["alt", "r"], # 假设Alt+R是复活快捷键 + "purpose": "revive", + "execution_priority": ActionPriority.CRITICAL, + "critical": True, + "reason": "角色已死亡,需要复活" + }) + + return actions def generate_battle_actions(detections, game_state): - """生成战斗状态下的操作""" + """ + 生成战斗状态下的操作 - 优化版 + + 参数: + detections (list): 检测结果 + game_state (dict): 游戏状态 + + 返回: + list: 操作指令 + """ actions = [] # 获取所有怪物 monsters = [d for d in detections if d["class_name"] in ["monster", "boss"]] if monsters: - # 按照优先级排序(Boss > 普通怪物) - monsters.sort(key=lambda x: 0 if x["class_name"] == "boss" else 1) + # 按照优先级排序(Boss > 精英怪 > 普通怪物) + def monster_priority(m): + if m["class_name"] == "boss": + return 0 + elif m.get("is_elite", False): + return 1 + else: + return 2 - # 选取最近或优先级最高的目标 + monsters.sort(key=monster_priority) + + # 选取目标 target = monsters[0] - # 添加移动到目标附近的动作 - actions.append({ - "type": "move_to", - "position": target["center"], - "target_id": target.get("id", "unknown"), - "target_type": target["class_name"] - }) + # 目标距离 + target_center = target["center"] - # 添加攻击动作 - skill_key = choose_skill(game_state) - actions.append({ - "type": "use_skill", - "key": skill_key, - "target_position": target["center"], - "skill_name": DNF["skills"].get(skill_key, "未知技能") - }) + # 选择合适的技能 + skill_key, skill_type = choose_skill(game_state, detections, monsters) + + if skill_type == "ranged": + # 远程技能 - 站在适当距离释放 + # 首先移动到合适的位置 + ideal_distance = 200 # 理想距离(像素) + current_pos = [game_state.get("player_x", 0), game_state.get("player_y", 0)] + + # 计算当前到目标的向量 + vec_to_target = [ + target_center[0] - current_pos[0], + target_center[1] - current_pos[1] + ] + + # 计算距离 + distance = math.sqrt(vec_to_target[0]**2 + vec_to_target[1]**2) + + # 如果太近,向后移动 + if distance < ideal_distance * 0.8: + # 标准化向量并反向 + norm = math.sqrt(vec_to_target[0]**2 + vec_to_target[1]**2) + if norm > 0: + vec_direction = [-vec_to_target[0]/norm, -vec_to_target[1]/norm] + back_pos = [ + current_pos[0] + vec_direction[0] * ideal_distance * 0.2, + current_pos[1] + vec_direction[1] * ideal_distance * 0.2 + ] + + actions.append({ + "type": "move_to", + "position": back_pos, + "purpose": "adjust_range", + "execution_priority": ActionPriority.HIGH + }) + + # 添加使用技能动作 + actions.append({ + "type": "use_skill", + "key": skill_key, + "target_position": target_center, + "target_id": target.get("id", "unknown"), + "target_type": target["class_name"], + "skill_type": "ranged", + "execution_priority": ActionPriority.NORMAL + }) + + elif skill_type == "aoe": + # AOE技能 - 瞄准多个怪物的中心 + if len(monsters) > 1: + # 计算多个怪物的中心点 + center_x = sum(m["center"][0] for m in monsters[:3]) / min(3, len(monsters)) + center_y = sum(m["center"][1] for m in monsters[:3]) / min(3, len(monsters)) + aoe_center = [center_x, center_y] + + actions.append({ + "type": "use_skill", + "key": skill_key, + "target_position": aoe_center, + "target_count": min(3, len(monsters)), + "skill_type": "aoe", + "execution_priority": ActionPriority.NORMAL + }) + else: + # 单个怪物时直接瞄准 + actions.append({ + "type": "use_skill", + "key": skill_key, + "target_position": target_center, + "target_id": target.get("id", "unknown"), + "target_type": target["class_name"], + "skill_type": "aoe", + "execution_priority": ActionPriority.NORMAL + }) + + else: # 默认为近战技能 + # 近战技能 - 需要先接近目标 + actions.append({ + "type": "move_to", + "position": target_center, + "target_id": target.get("id", "unknown"), + "target_type": target["class_name"], + "purpose": "approach_target", + "execution_priority": ActionPriority.HIGH + }) + + actions.append({ + "type": "use_skill", + "key": skill_key, + "target_position": target_center, + "target_id": target.get("id", "unknown"), + "target_type": target["class_name"], + "skill_type": "melee", + "execution_priority": ActionPriority.NORMAL + }) return actions -def generate_navigation_actions(detections, game_state): - """生成非战斗状态下的导航操作""" +def generate_exploration_actions(detections, game_state): + """ + 生成探索状态下的操作 - 优化版 + + 参数: + detections (list): 检测结果 + game_state (dict): 游戏状态 + + 返回: + list: 操作指令 + """ actions = [] # 检查是否有门、NPC等交互物体 @@ -114,18 +404,36 @@ def generate_navigation_actions(detections, game_state): items = [d for d in detections if d["class_name"] == "item"] # 优先拾取物品 - if items: - for item in items: - actions.append({ - "type": "move_to", - "position": item["center"], - "purpose": "pickup_item" - }) - actions.append({ - "type": "interact", - "key": "x", # 假设X键是拾取键 - "purpose": "pickup_item" - }) + if items and not game_state.get("inventory_full", False): + # 按照稀有度排序 + rarity_order = { + "legendary": 0, + "epic": 1, + "rare": 2, + "uncommon": 3, + "common": 4, + "unknown": 5 + } + + items.sort(key=lambda x: rarity_order.get(x.get("rarity", "unknown"), 999)) + + # 选择最高稀有度的物品 + item = items[0] + + actions.append({ + "type": "move_to", + "position": item["center"], + "purpose": "pickup_item", + "item_rarity": item.get("rarity", "unknown"), + "execution_priority": ActionPriority.NORMAL + }) + + actions.append({ + "type": "interact", + "key": "x", # 假设X键是拾取键 + "purpose": "pickup_item", + "execution_priority": ActionPriority.NORMAL + }) # 与NPC交互 elif npcs and game_state.get("should_talk_to_npc", False): @@ -133,12 +441,16 @@ def generate_navigation_actions(detections, game_state): actions.append({ "type": "move_to", "position": npc["center"], - "purpose": "talk_to_npc" + "purpose": "talk_to_npc", + "npc_id": npc.get("id", "unknown"), + "execution_priority": ActionPriority.NORMAL }) + actions.append({ "type": "interact", "key": "f", # 假设F键是交互键 - "purpose": "talk_to_npc" + "purpose": "talk_to_npc", + "execution_priority": ActionPriority.NORMAL }) # 前往下一个门 @@ -147,12 +459,16 @@ def generate_navigation_actions(detections, game_state): actions.append({ "type": "move_to", "position": door["center"], - "purpose": "go_to_next_room" + "purpose": "go_to_next_room", + "door_id": door.get("id", "unknown"), + "execution_priority": ActionPriority.NORMAL }) + actions.append({ "type": "interact", "key": "f", # 假设F键是交互键 - "purpose": "go_to_next_room" + "purpose": "go_to_next_room", + "execution_priority": ActionPriority.NORMAL }) # 探索地图 @@ -164,86 +480,463 @@ def generate_navigation_actions(detections, game_state): actions.append({ "type": "move_to", "position": exit_pos, - "purpose": "explore_map" + "purpose": "explore_map", + "execution_priority": ActionPriority.NORMAL }) else: - # 随机探索 + # 智能探索 - 朝未探索区域移动 + explored_areas = game_state.get("explored_areas", []) + + if explored_areas: + # 寻找未探索的方向 + directions = ["right", "up", "left", "down"] + unexplored = [d for d in directions if d not in explored_areas] + + if unexplored: + direction = random.choice(unexplored) + else: + direction = random.choice(directions) + else: + # 首次探索,默认向右 + direction = "right" + actions.append({ "type": "move_random", - "direction": random.choice(["right", "left", "up", "down"]), - "duration": random.uniform(0.5, 1.5), - "purpose": "explore_unknown_map" + "direction": direction, + "duration": random.uniform(1.0, 2.0), + "purpose": "explore_unknown_map", + "execution_priority": ActionPriority.NORMAL }) + + # 记录已探索方向 + explored_areas = game_state.get("explored_areas", []) + if direction not in explored_areas: + explored_areas.append(direction) + game_state["explored_areas"] = explored_areas return actions -def generate_additional_actions(detections, game_state): - """生成额外的操作,如使用药水、处理特殊情况等""" +def generate_pickup_actions(detections, game_state): + """生成拾取物品的操作""" actions = [] - # 检查血量是否过低 - hp_bars = [d for d in detections if d["class_name"] == "hp_bar"] - if hp_bars: - hp_bar = hp_bars[0] - # 假设我们可以从检测结果中获取血量百分比 - hp_percent = estimate_bar_percent(hp_bar["bbox"]) + items = [d for d in detections if d["class_name"] == "item"] + if items: + # 按照稀有度排序 + rarity_order = { + "legendary": 0, + "epic": 1, + "rare": 2, + "uncommon": 3, + "common": 4, + "unknown": 5 + } - if hp_percent < 30: # 血量低于30% + items.sort(key=lambda x: rarity_order.get(x.get("rarity", "unknown"), 999)) + + # 拾取最高稀有度的物品 + item = items[0] + actions.append({ + "type": "move_to", + "position": item["center"], + "purpose": "pickup_item", + "item_rarity": item.get("rarity", "unknown"), + "execution_priority": ActionPriority.HIGH + }) + + # 连续点击以确保拾取 + for i in range(2): actions.append({ - "type": "use_item", - "key": "f1", # 假设F1是使用HP药水的快捷键 - "purpose": "use_hp_potion" + "type": "interact", + "key": "x", # 假设X键是拾取键 + "purpose": "pickup_item", + "execution_priority": ActionPriority.HIGH }) - # 检查是否需要使用技能 - # (其他额外操作逻辑...) - return actions -def choose_skill(game_state): - """选择要使用的技能""" - # 可用的技能键 +def generate_talking_actions(detections, game_state): + """生成对话状态下的操作""" + actions = [] + + dialog_boxes = [d for d in detections if d["class_name"] == "dialog"] + if dialog_boxes: + # 检查对话框中是否有选项 + dialog_options = [d for d in detections if d["class_name"] == "dialog_option"] + + if dialog_options: + # 选择第一个选项(通常是继续任务) + option = dialog_options[0] + actions.append({ + "type": "click", + "position": option["center"], + "purpose": "select_dialog_option", + "execution_priority": ActionPriority.NORMAL + }) + else: + # 点击空白区域继续对话 + actions.append({ + "type": "click", + "position": [dialog_boxes[0]["center"][0], dialog_boxes[0]["center"][1] + 50], + "purpose": "continue_dialog", + "execution_priority": ActionPriority.NORMAL + }) + + # 或者按空格继续 + actions.append({ + "type": "interact", + "key": "space", + "purpose": "continue_dialog", + "execution_priority": ActionPriority.NORMAL + }) + + return actions + +def generate_lobby_actions(detections, game_state): + """生成大厅状态下的操作""" + actions = [] + + # 检查是否有任务NPC或地下城入口 + npcs = [d for d in detections if d["class_name"] == "npc"] + portals = [d for d in detections if d["class_name"] == "dungeon_portal"] + + if game_state.get("quest_active", False) and portals: + # 有任务且发现地下城入口,进入地下城 + portal = portals[0] + actions.append({ + "type": "move_to", + "position": portal["center"], + "purpose": "enter_dungeon", + "execution_priority": ActionPriority.NORMAL + }) + + actions.append({ + "type": "interact", + "key": "f", + "purpose": "enter_dungeon", + "execution_priority": ActionPriority.NORMAL + }) + + elif npcs: + # 找任务NPC + npc = npcs[0] + actions.append({ + "type": "move_to", + "position": npc["center"], + "purpose": "talk_to_npc", + "execution_priority": ActionPriority.NORMAL + }) + + actions.append({ + "type": "interact", + "key": "f", + "purpose": "talk_to_npc", + "execution_priority": ActionPriority.NORMAL + }) + + else: + # 随机移动 + actions.append({ + "type": "move_random", + "direction": random.choice(["right", "left", "up", "down"]), + "duration": random.uniform(0.5, 1.5), + "purpose": "explore_lobby", + "execution_priority": ActionPriority.LOW + }) + + return actions + +def generate_dungeon_selection_actions(detections, game_state): + """生成地下城选择界面的操作""" + actions = [] + + # 查找合适的地下城选项 + dungeon_options = [d for d in detections if d["class_name"] == "dungeon_option"] + + if dungeon_options: + # 按照推荐级别排序 + recommended = [d for d in dungeon_options if d.get("is_recommended", False)] + + if recommended: + option = recommended[0] + else: + option = dungeon_options[0] + + # 点击选择地下城 + actions.append({ + "type": "click", + "position": option["center"], + "purpose": "select_dungeon", + "dungeon_name": option.get("text", "unknown"), + "execution_priority": ActionPriority.NORMAL + }) + + # 点击确认按钮 + confirm_buttons = [d for d in detections if d["class_name"] == "confirm_button"] + if confirm_buttons: + actions.append({ + "type": "click", + "position": confirm_buttons[0]["center"], + "purpose": "confirm_dungeon", + "execution_priority": ActionPriority.NORMAL + }) + + return actions + +def generate_room_selection_actions(detections, game_state): + """生成房间选择界面的操作""" + actions = [] + + # 查找房间选项 + room_options = [d for d in detections if d["class_name"] == "room_option"] + + if room_options: + # 选择第一个房间 + option = room_options[0] + + actions.append({ + "type": "click", + "position": option["center"], + "purpose": "select_room", + "execution_priority": ActionPriority.NORMAL + }) + + # 点击确认按钮 + confirm_buttons = [d for d in detections if d["class_name"] == "confirm_button"] + if confirm_buttons: + actions.append({ + "type": "click", + "position": confirm_buttons[0]["center"], + "purpose": "confirm_room", + "execution_priority": ActionPriority.NORMAL + }) + + return actions + +def generate_dead_actions(detections, game_state): + """生成死亡状态下的操作""" + actions = [] + + # 查找复活按钮 + revive_buttons = [d for d in detections if d["class_name"] == "revive_button"] + + if revive_buttons: + # 点击复活按钮 + actions.append({ + "type": "click", + "position": revive_buttons[0]["center"], + "purpose": "revive", + "execution_priority": ActionPriority.CRITICAL + }) + else: + # 尝试快捷键复活 + actions.append({ + "type": "press_key_combo", + "keys": ["alt", "r"], # 假设Alt+R是复活快捷键 + "purpose": "revive", + "execution_priority": ActionPriority.CRITICAL + }) + + return actions + +def generate_default_actions(detections, game_state): + """生成默认状态下的操作""" + actions = [] + + # 尝试按ESC退出可能的菜单 + actions.append({ + "type": "interact", + "key": "escape", + "purpose": "close_menu", + "execution_priority": ActionPriority.HIGH + }) + + # 随机移动以探索 + actions.append({ + "type": "move_random", + "direction": random.choice(["right", "left", "up", "down"]), + "duration": random.uniform(0.5, 1.0), + "purpose": "explore_unknown", + "execution_priority": ActionPriority.LOW + }) + + return actions + +def choose_skill(game_state, detections, monsters): + """ + 智能选择要使用的技能 + + 参数: + game_state (dict): 游戏状态 + detections (list): 检测结果 + monsters (list): 怪物列表 + + 返回: + tuple: (技能按键, 技能类型) + """ + # 获取可用技能 available_skills = list(DNF["skills"].keys()) - # 根据游戏状态筛选可用技能 + # 检查冷却中的技能 cooldowns = game_state.get("cooldowns", {}) available_skills = [s for s in available_skills if s not in cooldowns] if not available_skills: - # 如果所有技能都在冷却中,使用普通攻击 - return "1" + # 所有技能都在冷却中,使用普通攻击 + return "1", "melee" - # 随机选择一个技能(可以根据策略优化) - return random.choice(available_skills) + # 根据怪物情况选择最佳技能 + monster_count = len(monsters) + + # 技能类型 + skill_types = { + "1": "melee", # 普通攻击 + "2": "melee", # 近战技能 + "3": "ranged", # 远程技能 + "4": "aoe", # AOE技能 + "5": "buff", # 增益技能 + "6": "ultimate" # 终极技能 + } + + # 对BOSS使用终极技能 + if any(m["class_name"] == "boss" for m in monsters) and "6" in available_skills: + return "6", skill_types.get("6", "ultimate") + + # 多个怪物时使用AOE技能 + if monster_count >= 3 and "4" in available_skills: + return "4", skill_types.get("4", "aoe") + + # 中等距离使用远程技能 + target_distance = estimate_distance(game_state, monsters[0]) + if target_distance > 150 and "3" in available_skills: + return "3", skill_types.get("3", "ranged") + + # 近距离使用近战技能 + if target_distance < 100 and "2" in available_skills: + return "2", skill_types.get("2", "melee") + + # 默认使用普通攻击 + return "1", skill_types.get("1", "melee") -def estimate_bar_percent(bbox): - """估计血条/蓝条的百分比""" - # 根据边界框计算百分比 - width = bbox[2] - bbox[0] - # 假设血条是从左到右填充的 - # 这里的实现取决于具体的血条UI - return random.uniform(0, 100) # 临时返回随机值 +def estimate_distance(game_state, target): + """估算到目标的距离""" + # 简单估算 - 实际情况可能需要更复杂的计算 + player_pos = [ + game_state.get("player_x", target["center"][0]), + game_state.get("player_y", target["center"][1] - 100) # 假设玩家在目标上方100像素 + ] + + dx = target["center"][0] - player_pos[0] + dy = target["center"][1] - player_pos[1] + + return math.sqrt(dx*dx + dy*dy) def humanize_actions(actions): - """为动作添加人类行为特征""" - for action in actions: - # 添加随机延迟 - delay = random.normalvariate( - (BEHAVIOR["min_delay"] + BEHAVIOR["max_delay"]) / 2, - (BEHAVIOR["max_delay"] - BEHAVIOR["min_delay"]) / 6 - ) - action["delay"] = max(BEHAVIOR["min_delay"], min(BEHAVIOR["max_delay"], delay)) + """为动作添加人类行为特征 - 优化版""" + previous_delay = 0 + + for i, action in enumerate(actions): + # 添加基础随机延迟 + base_delay = BEHAVIOR.get("min_delay", 0.1) + max_delay = BEHAVIOR.get("max_delay", 0.8) + + # 不同动作类型有不同的延迟模式 + action_type = action.get("type", "") + + if action_type == "move_to": + # 移动前的思考时间略长 + delay_mean = (base_delay + max_delay) / 2 + delay_std = (max_delay - base_delay) / 4 + elif action_type in ["use_skill", "interact"]: + # 技能释放和交互通常更快 + delay_mean = base_delay * 1.5 + delay_std = base_delay / 2 + elif action_type == "click": + # 点击通常很快 + delay_mean = base_delay + delay_std = base_delay / 3 + else: + # 默认延迟 + delay_mean = (base_delay + max_delay) / 2 + delay_std = (max_delay - base_delay) / 5 + + # 生成延迟时间 + delay = random.normalvariate(delay_mean, delay_std) + delay = max(base_delay, min(max_delay, delay)) + + # 考虑动作连贯性 - 如果是连续相关的动作,延迟更短 + if i > 0 and are_actions_related(actions[i-1], action): + delay *= 0.7 + + # 优先级较高的动作延迟更短 + priority = action.get("execution_priority", ActionPriority.NORMAL) + delay *= priority + + # 添加到动作中 + action["delay"] = delay + previous_delay + previous_delay = 0 # 重置累积延迟 # 为移动和点击添加随机偏移 - if action["type"] in ["move_to", "use_skill"] and "position" in action: + if action["type"] in ["move_to", "click", "use_skill"] and "position" in action: x, y = action["position"] - max_offset_x = BEHAVIOR["click_variance"] * 10 - max_offset_y = BEHAVIOR["click_variance"] * 10 + + # 计算适当的偏移量 + if "accuracy" in action: + # 如果指定了精确度,使用它 + accuracy = action["accuracy"] + elif action["type"] == "use_skill": + # 技能瞄准通常更精确 + accuracy = 0.95 + else: + # 默认精确度 + accuracy = 0.85 + + max_offset_x = (1 - accuracy) * 20 + max_offset_y = (1 - accuracy) * 20 x_offset = random.normalvariate(0, max_offset_x) y_offset = random.normalvariate(0, max_offset_y) action["position"] = [x + x_offset, y + y_offset] - # 添加随机执行顺序标记 - action["execution_priority"] = random.uniform(0.8, 1.2) \ No newline at end of file + # 添加动作描述(用于日志和调试) + if "purpose" in action and "description" not in action: + action["description"] = f"{action['type']} - {action['purpose']}" + +def are_actions_related(action1, action2): + """判断两个动作是否相关联""" + # 如果动作类型相同,可能是相关的 + if action1["type"] == action2["type"]: + return True + + # 移动后接点击/技能/交互 + if action1["type"] == "move_to" and action2["type"] in ["click", "use_skill", "interact"]: + return True + + # 技能后接移动(调整位置) + if action1["type"] == "use_skill" and action2["type"] == "move_to": + return True + + # 目的相同 + if "purpose" in action1 and "purpose" in action2 and action1["purpose"] == action2["purpose"]: + return True + + return False + +def record_decision(state, detections, actions): + """记录决策历史""" + decision = { + "timestamp": time.time(), + "state": state, + "detection_count": len(detections), + "action_count": len(actions), + "action_types": [a["type"] for a in actions] + } + + decision_history.append(decision) + +def record_state_transition(from_state, to_state): + """记录状态转换""" + if from_state != to_state: + transition_key = f"{from_state}->{to_state}" + state_transitions[transition_key] = state_transitions.get(transition_key, 0) + 1 \ No newline at end of file diff --git a/server/image_processor.py b/server/image_processor.py index 32ea5bd..06c5702 100644 --- a/server/image_processor.py +++ b/server/image_processor.py @@ -3,10 +3,12 @@ """ 图像处理模块,负责对接收到的图像进行预处理和后处理 +优化版 - 增强图像处理性能和准确性 """ import logging import numpy as np +import cv2 from PIL import Image, ImageEnhance, ImageFilter logger = logging.getLogger("DNFAutoCloud") @@ -32,7 +34,10 @@ def process_image(yolo_model, image): # 后处理检测结果 processed_detections = postprocess_detections(detections, image.size) - return processed_detections + # 计算额外特征 + enhanced_detections = calculate_extra_features(processed_detections, image) + + return enhanced_detections except Exception as e: logger.error(f"图像处理出错: {e}") @@ -40,7 +45,7 @@ def process_image(yolo_model, image): def preprocess_image(image): """ - 图像预处理 + 图像预处理 - 优化版 参数: image (PIL.Image): 输入图像 @@ -49,17 +54,38 @@ def preprocess_image(image): PIL.Image: 预处理后的图像 """ try: - # 图像增强 - image = image.convert("RGB") # 确保图像是RGB模式 + # 确保图像是RGB模式 + image = image.convert("RGB") - # 调整亮度和对比度以便更好地检测 - enhancer = ImageEnhance.Contrast(image) - image = enhancer.enhance(1.2) # 增加对比度 + # 应用自适应直方图均衡化 (对于UI元素检测很有帮助) + # 将PIL图像转换为numpy数组/OpenCV格式 + img_cv = np.array(image) + img_cv = img_cv[:, :, ::-1].copy() # RGB -> BGR - # 调整大小(可选,如果需要的话) - # image = image.resize((640, 640), Image.LANCZOS) + # 转换为LAB色彩空间并对L通道应用CLAHE + lab = cv2.cvtColor(img_cv, cv2.COLOR_BGR2LAB) + l, a, b = cv2.split(lab) - return image + # 创建CLAHE对象 + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + cl = clahe.apply(l) + + # 合并处理后的通道 + limg = cv2.merge((cl, a, b)) + enhanced_cv = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) + + # 将OpenCV图像转换回PIL格式 + enhanced = Image.fromarray(cv2.cvtColor(enhanced_cv, cv2.COLOR_BGR2RGB)) + + # 适度锐化 + enhancer = ImageEnhance.Sharpness(enhanced) + enhanced = enhancer.enhance(1.3) + + # 稍微增加对比度 + enhancer = ImageEnhance.Contrast(enhanced) + enhanced = enhancer.enhance(1.2) + + return enhanced except Exception as e: logger.error(f"图像预处理出错: {e}") @@ -67,7 +93,7 @@ def preprocess_image(image): def postprocess_detections(detections, image_size): """ - 后处理检测结果 + 后处理检测结果 - 优化版 参数: detections (list): 检测结果 @@ -79,25 +105,45 @@ def postprocess_detections(detections, image_size): try: # 过滤或合并重叠的检测框 filtered_detections = [] + + # 按类别和置信度排序 + detections.sort(key=lambda x: (x['class_name'], -x['confidence'])) + + # 按类别分组 + class_groups = {} for det in detections: - # 例如,过滤掉边缘的检测结果 - bbox = det['bbox'] - if is_valid_detection(bbox, image_size): - # 添加相对位置信息 - det['center'] = [ - (bbox[0] + bbox[2]) / 2, # x中心 - (bbox[1] + bbox[3]) / 2 # y中心 - ] - det['size'] = [ - bbox[2] - bbox[0], # 宽度 - bbox[3] - bbox[1] # 高度 - ] - det['relative_position'] = [ - det['center'][0] / image_size[0], # 相对x位置 - det['center'][1] / image_size[1] # 相对y位置 - ] - - filtered_detections.append(det) + class_name = det['class_name'] + if class_name not in class_groups: + class_groups[class_name] = [] + class_groups[class_name].append(det) + + # 处理每个类别内的重叠框 + for class_name, dets in class_groups.items(): + # 使用非极大值抑制 + remaining = non_max_suppression(dets, iou_threshold=0.5) + + # 处理每个保留的检测 + for det in remaining: + # 例如,过滤掉边缘的检测结果 + bbox = det['bbox'] + if is_valid_detection(bbox, image_size): + # 添加相对位置信息 + det['center'] = [ + (bbox[0] + bbox[2]) / 2, # x中心 + (bbox[1] + bbox[3]) / 2 # y中心 + ] + det['size'] = [ + bbox[2] - bbox[0], # 宽度 + bbox[3] - bbox[1] # 高度 + ] + det['relative_position'] = [ + det['center'][0] / image_size[0], # 相对x位置 + det['center'][1] / image_size[1] # 相对y位置 + ] + # 添加一个唯一ID + det['id'] = f"{class_name}_{len(filtered_detections)}" + + filtered_detections.append(det) return filtered_detections @@ -105,6 +151,71 @@ def postprocess_detections(detections, image_size): logger.error(f"检测结果后处理出错: {e}") return detections # 返回原始检测结果 +def non_max_suppression(detections, iou_threshold=0.5): + """ + 非极大值抑制实现 + + 参数: + detections (list): 检测列表 + iou_threshold (float): IOU阈值 + + 返回: + list: 抑制后的检测列表 + """ + # 如果列表为空,直接返回 + if not detections: + return [] + + # 获取所有边界框和置信度 + boxes = [d['bbox'] for d in detections] + scores = [d['confidence'] for d in detections] + + # 初始化保留的框索引列表 + keep = [] + + # 按置信度降序排序 + idxs = np.argsort(scores)[::-1] + + while len(idxs) > 0: + # 取最高置信度的框 + current = idxs[0] + keep.append(current) + + # 计算当前框与其他框的IoU + ious = [] + current_box = boxes[current] + + for i in idxs[1:]: + iou = calculate_iou(current_box, boxes[i]) + ious.append(iou) + + # 保留IoU小于阈值的框 + idxs = idxs[1:][np.array(ious) < iou_threshold] + + # 返回保留的检测 + return [detections[i] for i in keep] + +def calculate_iou(box1, box2): + """计算两个边界框的IoU""" + # 交集 + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + if x2 < x1 or y2 < y1: + return 0.0 + + intersection = (x2 - x1) * (y2 - y1) + + # 两个边界框的面积 + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + # IoU + iou = intersection / float(box1_area + box2_area - intersection) + return iou + def is_valid_detection(bbox, image_size): """ 检查检测框是否有效 @@ -119,11 +230,175 @@ def is_valid_detection(bbox, image_size): # 检查框是否太小 width = bbox[2] - bbox[0] height = bbox[3] - bbox[1] - if width < 10 or height < 10: + if width < 5 or height < 5: return False - # 检查框是否部分超出图像 - if bbox[0] < 0 or bbox[1] < 0 or bbox[2] > image_size[0] or bbox[3] > image_size[1]: + # 检查框是否大部分超出图像 + if bbox[0] < -width/2 or bbox[1] < -height/2 or bbox[2] > image_size[0] + width/2 or bbox[3] > image_size[1] + height/2: return False - return True \ No newline at end of file + return True + +def calculate_extra_features(detections, image): + """ + 计算额外特征,例如血条/蓝条百分比 + + 参数: + detections (list): 检测结果 + image (PIL.Image): 原始图像 + + 返回: + list: 增强的检测结果 + """ + try: + # 转换图像为numpy数组 + img_array = np.array(image) + + # 增强检测结果的额外信息 + for det in detections: + # 处理血条 + if det["class_name"] == "hp_bar": + det["percent"] = estimate_bar_percent(img_array, det["bbox"], "hp") + + # 处理蓝条 + elif det["class_name"] == "mp_bar": + det["percent"] = estimate_bar_percent(img_array, det["bbox"], "mp") + + # 处理技能冷却 + elif det["class_name"] == "cooldown": + # 如果有文本,尝试提取技能ID和剩余时间 + if "text" in det: + # 简单示例,真实情况可能需要OCR + det["skill_id"] = det.get("text", "unknown").split("_")[0] + det["remaining_time"] = 1.0 # 假设值 + + # 处理物品 + elif det["class_name"] == "item": + # 尝试识别物品稀有度(通过颜色) + det["rarity"] = estimate_item_rarity(img_array, det["bbox"]) + + return detections + + except Exception as e: + logger.error(f"计算额外特征时出错: {e}") + return detections + +def estimate_bar_percent(img_array, bbox, bar_type): + """ + 估计血条/蓝条的百分比 - 优化版 + + 参数: + img_array (numpy.ndarray): 图像数组 + bbox (list): 边界框 [x1, y1, x2, y2] + bar_type (str): 条形类型 ("hp" 或 "mp") + + 返回: + float: 百分比值(0-100) + """ + try: + # 确保边界框在图像范围内 + x1, y1, x2, y2 = [int(v) for v in bbox] + height, width = img_array.shape[:2] + + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(width - 1, x2) + y2 = min(height - 1, y2) + + if x1 >= x2 or y1 >= y2: + return 50.0 # 默认值 + + # 提取条形区域 + bar_region = img_array[y1:y2, x1:x2] + + # 根据条形类型设置颜色阈值 + if bar_type == "hp": + # 红色血条,查找红色像素 + lower_threshold = np.array([150, 0, 0]) + upper_threshold = np.array([255, 70, 70]) + elif bar_type == "mp": + # 蓝色蓝条,查找蓝色像素 + lower_threshold = np.array([0, 0, 150]) + upper_threshold = np.array([70, 70, 255]) + else: + return 50.0 # 默认值 + + # 创建掩码 + mask = np.all((bar_region >= lower_threshold) & (bar_region <= upper_threshold), axis=2) + + # 计算比例(假设条形从左到右填充) + bar_width = x2 - x1 + + # 找出最右边的填充像素 + filled_width = 0 + for col in range(bar_width): + if np.any(mask[:, col]): + filled_width = col + 1 + + # 计算百分比 + percent = (filled_width / bar_width) * 100 + + return max(0, min(100, percent)) # 确保在0-100范围内 + + except Exception as e: + logger.error(f"估计条形百分比时出错: {e}") + return 50.0 # 出错时返回默认值 + +def estimate_item_rarity(img_array, bbox): + """ + 估计物品稀有度(通过颜色) + + 参数: + img_array (numpy.ndarray): 图像数组 + bbox (list): 边界框 [x1, y1, x2, y2] + + 返回: + str: 稀有度描述 + """ + try: + # 提取物品区域 + x1, y1, x2, y2 = [int(v) for v in bbox] + height, width = img_array.shape[:2] + + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(width - 1, x2) + y2 = min(height - 1, y2) + + # 提取边框区域(只取几个像素宽的边框) + border_size = max(2, min(5, int((x2 - x1) * 0.1))) + + top_border = img_array[y1:y1+border_size, x1:x2] + bottom_border = img_array[y2-border_size:y2, x1:x2] + left_border = img_array[y1:y2, x1:x1+border_size] + right_border = img_array[y1:y2, x2-border_size:x2] + + # 合并所有边框像素 + borders = np.vstack([ + top_border.reshape(-1, 3), + bottom_border.reshape(-1, 3), + left_border.reshape(-1, 3), + right_border.reshape(-1, 3) + ]) + + # 计算平均颜色 + avg_color = np.mean(borders, axis=0) + + # 根据颜色判断稀有度 + r, g, b = avg_color + + # 简单启发式规则 + if r > 200 and g < 100 and b < 100: + return "legendary" # 红色:传说 + elif r > 200 and g > 150 and b < 100: + return "epic" # 橙色:史诗 + elif r < 100 and g < 100 and b > 180: + return "rare" # 蓝色:稀有 + elif r < 100 and g > 180 and b < 100: + return "uncommon" # 绿色:优秀 + else: + return "common" # 白色/灰色:普通 + + except Exception as e: + logger.error(f"估计物品稀有度时出错: {e}") + return "unknown" \ No newline at end of file diff --git a/server/main.py b/server/main.py index a54a2c1..9bcc332 100644 --- a/server/main.py +++ b/server/main.py @@ -3,6 +3,7 @@ """ 服务器主模块,负责初始化和启动WebSocket服务 +优化版 - 增加更好的资源管理和错误恢复 """ import logging @@ -10,64 +11,241 @@ import asyncio import ssl import sys import os +import signal +import gc +import json +import time +import yaml +from pathlib import Path # 添加项目根目录到系统路径 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from server.websocket_server import WebsocketServer -from config.settings import SECURITY, SERVER +from config.settings import SECURITY, SERVER, MODEL from models.yolo_model import YOLOModel from utils.security import generate_ssl_cert +from utils.logging_utils import setup_logging logger = logging.getLogger("DNFAutoCloud") -def start_server(host="0.0.0.0", port=8080, debug=False): - """启动WebSocket服务器""" - try: - # 初始化YOLO模型 - logger.info("正在加载YOLO模型...") - yolo_model = YOLOModel() - - # 创建SSL上下文(如果启用) - ssl_context = None - if SECURITY["ssl_enabled"]: - if not os.path.exists(SECURITY["ssl_cert"]) or not os.path.exists(SECURITY["ssl_key"]): - logger.info("未找到SSL证书,正在生成自签名证书...") +# 全局变量,用于优雅关闭 +websocket_server = None +shutdown_event = asyncio.Event() + +def setup_signals(): + """设置信号处理器""" + def signal_handler(sig, frame): + logger.info(f"收到信号 {sig},准备关闭服务器...") + shutdown_event.set() + + # 注册SIGINT和SIGTERM信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + logger.info("信号处理器已设置") + +async def cleanup(): + """清理资源""" + logger.info("开始清理资源...") + + # 关闭WebSocket服务器 + global websocket_server + if websocket_server: + await websocket_server.stop() + + # 强制执行垃圾回收 + gc.collect() + + logger.info("资源清理完成") + +async def startup_checks(): + """启动前检查""" + logger.info("执行启动前检查...") + + # 检查模型权重 + if not os.path.exists(MODEL.get("weights", "")): + logger.error(f"模型权重文件不存在: {MODEL.get('weights', '')}") + return False + + # 检查GPU可用性 + if MODEL.get("device", "").startswith("cuda"): + try: + import torch + if not torch.cuda.is_available(): + logger.warning("已配置使用CUDA,但无法找到可用的GPU。将使用CPU模式。") + MODEL["device"] = "cpu" + else: + logger.info(f"已检测到GPU: {torch.cuda.get_device_name(0)}") + # 打印CUDA版本等详细信息 + logger.info(f"CUDA版本: {torch.version.cuda}") + logger.info(f"可用GPU数量: {torch.cuda.device_count()}") + except ImportError: + logger.warning("无法导入torch模块,将使用CPU模式。") + MODEL["device"] = "cpu" + + # 检查SSL证书 + if SECURITY.get("ssl_enabled", False): + if not os.path.exists(SECURITY.get("ssl_cert", "")) or not os.path.exists(SECURITY.get("ssl_key", "")): + logger.info("未找到SSL证书,正在生成自签名证书...") + try: generate_ssl_cert() - - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain(SECURITY["ssl_cert"], SECURITY["ssl_key"]) - logger.info("SSL证书加载成功") + except Exception as e: + logger.error(f"生成SSL证书失败: {e}") + return False + + # 检查数据目录 + data_dirs = ["data/logs", "data/cache", "data/sessions"] + for dir_path in data_dirs: + try: + os.makedirs(dir_path, exist_ok=True) + except Exception as e: + logger.error(f"创建目录失败 {dir_path}: {e}") + return False + + logger.info("启动前检查完成") + return True + +async def load_model(): + """加载YOLO模型""" + logger.info("正在加载YOLO模型...") + try: + # 使用上下文管理器捕获并正确处理CUDA内存错误 + yolo_model = YOLOModel() + logger.info("YOLO模型加载成功") - # 创建事件循环 + # 打印模型性能基准 + performance = yolo_model.get_performance_stats() + logger.info(f"模型性能基准: 平均推理时间 {performance['average_time']:.4f}秒 ({performance['fps']:.2f} FPS)") + + return yolo_model + except Exception as e: + logger.error(f"加载YOLO模型失败: {e}") + return None + +async def save_server_info(host, port, ssl_enabled): + """保存服务器信息到文件,便于客户端连接""" + info = { + "server": { + "host": host, + "port": port, + "url": f"{'wss' if ssl_enabled else 'ws'}://{host}:{port}/ws", + "ssl_enabled": ssl_enabled + }, + "started_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "version": "1.0.1" + } + + # 保存为JSON + try: + with open("server_info.json", "w") as f: + json.dump(info, f, indent=2) + logger.info(f"服务器信息已保存到 server_info.json") + except Exception as e: + logger.error(f"保存服务器信息失败: {e}") + +async def run_server(host, port, debug): + """运行服务器""" + global websocket_server + + # 执行启动前检查 + if not await startup_checks(): + logger.error("启动前检查失败,服务器未启动") + return False + + # 加载YOLO模型 + yolo_model = await load_model() + if not yolo_model: + logger.error("加载模型失败,服务器未启动") + return False + + # 创建SSL上下文(如果启用) + ssl_context = None + if SECURITY.get("ssl_enabled", False): + try: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(SECURITY.get("ssl_cert", ""), SECURITY.get("ssl_key", "")) + logger.info("SSL证书加载成功") + except Exception as e: + logger.error(f"加载SSL证书失败: {e}") + logger.warning("将以非SSL模式启动服务器") + + try: + # 创建WebSocket服务器 + websocket_server = WebsocketServer(yolo_model, host, port, ssl_context) + + # 启动服务器 + await websocket_server.start() + + # 保存服务器信息 + await save_server_info(host, port, SECURITY.get("ssl_enabled", False)) + + # 等待关闭信号 + await shutdown_event.wait() + + logger.info("收到关闭信号,正在停止服务器...") + + # 清理资源 + await cleanup() + + return True + + except Exception as e: + logger.error(f"运行服务器时出错: {e}") + return False + +async def start_server_async(host="0.0.0.0", port=8080, debug=False): + """异步启动服务器""" + # 设置信号处理器 + setup_signals() + + # 运行服务器 + success = await run_server(host, port, debug) + + return success + +def start_server(host="0.0.0.0", port=8080, debug=False): + """同步启动服务器""" + # 设置日志 + setup_logging(debug=debug) + + # 打印启动信息 + logger.info("=" * 50) + logger.info("正在启动DNF自动化云服务...") + logger.info(f"主机: {host}, 端口: {port}, 调试模式: {'开启' if debug else '关闭'}") + logger.info(f"服务器版本: 1.0.1") + logger.info("=" * 50) + + try: + # 获取事件循环 loop = asyncio.get_event_loop() if debug: loop.set_debug(True) logger.info("调试模式已启用") - # 创建WebSocket服务器 - server = WebsocketServer(yolo_model, host, port, ssl_context) + # 运行服务器 + success = loop.run_until_complete(start_server_async(host, port, debug)) - # 启动服务器 - logger.info(f"正在启动WebSocket服务器: {host}:{port}") - loop.run_until_complete(server.start()) + # 如果服务器启动成功,运行事件循环直到关闭 + if success: + loop.run_until_complete(shutdown_event.wait()) - # 运行服务器直到被中断 - logger.info("服务器启动成功,等待连接...") - loop.run_forever() + # 清理资源 + loop.run_until_complete(cleanup()) + + # 关闭事件循环 + loop.close() + + logger.info("服务器已关闭") + return success except KeyboardInterrupt: - logger.info("服务器收到终止信号,正在关闭...") + logger.info("接收到用户中断,正在关闭...") + return False except Exception as e: - logger.error(f"服务器启动失败: {e}") - raise - finally: - # 清理资源 - tasks = asyncio.all_tasks(loop=loop) - for task in tasks: - task.cancel() - - logger.info("正在关闭事件循环...") - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() - logger.info("服务器已关闭") \ No newline at end of file + logger.error(f"启动服务器时出错: {e}") + return False + +if __name__ == "__main__": + # 直接运行此脚本时的默认行为 + start_server(host="0.0.0.0", port=8080, debug=True) \ No newline at end of file diff --git a/server/websocket_server.py b/server/websocket_server.py index c782034..439abe0 100644 --- a/server/websocket_server.py +++ b/server/websocket_server.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- """ -WebSocket服务器实现,处理客户端连接和通信 - Linux兼容版 +WebSocket服务器实现,处理客户端连接和通信 - 优化版本 +增强稳定性、连接管理和会话持久化 """ import asyncio @@ -11,11 +12,14 @@ import logging import time import uuid import base64 +import os from datetime import datetime from io import BytesIO +import threading import websockets from PIL import Image +import numpy as np from config.settings import SERVER, SECURITY from server.image_processor import process_image @@ -26,7 +30,7 @@ from utils.human_behavior import generate_human_delay logger = logging.getLogger("DNFAutoCloud") class WebsocketServer: - """WebSocket服务器类 - Linux兼容版""" + """WebSocket服务器类 - 优化版本""" def __init__(self, yolo_model, host="0.0.0.0", port=8080, ssl_context=None): """初始化WebSocket服务器""" @@ -35,38 +39,136 @@ class WebsocketServer: self.port = port self.ssl_context = ssl_context self.clients = {} # 客户端连接管理 + self.client_sessions = {} # 客户端会话数据 self.start_time = datetime.now() + self.shutdown_event = asyncio.Event() + self.stats = { + "total_connections": 0, + "active_connections": 0, + "images_processed": 0, + "actions_generated": 0, + "errors": 0 + } + + # 客户端监控线程 + self.monitoring_task = None + + # 保存最近处理的图像(用于调试) + self.debug_mode = os.environ.get("DNF_DEBUG", "0") == "1" + if self.debug_mode: + os.makedirs("debug/images", exist_ok=True) async def start(self): """启动WebSocket服务器""" - return await websockets.serve( + ws_server = await websockets.serve( self.handle_connection, self.host, self.port, ssl=self.ssl_context, max_size=10 * 1024 * 1024, # 10MB最大消息大小 - ping_interval=SERVER["heartbeat_interval"], - ping_timeout=SERVER["timeout"] + ping_interval=SERVER.get("heartbeat_interval", 5), + ping_timeout=SERVER.get("timeout", 60), + close_timeout=10 ) + + # 启动客户端监控 + self.monitoring_task = asyncio.create_task(self.monitor_clients()) + + logger.info(f"WebSocket服务器已启动: {self.host}:{self.port}") + return ws_server - async def handle_connection(self, websocket, path): + async def stop(self): + """停止WebSocket服务器""" + logger.info("正在停止WebSocket服务器...") + + # 设置关闭事件 + self.shutdown_event.set() + + # 关闭所有客户端连接 + close_tasks = [] + for client_id, info in list(self.clients.items()): + if "websocket" in info: + try: + close_tasks.append(info["websocket"].close(1001, "服务器关闭")) + except: + pass + + if close_tasks: + await asyncio.gather(*close_tasks, return_exceptions=True) + + # 停止监控任务 + if self.monitoring_task: + self.monitoring_task.cancel() + try: + await self.monitoring_task + except asyncio.CancelledError: + pass + + logger.info("WebSocket服务器已停止") + + async def monitor_clients(self): + """监控客户端连接状态""" + inactive_timeout = SERVER.get("inactive_timeout", 300) # 默认5分钟 + + while not self.shutdown_event.is_set(): + try: + current_time = time.time() + + # 检查不活跃的客户端 + for client_id, info in list(self.clients.items()): + if "last_activity" in info: + inactive_time = current_time - info["last_activity"] + + # 清理超时客户端 + if inactive_time > inactive_timeout: + logger.info(f"客户端 {client_id} 不活跃超过 {inactive_timeout} 秒,断开连接") + try: + if "websocket" in info: + await info["websocket"].close(1000, "不活跃超时") + except: + pass + + if client_id in self.clients: + del self.clients[client_id] + self.stats["active_connections"] -= 1 + + # 更新服务器状态 + self.stats["active_connections"] = len(self.clients) + + # 等待下一个检查周期 + await asyncio.sleep(60) # 每分钟检查一次 + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"监控客户端出错: {e}") + await asyncio.sleep(60) # 出错后等待一分钟再重试 + + async def handle_connection(self, websocket, path=None): """处理新的WebSocket连接""" client_id = str(uuid.uuid4()) + + # 创建客户端信息 client_info = { "id": client_id, "connected_at": datetime.now(), "remote": websocket.remote_address, - "last_activity": time.time() + "last_activity": time.time(), + "websocket": websocket, + "authenticated": False } # 限制最大连接数 - if len(self.clients) >= SERVER["max_connections"]: + if len(self.clients) >= SERVER.get("max_connections", 10): logger.warning(f"达到最大连接数限制,拒绝新连接: {client_info['remote']}") await websocket.close(1013, "服务器连接数已满") return # 添加到客户端列表 self.clients[client_id] = client_info + self.stats["total_connections"] += 1 + self.stats["active_connections"] += 1 + logger.info(f"新客户端连接: {client_id} 来自 {client_info['remote']}") try: @@ -83,10 +185,12 @@ class WebsocketServer: logger.info(f"客户端连接关闭: {client_id}, 代码: {e.code}, 原因: {e.reason}") except Exception as e: logger.error(f"处理客户端连接时出错: {client_id}, 错误: {e}") + self.stats["errors"] += 1 finally: # 清理客户端连接 if client_id in self.clients: del self.clients[client_id] + self.stats["active_connections"] -= 1 logger.info(f"客户端连接已关闭: {client_id}") async def _authenticate(self, websocket, client_id): @@ -96,13 +200,14 @@ class WebsocketServer: challenge = str(uuid.uuid4()) await websocket.send(json.dumps({ "type": "auth_challenge", - "challenge": challenge + "challenge": challenge, + "timestamp": time.time() })) # 等待认证响应 response_raw = await asyncio.wait_for( websocket.recv(), - timeout=SERVER["timeout"] + timeout=SERVER.get("timeout", 60) ) # 检查认证响应 @@ -112,23 +217,37 @@ class WebsocketServer: await websocket.close(1008, "认证格式错误") return False - # 这里可以实现更复杂的认证逻辑 - # 简单示例仅检查是否有响应 - if "response" in response: - # 认证成功 - self.clients[client_id]["authenticated"] = True - await websocket.send(json.dumps({ - "type": "auth_result", - "status": "success", - "client_id": client_id - })) - logger.info(f"客户端认证成功: {client_id}") - return True - else: - # 认证失败 - await websocket.close(1008, "认证失败") - logger.warning(f"客户端认证失败: {client_id}") - return False + # 获取客户端信息 + client_info = response.get("client_info", {}) + + # 保存会话信息 + self.client_sessions[client_id] = { + "client_info": client_info, + "game_state": {}, + "last_actions": [], + "statistics": { + "images_processed": 0, + "actions_sent": 0, + "errors": 0 + } + } + + # 认证成功 + self.clients[client_id]["authenticated"] = True + self.clients[client_id]["client_info"] = client_info + + await websocket.send(json.dumps({ + "type": "auth_result", + "status": "success", + "client_id": client_id, + "server_info": { + "version": "1.0.1", + "uptime": (datetime.now() - self.start_time).total_seconds() + } + })) + + logger.info(f"客户端认证成功: {client_id}, 版本: {client_info.get('version', 'unknown')}") + return True except asyncio.TimeoutError: logger.warning(f"客户端认证超时: {client_id}") @@ -137,11 +256,12 @@ class WebsocketServer: except Exception as e: logger.error(f"客户端认证出错: {client_id}, 错误: {e}") await websocket.close(1011, "认证处理错误") + self.stats["errors"] += 1 return False async def _handle_messages(self, websocket, client_id): """处理客户端消息""" - while True: + while not self.shutdown_event.is_set(): # 接收消息 message_raw = await websocket.recv() @@ -150,7 +270,7 @@ class WebsocketServer: try: # 解析消息(可能需要解密) - if SECURITY["encryption_enabled"]: + if SECURITY.get("encryption_enabled", False): message_data = decrypt_message(message_raw) else: message_data = json.loads(message_raw) @@ -163,7 +283,7 @@ class WebsocketServer: await self._handle_image_request(websocket, client_id, message_data) elif msg_type == "heartbeat": # 处理心跳消息 - await self._handle_heartbeat(websocket, client_id) + await self._handle_heartbeat(websocket, client_id, message_data) else: # 未知消息类型 logger.warning(f"收到未知消息类型: {msg_type} 来自客户端: {client_id}") @@ -180,6 +300,7 @@ class WebsocketServer: "error": "invalid_json", "message": "无效的JSON格式" })) + self.stats["errors"] += 1 except Exception as e: logger.error(f"处理消息时出错,来自客户端: {client_id}, 错误: {e}") await websocket.send(json.dumps({ @@ -187,44 +308,76 @@ class WebsocketServer: "error": "processing_error", "message": f"处理消息时出错: {str(e)}" })) + self.stats["errors"] += 1 async def _handle_image_request(self, websocket, client_id, message_data): """处理图像识别请求""" try: + start_time = time.time() + # 提取图像数据 image_base64 = message_data.get("data", "") image_bytes = base64.b64decode(image_base64) image = Image.open(BytesIO(image_bytes)) - # 记录请求 - logger.debug(f"收到图像识别请求,来自客户端: {client_id}, 图像尺寸: {image.size}") + # 提取游戏状态 + game_state = message_data.get("game_state", {}) + window_rect = message_data.get("window_rect", [0, 0, 0, 0]) + + # 更新会话中的游戏状态 + if client_id in self.client_sessions: + self.client_sessions[client_id]["game_state"] = game_state + + # 保存调试图像 + if self.debug_mode: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = f"debug/images/{client_id}_{timestamp}.jpg" + try: + image.save(debug_path) + except: + pass # 处理图像并进行检测 detections = process_image(self.yolo_model, image) # 根据检测结果生成动作 - game_state = message_data.get("game_state", {}) actions = generate_actions(detections, game_state) # 生成人类化延迟 delay = generate_human_delay() + # 追踪处理时间 + processing_time = time.time() - start_time + # 构建响应 response = { "type": "action_response", "request_id": message_data.get("request_id"), "timestamp": time.time(), "actions": actions, - "delay": delay + "detections": detections, # 返回检测结果供客户端分析 + "delay": delay, + "processing_time": processing_time } # 发送响应(可能需要加密) - if SECURITY["encryption_enabled"]: + if SECURITY.get("encryption_enabled", False): await websocket.send(encrypt_message(response)) else: await websocket.send(json.dumps(response)) - logger.debug(f"已发送动作响应,客户端: {client_id}, 动作数: {len(actions)}") + # 更新统计信息 + self.stats["images_processed"] += 1 + self.stats["actions_generated"] += len(actions) + + if client_id in self.client_sessions: + self.client_sessions[client_id]["statistics"]["images_processed"] += 1 + self.client_sessions[client_id]["statistics"]["actions_sent"] += len(actions) + self.client_sessions[client_id]["last_actions"] = actions + + # 记录处理信息 + if len(actions) > 0: + logger.debug(f"已发送 {len(actions)} 个动作给客户端: {client_id}, 处理时间: {processing_time:.3f}秒") except Exception as e: logger.error(f"处理图像请求时出错,客户端: {client_id}, 错误: {e}") @@ -234,20 +387,58 @@ class WebsocketServer: "request_id": message_data.get("request_id"), "message": f"处理图像时出错: {str(e)}" })) + self.stats["errors"] += 1 + if client_id in self.client_sessions: + self.client_sessions[client_id]["statistics"]["errors"] += 1 - async def _handle_heartbeat(self, websocket, client_id): + async def _handle_heartbeat(self, websocket, client_id, message_data): """处理心跳消息""" try: + # 提取客户端游戏状态 + if "game_state" in message_data and client_id in self.client_sessions: + self.client_sessions[client_id]["game_state"].update(message_data["game_state"]) + # 构建心跳响应 response = { "type": "heartbeat_response", "timestamp": time.time(), - "server_uptime": (datetime.now() - self.start_time).total_seconds() + "server_uptime": (datetime.now() - self.start_time).total_seconds(), + "server_stats": { + "connections": len(self.clients), + "images_processed": self.stats["images_processed"], + "actions_generated": self.stats["actions_generated"] + } } # 发送响应 await websocket.send(json.dumps(response)) - logger.debug(f"已发送心跳响应,客户端: {client_id}") except Exception as e: - logger.error(f"处理心跳消息时出错,客户端: {client_id}, 错误: {e}") \ No newline at end of file + logger.error(f"处理心跳消息时出错,客户端: {client_id}, 错误: {e}") + self.stats["errors"] += 1 + + def get_server_stats(self): + """获取服务器统计信息""" + return { + "start_time": self.start_time.isoformat(), + "uptime": (datetime.now() - self.start_time).total_seconds(), + "connections": { + "total": self.stats["total_connections"], + "active": self.stats["active_connections"] + }, + "performance": { + "images_processed": self.stats["images_processed"], + "actions_generated": self.stats["actions_generated"], + "errors": self.stats["errors"] + }, + "clients": [ + { + "id": client_id, + "connected_at": info["connected_at"].isoformat(), + "remote": info["remote"], + "authenticated": info["authenticated"], + "version": info.get("client_info", {}).get("version", "unknown") + } + for client_id, info in self.clients.items() + ] + } \ No newline at end of file diff --git a/server_info.json b/server_info.json new file mode 100644 index 0000000..e825a49 --- /dev/null +++ b/server_info.json @@ -0,0 +1,10 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8080, + "url": "ws://0.0.0.0:8080/ws", + "ssl_enabled": false + }, + "started_at": "2025-03-26 17:52:39", + "version": "1.0.1" +} \ No newline at end of file diff --git a/utils/human_behavior.py b/utils/human_behavior.py index 98b252e..b22cf9a 100644 --- a/utils/human_behavior.py +++ b/utils/human_behavior.py @@ -3,33 +3,161 @@ """ 人类行为模拟工具,提供模拟人类操作的功能 +优化版 - 更自然的行为模式和个性化特征 """ import random import time import math import numpy as np +from datetime import datetime from config.settings import BEHAVIOR +# 不同用户的行为模式 +USER_PROFILES = { + "fast": { + "delay_mean": 0.15, + "delay_std": 0.05, + "movement_smoothness": 0.7, + "click_accuracy": 0.9, + "double_click_chance": 0.05, + "description": "快速玩家 - 反应迅速,操作精准" + }, + "normal": { + "delay_mean": 0.25, + "delay_std": 0.1, + "movement_smoothness": 0.8, + "click_accuracy": 0.85, + "double_click_chance": 0.1, + "description": "普通玩家 - 中等速度,操作稳定" + }, + "casual": { + "delay_mean": 0.4, + "delay_std": 0.15, + "movement_smoothness": 0.6, + "click_accuracy": 0.7, + "double_click_chance": 0.15, + "description": "休闲玩家 - 操作较慢,精确度一般" + } +} + +# 当前用户配置文件 +current_profile = "normal" + +# 用户注意力随时间变化的模拟 +attention_level = 1.0 +last_attention_update = time.time() + +# 活动模式(早上、下午、晚上) +activity_mode = "normal" + +def set_user_profile(profile_name): + """ + 设置用户行为配置文件 + + 参数: + profile_name (str): 配置文件名称 ("fast", "normal", "casual") + """ + global current_profile + if profile_name in USER_PROFILES: + current_profile = profile_name + return True + return False + +def update_attention_level(): + """更新用户注意力水平,模拟人类随时间的疲劳""" + global attention_level, last_attention_update + + current_time = time.time() + elapsed = current_time - last_attention_update + + # 随着时间推移,注意力逐渐下降 + if elapsed > 60: # 每分钟更新一次 + # 注意力随时间缓慢下降(0.5%/分钟) + attention_decay = 0.005 * (elapsed / 60) + + # 随机波动(+/- 2%) + attention_fluctuation = random.uniform(-0.02, 0.02) + + # 更新注意力水平 + attention_level = max(0.7, min(1.0, attention_level - attention_decay + attention_fluctuation)) + + # 重置时间戳 + last_attention_update = current_time + + return attention_level + +def set_activity_mode(): + """根据一天中的时间设置活动模式""" + global activity_mode + + # 获取当前小时 + current_hour = datetime.now().hour + + # 根据时间段设置模式 + if 5 <= current_hour < 9: # 早晨 + activity_mode = "morning" + elif 9 <= current_hour < 17: # 工作时间 + activity_mode = "normal" + elif 17 <= current_hour < 22: # 晚上 + activity_mode = "evening" + else: # 深夜 + activity_mode = "night" + + return activity_mode + def generate_human_delay(): """ - 生成人类化的延迟时间 + 生成人类化的延迟时间 - 优化版 返回: float: 延迟时间(秒) """ + # 获取当前用户配置 + profile = USER_PROFILES[current_profile] + + # 更新注意力水平 + attention = update_attention_level() + + # 检查活动模式 + mode = set_activity_mode() + + # 基础延迟参数 + delay_mean = profile["delay_mean"] + delay_std = profile["delay_std"] + + # 根据注意力调整 + # 注意力下降,延迟增加 + delay_mean = delay_mean * (2 - attention) + + # 根据活动模式调整 + if mode == "morning": + # 早晨略慢 + delay_mean *= 1.1 + elif mode == "evening": + # 晚上正常 + pass + elif mode == "night": + # 深夜反应较慢 + delay_mean *= 1.2 + delay_std *= 1.2 + + # 偶尔的停顿(思考时间) + if random.random() < 0.05: # 5%几率 + thinking_time = random.uniform(0.5, 1.5) + return delay_mean + thinking_time + # 使用正态分布生成随机延迟 - mean_delay = (BEHAVIOR["min_delay"] + BEHAVIOR["max_delay"]) / 2 - std_dev = (BEHAVIOR["max_delay"] - BEHAVIOR["min_delay"]) / 6 # 使99.7%的值在范围内 + delay = random.normalvariate(delay_mean, delay_std) - delay = random.normalvariate(mean_delay, std_dev) - - # 确保在设定范围内 - return max(BEHAVIOR["min_delay"], min(BEHAVIOR["max_delay"], delay)) + # 确保在合理范围内 + min_delay = BEHAVIOR.get("min_delay", 0.1) + max_delay = BEHAVIOR.get("max_delay", 0.8) + return max(min_delay, min(max_delay, delay)) def generate_human_movement_path(start_pos, end_pos, steps=None): """ - 生成模拟人类的鼠标移动路径 + 生成模拟人类的鼠标移动路径 - 优化版 参数: start_pos (list): 起始位置 [x, y] @@ -39,41 +167,99 @@ def generate_human_movement_path(start_pos, end_pos, steps=None): 返回: list: 路径点列表 [[x1, y1], [x2, y2], ...] """ + # 获取当前用户配置 + profile = USER_PROFILES[current_profile] + smoothness = profile["movement_smoothness"] + # 计算距离 distance = math.sqrt((end_pos[0] - start_pos[0])**2 + (end_pos[1] - start_pos[1])**2) - # 如果未指定步数,根据距离计算 + # 如果距离太短,使用直线路径 + if distance < 10: + return [start_pos, end_pos] + + # 如果未指定步数,根据距离和速度计算 if steps is None: - steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点 + # 根据距离、平滑度和注意力计算步数 + attention = update_attention_level() + base_steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点 + + # 平滑度越高,步数越多;注意力越高,步数越多 + steps = int(base_steps * smoothness * attention) + steps = max(5, min(30, steps)) # 限制在合理范围内 # 生成基础路径 t = np.linspace(0, 1, steps) path = [] + # 添加曲率 - 模拟手腕运动 + # 贝塞尔曲线的控制点 + control_x = start_pos[0] + (end_pos[0] - start_pos[0]) / 2 + control_y = start_pos[1] + (end_pos[1] - start_pos[1]) / 2 + + # 添加随机偏移到控制点 + offset_factor = (1.0 - smoothness) * 0.5 # 平滑度越低,偏移越大 + max_offset = distance * offset_factor + control_x += random.normalvariate(0, max_offset / 2) + control_y += random.normalvariate(0, max_offset / 2) + for i in range(steps): - # 基础线性插值 - x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i] - y = start_pos[1] + (end_pos[1] - start_pos[1]) * t[i] + # 二次贝塞尔曲线 + t_i = t[i] + x = (1 - t_i)**2 * start_pos[0] + 2 * (1 - t_i) * t_i * control_x + t_i**2 * end_pos[0] + y = (1 - t_i)**2 * start_pos[1] + 2 * (1 - t_i) * t_i * control_y + t_i**2 * end_pos[1] - # 添加随机偏移(越靠近中间偏移越大) - mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大 - max_offset = distance * 0.05 * mid_factor # 最大偏移为距离的5% + # 添加细微的随机抖动(手部微小颤抖) + jitter_factor = (1.0 - smoothness) * 0.02 * distance + jitter_x = random.normalvariate(0, jitter_factor) + jitter_y = random.normalvariate(0, jitter_factor) - offset_x = random.normalvariate(0, max_offset / 3) - offset_y = random.normalvariate(0, max_offset / 3) + # 注意力越低,抖动越大 + attention = update_attention_level() + jitter_x *= (2 - attention) + jitter_y *= (2 - attention) # 添加到路径 - path.append([x + offset_x, y + offset_y]) + path.append([x + jitter_x, y + jitter_y]) # 确保起点和终点准确 path[0] = start_pos.copy() path[-1] = end_pos.copy() + # 模拟加速和减速 + # 路径采样 - 开始慢,中间快,结束慢 + if steps > 10: + resampled_path = [] + resampled_path.append(path[0]) # 起点 + + # 前20%缓慢加速 + accel_end = int(steps * 0.2) + for i in range(1, accel_end): + resampled_path.append(path[i]) + + # 中间60%快速移动(跳过一些点) + mid_start = accel_end + mid_end = int(steps * 0.8) + + # 根据平滑度决定中间部分的采样率 + skip_factor = int(2 + (1 - smoothness) * 3) + for i in range(mid_start, mid_end, skip_factor): + resampled_path.append(path[min(i, steps - 1)]) + + # 最后20%缓慢减速 + for i in range(mid_end, steps): + resampled_path.append(path[i]) + + if resampled_path[-1] != end_pos: + resampled_path.append(end_pos) # 确保终点 + + return resampled_path + return path def generate_human_click_offset(target_pos, target_size=(50, 50)): """ - 生成人类化的点击位置偏移 + 生成人类化的点击位置偏移 - 优化版 参数: target_pos (list): 目标中心位置 [x, y] @@ -82,13 +268,23 @@ def generate_human_click_offset(target_pos, target_size=(50, 50)): 返回: list: 带偏移的点击位置 [x, y] """ + # 获取当前用户配置 + profile = USER_PROFILES[current_profile] + accuracy = profile["click_accuracy"] + + # 更新注意力水平 + attention = update_attention_level() + + # 综合准确度和注意力 + effective_accuracy = accuracy * attention + # 计算最大偏移(不超过目标大小的一半) - max_offset_x = min(target_size[0] / 2 * 0.8, 10) # 最大偏移不超过10像素 - max_offset_y = min(target_size[1] / 2 * 0.8, 10) + max_offset_x = min(target_size[0] / 2 * (1 - effective_accuracy) * 1.5, 15) + max_offset_y = min(target_size[1] / 2 * (1 - effective_accuracy) * 1.5, 15) # 生成偏移(中心位置概率更高) - offset_x = random.normalvariate(0, max_offset_x / 3) - offset_y = random.normalvariate(0, max_offset_y / 3) + offset_x = random.normalvariate(0, max_offset_x / 2) + offset_y = random.normalvariate(0, max_offset_y / 2) # 应用偏移 click_pos = [ @@ -98,22 +294,137 @@ def generate_human_click_offset(target_pos, target_size=(50, 50)): return click_pos -def generate_typing_speed(base_cps=5.0, variance=0.2): +def should_double_click(): """ - 生成人类化的打字速度 + 确定是否应该双击 + + 返回: + bool: 是否应该双击 + """ + profile = USER_PROFILES[current_profile] + return random.random() < profile["double_click_chance"] + +def generate_typing_speed(text_length, base_cps=5.0): + """ + 生成人类化的打字速度序列 参数: + text_length (int): 文本长度 base_cps (float): 基础字符每秒速度 - variance (float): 速度方差 返回: - float: 字符间延迟时间(秒) + list: 每个字符的延迟时间列表 """ + # 获取当前用户配置和注意力 + profile = USER_PROFILES[current_profile] + attention = update_attention_level() + + # 基础打字速度 - 根据配置文件调整 + if current_profile == "fast": + base_cps = 8.0 + elif current_profile == "casual": + base_cps = 3.5 + + # 根据注意力调整速度 + base_cps = base_cps * attention + + # 生成每个字符的延迟 + delays = [] + current_speed = base_cps + + for i in range(text_length): + # 随机波动 + speed_variance = 0.2 * base_cps + current_speed = random.normalvariate(base_cps, speed_variance) + current_speed = max(base_cps * 0.5, current_speed) # 确保速度不会太慢 + + # 将速度转换为延迟 + delay = 1.0 / current_speed + + # 某些特殊情况 + if random.random() < 0.05: # 偶尔停顿 + delay += random.uniform(0.2, 0.7) + + if i > 0 and i % 10 == 0 and random.random() < 0.2: # 句子结束停顿 + delay += random.uniform(0.3, 0.8) + + delays.append(delay) + + return delays + +def generate_action_sequence(action_count): + """ + 生成一系列人类化的动作延迟 + + 参数: + action_count (int): 动作数量 + + 返回: + list: 延迟时间列表 + """ + delays = [] + + for i in range(action_count): + # 基础延迟 + delay = generate_human_delay() + + # 连续动作的相关性 + if i > 0: + # 连续动作通常更快 + delay *= 0.8 + + # 序列中的变化 + if i == 0: + # 第一个动作前可能有更长的准备时间 + delay *= 1.5 + elif i == action_count - 1: + # 最后一个动作后可能有更长的思考时间 + delay *= 1.2 + + delays.append(delay) + + return delays + +def simulate_fatigue(session_duration): + """ + 模拟随时间产生的疲劳效应 + + 参数: + session_duration (float): 会话持续时间(秒) + + 返回: + float: 疲劳系数 (1.0=正常, >1.0=疲劳) + """ + # 基础疲劳曲线 + hours = session_duration / 3600.0 + + # 前1小时基本无疲劳 + if hours < 1: + base_fatigue = 1.0 + # 1-2小时逐渐增加疲劳 + elif hours < 2: + base_fatigue = 1.0 + (hours - 1) * 0.1 + # 2-3小时疲劳加重 + elif hours < 3: + base_fatigue = 1.1 + (hours - 2) * 0.2 + # 3小时以上疲劳显著 + else: + base_fatigue = 1.3 + min(0.4, (hours - 3) * 0.1) + # 添加随机波动 - speed = random.normalvariate(base_cps, base_cps * variance) - speed = max(speed, base_cps * 0.5) # 确保速度不会太慢 + fatigue = base_fatigue * random.uniform(0.95, 1.05) - # 将速度转换为延迟 - delay = 1.0 / speed + return fatigue + +def get_behavior_profile(): + """ + 获取当前行为配置文件信息 - return delay \ No newline at end of file + 返回: + dict: 行为配置信息 + """ + profile = USER_PROFILES[current_profile].copy() + profile["attention"] = update_attention_level() + profile["activity_mode"] = set_activity_mode() + + return profile \ No newline at end of file