mirror of
https://github.com/yynps737/dnf-auto-cloud.git
synced 2026-05-07 02:35:46 +08:00
INIT
This commit is contained in:
393
client.py
393
client.py
@@ -3,6 +3,7 @@
|
||||
|
||||
"""
|
||||
DNF自动化客户端,负责截取游戏画面并执行服务器返回的操作
|
||||
优化版 - 增加断线重连、性能优化和增强的游戏状态管理
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -23,7 +24,7 @@ import logging
|
||||
import traceback
|
||||
|
||||
# 图像处理
|
||||
from PIL import ImageGrab, Image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Windows API
|
||||
@@ -36,8 +37,17 @@ import win32process
|
||||
import keyboard
|
||||
import mouse
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = "config.ini"
|
||||
# 尝试导入mss库,如果不存在则继续使用PIL
|
||||
try:
|
||||
import mss
|
||||
MSS_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSS_AVAILABLE = False
|
||||
print("警告: mss库未安装,将使用PIL进行截图(性能较低)")
|
||||
print("请使用 pip install mss 安装以获得更好的性能")
|
||||
|
||||
# 配置文件路径 - 使用绝对路径
|
||||
CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.ini")
|
||||
|
||||
# 日志设置
|
||||
logging.basicConfig(
|
||||
@@ -45,7 +55,7 @@ logging.basicConfig(
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=[
|
||||
logging.FileHandler("client.log", encoding="utf-8"),
|
||||
logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), "client.log"), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
@@ -62,7 +72,27 @@ class DNFAutoClient:
|
||||
self.running = False
|
||||
self.ws = None
|
||||
self.capture_interval = float(self.config.get("Capture", "interval"))
|
||||
self.game_state = {"in_battle": False}
|
||||
self.max_retries = int(self.config.get("Connection", "max_retries", fallback="5"))
|
||||
self.retry_delay = int(self.config.get("Connection", "retry_delay", fallback="5"))
|
||||
|
||||
# 增强的游戏状态
|
||||
self.game_state = {
|
||||
"in_battle": False,
|
||||
"current_map": "",
|
||||
"hp_percent": 100,
|
||||
"mp_percent": 100,
|
||||
"active_buffs": [],
|
||||
"cooldowns": {},
|
||||
"inventory_full": False,
|
||||
"current_quest": None,
|
||||
"last_operation_time": time.time(),
|
||||
"session_start_time": time.time()
|
||||
}
|
||||
|
||||
# 连接状态
|
||||
self.last_heartbeat_time = 0
|
||||
self.connection_attempts = 0
|
||||
self.reconnecting = False
|
||||
|
||||
def load_config(self):
|
||||
"""加载配置文件"""
|
||||
@@ -92,6 +122,12 @@ class DNFAutoClient:
|
||||
"key_mapping": "default"
|
||||
}
|
||||
|
||||
config["Connection"] = {
|
||||
"max_retries": "5",
|
||||
"retry_delay": "5",
|
||||
"heartbeat_interval": "5"
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
@@ -113,19 +149,60 @@ class DNFAutoClient:
|
||||
self.ws = await websockets.connect(
|
||||
self.server_url,
|
||||
ssl=ssl_context,
|
||||
max_size=10 * 1024 * 1024 # 10MB
|
||||
max_size=10 * 1024 * 1024, # 10MB
|
||||
ping_interval=None, # 禁用自动ping,我们将使用自己的心跳
|
||||
close_timeout=5
|
||||
)
|
||||
|
||||
# 等待认证
|
||||
await self.authenticate()
|
||||
|
||||
logger.info("已连接到服务器")
|
||||
self.connection_attempts = 0 # 重置连接尝试次数
|
||||
self.last_heartbeat_time = time.time()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接服务器失败: {e}")
|
||||
return False
|
||||
|
||||
async def connect_with_retry(self):
|
||||
"""带重试机制的连接函数"""
|
||||
if self.reconnecting:
|
||||
logger.info("已有重连过程在进行中,跳过")
|
||||
return False
|
||||
|
||||
self.reconnecting = True
|
||||
retry_count = 0
|
||||
|
||||
try:
|
||||
while retry_count < self.max_retries and not self.ws:
|
||||
try:
|
||||
logger.info(f"尝试连接服务器 (尝试 {retry_count + 1}/{self.max_retries})...")
|
||||
success = await self.connect()
|
||||
if success:
|
||||
self.reconnecting = False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"连接服务器失败: {e}")
|
||||
|
||||
retry_count += 1
|
||||
retry_delay = min(60, self.retry_delay * (2 ** retry_count)) # 指数退避策略
|
||||
logger.info(f"等待 {retry_delay} 秒后重试...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
if not self.ws:
|
||||
logger.error(f"达到最大重试次数 ({self.max_retries}),连接失败")
|
||||
self.reconnecting = False
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"重连过程中发生错误: {e}")
|
||||
self.reconnecting = False
|
||||
return False
|
||||
|
||||
self.reconnecting = False
|
||||
return True
|
||||
|
||||
async def authenticate(self):
|
||||
"""客户端认证"""
|
||||
try:
|
||||
@@ -136,14 +213,18 @@ class DNFAutoClient:
|
||||
if challenge.get("type") != "auth_challenge":
|
||||
raise ValueError("无效的认证挑战")
|
||||
|
||||
# 收集系统信息
|
||||
system_info = self.get_system_info()
|
||||
|
||||
# 发送认证响应
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "auth_response",
|
||||
"response": f"client_{random.getrandbits(32)}",
|
||||
"client_info": {
|
||||
"version": "1.0.0",
|
||||
"version": "1.0.1", # 更新版本号
|
||||
"os": "Windows",
|
||||
"screen_resolution": self.get_screen_resolution()
|
||||
"screen_resolution": self.get_screen_resolution(),
|
||||
"system_info": system_info
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -162,6 +243,24 @@ class DNFAutoClient:
|
||||
logger.error(f"认证失败: {e}")
|
||||
raise
|
||||
|
||||
def get_system_info(self):
|
||||
"""获取系统信息"""
|
||||
system_info = {}
|
||||
try:
|
||||
system_info["hostname"] = os.environ.get("COMPUTERNAME", "Unknown")
|
||||
system_info["username"] = os.environ.get("USERNAME", "Unknown")
|
||||
system_info["processor"] = os.environ.get("PROCESSOR_IDENTIFIER", "Unknown")
|
||||
|
||||
# 获取系统内存信息
|
||||
mem = ctypes.c_ulonglong()
|
||||
ctypes.windll.kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem))
|
||||
system_info["memory_gb"] = round(mem.value / (1024 * 1024), 2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统信息失败: {e}")
|
||||
|
||||
return system_info
|
||||
|
||||
def get_screen_resolution(self):
|
||||
"""获取屏幕分辨率"""
|
||||
user32 = ctypes.windll.user32
|
||||
@@ -188,13 +287,26 @@ class DNFAutoClient:
|
||||
rect = win32gui.GetWindowRect(hwnd)
|
||||
x, y, width, height = rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]
|
||||
|
||||
# 截取画面
|
||||
screen = ImageGrab.grab(bbox=(x, y, x + width, y + height))
|
||||
# 检查窗口是否最小化
|
||||
if width <= 0 or height <= 0:
|
||||
logger.warning("游戏窗口被最小化或大小无效")
|
||||
return None
|
||||
|
||||
# 使用mss进行截图(如果可用),速度比PIL快5-10倍
|
||||
if MSS_AVAILABLE:
|
||||
with mss.mss() as sct:
|
||||
monitor = {"top": y, "left": x, "width": width, "height": height}
|
||||
sct_img = sct.grab(monitor)
|
||||
# 将mss图像转换为PIL图像
|
||||
img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
|
||||
else:
|
||||
# 使用PIL进行截图(备选方案)
|
||||
img = ImageGrab.grab(bbox=(x, y, x + width, y + height))
|
||||
|
||||
# 压缩图像
|
||||
quality = int(self.config.get("Capture", "quality"))
|
||||
buffer = BytesIO()
|
||||
screen.save(buffer, format="JPEG", quality=quality)
|
||||
img.save(buffer, format="JPEG", quality=quality)
|
||||
|
||||
# 转换为Base64
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
@@ -207,18 +319,23 @@ class DNFAutoClient:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"截取游戏画面失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
async def send_heartbeat(self):
|
||||
"""发送心跳包"""
|
||||
if not self.ws or not self.running:
|
||||
return
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "heartbeat",
|
||||
"timestamp": time.time(),
|
||||
"client_id": self.client_id
|
||||
"client_id": self.client_id,
|
||||
"game_state": {
|
||||
"in_battle": self.game_state["in_battle"],
|
||||
"current_map": self.game_state["current_map"]
|
||||
}
|
||||
}))
|
||||
|
||||
# 等待心跳响应
|
||||
@@ -231,17 +348,50 @@ class DNFAutoClient:
|
||||
response = json.loads(response_raw)
|
||||
if response.get("type") != "heartbeat_response":
|
||||
logger.warning(f"收到非心跳响应: {response.get('type')}")
|
||||
return False
|
||||
|
||||
self.last_heartbeat_time = time.time()
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("心跳超时")
|
||||
return False
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("发送心跳时连接已关闭")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"发送心跳失败: {e}")
|
||||
return False
|
||||
|
||||
async def heartbeat_loop(self):
|
||||
"""心跳循环"""
|
||||
heartbeat_interval = float(self.config.get("Connection", "heartbeat_interval", fallback="5"))
|
||||
|
||||
while self.running:
|
||||
await self.send_heartbeat()
|
||||
await asyncio.sleep(5.0) # 每5秒发送一次心跳
|
||||
if self.ws and not self.ws.closed:
|
||||
success = await self.send_heartbeat()
|
||||
if not success:
|
||||
# 心跳失败,检查连接状态
|
||||
if time.time() - self.last_heartbeat_time > heartbeat_interval * 3:
|
||||
logger.warning(f"心跳超时超过 {heartbeat_interval * 3} 秒,尝试重新连接")
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
await asyncio.sleep(heartbeat_interval)
|
||||
|
||||
async def reconnect_loop(self):
|
||||
"""重连检查循环"""
|
||||
while self.running:
|
||||
if not self.ws or self.ws.closed:
|
||||
logger.warning("WebSocket连接已断开,尝试重连...")
|
||||
if await self.connect_with_retry():
|
||||
logger.info("重连成功")
|
||||
else:
|
||||
logger.error("重连失败")
|
||||
# 不要立即停止,继续尝试
|
||||
|
||||
await asyncio.sleep(5) # 每5秒检查一次连接状态
|
||||
|
||||
async def execute_action(self, action):
|
||||
"""执行操作"""
|
||||
@@ -259,8 +409,27 @@ class DNFAutoClient:
|
||||
|
||||
# 确保窗口处于前台
|
||||
if win32gui.GetForegroundWindow() != hwnd:
|
||||
win32gui.SetForegroundWindow(hwnd)
|
||||
await asyncio.sleep(0.1)
|
||||
try:
|
||||
# 尝试多种方法激活窗口
|
||||
win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) # 恢复窗口(如果最小化)
|
||||
win32gui.SetForegroundWindow(hwnd) # 尝试置为前台
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 如果窗口仍然不在前台,使用更强的方法
|
||||
if win32gui.GetForegroundWindow() != hwnd:
|
||||
# 获取当前激活窗口的线程和进程ID
|
||||
curr_hwnd = win32gui.GetForegroundWindow()
|
||||
curr_thread_id = win32process.GetWindowThreadProcessId(curr_hwnd)[0]
|
||||
# 获取目标窗口的线程和进程ID
|
||||
target_thread_id = win32process.GetWindowThreadProcessId(hwnd)[0]
|
||||
# 附加线程输入
|
||||
win32process.AttachThreadInput(target_thread_id, curr_thread_id, True)
|
||||
win32gui.SetForegroundWindow(hwnd)
|
||||
win32gui.BringWindowToTop(hwnd)
|
||||
win32process.AttachThreadInput(target_thread_id, curr_thread_id, False)
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"激活窗口失败: {e}")
|
||||
|
||||
# 获取窗口位置
|
||||
rect = win32gui.GetWindowRect(hwnd)
|
||||
@@ -281,6 +450,22 @@ class DNFAutoClient:
|
||||
win32api.SetCursorPos((int(point[0]), int(point[1])))
|
||||
await asyncio.sleep(0.01) # 10ms延迟
|
||||
|
||||
elif action_type == "click":
|
||||
# 点击指定位置
|
||||
position = action.get("position", [0, 0])
|
||||
x, y = position[0] + window_x, position[1] + window_y
|
||||
|
||||
# 移动到位置
|
||||
current_pos = win32gui.GetCursorPos()
|
||||
path = self.generate_movement_path(current_pos, [x, y])
|
||||
|
||||
for point in path:
|
||||
win32api.SetCursorPos((int(point[0]), int(point[1])))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# 执行点击
|
||||
mouse.click()
|
||||
|
||||
elif action_type == "use_skill":
|
||||
# 使用技能
|
||||
key = action.get("key", "1")
|
||||
@@ -333,12 +518,32 @@ class DNFAutoClient:
|
||||
for key in ["w", "a", "s", "d", "1", "2", "3", "4", "5", "6"]:
|
||||
if keyboard.is_pressed(key):
|
||||
keyboard.release(key)
|
||||
|
||||
elif action_type == "type_text":
|
||||
# 输入文本
|
||||
text = action.get("text", "")
|
||||
if text:
|
||||
keyboard.write(text, delay=0.05)
|
||||
|
||||
elif action_type == "press_key_combo":
|
||||
# 按下组合键
|
||||
keys = action.get("keys", [])
|
||||
if keys:
|
||||
for key in keys:
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(0.1)
|
||||
for key in reversed(keys):
|
||||
keyboard.release(key)
|
||||
|
||||
else:
|
||||
logger.warning(f"未知操作类型: {action_type}")
|
||||
|
||||
# 更新游戏状态
|
||||
self.game_state["last_operation_time"] = time.time()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行操作失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def generate_movement_path(self, start_pos, end_pos, steps=None):
|
||||
"""生成模拟人类的鼠标移动路径"""
|
||||
@@ -353,6 +558,9 @@ class DNFAutoClient:
|
||||
t = np.linspace(0, 1, steps)
|
||||
path = []
|
||||
|
||||
# 动态调整平滑度
|
||||
smoothness = float(self.config.get("Game", "movement_smoothness", fallback="0.8"))
|
||||
|
||||
for i in range(steps):
|
||||
# 基础线性插值
|
||||
x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i]
|
||||
@@ -360,7 +568,7 @@ class DNFAutoClient:
|
||||
|
||||
# 添加随机偏移(越靠近中间偏移越大)
|
||||
mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大
|
||||
max_offset = distance * 0.05 * mid_factor # 最大偏移为距离的5%
|
||||
max_offset = distance * 0.05 * mid_factor * (1 - smoothness) # 最大偏移为距离的5%,受平滑度影响
|
||||
|
||||
offset_x = random.normalvariate(0, max_offset / 3)
|
||||
offset_y = random.normalvariate(0, max_offset / 3)
|
||||
@@ -374,10 +582,68 @@ class DNFAutoClient:
|
||||
|
||||
return path
|
||||
|
||||
def analyze_image(self, image_data, detection_results):
|
||||
"""
|
||||
分析图像数据,更新游戏状态
|
||||
|
||||
参数:
|
||||
image_data (dict): 图像数据
|
||||
detection_results (list): 检测结果
|
||||
"""
|
||||
# 更新战斗状态
|
||||
monsters_detected = False
|
||||
for det in detection_results:
|
||||
if det["class_name"] in ["monster", "boss"]:
|
||||
monsters_detected = True
|
||||
break
|
||||
|
||||
self.game_state["in_battle"] = monsters_detected
|
||||
|
||||
# 分析血条和蓝条
|
||||
hp_bars = [d for d in detection_results if d["class_name"] == "hp_bar"]
|
||||
mp_bars = [d for d in detection_results if d["class_name"] == "mp_bar"]
|
||||
|
||||
if hp_bars:
|
||||
# 估算血量百分比
|
||||
self.game_state["hp_percent"] = self.estimate_bar_percent(hp_bars[0])
|
||||
|
||||
if mp_bars:
|
||||
# 估算蓝量百分比
|
||||
self.game_state["mp_percent"] = self.estimate_bar_percent(mp_bars[0])
|
||||
|
||||
# 检测技能冷却
|
||||
cooldowns = [d for d in detection_results if d["class_name"] == "cooldown"]
|
||||
self.game_state["cooldowns"] = {}
|
||||
|
||||
for cd in cooldowns:
|
||||
if "skill_id" in cd:
|
||||
self.game_state["cooldowns"][cd["skill_id"]] = cd.get("remaining_time", 1.0)
|
||||
|
||||
def estimate_bar_percent(self, bar_detection):
|
||||
"""
|
||||
估计血条/蓝条的百分比
|
||||
|
||||
参数:
|
||||
bar_detection (dict): 条形检测结果
|
||||
|
||||
返回:
|
||||
float: 百分比值(0-100)
|
||||
"""
|
||||
# 这里需要根据实际情况实现
|
||||
# 临时方案:返回检测结果中的值,如果没有则返回默认值
|
||||
return bar_detection.get("percent", 100)
|
||||
|
||||
async def capture_and_process_loop(self):
|
||||
"""截图和处理循环"""
|
||||
consecutive_errors = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 检查WebSocket连接
|
||||
if not self.ws or self.ws.closed:
|
||||
await asyncio.sleep(0.5) # 连接断开时等待
|
||||
continue
|
||||
|
||||
# 获取游戏窗口
|
||||
hwnd = self.get_game_window()
|
||||
if hwnd is None:
|
||||
@@ -413,6 +679,10 @@ class DNFAutoClient:
|
||||
response = json.loads(response_raw)
|
||||
|
||||
if response.get("type") == "action_response":
|
||||
# 获取检测结果并更新游戏状态
|
||||
if "detections" in response:
|
||||
self.analyze_image(screen_data, response["detections"])
|
||||
|
||||
# 执行动作
|
||||
actions = response.get("actions", [])
|
||||
|
||||
@@ -421,30 +691,43 @@ class DNFAutoClient:
|
||||
|
||||
for action in actions:
|
||||
await self.execute_action(action)
|
||||
|
||||
# 重置错误计数
|
||||
consecutive_errors = 0
|
||||
|
||||
elif response.get("type") == "error":
|
||||
logger.error(f"服务器返回错误: {response.get('message')}")
|
||||
consecutive_errors += 1
|
||||
|
||||
# 等待指定的间隔时间
|
||||
await asyncio.sleep(self.capture_interval)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("等待服务器响应超时")
|
||||
consecutive_errors += 1
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.error("WebSocket连接已关闭")
|
||||
self.running = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理循环出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
consecutive_errors += 1
|
||||
await asyncio.sleep(1.0) # 出错时等待1秒
|
||||
|
||||
# 如果连续错误过多,尝试重新连接
|
||||
if consecutive_errors >= 5:
|
||||
logger.warning(f"连续出错 {consecutive_errors} 次,尝试重新连接")
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
consecutive_errors = 0
|
||||
|
||||
async def run(self):
|
||||
"""运行客户端"""
|
||||
self.running = True
|
||||
|
||||
# 连接到服务器
|
||||
if not await self.connect():
|
||||
if not await self.connect_with_retry():
|
||||
self.running = False
|
||||
return
|
||||
|
||||
@@ -452,9 +735,10 @@ class DNFAutoClient:
|
||||
# 创建任务
|
||||
capture_task = asyncio.create_task(self.capture_and_process_loop())
|
||||
heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
||||
reconnect_task = asyncio.create_task(self.reconnect_loop())
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.gather(capture_task, heartbeat_task)
|
||||
await asyncio.gather(capture_task, heartbeat_task, reconnect_task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("客户端任务已取消")
|
||||
@@ -469,15 +753,80 @@ class DNFAutoClient:
|
||||
|
||||
def start(self):
|
||||
"""启动客户端"""
|
||||
asyncio.run(self.run())
|
||||
try:
|
||||
# 启动心跳监控线程(备用方案,以防异步心跳失效)
|
||||
self._monitor_thread = threading.Thread(target=self._monitor_connection)
|
||||
self._monitor_thread.daemon = True
|
||||
self._monitor_thread.start()
|
||||
|
||||
# 运行主循环
|
||||
asyncio.run(self.run())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断,正在退出...")
|
||||
except Exception as e:
|
||||
logger.error(f"客户端出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _monitor_connection(self):
|
||||
"""监控连接的后台线程"""
|
||||
while True:
|
||||
try:
|
||||
time.sleep(30) # 每30秒检查一次
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
# 检查心跳时间
|
||||
if self.last_heartbeat_time > 0 and time.time() - self.last_heartbeat_time > 60:
|
||||
logger.warning("心跳超时,可能需要重连")
|
||||
# 不直接重连,留给重连循环处理
|
||||
except Exception as e:
|
||||
logger.error(f"连接监控线程出错: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止客户端"""
|
||||
self.running = False
|
||||
logger.info("正在停止客户端...")
|
||||
|
||||
# 创建默认配置文件(如果不存在)
|
||||
def ensure_config():
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
config["Server"] = {
|
||||
"url": "wss://your-server-url:8080/ws",
|
||||
"verify_ssl": "false"
|
||||
}
|
||||
|
||||
config["Capture"] = {
|
||||
"interval": "0.5",
|
||||
"quality": "70"
|
||||
}
|
||||
|
||||
config["Game"] = {
|
||||
"window_title": "地下城与勇士",
|
||||
"key_mapping": "default",
|
||||
"movement_smoothness": "0.8"
|
||||
}
|
||||
|
||||
config["Connection"] = {
|
||||
"max_retries": "5",
|
||||
"retry_delay": "5",
|
||||
"heartbeat_interval": "5"
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
|
||||
print(f"已创建默认配置文件: {CONFIG_FILE}")
|
||||
print("请编辑配置文件设置正确的服务器地址等信息")
|
||||
|
||||
# 启动客户端
|
||||
if __name__ == "__main__":
|
||||
# 确保配置文件存在
|
||||
ensure_config()
|
||||
|
||||
try:
|
||||
client = DNFAutoClient()
|
||||
client.start()
|
||||
|
||||
19
config/cert.pem
Normal file
19
config/cert.pem
Normal file
@@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDCTCCAfGgAwIBAgIUPKAkPsxw1moIm//p7Cin44rkneQwDQYJKoZIhvcNAQEL
|
||||
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDMyNjE3MzM0MVoXDTI2MDMy
|
||||
NjE3MzM0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
|
||||
AAOCAQ8AMIIBCgKCAQEAya16YWOlIvnUHo7ubqTsQslYGPGbBiCNEGux/YNUagCr
|
||||
51RrNP6pwPXzRNIdDpudSz78wUVCtHEgn1shDIRJuGkENK4s6XpvGbg655DOEev4
|
||||
sKo4s/lDAjjuN450bth9Iv1CjuP/mLEOPQ9la+PwF0rcpGhh2qLGkMX399C2a3qf
|
||||
/UTOg7ifATWXM81VwPfYgpCV+jblAjh56n3rfN/lnIMMsAIxJRa23joLzmR4WIjI
|
||||
fJRMJFYYhccUBeALAPvy+qxOpe1xRz1SrCpAe8mFE5OTUB6kmNXvfu+xJZ/t+k7s
|
||||
XHcGccd4ZMwCBkHc18bmvtD/XoB6XAMSGvswznXMOwIDAQABo1MwUTAdBgNVHQ4E
|
||||
FgQUsU/YgtZL1WkbZHRQ2Xu5GJ7SlgwwHwYDVR0jBBgwFoAUsU/YgtZL1WkbZHRQ
|
||||
2Xu5GJ7SlgwwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAMf4f
|
||||
3kxcg2HkLnq9lbA3gFkkaXQ5OyFGPBXe/uQ+jk3Kh/2rfnMCci1RCEMOK9gD7Lh+
|
||||
fQXPHDtGTmLE5wm9M43NxhRfzfnckCDENmsEzKssvdinqq2zbhTzSXrPN7W7PYPQ
|
||||
wvYo3ZZIDTXmyZhTgph4su0zuglP/kkfmRrUSxaFnGWDKJE+W+KQ6r6uestGoa9w
|
||||
ScoLYcB2ssWEK2uU1/kto8PzRAJuOLiGMur7j5fIDfaM13V2+Qt3LpFd32cZnKFH
|
||||
x9W9/e+I5dbAfqqItIkbgjM0akMhq10t9MYouXbAw1OscXDmkomObC2Wq2QviPkY
|
||||
zP3UOCSWi9eAWvknLA==
|
||||
-----END CERTIFICATE-----
|
||||
28
config/key.pem
Normal file
28
config/key.pem
Normal file
@@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDJrXphY6Ui+dQe
|
||||
ju5upOxCyVgY8ZsGII0Qa7H9g1RqAKvnVGs0/qnA9fNE0h0Om51LPvzBRUK0cSCf
|
||||
WyEMhEm4aQQ0rizpem8ZuDrnkM4R6/iwqjiz+UMCOO43jnRu2H0i/UKO4/+YsQ49
|
||||
D2Vr4/AXStykaGHaosaQxff30LZrep/9RM6DuJ8BNZczzVXA99iCkJX6NuUCOHnq
|
||||
fet83+WcgwywAjElFrbeOgvOZHhYiMh8lEwkVhiFxxQF4AsA+/L6rE6l7XFHPVKs
|
||||
KkB7yYUTk5NQHqSY1e9+77Eln+36TuxcdwZxx3hkzAIGQdzXxua+0P9egHpcAxIa
|
||||
+zDOdcw7AgMBAAECggEACrGTIqTY9cDPeYtUozNFf8kTTcdJ1ApX0H4VYv7atAAz
|
||||
HUIBqT6zm5KvAoAtoD+qGHpPhqP4hH7XHvwDBZniGtAes/hkU0D1sSRuoyavdo3P
|
||||
kvaDqS9XWT/RicqY6+O4xuks5Uy7mcoRmjU9yHm+mk2S43jRb3lgE/8bRd2gPpSa
|
||||
14HqNwFqLmB4ZiYaq20KuO+cNTS/wPrxicVABVNwJ2cM+6vdoonkNYaPNPulWzeg
|
||||
3CgSlageaGPVByaCUK24/0+b7Jgd2GMd34FWTMils8TkeuodQtcTwHEWphAbgpVB
|
||||
5VlJRA6jYMRRfsZNlT82VvR90uZowLhJ84eGMwdqIQKBgQDsefnK7L0JtIWALljC
|
||||
BVA6ljbRxZcolVFo6jMLtw/yfChzdP8d14x9s3sxX5J+69kmAAAGeS8tjJeautW6
|
||||
kDtsUCdhN3X2GftvR2Y8wUNHW8kMoU9rCN09HaxXq9HK2/RZlNhiHA6FycP/Pe6t
|
||||
1Qlo1Hn5XniNtbr71NgVM/NX8QKBgQDaVAPxntGm2SBzpsDpXwAVHxWJzK4aveMf
|
||||
rq7I5JJ5b9PnMV9cNJncF5vOdWM5htb34orJtBZ8IzuZMXAEU5OJwfmaGMleaKMv
|
||||
mGRoDpWtAE+zj8BB0RHr828tCs3M3/klx0/FUDRPsYXRfbaJb5CCsQmv7QHV4FRp
|
||||
+bUQ9Fcy6wKBgQDJ1yrQe9S2bfDtAaIcqPBbsU9FKYPlzd1Y0V2UiEICVNsqARin
|
||||
3g06VXG3KL4fuyrzdliPLeyI0lGsbgBzZxxxTNDv96il0HN9/dFT1hmY1Mz8DMt+
|
||||
rmg3/BXYFv3QSoF73MH8q7nxk8/JEpGgqg+H/KPHp0z6l7zrqjZtkpQH4QKBgQDW
|
||||
jcHibITzRmURwknKDUXze7ya0r42IW1V8UBqw9T96duAU5C2+CpLlBfVaJ6+JbiT
|
||||
mdlyJrwB+k3TWjYOymMu+aTkvn8FfCcB2uyxJcQJY0jv2NDC3UaTbYNP7FIah/A8
|
||||
JAZMjWka+AXdvYDoxu5owLoYXP10xSOvkWlS5AvdSQKBgB0365BytqUd8/ijGuTa
|
||||
mSaL4WDFFdCkOhPNEmIn1llh0HAh2jWpyBRnJeGyQ/4vZJjA8gIgz4IZnZfiw+ov
|
||||
SDRnpqvguklboB+mPWOM4Itj/5G7XOAEN1kkA81SUDJBoCuLiZ7KwIaSlefLWi0l
|
||||
VwmN0w+3GIt4F7rg0dmDsVV4
|
||||
-----END PRIVATE KEY-----
|
||||
34
config/runtime_config.json
Normal file
34
config/runtime_config.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"version": "1.0.1",
|
||||
"build_date": "2025-03-27",
|
||||
"model": {
|
||||
"name": "yolov8m",
|
||||
"weights": "/workspace/dnf-auto-cloud/models/weights/dnf_yolo8m.pt",
|
||||
"conf_threshold": 0.5,
|
||||
"iou_threshold": 0.45,
|
||||
"device": "cuda",
|
||||
"half_precision": true,
|
||||
"batch_size": 1,
|
||||
"img_size": 640,
|
||||
"engine": "pytorch",
|
||||
"class_names": "/workspace/dnf-auto-cloud/data/training/data.yaml"
|
||||
},
|
||||
"server": {
|
||||
"max_connections": 10,
|
||||
"timeout": 60,
|
||||
"heartbeat_interval": 5,
|
||||
"inactive_timeout": 300,
|
||||
"max_retries": 3,
|
||||
"status_update_interval": 30,
|
||||
"api_enabled": true
|
||||
},
|
||||
"security": {
|
||||
"token_expiry": 3600,
|
||||
"encryption_enabled": false,
|
||||
"ssl_enabled": false,
|
||||
"ssl_cert": "/workspace/dnf-auto-cloud/config/cert.pem",
|
||||
"ip_whitelist": [],
|
||||
"rate_limit": 100,
|
||||
"authentication_required": false
|
||||
}
|
||||
}
|
||||
@@ -3,70 +3,193 @@
|
||||
|
||||
"""
|
||||
全局配置参数
|
||||
优化版 - 增加更多配置选项和兼容性检查
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# 检测当前环境
|
||||
IS_LINUX = sys.platform.startswith('linux')
|
||||
IS_WINDOWS = sys.platform.startswith('win')
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
# 基础路径
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
# 版本信息
|
||||
VERSION = "1.0.1"
|
||||
BUILD_DATE = "2025-03-27"
|
||||
|
||||
# 加载环境变量
|
||||
def load_env_var(name, default):
|
||||
"""从环境变量加载设置,如果不存在则使用默认值"""
|
||||
return os.environ.get(f"DNF_{name}", default)
|
||||
|
||||
# 模型配置
|
||||
MODEL = {
|
||||
"name": "yolov8m",
|
||||
"weights": os.path.join(BASE_DIR, "models", "weights", "dnf_yolo8m.pt"),
|
||||
"conf_threshold": 0.5,
|
||||
"iou_threshold": 0.45,
|
||||
"device": "cuda" # 使用GPU
|
||||
"name": load_env_var("MODEL_NAME", "yolov8m"),
|
||||
"weights": os.path.join(BASE_DIR, "models", "weights", load_env_var("MODEL_FILE", "dnf_yolo8m.pt")),
|
||||
"conf_threshold": float(load_env_var("CONF_THRESHOLD", "0.5")),
|
||||
"iou_threshold": float(load_env_var("IOU_THRESHOLD", "0.45")),
|
||||
"device": load_env_var("DEVICE", "cuda" if HAS_CUDA else "cpu"),
|
||||
"half_precision": load_env_var("HALF_PRECISION", "true").lower() == "true" and HAS_CUDA,
|
||||
"batch_size": int(load_env_var("BATCH_SIZE", "1")),
|
||||
"img_size": int(load_env_var("IMG_SIZE", "640")),
|
||||
"engine": load_env_var("ENGINE", "pytorch"), # pytorch 或 onnx
|
||||
"class_names": os.path.join(BASE_DIR, "data", "training", "data.yaml")
|
||||
}
|
||||
|
||||
# 服务器配置
|
||||
SERVER = {
|
||||
"max_connections": 10,
|
||||
"timeout": 60,
|
||||
"heartbeat_interval": 5
|
||||
"max_connections": int(load_env_var("MAX_CONNECTIONS", "10")),
|
||||
"timeout": int(load_env_var("TIMEOUT", "60")),
|
||||
"heartbeat_interval": int(load_env_var("HEARTBEAT_INTERVAL", "5")),
|
||||
"inactive_timeout": int(load_env_var("INACTIVE_TIMEOUT", "300")), # 用户不活跃超时(秒)
|
||||
"max_retries": int(load_env_var("MAX_RETRIES", "3")), # 操作最大重试次数
|
||||
"status_update_interval": int(load_env_var("STATUS_UPDATE_INTERVAL", "30")), # 状态更新间隔(秒)
|
||||
"api_enabled": load_env_var("API_ENABLED", "true").lower() == "true" # 是否启用API
|
||||
}
|
||||
|
||||
# 安全配置
|
||||
SECURITY = {
|
||||
"token_expiry": 3600, # 1小时
|
||||
"encryption_enabled": True,
|
||||
"ssl_enabled": True,
|
||||
"token_expiry": int(load_env_var("TOKEN_EXPIRY", "3600")), # 1小时
|
||||
"encryption_enabled": False, # 禁用加密以简化调试
|
||||
"ssl_enabled": False, # 禁用SSL
|
||||
"ssl_cert": os.path.join(BASE_DIR, "config", "cert.pem"),
|
||||
"ssl_key": os.path.join(BASE_DIR, "config", "key.pem")
|
||||
"ssl_key": os.path.join(BASE_DIR, "config", "key.pem"),
|
||||
"ip_whitelist": load_env_var("IP_WHITELIST", "").split(",") if load_env_var("IP_WHITELIST", "") else [],
|
||||
"rate_limit": int(load_env_var("RATE_LIMIT", "100")), # 每分钟最大请求数
|
||||
"authentication_required": False # 暂时禁用认证要求
|
||||
}
|
||||
|
||||
# 行为模拟配置
|
||||
BEHAVIOR = {
|
||||
"min_delay": 0.1,
|
||||
"max_delay": 0.8,
|
||||
"click_variance": 0.15, # 点击位置随机偏移比例
|
||||
"movement_smoothness": 0.8 # 移动平滑度 (0-1)
|
||||
"min_delay": float(load_env_var("MIN_DELAY", "0.1")),
|
||||
"max_delay": float(load_env_var("MAX_DELAY", "0.8")),
|
||||
"click_variance": float(load_env_var("CLICK_VARIANCE", "0.15")), # 点击位置随机偏移比例
|
||||
"movement_smoothness": float(load_env_var("MOVEMENT_SMOOTHNESS", "0.8")), # 移动平滑度 (0-1)
|
||||
"double_click_chance": float(load_env_var("DOUBLE_CLICK_CHANCE", "0.05")), # 双击概率
|
||||
"user_profile": load_env_var("USER_PROFILE", "normal"), # 用户行为配置 (fast, normal, casual)
|
||||
"randomize_behavior": load_env_var("RANDOMIZE_BEHAVIOR", "true").lower() == "true" # 是否随机化行为
|
||||
}
|
||||
|
||||
# 日志配置
|
||||
LOGGING = {
|
||||
"level": "INFO",
|
||||
"level": load_env_var("LOG_LEVEL", "INFO"),
|
||||
"file": os.path.join(BASE_DIR, "data", "logs", "server.log"),
|
||||
"max_size": 10 * 1024 * 1024, # 10MB
|
||||
"backup_count": 5
|
||||
"max_size": int(load_env_var("LOG_MAX_SIZE", "10")) * 1024 * 1024, # 10MB
|
||||
"backup_count": int(load_env_var("LOG_BACKUP_COUNT", "5")),
|
||||
"console_output": load_env_var("LOG_CONSOLE", "true").lower() == "true",
|
||||
"log_requests": load_env_var("LOG_REQUESTS", "true").lower() == "true",
|
||||
"log_performance": load_env_var("LOG_PERFORMANCE", "true").lower() == "true"
|
||||
}
|
||||
|
||||
# 缓存配置
|
||||
CACHE = {
|
||||
"enabled": load_env_var("CACHE_ENABLED", "true").lower() == "true",
|
||||
"directory": os.path.join(BASE_DIR, "data", "cache"),
|
||||
"max_size": int(load_env_var("CACHE_MAX_SIZE", "1000")), # 最大缓存项数
|
||||
"ttl": int(load_env_var("CACHE_TTL", "3600")) # 缓存有效期(秒)
|
||||
}
|
||||
|
||||
# 会话配置
|
||||
SESSION = {
|
||||
"directory": os.path.join(BASE_DIR, "data", "sessions"),
|
||||
"save_interval": int(load_env_var("SESSION_SAVE_INTERVAL", "300")), # 会话保存间隔(秒)
|
||||
"max_idle_time": int(load_env_var("SESSION_MAX_IDLE", "1800")), # 最大空闲时间(秒)
|
||||
"persistence_enabled": load_env_var("SESSION_PERSISTENCE", "true").lower() == "true"
|
||||
}
|
||||
|
||||
# DNF游戏特定配置
|
||||
DNF = {
|
||||
"classes": [
|
||||
"monster", "boss", "door", "item", "npc", "player",
|
||||
"hp_bar", "mp_bar", "skill_ready", "cooldown"
|
||||
"hp_bar", "mp_bar", "skill_ready", "cooldown", "dialog",
|
||||
"dialog_option", "revive_button", "confirm_button",
|
||||
"dungeon_portal", "dungeon_select", "dungeon_option",
|
||||
"room_select", "room_option", "death_ui"
|
||||
],
|
||||
"maps": {
|
||||
"赫顿玛尔": {"entrance": (123, 456), "exit": (789, 101)},
|
||||
"天空之城": {"entrance": (234, 567), "exit": (890, 121)},
|
||||
"格兰之森": {"entrance": (345, 678), "exit": (901, 234)},
|
||||
"诺斯玛尔": {"entrance": (456, 789), "exit": (345, 678)},
|
||||
# 更多地图配置...
|
||||
},
|
||||
"skills": {
|
||||
"1": "普通攻击",
|
||||
"2": "技能1",
|
||||
"3": "技能2",
|
||||
"1": {"name": "普通攻击", "type": "melee", "cooldown": 0.5},
|
||||
"2": {"name": "技能1", "type": "melee", "cooldown": 3.0},
|
||||
"3": {"name": "技能2", "type": "ranged", "cooldown": 5.0},
|
||||
"4": {"name": "技能3", "type": "aoe", "cooldown": 8.0},
|
||||
"5": {"name": "技能4", "type": "buff", "cooldown": 15.0},
|
||||
"6": {"name": "技能5", "type": "ultimate", "cooldown": 30.0},
|
||||
# 更多技能配置...
|
||||
},
|
||||
"difficulty_levels": {
|
||||
"normal": {"monster_hp": 100, "monster_damage": 100},
|
||||
"hard": {"monster_hp": 150, "monster_damage": 150},
|
||||
"expert": {"monster_hp": 200, "monster_damage": 200},
|
||||
"master": {"monster_hp": 250, "monster_damage": 250},
|
||||
"hell": {"monster_hp": 300, "monster_damage": 300}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 检查并创建必要的目录
|
||||
def ensure_directories():
|
||||
"""确保所有需要的目录都存在"""
|
||||
directories = [
|
||||
os.path.dirname(LOGGING["file"]), # 日志目录
|
||||
CACHE["directory"], # 缓存目录
|
||||
SESSION["directory"], # 会话目录
|
||||
os.path.join(BASE_DIR, "data", "training"), # 训练数据目录
|
||||
os.path.dirname(MODEL["weights"]), # 模型权重目录
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
# 配置验证
|
||||
def validate_config():
|
||||
"""验证配置有效性"""
|
||||
# 检查模型配置
|
||||
if not os.path.exists(MODEL["weights"]) and not os.path.exists(MODEL["weights"] + ".onnx"):
|
||||
print(f"警告: 模型权重文件不存在: {MODEL['weights']}")
|
||||
|
||||
# 检查CUDA设置
|
||||
if MODEL["device"] == "cuda" and not HAS_CUDA:
|
||||
print("警告: 已配置使用CUDA但系统中没有可用的GPU,将使用CPU模式")
|
||||
MODEL["device"] = "cpu"
|
||||
MODEL["half_precision"] = False
|
||||
|
||||
# 检查日志级别
|
||||
valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
if LOGGING["level"] not in valid_log_levels:
|
||||
print(f"警告: 无效的日志级别 '{LOGGING['level']}',将使用 'INFO'")
|
||||
LOGGING["level"] = "INFO"
|
||||
|
||||
# 保存验证后的配置
|
||||
save_validated_config()
|
||||
|
||||
def save_validated_config():
|
||||
"""保存验证后的配置到文件,便于其他进程读取"""
|
||||
config = {
|
||||
"version": VERSION,
|
||||
"build_date": BUILD_DATE,
|
||||
"model": MODEL,
|
||||
"server": SERVER,
|
||||
"security": {k: v for k, v in SECURITY.items() if k not in ["ssl_key"]} # 不保存敏感信息
|
||||
}
|
||||
|
||||
try:
|
||||
with open(os.path.join(BASE_DIR, "config", "runtime_config.json"), "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存运行时配置失败: {e}")
|
||||
|
||||
# 执行初始化
|
||||
ensure_directories()
|
||||
validate_config()
|
||||
9
dnf-client-package/README.txt
Normal file
9
dnf-client-package/README.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
DNF自动化客户端
|
||||
|
||||
安装说明:
|
||||
1. 安装Python 3.8或更高版本
|
||||
2. 安装依赖包:
|
||||
pip install pillow numpy websockets keyboard mouse pywin32 mss
|
||||
|
||||
3. 编辑config.ini设置服务器地址
|
||||
4. 运行start.bat启动客户端
|
||||
837
dnf-client-package/client.py
Normal file
837
dnf-client-package/client.py
Normal file
@@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
DNF自动化客户端,负责截取游戏画面并执行服务器返回的操作
|
||||
优化版 - 增加断线重连、性能优化和增强的游戏状态管理
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
import random
|
||||
import asyncio
|
||||
import websockets
|
||||
import configparser
|
||||
import ssl
|
||||
import ctypes
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
# 图像处理
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Windows API
|
||||
import win32gui
|
||||
import win32con
|
||||
import win32api
|
||||
import win32process
|
||||
|
||||
# 输入模拟
|
||||
import keyboard
|
||||
import mouse
|
||||
|
||||
# 尝试导入mss库,如果不存在则继续使用PIL
|
||||
try:
|
||||
import mss
|
||||
MSS_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSS_AVAILABLE = False
|
||||
print("警告: mss库未安装,将使用PIL进行截图(性能较低)")
|
||||
print("请使用 pip install mss 安装以获得更好的性能")
|
||||
|
||||
# 配置文件路径 - 使用绝对路径
|
||||
CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.ini")
|
||||
|
||||
# 日志设置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), "client.log"), encoding="utf-8"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("DNFAutoClient")
|
||||
|
||||
class DNFAutoClient:
|
||||
"""DNF自动化客户端类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化客户端"""
|
||||
self.config = self.load_config()
|
||||
self.server_url = self.config.get("Server", "url")
|
||||
self.client_id = None
|
||||
self.running = False
|
||||
self.ws = None
|
||||
self.capture_interval = float(self.config.get("Capture", "interval"))
|
||||
self.max_retries = int(self.config.get("Connection", "max_retries", fallback="5"))
|
||||
self.retry_delay = int(self.config.get("Connection", "retry_delay", fallback="5"))
|
||||
|
||||
# 增强的游戏状态
|
||||
self.game_state = {
|
||||
"in_battle": False,
|
||||
"current_map": "",
|
||||
"hp_percent": 100,
|
||||
"mp_percent": 100,
|
||||
"active_buffs": [],
|
||||
"cooldowns": {},
|
||||
"inventory_full": False,
|
||||
"current_quest": None,
|
||||
"last_operation_time": time.time(),
|
||||
"session_start_time": time.time()
|
||||
}
|
||||
|
||||
# 连接状态
|
||||
self.last_heartbeat_time = 0
|
||||
self.connection_attempts = 0
|
||||
self.reconnecting = False
|
||||
|
||||
def load_config(self):
|
||||
"""加载配置文件"""
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
self.create_default_config()
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG_FILE, encoding="utf-8")
|
||||
return config
|
||||
|
||||
def create_default_config(self):
|
||||
"""创建默认配置文件"""
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
config["Server"] = {
|
||||
"url": "wss://your-server-url:8080/ws",
|
||||
"verify_ssl": "false"
|
||||
}
|
||||
|
||||
config["Capture"] = {
|
||||
"interval": "0.5",
|
||||
"quality": "70"
|
||||
}
|
||||
|
||||
config["Game"] = {
|
||||
"window_title": "地下城与勇士",
|
||||
"key_mapping": "default"
|
||||
}
|
||||
|
||||
config["Connection"] = {
|
||||
"max_retries": "5",
|
||||
"retry_delay": "5",
|
||||
"heartbeat_interval": "5"
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
|
||||
logger.info(f"已创建默认配置文件: {CONFIG_FILE}")
|
||||
|
||||
async def connect(self):
|
||||
"""连接到服务器"""
|
||||
logger.info(f"正在连接到服务器: {self.server_url}")
|
||||
|
||||
ssl_context = None
|
||||
if self.server_url.startswith("wss://"):
|
||||
ssl_context = ssl.create_default_context()
|
||||
if self.config.get("Server", "verify_ssl").lower() == "false":
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(
|
||||
self.server_url,
|
||||
ssl=ssl_context,
|
||||
max_size=10 * 1024 * 1024, # 10MB
|
||||
ping_interval=None, # 禁用自动ping,我们将使用自己的心跳
|
||||
close_timeout=5
|
||||
)
|
||||
|
||||
# 等待认证
|
||||
await self.authenticate()
|
||||
|
||||
logger.info("已连接到服务器")
|
||||
self.connection_attempts = 0 # 重置连接尝试次数
|
||||
self.last_heartbeat_time = time.time()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接服务器失败: {e}")
|
||||
return False
|
||||
|
||||
async def connect_with_retry(self):
|
||||
"""带重试机制的连接函数"""
|
||||
if self.reconnecting:
|
||||
logger.info("已有重连过程在进行中,跳过")
|
||||
return False
|
||||
|
||||
self.reconnecting = True
|
||||
retry_count = 0
|
||||
|
||||
try:
|
||||
while retry_count < self.max_retries and not self.ws:
|
||||
try:
|
||||
logger.info(f"尝试连接服务器 (尝试 {retry_count + 1}/{self.max_retries})...")
|
||||
success = await self.connect()
|
||||
if success:
|
||||
self.reconnecting = False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"连接服务器失败: {e}")
|
||||
|
||||
retry_count += 1
|
||||
retry_delay = min(60, self.retry_delay * (2 ** retry_count)) # 指数退避策略
|
||||
logger.info(f"等待 {retry_delay} 秒后重试...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
if not self.ws:
|
||||
logger.error(f"达到最大重试次数 ({self.max_retries}),连接失败")
|
||||
self.reconnecting = False
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"重连过程中发生错误: {e}")
|
||||
self.reconnecting = False
|
||||
return False
|
||||
|
||||
self.reconnecting = False
|
||||
return True
|
||||
|
||||
async def authenticate(self):
|
||||
"""客户端认证"""
|
||||
try:
|
||||
# 等待认证挑战
|
||||
challenge_raw = await self.ws.recv()
|
||||
challenge = json.loads(challenge_raw)
|
||||
|
||||
if challenge.get("type") != "auth_challenge":
|
||||
raise ValueError("无效的认证挑战")
|
||||
|
||||
# 收集系统信息
|
||||
system_info = self.get_system_info()
|
||||
|
||||
# 发送认证响应
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "auth_response",
|
||||
"response": f"client_{random.getrandbits(32)}",
|
||||
"client_info": {
|
||||
"version": "1.0.1", # 更新版本号
|
||||
"os": "Windows",
|
||||
"screen_resolution": self.get_screen_resolution(),
|
||||
"system_info": system_info
|
||||
}
|
||||
}))
|
||||
|
||||
# 等待认证结果
|
||||
result_raw = await self.ws.recv()
|
||||
result = json.loads(result_raw)
|
||||
|
||||
if result.get("type") != "auth_result" or result.get("status") != "success":
|
||||
raise ValueError("认证失败")
|
||||
|
||||
# 保存客户端ID
|
||||
self.client_id = result.get("client_id")
|
||||
logger.info(f"认证成功,客户端ID: {self.client_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"认证失败: {e}")
|
||||
raise
|
||||
|
||||
def get_system_info(self):
|
||||
"""获取系统信息"""
|
||||
system_info = {}
|
||||
try:
|
||||
system_info["hostname"] = os.environ.get("COMPUTERNAME", "Unknown")
|
||||
system_info["username"] = os.environ.get("USERNAME", "Unknown")
|
||||
system_info["processor"] = os.environ.get("PROCESSOR_IDENTIFIER", "Unknown")
|
||||
|
||||
# 获取系统内存信息
|
||||
mem = ctypes.c_ulonglong()
|
||||
ctypes.windll.kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem))
|
||||
system_info["memory_gb"] = round(mem.value / (1024 * 1024), 2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取系统信息失败: {e}")
|
||||
|
||||
return system_info
|
||||
|
||||
def get_screen_resolution(self):
|
||||
"""获取屏幕分辨率"""
|
||||
user32 = ctypes.windll.user32
|
||||
return [user32.GetSystemMetrics(0), user32.GetSystemMetrics(1)]
|
||||
|
||||
def get_game_window(self):
|
||||
"""获取游戏窗口句柄"""
|
||||
window_title = self.config.get("Game", "window_title")
|
||||
hwnd = win32gui.FindWindow(None, window_title)
|
||||
if hwnd == 0:
|
||||
logger.warning(f"找不到游戏窗口: {window_title}")
|
||||
return None
|
||||
return hwnd
|
||||
|
||||
def capture_game_screen(self, hwnd=None):
|
||||
"""截取游戏画面"""
|
||||
try:
|
||||
if hwnd is None:
|
||||
hwnd = self.get_game_window()
|
||||
if hwnd is None:
|
||||
return None
|
||||
|
||||
# 获取窗口位置和大小
|
||||
rect = win32gui.GetWindowRect(hwnd)
|
||||
x, y, width, height = rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]
|
||||
|
||||
# 检查窗口是否最小化
|
||||
if width <= 0 or height <= 0:
|
||||
logger.warning("游戏窗口被最小化或大小无效")
|
||||
return None
|
||||
|
||||
# 使用mss进行截图(如果可用),速度比PIL快5-10倍
|
||||
if MSS_AVAILABLE:
|
||||
with mss.mss() as sct:
|
||||
monitor = {"top": y, "left": x, "width": width, "height": height}
|
||||
sct_img = sct.grab(monitor)
|
||||
# 将mss图像转换为PIL图像
|
||||
img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
|
||||
else:
|
||||
# 使用PIL进行截图(备选方案)
|
||||
img = ImageGrab.grab(bbox=(x, y, x + width, y + height))
|
||||
|
||||
# 压缩图像
|
||||
quality = int(self.config.get("Capture", "quality"))
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format="JPEG", quality=quality)
|
||||
|
||||
# 转换为Base64
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"image": img_base64,
|
||||
"window_rect": [x, y, width, height],
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"截取游戏画面失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
async def send_heartbeat(self):
|
||||
"""发送心跳包"""
|
||||
if not self.ws or not self.running:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self.ws.send(json.dumps({
|
||||
"type": "heartbeat",
|
||||
"timestamp": time.time(),
|
||||
"client_id": self.client_id,
|
||||
"game_state": {
|
||||
"in_battle": self.game_state["in_battle"],
|
||||
"current_map": self.game_state["current_map"]
|
||||
}
|
||||
}))
|
||||
|
||||
# 等待心跳响应
|
||||
response_raw = await asyncio.wait_for(
|
||||
self.ws.recv(),
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
# 检查响应
|
||||
response = json.loads(response_raw)
|
||||
if response.get("type") != "heartbeat_response":
|
||||
logger.warning(f"收到非心跳响应: {response.get('type')}")
|
||||
return False
|
||||
|
||||
self.last_heartbeat_time = time.time()
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("心跳超时")
|
||||
return False
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("发送心跳时连接已关闭")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"发送心跳失败: {e}")
|
||||
return False
|
||||
|
||||
async def heartbeat_loop(self):
|
||||
"""心跳循环"""
|
||||
heartbeat_interval = float(self.config.get("Connection", "heartbeat_interval", fallback="5"))
|
||||
|
||||
while self.running:
|
||||
if self.ws and not self.ws.closed:
|
||||
success = await self.send_heartbeat()
|
||||
if not success:
|
||||
# 心跳失败,检查连接状态
|
||||
if time.time() - self.last_heartbeat_time > heartbeat_interval * 3:
|
||||
logger.warning(f"心跳超时超过 {heartbeat_interval * 3} 秒,尝试重新连接")
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
await asyncio.sleep(heartbeat_interval)
|
||||
|
||||
async def reconnect_loop(self):
|
||||
"""重连检查循环"""
|
||||
while self.running:
|
||||
if not self.ws or self.ws.closed:
|
||||
logger.warning("WebSocket连接已断开,尝试重连...")
|
||||
if await self.connect_with_retry():
|
||||
logger.info("重连成功")
|
||||
else:
|
||||
logger.error("重连失败")
|
||||
# 不要立即停止,继续尝试
|
||||
|
||||
await asyncio.sleep(5) # 每5秒检查一次连接状态
|
||||
|
||||
async def execute_action(self, action):
|
||||
"""执行操作"""
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
|
||||
# 等待指定的延迟时间
|
||||
if "delay" in action:
|
||||
await asyncio.sleep(action["delay"])
|
||||
|
||||
# 获取游戏窗口
|
||||
hwnd = self.get_game_window()
|
||||
if hwnd is None:
|
||||
return
|
||||
|
||||
# 确保窗口处于前台
|
||||
if win32gui.GetForegroundWindow() != hwnd:
|
||||
try:
|
||||
# 尝试多种方法激活窗口
|
||||
win32gui.ShowWindow(hwnd, win32con.SW_RESTORE) # 恢复窗口(如果最小化)
|
||||
win32gui.SetForegroundWindow(hwnd) # 尝试置为前台
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 如果窗口仍然不在前台,使用更强的方法
|
||||
if win32gui.GetForegroundWindow() != hwnd:
|
||||
# 获取当前激活窗口的线程和进程ID
|
||||
curr_hwnd = win32gui.GetForegroundWindow()
|
||||
curr_thread_id = win32process.GetWindowThreadProcessId(curr_hwnd)[0]
|
||||
# 获取目标窗口的线程和进程ID
|
||||
target_thread_id = win32process.GetWindowThreadProcessId(hwnd)[0]
|
||||
# 附加线程输入
|
||||
win32process.AttachThreadInput(target_thread_id, curr_thread_id, True)
|
||||
win32gui.SetForegroundWindow(hwnd)
|
||||
win32gui.BringWindowToTop(hwnd)
|
||||
win32process.AttachThreadInput(target_thread_id, curr_thread_id, False)
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"激活窗口失败: {e}")
|
||||
|
||||
# 获取窗口位置
|
||||
rect = win32gui.GetWindowRect(hwnd)
|
||||
window_x, window_y = rect[0], rect[1]
|
||||
|
||||
# 执行不同类型的操作
|
||||
if action_type == "move_to":
|
||||
# 移动到指定位置
|
||||
position = action.get("position", [0, 0])
|
||||
x, y = position[0] + window_x, position[1] + window_y
|
||||
|
||||
# 生成人类化的移动路径
|
||||
current_pos = win32gui.GetCursorPos()
|
||||
path = self.generate_movement_path(current_pos, [x, y])
|
||||
|
||||
# 执行移动
|
||||
for point in path:
|
||||
win32api.SetCursorPos((int(point[0]), int(point[1])))
|
||||
await asyncio.sleep(0.01) # 10ms延迟
|
||||
|
||||
elif action_type == "click":
|
||||
# 点击指定位置
|
||||
position = action.get("position", [0, 0])
|
||||
x, y = position[0] + window_x, position[1] + window_y
|
||||
|
||||
# 移动到位置
|
||||
current_pos = win32gui.GetCursorPos()
|
||||
path = self.generate_movement_path(current_pos, [x, y])
|
||||
|
||||
for point in path:
|
||||
win32api.SetCursorPos((int(point[0]), int(point[1])))
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# 执行点击
|
||||
mouse.click()
|
||||
|
||||
elif action_type == "use_skill":
|
||||
# 使用技能
|
||||
key = action.get("key", "1")
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(0.05)
|
||||
keyboard.release(key)
|
||||
|
||||
# 如果有目标位置,移动鼠标并点击
|
||||
if "target_position" in action:
|
||||
pos = action["target_position"]
|
||||
x, y = pos[0] + window_x, pos[1] + window_y
|
||||
win32api.SetCursorPos((int(x), int(y)))
|
||||
await asyncio.sleep(0.05)
|
||||
mouse.click()
|
||||
|
||||
elif action_type == "interact":
|
||||
# 交互
|
||||
key = action.get("key", "f")
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(0.1)
|
||||
keyboard.release(key)
|
||||
|
||||
elif action_type == "move_random":
|
||||
# 随机移动
|
||||
direction = action.get("direction", "right")
|
||||
duration = action.get("duration", 1.0)
|
||||
|
||||
# 方向键映射
|
||||
dir_keys = {
|
||||
"up": "w",
|
||||
"down": "s",
|
||||
"left": "a",
|
||||
"right": "d"
|
||||
}
|
||||
|
||||
key = dir_keys.get(direction, "d")
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(duration)
|
||||
keyboard.release(key)
|
||||
|
||||
elif action_type == "use_item":
|
||||
# 使用物品
|
||||
key = action.get("key", "f1")
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(0.1)
|
||||
keyboard.release(key)
|
||||
|
||||
elif action_type == "stop":
|
||||
# 停止所有按键
|
||||
for key in ["w", "a", "s", "d", "1", "2", "3", "4", "5", "6"]:
|
||||
if keyboard.is_pressed(key):
|
||||
keyboard.release(key)
|
||||
|
||||
elif action_type == "type_text":
|
||||
# 输入文本
|
||||
text = action.get("text", "")
|
||||
if text:
|
||||
keyboard.write(text, delay=0.05)
|
||||
|
||||
elif action_type == "press_key_combo":
|
||||
# 按下组合键
|
||||
keys = action.get("keys", [])
|
||||
if keys:
|
||||
for key in keys:
|
||||
keyboard.press(key)
|
||||
await asyncio.sleep(0.1)
|
||||
for key in reversed(keys):
|
||||
keyboard.release(key)
|
||||
|
||||
else:
|
||||
logger.warning(f"未知操作类型: {action_type}")
|
||||
|
||||
# 更新游戏状态
|
||||
self.game_state["last_operation_time"] = time.time()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行操作失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def generate_movement_path(self, start_pos, end_pos, steps=None):
|
||||
"""生成模拟人类的鼠标移动路径"""
|
||||
# 计算距离
|
||||
distance = ((end_pos[0] - start_pos[0])**2 + (end_pos[1] - start_pos[1])**2)**0.5
|
||||
|
||||
# 如果未指定步数,根据距离计算
|
||||
if steps is None:
|
||||
steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点
|
||||
|
||||
# 生成基础路径
|
||||
t = np.linspace(0, 1, steps)
|
||||
path = []
|
||||
|
||||
# 动态调整平滑度
|
||||
smoothness = float(self.config.get("Game", "movement_smoothness", fallback="0.8"))
|
||||
|
||||
for i in range(steps):
|
||||
# 基础线性插值
|
||||
x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i]
|
||||
y = start_pos[1] + (end_pos[1] - start_pos[1]) * t[i]
|
||||
|
||||
# 添加随机偏移(越靠近中间偏移越大)
|
||||
mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大
|
||||
max_offset = distance * 0.05 * mid_factor * (1 - smoothness) # 最大偏移为距离的5%,受平滑度影响
|
||||
|
||||
offset_x = random.normalvariate(0, max_offset / 3)
|
||||
offset_y = random.normalvariate(0, max_offset / 3)
|
||||
|
||||
# 添加到路径
|
||||
path.append([x + offset_x, y + offset_y])
|
||||
|
||||
# 确保起点和终点准确
|
||||
path[0] = start_pos
|
||||
path[-1] = end_pos
|
||||
|
||||
return path
|
||||
|
||||
def analyze_image(self, image_data, detection_results):
|
||||
"""
|
||||
分析图像数据,更新游戏状态
|
||||
|
||||
参数:
|
||||
image_data (dict): 图像数据
|
||||
detection_results (list): 检测结果
|
||||
"""
|
||||
# 更新战斗状态
|
||||
monsters_detected = False
|
||||
for det in detection_results:
|
||||
if det["class_name"] in ["monster", "boss"]:
|
||||
monsters_detected = True
|
||||
break
|
||||
|
||||
self.game_state["in_battle"] = monsters_detected
|
||||
|
||||
# 分析血条和蓝条
|
||||
hp_bars = [d for d in detection_results if d["class_name"] == "hp_bar"]
|
||||
mp_bars = [d for d in detection_results if d["class_name"] == "mp_bar"]
|
||||
|
||||
if hp_bars:
|
||||
# 估算血量百分比
|
||||
self.game_state["hp_percent"] = self.estimate_bar_percent(hp_bars[0])
|
||||
|
||||
if mp_bars:
|
||||
# 估算蓝量百分比
|
||||
self.game_state["mp_percent"] = self.estimate_bar_percent(mp_bars[0])
|
||||
|
||||
# 检测技能冷却
|
||||
cooldowns = [d for d in detection_results if d["class_name"] == "cooldown"]
|
||||
self.game_state["cooldowns"] = {}
|
||||
|
||||
for cd in cooldowns:
|
||||
if "skill_id" in cd:
|
||||
self.game_state["cooldowns"][cd["skill_id"]] = cd.get("remaining_time", 1.0)
|
||||
|
||||
def estimate_bar_percent(self, bar_detection):
|
||||
"""
|
||||
估计血条/蓝条的百分比
|
||||
|
||||
参数:
|
||||
bar_detection (dict): 条形检测结果
|
||||
|
||||
返回:
|
||||
float: 百分比值(0-100)
|
||||
"""
|
||||
# 这里需要根据实际情况实现
|
||||
# 临时方案:返回检测结果中的值,如果没有则返回默认值
|
||||
return bar_detection.get("percent", 100)
|
||||
|
||||
async def capture_and_process_loop(self):
|
||||
"""截图和处理循环"""
|
||||
consecutive_errors = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# 检查WebSocket连接
|
||||
if not self.ws or self.ws.closed:
|
||||
await asyncio.sleep(0.5) # 连接断开时等待
|
||||
continue
|
||||
|
||||
# 获取游戏窗口
|
||||
hwnd = self.get_game_window()
|
||||
if hwnd is None:
|
||||
await asyncio.sleep(1.0) # 找不到窗口时等待1秒
|
||||
continue
|
||||
|
||||
# 截取游戏画面
|
||||
screen_data = self.capture_game_screen(hwnd)
|
||||
if screen_data is None:
|
||||
await asyncio.sleep(0.5) # 截图失败时等待0.5秒
|
||||
continue
|
||||
|
||||
# 准备请求数据
|
||||
request = {
|
||||
"type": "image",
|
||||
"request_id": f"req_{random.getrandbits(32)}",
|
||||
"timestamp": time.time(),
|
||||
"data": screen_data["image"],
|
||||
"game_state": self.game_state,
|
||||
"window_rect": screen_data["window_rect"]
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
await self.ws.send(json.dumps(request))
|
||||
|
||||
# 等待响应
|
||||
response_raw = await asyncio.wait_for(
|
||||
self.ws.recv(),
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
response = json.loads(response_raw)
|
||||
|
||||
if response.get("type") == "action_response":
|
||||
# 获取检测结果并更新游戏状态
|
||||
if "detections" in response:
|
||||
self.analyze_image(screen_data, response["detections"])
|
||||
|
||||
# 执行动作
|
||||
actions = response.get("actions", [])
|
||||
|
||||
# 按优先级排序
|
||||
actions.sort(key=lambda x: x.get("execution_priority", 1.0))
|
||||
|
||||
for action in actions:
|
||||
await self.execute_action(action)
|
||||
|
||||
# 重置错误计数
|
||||
consecutive_errors = 0
|
||||
|
||||
elif response.get("type") == "error":
|
||||
logger.error(f"服务器返回错误: {response.get('message')}")
|
||||
consecutive_errors += 1
|
||||
|
||||
# 等待指定的间隔时间
|
||||
await asyncio.sleep(self.capture_interval)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("等待服务器响应超时")
|
||||
consecutive_errors += 1
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.error("WebSocket连接已关闭")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理循环出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
consecutive_errors += 1
|
||||
await asyncio.sleep(1.0) # 出错时等待1秒
|
||||
|
||||
# 如果连续错误过多,尝试重新连接
|
||||
if consecutive_errors >= 5:
|
||||
logger.warning(f"连续出错 {consecutive_errors} 次,尝试重新连接")
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
self.ws = None
|
||||
consecutive_errors = 0
|
||||
|
||||
async def run(self):
|
||||
"""运行客户端"""
|
||||
self.running = True
|
||||
|
||||
# 连接到服务器
|
||||
if not await self.connect_with_retry():
|
||||
self.running = False
|
||||
return
|
||||
|
||||
try:
|
||||
# 创建任务
|
||||
capture_task = asyncio.create_task(self.capture_and_process_loop())
|
||||
heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
||||
reconnect_task = asyncio.create_task(self.reconnect_loop())
|
||||
|
||||
# 等待任务完成
|
||||
await asyncio.gather(capture_task, heartbeat_task, reconnect_task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("客户端任务已取消")
|
||||
except Exception as e:
|
||||
logger.error(f"客户端运行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
self.running = False
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
logger.info("客户端已停止")
|
||||
|
||||
def start(self):
|
||||
"""启动客户端"""
|
||||
try:
|
||||
# 启动心跳监控线程(备用方案,以防异步心跳失效)
|
||||
self._monitor_thread = threading.Thread(target=self._monitor_connection)
|
||||
self._monitor_thread.daemon = True
|
||||
self._monitor_thread.start()
|
||||
|
||||
# 运行主循环
|
||||
asyncio.run(self.run())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断,正在退出...")
|
||||
except Exception as e:
|
||||
logger.error(f"客户端出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _monitor_connection(self):
|
||||
"""监控连接的后台线程"""
|
||||
while True:
|
||||
try:
|
||||
time.sleep(30) # 每30秒检查一次
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
# 检查心跳时间
|
||||
if self.last_heartbeat_time > 0 and time.time() - self.last_heartbeat_time > 60:
|
||||
logger.warning("心跳超时,可能需要重连")
|
||||
# 不直接重连,留给重连循环处理
|
||||
except Exception as e:
|
||||
logger.error(f"连接监控线程出错: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止客户端"""
|
||||
self.running = False
|
||||
logger.info("正在停止客户端...")
|
||||
|
||||
# 创建默认配置文件(如果不存在)
|
||||
def ensure_config():
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
config["Server"] = {
|
||||
"url": "wss://your-server-url:8080/ws",
|
||||
"verify_ssl": "false"
|
||||
}
|
||||
|
||||
config["Capture"] = {
|
||||
"interval": "0.5",
|
||||
"quality": "70"
|
||||
}
|
||||
|
||||
config["Game"] = {
|
||||
"window_title": "地下城与勇士",
|
||||
"key_mapping": "default",
|
||||
"movement_smoothness": "0.8"
|
||||
}
|
||||
|
||||
config["Connection"] = {
|
||||
"max_retries": "5",
|
||||
"retry_delay": "5",
|
||||
"heartbeat_interval": "5"
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
|
||||
print(f"已创建默认配置文件: {CONFIG_FILE}")
|
||||
print("请编辑配置文件设置正确的服务器地址等信息")
|
||||
|
||||
# 启动客户端
|
||||
if __name__ == "__main__":
|
||||
# 确保配置文件存在
|
||||
ensure_config()
|
||||
|
||||
try:
|
||||
client = DNFAutoClient()
|
||||
client.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断,正在退出...")
|
||||
except Exception as e:
|
||||
logger.error(f"客户端出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
17
dnf-client-package/config.ini
Normal file
17
dnf-client-package/config.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[Server]
|
||||
url = wss://your-server-ip:8080/ws
|
||||
verify_ssl = false
|
||||
|
||||
[Capture]
|
||||
interval = 0.5
|
||||
quality = 70
|
||||
|
||||
[Game]
|
||||
window_title = 地下城与勇士
|
||||
key_mapping = default
|
||||
movement_smoothness = 0.8
|
||||
|
||||
[Connection]
|
||||
max_retries = 5
|
||||
retry_delay = 5
|
||||
heartbeat_interval = 5
|
||||
4
dnf-client-package/start.bat
Normal file
4
dnf-client-package/start.bat
Normal file
@@ -0,0 +1,4 @@
|
||||
@echo off
|
||||
echo 正在启动DNF自动化客户端...
|
||||
python client.py
|
||||
pause
|
||||
BIN
dnf-client-windows.zip
Normal file
BIN
dnf-client-windows.zip
Normal file
Binary file not shown.
@@ -3,48 +3,74 @@
|
||||
|
||||
"""
|
||||
YOLO模型封装,提供图像识别功能
|
||||
优化版 - 增强性能和模型缓存管理
|
||||
修复版 - 解决模型加载依赖问题
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys # 添加缺失的sys模块导入
|
||||
import logging
|
||||
import torch
|
||||
import yaml
|
||||
from PIL import Image
|
||||
import time
|
||||
import numpy as np
|
||||
import subprocess
|
||||
from PIL import Image
|
||||
from collections import deque
|
||||
|
||||
from config.settings import MODEL
|
||||
from config.settings import MODEL, BASE_DIR
|
||||
|
||||
logger = logging.getLogger("DNFAutoCloud")
|
||||
|
||||
class YOLOModel:
|
||||
"""YOLO模型封装类"""
|
||||
"""YOLO模型封装类 - 优化版"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化YOLO模型"""
|
||||
self.device = MODEL["device"]
|
||||
self.conf_threshold = MODEL["conf_threshold"]
|
||||
self.iou_threshold = MODEL["iou_threshold"]
|
||||
self.device = MODEL.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.conf_threshold = MODEL.get("conf_threshold", 0.5)
|
||||
self.iou_threshold = MODEL.get("iou_threshold", 0.45)
|
||||
|
||||
logger.info(f"正在加载YOLO模型: {MODEL['name']}")
|
||||
logger.info(f"模型权重路径: {MODEL['weights']}")
|
||||
# 性能监控
|
||||
self.inference_times = deque(maxlen=100)
|
||||
self.batch_size = MODEL.get("batch_size", 1)
|
||||
self.use_half_precision = MODEL.get("half_precision", True) and self.device == "cuda"
|
||||
|
||||
logger.info(f"正在加载YOLO模型: {MODEL.get('name', 'unknown')}")
|
||||
logger.info(f"模型权重路径: {MODEL.get('weights', 'unknown')}")
|
||||
logger.info(f"设备: {self.device}, 半精度: {self.use_half_precision}")
|
||||
|
||||
# 检查权重文件是否存在
|
||||
if not os.path.exists(MODEL["weights"]):
|
||||
raise FileNotFoundError(f"找不到模型权重文件: {MODEL['weights']}")
|
||||
weights_path = MODEL.get("weights", "")
|
||||
if not os.path.exists(weights_path):
|
||||
# 尝试从备选路径加载
|
||||
alt_paths = [
|
||||
os.path.join(BASE_DIR, "models", "weights", "best.pt"),
|
||||
os.path.join(BASE_DIR, "models", "best.pt"),
|
||||
os.path.join(BASE_DIR, "yolov5", "runs", "train", "exp", "weights", "best.pt")
|
||||
]
|
||||
|
||||
for path in alt_paths:
|
||||
if os.path.exists(path):
|
||||
weights_path = path
|
||||
logger.info(f"使用备选模型权重: {weights_path}")
|
||||
break
|
||||
|
||||
if not os.path.exists(weights_path):
|
||||
raise FileNotFoundError(f"找不到模型权重文件: {weights_path}")
|
||||
|
||||
# 加载模型
|
||||
try:
|
||||
self.model = torch.hub.load('ultralytics/yolov5', 'custom',
|
||||
path=MODEL["weights"],
|
||||
device=self.device)
|
||||
|
||||
# 设置模型参数
|
||||
self.model.conf = self.conf_threshold
|
||||
self.model.iou = self.iou_threshold
|
||||
|
||||
# 预热模型 (运行一个空白图像以初始化)
|
||||
dummy_img = torch.zeros((1, 3, 640, 640), device=self.device)
|
||||
self.model(dummy_img)
|
||||
if MODEL.get("engine", "") == "onnx":
|
||||
self._load_onnx_model()
|
||||
else:
|
||||
# 尝试不同的加载方法
|
||||
try:
|
||||
self._load_yolov5_custom()
|
||||
except Exception as e:
|
||||
logger.warning(f"使用torch.hub加载YOLOv5模型失败: {e}")
|
||||
logger.info("尝试备选方法加载模型...")
|
||||
self._load_yolov5_manually()
|
||||
|
||||
logger.info("YOLO模型加载成功")
|
||||
|
||||
@@ -52,6 +78,256 @@ class YOLOModel:
|
||||
logger.error(f"加载YOLO模型失败: {e}")
|
||||
raise
|
||||
|
||||
def _load_yolov5_custom(self):
|
||||
"""使用torch.hub加载YOLOv5模型"""
|
||||
try:
|
||||
# 尝试从本地加载YOLOv5
|
||||
yolov5_dir = os.path.join(BASE_DIR, "tools", "yolov5")
|
||||
if os.path.exists(yolov5_dir):
|
||||
logger.info(f"从本地目录加载YOLOv5: {yolov5_dir}")
|
||||
sys.path.insert(0, yolov5_dir)
|
||||
|
||||
try:
|
||||
# 尝试直接导入YOLOv5模块
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.torch_utils import select_device
|
||||
from utils.general import check_img_size, non_max_suppression, scale_coords
|
||||
|
||||
# 加载模型
|
||||
self.model = DetectMultiBackend(MODEL.get("weights"), device=self.device)
|
||||
self.stride = self.model.stride
|
||||
self.names = self.model.names
|
||||
|
||||
# 设置参数
|
||||
self.imgsz = check_img_size((640, 640), s=self.stride)
|
||||
|
||||
# 使用半精度
|
||||
if self.use_half_precision:
|
||||
self.model.half()
|
||||
|
||||
# 预热模型
|
||||
self._warmup_model()
|
||||
|
||||
return
|
||||
except ImportError as e:
|
||||
logger.warning(f"无法直接导入YOLOv5模块: {e}")
|
||||
|
||||
# 尝试从torch.hub加载
|
||||
logger.info("尝试从torch.hub加载YOLOv5模型")
|
||||
self.model = torch.hub.load(
|
||||
'ultralytics/yolov5',
|
||||
'custom',
|
||||
path=MODEL.get("weights"),
|
||||
device=self.device,
|
||||
force_reload=True
|
||||
)
|
||||
|
||||
# 设置模型参数
|
||||
self.model.conf = self.conf_threshold
|
||||
self.model.iou = self.iou_threshold
|
||||
|
||||
# 使用半精度
|
||||
if self.use_half_precision:
|
||||
self.model.half()
|
||||
|
||||
# 预热模型
|
||||
self._warmup_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载YOLOv5模型失败: {e}")
|
||||
raise
|
||||
|
||||
def _load_yolov5_manually(self):
|
||||
"""手动加载YOLOv5模型(不依赖torch.hub)"""
|
||||
try:
|
||||
# 检查YOLOv5目录是否存在,不存在则克隆
|
||||
yolov5_dir = os.path.join(BASE_DIR, "tools", "yolov5")
|
||||
if not os.path.exists(yolov5_dir):
|
||||
logger.info("YOLOv5目录不存在,正在克隆仓库...")
|
||||
os.makedirs(os.path.dirname(yolov5_dir), exist_ok=True)
|
||||
|
||||
# 克隆YOLOv5仓库
|
||||
subprocess.run(
|
||||
["git", "clone", "https://github.com/ultralytics/yolov5.git", yolov5_dir],
|
||||
check=True
|
||||
)
|
||||
|
||||
# 添加YOLOv5目录到系统路径
|
||||
sys.path.insert(0, yolov5_dir)
|
||||
|
||||
try:
|
||||
# 尝试导入YOLOv5模块
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.torch_utils import select_device
|
||||
from utils.general import check_img_size, non_max_suppression, scale_coords
|
||||
|
||||
# 加载模型
|
||||
self.model = DetectMultiBackend(MODEL.get("weights"), device=self.device)
|
||||
self.stride = self.model.stride
|
||||
self.names = self.model.names
|
||||
|
||||
# 设置参数
|
||||
self.imgsz = check_img_size((640, 640), s=self.stride)
|
||||
|
||||
# 使用半精度
|
||||
if self.use_half_precision:
|
||||
self.model.half()
|
||||
|
||||
# 保存必要的函数
|
||||
self.non_max_suppression = non_max_suppression
|
||||
self.scale_coords = scale_coords
|
||||
|
||||
# 预热模型
|
||||
dummy_img = torch.zeros((1, 3, self.imgsz[0], self.imgsz[1]), device=self.device)
|
||||
if self.use_half_precision:
|
||||
dummy_img = dummy_img.half()
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(2):
|
||||
self.model(dummy_img)
|
||||
|
||||
logger.info("手动加载YOLOv5模型成功")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"导入YOLOv5模块失败: {e}")
|
||||
# 尝试使用更简单的模型
|
||||
self._load_fallback_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"手动加载YOLOv5模型失败: {e}")
|
||||
# 尝试使用更简单的模型
|
||||
self._load_fallback_model()
|
||||
|
||||
def _load_fallback_model(self):
|
||||
"""加载备用模型(使用PyTorch内置模型)"""
|
||||
logger.info("尝试加载备用模型 (PyTorch YOLO)")
|
||||
|
||||
try:
|
||||
# 使用PyTorch的预训练模型
|
||||
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
||||
|
||||
self.model = fasterrcnn_resnet50_fpn(pretrained=True)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 备用模型的类别
|
||||
self.names = [
|
||||
'background', 'monster', 'boss', 'door', 'item', 'npc', 'player',
|
||||
'hp_bar', 'mp_bar', 'skill_ready', 'cooldown'
|
||||
]
|
||||
|
||||
# 标记使用备用模型
|
||||
self.using_fallback = True
|
||||
|
||||
logger.info("备用模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载备用模型失败: {e}")
|
||||
raise
|
||||
|
||||
def _load_onnx_model(self):
|
||||
"""加载ONNX版YOLO模型"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
# 设置ONNX运行时参数
|
||||
if self.device == "cuda":
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
providers = ['CPUExecutionProvider']
|
||||
|
||||
# 检查ONNX模型文件
|
||||
onnx_path = MODEL.get("weights", "")
|
||||
if not onnx_path.endswith('.onnx'):
|
||||
onnx_path = onnx_path + '.onnx'
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
logger.warning(f"找不到ONNX模型: {onnx_path}")
|
||||
logger.info("尝试导出PyTorch模型到ONNX格式...")
|
||||
|
||||
# 尝试加载PyTorch模型并导出为ONNX
|
||||
self._load_yolov5_custom()
|
||||
self._export_to_onnx(onnx_path)
|
||||
|
||||
# 创建ONNX会话
|
||||
self.onnx_session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
|
||||
# 获取模型输入输出名称
|
||||
self.input_name = self.onnx_session.get_inputs()[0].name
|
||||
self.output_names = [output.name for output in self.onnx_session.get_outputs()]
|
||||
|
||||
# 加载类别名称
|
||||
if os.path.exists(MODEL.get("class_names", "")):
|
||||
with open(MODEL.get("class_names"), "r") as f:
|
||||
self.class_names = yaml.safe_load(f)
|
||||
else:
|
||||
self.class_names = ["object"]
|
||||
|
||||
# 标记使用ONNX
|
||||
self.using_onnx = True
|
||||
|
||||
# 预热模型
|
||||
self._warmup_onnx_model()
|
||||
|
||||
except ImportError:
|
||||
logger.error("缺少ONNX运行时依赖,请安装onnxruntime或onnxruntime-gpu")
|
||||
logger.info("尝试使用PyTorch模型代替...")
|
||||
self._load_yolov5_custom()
|
||||
except Exception as e:
|
||||
logger.error(f"加载ONNX模型失败: {e}")
|
||||
logger.info("尝试使用PyTorch模型代替...")
|
||||
self._load_yolov5_custom()
|
||||
|
||||
def _export_to_onnx(self, onnx_path):
|
||||
"""将PyTorch模型导出为ONNX格式"""
|
||||
try:
|
||||
import torch.onnx
|
||||
|
||||
# 准备导出
|
||||
dummy_input = torch.randn(1, 3, 640, 640, device=self.device)
|
||||
if self.use_half_precision:
|
||||
dummy_input = dummy_input.half()
|
||||
|
||||
# 导出ONNX
|
||||
torch.onnx.export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
onnx_path,
|
||||
verbose=False,
|
||||
opset_version=12,
|
||||
input_names=['images'],
|
||||
output_names=['output']
|
||||
)
|
||||
|
||||
logger.info(f"PyTorch模型成功导出为ONNX格式: {onnx_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"导出ONNX模型失败: {e}")
|
||||
raise
|
||||
|
||||
def _warmup_model(self):
|
||||
"""预热模型(运行空白图像以初始化)"""
|
||||
dummy_img = torch.zeros((1, 3, 640, 640), device=self.device)
|
||||
if self.use_half_precision:
|
||||
dummy_img = dummy_img.half()
|
||||
|
||||
# 进行多次推理预热
|
||||
with torch.no_grad():
|
||||
for _ in range(2):
|
||||
self.model(dummy_img)
|
||||
|
||||
logger.info("模型预热完成")
|
||||
|
||||
def _warmup_onnx_model(self):
|
||||
"""预热ONNX模型"""
|
||||
dummy_img = np.zeros((1, 3, 640, 640), dtype=np.float32)
|
||||
|
||||
# 进行多次推理预热
|
||||
for _ in range(2):
|
||||
self.onnx_session.run(self.output_names, {self.input_name: dummy_img})
|
||||
|
||||
logger.info("ONNX模型预热完成")
|
||||
|
||||
def detect(self, image):
|
||||
"""
|
||||
对图像进行目标检测
|
||||
@@ -63,20 +339,25 @@ class YOLOModel:
|
||||
list: 检测结果,每个结果包含边界框、类别和置信度
|
||||
"""
|
||||
try:
|
||||
# 在GPU上进行推理
|
||||
with torch.no_grad():
|
||||
results = self.model(image)
|
||||
start_time = time.time()
|
||||
|
||||
# 处理结果
|
||||
detections = []
|
||||
for pred in results.xyxy[0].cpu().numpy():
|
||||
x1, y1, x2, y2, conf, cls = pred
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': float(conf),
|
||||
'class_id': int(cls),
|
||||
'class_name': self.model.names[int(cls)]
|
||||
})
|
||||
# 根据模型类型选择检测方法
|
||||
if hasattr(self, 'using_onnx') and self.using_onnx:
|
||||
detections = self._detect_onnx(image)
|
||||
elif hasattr(self, 'using_fallback') and self.using_fallback:
|
||||
detections = self._detect_fallback(image)
|
||||
else:
|
||||
detections = self._detect_pytorch(image)
|
||||
|
||||
# 记录推理时间
|
||||
inference_time = time.time() - start_time
|
||||
self.inference_times.append(inference_time)
|
||||
|
||||
# 计算平均推理时间
|
||||
avg_time = sum(self.inference_times) / len(self.inference_times)
|
||||
|
||||
if len(self.inference_times) % 10 == 0:
|
||||
logger.debug(f"平均推理时间: {avg_time:.3f}秒, 当前: {inference_time:.3f}秒")
|
||||
|
||||
return detections
|
||||
|
||||
@@ -84,6 +365,217 @@ class YOLOModel:
|
||||
logger.error(f"目标检测出错: {e}")
|
||||
return []
|
||||
|
||||
def _detect_pytorch(self, image):
|
||||
"""使用PyTorch模型进行检测"""
|
||||
# 在GPU上进行推理
|
||||
with torch.no_grad():
|
||||
# 处理不同的模型接口
|
||||
if hasattr(self, 'stride') and hasattr(self, 'names'):
|
||||
# 自定义加载的YOLOv5
|
||||
# 转换图像
|
||||
img = self._prepare_image_custom(image)
|
||||
|
||||
# 推理
|
||||
output = self.model(img)
|
||||
|
||||
# 处理输出
|
||||
pred = self.non_max_suppression(output[0], self.conf_threshold, self.iou_threshold)
|
||||
|
||||
# 解析结果
|
||||
detections = []
|
||||
for det in pred[0]:
|
||||
x1, y1, x2, y2, conf, cls = det
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': float(conf),
|
||||
'class_id': int(cls),
|
||||
'class_name': self.names[int(cls)]
|
||||
})
|
||||
else:
|
||||
# 标准torch.hub加载的模型
|
||||
results = self.model(image)
|
||||
|
||||
# 处理结果
|
||||
detections = []
|
||||
for pred in results.xyxy[0].cpu().numpy():
|
||||
x1, y1, x2, y2, conf, cls = pred
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': float(conf),
|
||||
'class_id': int(cls),
|
||||
'class_name': self.model.names[int(cls)]
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_fallback(self, image):
|
||||
"""使用备用模型进行检测"""
|
||||
# 预处理图像
|
||||
img = self._prepare_image_pytorch(image)
|
||||
|
||||
# 进行推理
|
||||
with torch.no_grad():
|
||||
predictions = self.model(img)
|
||||
|
||||
# 处理结果
|
||||
detections = []
|
||||
for i, prediction in enumerate(predictions[0]['boxes']):
|
||||
score = predictions[0]['scores'][i].item()
|
||||
if score > self.conf_threshold:
|
||||
x1, y1, x2, y2 = prediction.tolist()
|
||||
class_id = predictions[0]['labels'][i].item()
|
||||
|
||||
# 映射torchvision的COCO类别到我们的类别
|
||||
class_name = self.names[min(class_id, len(self.names) - 1)]
|
||||
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': float(score),
|
||||
'class_id': class_id,
|
||||
'class_name': class_name
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_onnx(self, image):
|
||||
"""使用ONNX模型进行检测"""
|
||||
# 预处理图像
|
||||
input_tensor = self._prepare_image_onnx(image)
|
||||
|
||||
# 运行推理
|
||||
outputs = self.onnx_session.run(self.output_names, {self.input_name: input_tensor})
|
||||
|
||||
# 解析输出(根据ONNX模型的输出格式可能需要调整)
|
||||
# 假设输出格式为 [batch_id, x1, y1, x2, y2, conf, class_id]
|
||||
predictions = outputs[0]
|
||||
|
||||
# 过滤低置信度预测
|
||||
mask = predictions[:, 5] > self.conf_threshold
|
||||
filtered_preds = predictions[mask]
|
||||
|
||||
# 组织检测结果
|
||||
detections = []
|
||||
for pred in filtered_preds:
|
||||
x1, y1, x2, y2, conf, cls = pred[1:7]
|
||||
cls_id = int(cls)
|
||||
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': float(conf),
|
||||
'class_id': cls_id,
|
||||
'class_name': self.class_names[cls_id] if cls_id < len(self.class_names) else "unknown"
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _prepare_image_custom(self, image):
|
||||
"""为自定义加载的YOLOv5模型准备图像"""
|
||||
# 转换PIL图像为numpy数组
|
||||
img = np.array(image)
|
||||
|
||||
# 调整大小
|
||||
img = self._letterbox(img, new_shape=self.imgsz)[0]
|
||||
|
||||
# BGR转RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
# 转换为PyTorch张量
|
||||
img = torch.from_numpy(img).to(self.device)
|
||||
img = img.half() if self.use_half_precision else img.float()
|
||||
img /= 255.0
|
||||
|
||||
# 增加批次维度
|
||||
if img.ndimension() == 3:
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
return img
|
||||
|
||||
def _prepare_image_pytorch(self, image):
|
||||
"""为PyTorch模型准备图像"""
|
||||
# 转换PIL图像为PyTorch张量
|
||||
from torchvision import transforms
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
img = transform(image).unsqueeze(0).to(self.device)
|
||||
|
||||
return img
|
||||
|
||||
def _prepare_image_onnx(self, image):
|
||||
"""为ONNX模型准备图像"""
|
||||
# 调整图像大小
|
||||
img_size = MODEL.get("img_size", 640)
|
||||
image = image.resize((img_size, img_size), Image.LANCZOS)
|
||||
|
||||
# 转换为numpy数组
|
||||
img = np.array(image).astype(np.float32) / 255.0
|
||||
|
||||
# 从HWC转换为CHW格式
|
||||
img = img.transpose(2, 0, 1)
|
||||
|
||||
# 添加批次维度
|
||||
img = np.expand_dims(img, axis=0)
|
||||
|
||||
return img
|
||||
|
||||
def _letterbox(self, img, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
"""调整图像大小并填充(YOLOv5风格)"""
|
||||
shape = img.shape[:2] # 当前形状 [高, 宽]
|
||||
|
||||
# 缩放比例 (新 / 旧)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
|
||||
# 计算填充
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||
|
||||
# 平均分配填充
|
||||
dw /= 2
|
||||
dh /= 2
|
||||
|
||||
# 调整大小
|
||||
if shape[::-1] != new_unpad:
|
||||
import cv2
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 添加边框
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||||
|
||||
return img, r, (dw, dh)
|
||||
|
||||
def get_class_names(self):
|
||||
"""获取类别名称列表"""
|
||||
return self.model.names
|
||||
if hasattr(self, 'using_onnx') and self.using_onnx:
|
||||
return self.class_names
|
||||
elif hasattr(self, 'using_fallback') and self.using_fallback:
|
||||
return self.names
|
||||
elif hasattr(self, 'names'):
|
||||
return self.names
|
||||
else:
|
||||
return self.model.names
|
||||
|
||||
def get_performance_stats(self):
|
||||
"""获取性能统计信息"""
|
||||
if not self.inference_times:
|
||||
return {"average_time": 0, "min_time": 0, "max_time": 0, "count": 0, "fps": 0}
|
||||
|
||||
avg_time = sum(self.inference_times) / len(self.inference_times)
|
||||
|
||||
# 修复fps计算逻辑
|
||||
if avg_time > 0:
|
||||
fps = 1.0 / avg_time
|
||||
else:
|
||||
fps = 0
|
||||
|
||||
return {
|
||||
"average_time": avg_time,
|
||||
"min_time": min(self.inference_times),
|
||||
"max_time": max(self.inference_times),
|
||||
"count": len(self.inference_times),
|
||||
"fps": fps
|
||||
}
|
||||
41
package.sh
Executable file
41
package.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/bin/bash
|
||||
# 创建客户端打包目录
|
||||
mkdir -p dnf-client-package
|
||||
cp client.py dnf-client-package/
|
||||
cp -r win_libs/* dnf-client-package/
|
||||
|
||||
# 创建默认配置文件
|
||||
cat > dnf-client-package/config.ini << EOF
|
||||
[Server]
|
||||
url = wss://your-server-ip:8080/ws
|
||||
verify_ssl = false
|
||||
|
||||
[Capture]
|
||||
interval = 0.5
|
||||
quality = 70
|
||||
|
||||
[Game]
|
||||
window_title = 地下城与勇士
|
||||
key_mapping = default
|
||||
movement_smoothness = 0.8
|
||||
|
||||
[Connection]
|
||||
max_retries = 5
|
||||
retry_delay = 5
|
||||
heartbeat_interval = 5
|
||||
EOF
|
||||
|
||||
# 创建启动脚本
|
||||
cat > dnf-client-package/start.bat << EOF
|
||||
@echo off
|
||||
echo 正在启动DNF自动化客户端...
|
||||
python client.py
|
||||
pause
|
||||
EOF
|
||||
|
||||
# 打包为zip
|
||||
cd dnf-client-package
|
||||
zip -r ../dnf-client-windows.zip *
|
||||
cd ..
|
||||
|
||||
echo "打包完成: dnf-client-windows.zip"
|
||||
238
run.py
238
run.py
@@ -3,25 +3,257 @@
|
||||
|
||||
"""
|
||||
DNF自动化云服务主启动脚本
|
||||
优化版 - 增加命令行参数和状态监控
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import platform
|
||||
import subprocess
|
||||
import threading
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from server.main import start_server
|
||||
from utils.logging_utils import setup_logging
|
||||
from config.settings import SERVER, MODEL, LOGGING
|
||||
|
||||
# 版本信息
|
||||
VERSION = "1.0.1"
|
||||
|
||||
def get_system_info():
|
||||
"""获取系统信息"""
|
||||
info = {
|
||||
"os": platform.system(),
|
||||
"os_version": platform.version(),
|
||||
"python": platform.python_version(),
|
||||
"cpu": platform.processor(),
|
||||
"architecture": platform.machine(),
|
||||
"hostname": platform.node()
|
||||
}
|
||||
|
||||
# 检查CUDA和PyTorch
|
||||
try:
|
||||
import torch
|
||||
info["torch"] = torch.__version__
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info["cuda"] = torch.version.cuda
|
||||
info["gpu"] = torch.cuda.get_device_name(0)
|
||||
info["gpu_count"] = torch.cuda.device_count()
|
||||
info["gpu_memory"] = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB
|
||||
else:
|
||||
info["cuda"] = "Not available"
|
||||
except ImportError:
|
||||
info["torch"] = "Not installed"
|
||||
info["cuda"] = "Unknown"
|
||||
|
||||
return info
|
||||
|
||||
def print_banner():
|
||||
"""打印启动横幅"""
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ DNF 自动化云服务 v{VERSION} ║
|
||||
║ ║
|
||||
║ 适用于地下城与勇士游戏的AI辅助系统 ║
|
||||
║ 启动时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ║
|
||||
║ ║
|
||||
╚═══════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
print(banner)
|
||||
|
||||
# 添加警告信息
|
||||
if MODEL.get("device", "") == "cpu":
|
||||
print("警告: 正在使用CPU模式运行,性能可能受限,建议使用GPU以获得更好的体验\n")
|
||||
|
||||
def monitor_server(port):
|
||||
"""监控服务器状态"""
|
||||
try:
|
||||
# 定义状态监控线程
|
||||
def status_thread():
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_time = time.time()
|
||||
uptime = current_time - start_time
|
||||
|
||||
# 读取server_info.json获取服务器信息
|
||||
if os.path.exists("server_info.json"):
|
||||
with open("server_info.json", "r") as f:
|
||||
info = json.load(f)
|
||||
|
||||
# 显示状态
|
||||
os.system('cls' if os.name == 'nt' else 'clear')
|
||||
print(f"\n==== DNF自动化云服务状态监控 ====")
|
||||
print(f"服务器版本: {info.get('version', 'unknown')}")
|
||||
print(f"启动时间: {info.get('started_at', 'unknown')}")
|
||||
print(f"运行时长: {timedelta(seconds=int(uptime))}")
|
||||
print(f"服务地址: {info.get('server', {}).get('url', 'unknown')}")
|
||||
|
||||
# 尝试获取服务器状态
|
||||
try:
|
||||
status_file = "server_status.json"
|
||||
if os.path.exists(status_file) and os.path.getmtime(status_file) > current_time - 60:
|
||||
with open(status_file, "r") as sf:
|
||||
status = json.load(sf)
|
||||
|
||||
print(f"\n==== 连接统计 ====")
|
||||
print(f"活跃连接数: {status.get('connections', {}).get('active', 0)}")
|
||||
print(f"总连接数: {status.get('connections', {}).get('total', 0)}")
|
||||
|
||||
print(f"\n==== 性能统计 ====")
|
||||
print(f"已处理图像: {status.get('performance', {}).get('images_processed', 0)}")
|
||||
print(f"已生成动作: {status.get('performance', {}).get('actions_generated', 0)}")
|
||||
print(f"错误数: {status.get('performance', {}).get('errors', 0)}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 等待10秒更新一次
|
||||
time.sleep(10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"监控出错: {e}")
|
||||
time.sleep(30)
|
||||
|
||||
# 启动监控线程
|
||||
t = threading.Thread(target=status_thread, daemon=True)
|
||||
t.start()
|
||||
|
||||
except Exception as e:
|
||||
print(f"启动监控失败: {e}")
|
||||
|
||||
def clean_old_logs():
|
||||
"""清理旧的日志文件"""
|
||||
log_dir = os.path.dirname(LOGGING.get("file", ""))
|
||||
if not os.path.exists(log_dir):
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取所有日志文件
|
||||
log_files = [f for f in os.listdir(log_dir) if f.endswith(".log")]
|
||||
|
||||
# 按修改时间排序
|
||||
log_files.sort(key=lambda x: os.path.getmtime(os.path.join(log_dir, x)))
|
||||
|
||||
# 如果超过10个日志文件,删除最旧的
|
||||
if len(log_files) > 10:
|
||||
for f in log_files[:-10]:
|
||||
try:
|
||||
os.remove(os.path.join(log_dir, f))
|
||||
print(f"已删除旧日志文件: {f}")
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"清理日志时出错: {e}")
|
||||
|
||||
def check_updates():
|
||||
"""检查更新"""
|
||||
try:
|
||||
# 这里可以实现检查更新的逻辑
|
||||
# 例如,检查Git仓库或向API服务器查询
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
def create_status_api(port):
|
||||
"""创建一个简单的状态API服务器"""
|
||||
try:
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
import threading
|
||||
import json
|
||||
|
||||
class StatusHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == "/status":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
# 读取服务器状态
|
||||
status = {"status": "running"}
|
||||
if os.path.exists("server_status.json"):
|
||||
try:
|
||||
with open("server_status.json", "r") as f:
|
||||
status = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.wfile.write(json.dumps(status).encode())
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def run_server():
|
||||
server_address = ("", port + 1) # 使用主端口+1
|
||||
httpd = HTTPServer(server_address, StatusHandler)
|
||||
print(f"状态API服务器已启动: http://localhost:{port + 1}/status")
|
||||
httpd.serve_forever()
|
||||
|
||||
# 启动API服务器线程
|
||||
t = threading.Thread(target=run_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
except Exception as e:
|
||||
print(f"启动状态API服务器失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="DNF自动化云服务")
|
||||
parser = argparse.ArgumentParser(description=f"DNF自动化云服务 v{VERSION}")
|
||||
parser.add_argument("--port", type=int, default=8080, help="服务端口")
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="服务地址")
|
||||
parser.add_argument("--debug", action="store_true", help="开启调试模式")
|
||||
parser.add_argument("--cpu", action="store_true", help="强制使用CPU模式")
|
||||
parser.add_argument("--no-monitor", action="store_true", help="禁用状态监控")
|
||||
parser.add_argument("--version", action="version", version=f"DNF自动化云服务 v{VERSION}")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 开启状态监控
|
||||
if not args.no_monitor:
|
||||
monitor_server(args.port)
|
||||
create_status_api(args.port)
|
||||
|
||||
# 打印横幅
|
||||
print_banner()
|
||||
|
||||
# 清理旧日志
|
||||
clean_old_logs()
|
||||
|
||||
# 检查更新
|
||||
check_updates()
|
||||
|
||||
# 强制CPU模式
|
||||
if args.cpu:
|
||||
MODEL["device"] = "cpu"
|
||||
print("已强制使用CPU模式")
|
||||
|
||||
# 设置日志
|
||||
setup_logging(debug=args.debug)
|
||||
logger = logging.getLogger("DNFAutoCloud")
|
||||
logger.info("正在启动DNF自动化云服务...")
|
||||
|
||||
# 打印系统信息
|
||||
system_info = get_system_info()
|
||||
logger.info(f"系统信息: {json.dumps(system_info, ensure_ascii=False)}")
|
||||
|
||||
# 启动服务
|
||||
start_server(host=args.host, port=args.port, debug=args.debug)
|
||||
logger.info(f"正在启动DNF自动化云服务 v{VERSION}...")
|
||||
|
||||
try:
|
||||
# 启动服务器
|
||||
success = start_server(host=args.host, port=args.port, debug=args.debug)
|
||||
|
||||
if not success:
|
||||
logger.error("服务器启动失败")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断,服务已停止")
|
||||
except Exception as e:
|
||||
logger.critical(f"启动服务时出错: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,10 +3,12 @@
|
||||
|
||||
"""
|
||||
图像处理模块,负责对接收到的图像进行预处理和后处理
|
||||
优化版 - 增强图像处理性能和准确性
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
|
||||
logger = logging.getLogger("DNFAutoCloud")
|
||||
@@ -32,7 +34,10 @@ def process_image(yolo_model, image):
|
||||
# 后处理检测结果
|
||||
processed_detections = postprocess_detections(detections, image.size)
|
||||
|
||||
return processed_detections
|
||||
# 计算额外特征
|
||||
enhanced_detections = calculate_extra_features(processed_detections, image)
|
||||
|
||||
return enhanced_detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图像处理出错: {e}")
|
||||
@@ -40,7 +45,7 @@ def process_image(yolo_model, image):
|
||||
|
||||
def preprocess_image(image):
|
||||
"""
|
||||
图像预处理
|
||||
图像预处理 - 优化版
|
||||
|
||||
参数:
|
||||
image (PIL.Image): 输入图像
|
||||
@@ -49,17 +54,38 @@ def preprocess_image(image):
|
||||
PIL.Image: 预处理后的图像
|
||||
"""
|
||||
try:
|
||||
# 图像增强
|
||||
image = image.convert("RGB") # 确保图像是RGB模式
|
||||
# 确保图像是RGB模式
|
||||
image = image.convert("RGB")
|
||||
|
||||
# 调整亮度和对比度以便更好地检测
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
image = enhancer.enhance(1.2) # 增加对比度
|
||||
# 应用自适应直方图均衡化 (对于UI元素检测很有帮助)
|
||||
# 将PIL图像转换为numpy数组/OpenCV格式
|
||||
img_cv = np.array(image)
|
||||
img_cv = img_cv[:, :, ::-1].copy() # RGB -> BGR
|
||||
|
||||
# 调整大小(可选,如果需要的话)
|
||||
# image = image.resize((640, 640), Image.LANCZOS)
|
||||
# 转换为LAB色彩空间并对L通道应用CLAHE
|
||||
lab = cv2.cvtColor(img_cv, cv2.COLOR_BGR2LAB)
|
||||
l, a, b = cv2.split(lab)
|
||||
|
||||
return image
|
||||
# 创建CLAHE对象
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
cl = clahe.apply(l)
|
||||
|
||||
# 合并处理后的通道
|
||||
limg = cv2.merge((cl, a, b))
|
||||
enhanced_cv = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
|
||||
|
||||
# 将OpenCV图像转换回PIL格式
|
||||
enhanced = Image.fromarray(cv2.cvtColor(enhanced_cv, cv2.COLOR_BGR2RGB))
|
||||
|
||||
# 适度锐化
|
||||
enhancer = ImageEnhance.Sharpness(enhanced)
|
||||
enhanced = enhancer.enhance(1.3)
|
||||
|
||||
# 稍微增加对比度
|
||||
enhancer = ImageEnhance.Contrast(enhanced)
|
||||
enhanced = enhancer.enhance(1.2)
|
||||
|
||||
return enhanced
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图像预处理出错: {e}")
|
||||
@@ -67,7 +93,7 @@ def preprocess_image(image):
|
||||
|
||||
def postprocess_detections(detections, image_size):
|
||||
"""
|
||||
后处理检测结果
|
||||
后处理检测结果 - 优化版
|
||||
|
||||
参数:
|
||||
detections (list): 检测结果
|
||||
@@ -79,25 +105,45 @@ def postprocess_detections(detections, image_size):
|
||||
try:
|
||||
# 过滤或合并重叠的检测框
|
||||
filtered_detections = []
|
||||
|
||||
# 按类别和置信度排序
|
||||
detections.sort(key=lambda x: (x['class_name'], -x['confidence']))
|
||||
|
||||
# 按类别分组
|
||||
class_groups = {}
|
||||
for det in detections:
|
||||
# 例如,过滤掉边缘的检测结果
|
||||
bbox = det['bbox']
|
||||
if is_valid_detection(bbox, image_size):
|
||||
# 添加相对位置信息
|
||||
det['center'] = [
|
||||
(bbox[0] + bbox[2]) / 2, # x中心
|
||||
(bbox[1] + bbox[3]) / 2 # y中心
|
||||
]
|
||||
det['size'] = [
|
||||
bbox[2] - bbox[0], # 宽度
|
||||
bbox[3] - bbox[1] # 高度
|
||||
]
|
||||
det['relative_position'] = [
|
||||
det['center'][0] / image_size[0], # 相对x位置
|
||||
det['center'][1] / image_size[1] # 相对y位置
|
||||
]
|
||||
|
||||
filtered_detections.append(det)
|
||||
class_name = det['class_name']
|
||||
if class_name not in class_groups:
|
||||
class_groups[class_name] = []
|
||||
class_groups[class_name].append(det)
|
||||
|
||||
# 处理每个类别内的重叠框
|
||||
for class_name, dets in class_groups.items():
|
||||
# 使用非极大值抑制
|
||||
remaining = non_max_suppression(dets, iou_threshold=0.5)
|
||||
|
||||
# 处理每个保留的检测
|
||||
for det in remaining:
|
||||
# 例如,过滤掉边缘的检测结果
|
||||
bbox = det['bbox']
|
||||
if is_valid_detection(bbox, image_size):
|
||||
# 添加相对位置信息
|
||||
det['center'] = [
|
||||
(bbox[0] + bbox[2]) / 2, # x中心
|
||||
(bbox[1] + bbox[3]) / 2 # y中心
|
||||
]
|
||||
det['size'] = [
|
||||
bbox[2] - bbox[0], # 宽度
|
||||
bbox[3] - bbox[1] # 高度
|
||||
]
|
||||
det['relative_position'] = [
|
||||
det['center'][0] / image_size[0], # 相对x位置
|
||||
det['center'][1] / image_size[1] # 相对y位置
|
||||
]
|
||||
# 添加一个唯一ID
|
||||
det['id'] = f"{class_name}_{len(filtered_detections)}"
|
||||
|
||||
filtered_detections.append(det)
|
||||
|
||||
return filtered_detections
|
||||
|
||||
@@ -105,6 +151,71 @@ def postprocess_detections(detections, image_size):
|
||||
logger.error(f"检测结果后处理出错: {e}")
|
||||
return detections # 返回原始检测结果
|
||||
|
||||
def non_max_suppression(detections, iou_threshold=0.5):
|
||||
"""
|
||||
非极大值抑制实现
|
||||
|
||||
参数:
|
||||
detections (list): 检测列表
|
||||
iou_threshold (float): IOU阈值
|
||||
|
||||
返回:
|
||||
list: 抑制后的检测列表
|
||||
"""
|
||||
# 如果列表为空,直接返回
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
# 获取所有边界框和置信度
|
||||
boxes = [d['bbox'] for d in detections]
|
||||
scores = [d['confidence'] for d in detections]
|
||||
|
||||
# 初始化保留的框索引列表
|
||||
keep = []
|
||||
|
||||
# 按置信度降序排序
|
||||
idxs = np.argsort(scores)[::-1]
|
||||
|
||||
while len(idxs) > 0:
|
||||
# 取最高置信度的框
|
||||
current = idxs[0]
|
||||
keep.append(current)
|
||||
|
||||
# 计算当前框与其他框的IoU
|
||||
ious = []
|
||||
current_box = boxes[current]
|
||||
|
||||
for i in idxs[1:]:
|
||||
iou = calculate_iou(current_box, boxes[i])
|
||||
ious.append(iou)
|
||||
|
||||
# 保留IoU小于阈值的框
|
||||
idxs = idxs[1:][np.array(ious) < iou_threshold]
|
||||
|
||||
# 返回保留的检测
|
||||
return [detections[i] for i in keep]
|
||||
|
||||
def calculate_iou(box1, box2):
|
||||
"""计算两个边界框的IoU"""
|
||||
# 交集
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
|
||||
if x2 < x1 or y2 < y1:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
|
||||
# 两个边界框的面积
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
|
||||
# IoU
|
||||
iou = intersection / float(box1_area + box2_area - intersection)
|
||||
return iou
|
||||
|
||||
def is_valid_detection(bbox, image_size):
|
||||
"""
|
||||
检查检测框是否有效
|
||||
@@ -119,11 +230,175 @@ def is_valid_detection(bbox, image_size):
|
||||
# 检查框是否太小
|
||||
width = bbox[2] - bbox[0]
|
||||
height = bbox[3] - bbox[1]
|
||||
if width < 10 or height < 10:
|
||||
if width < 5 or height < 5:
|
||||
return False
|
||||
|
||||
# 检查框是否部分超出图像
|
||||
if bbox[0] < 0 or bbox[1] < 0 or bbox[2] > image_size[0] or bbox[3] > image_size[1]:
|
||||
# 检查框是否大部分超出图像
|
||||
if bbox[0] < -width/2 or bbox[1] < -height/2 or bbox[2] > image_size[0] + width/2 or bbox[3] > image_size[1] + height/2:
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
def calculate_extra_features(detections, image):
|
||||
"""
|
||||
计算额外特征,例如血条/蓝条百分比
|
||||
|
||||
参数:
|
||||
detections (list): 检测结果
|
||||
image (PIL.Image): 原始图像
|
||||
|
||||
返回:
|
||||
list: 增强的检测结果
|
||||
"""
|
||||
try:
|
||||
# 转换图像为numpy数组
|
||||
img_array = np.array(image)
|
||||
|
||||
# 增强检测结果的额外信息
|
||||
for det in detections:
|
||||
# 处理血条
|
||||
if det["class_name"] == "hp_bar":
|
||||
det["percent"] = estimate_bar_percent(img_array, det["bbox"], "hp")
|
||||
|
||||
# 处理蓝条
|
||||
elif det["class_name"] == "mp_bar":
|
||||
det["percent"] = estimate_bar_percent(img_array, det["bbox"], "mp")
|
||||
|
||||
# 处理技能冷却
|
||||
elif det["class_name"] == "cooldown":
|
||||
# 如果有文本,尝试提取技能ID和剩余时间
|
||||
if "text" in det:
|
||||
# 简单示例,真实情况可能需要OCR
|
||||
det["skill_id"] = det.get("text", "unknown").split("_")[0]
|
||||
det["remaining_time"] = 1.0 # 假设值
|
||||
|
||||
# 处理物品
|
||||
elif det["class_name"] == "item":
|
||||
# 尝试识别物品稀有度(通过颜色)
|
||||
det["rarity"] = estimate_item_rarity(img_array, det["bbox"])
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算额外特征时出错: {e}")
|
||||
return detections
|
||||
|
||||
def estimate_bar_percent(img_array, bbox, bar_type):
|
||||
"""
|
||||
估计血条/蓝条的百分比 - 优化版
|
||||
|
||||
参数:
|
||||
img_array (numpy.ndarray): 图像数组
|
||||
bbox (list): 边界框 [x1, y1, x2, y2]
|
||||
bar_type (str): 条形类型 ("hp" 或 "mp")
|
||||
|
||||
返回:
|
||||
float: 百分比值(0-100)
|
||||
"""
|
||||
try:
|
||||
# 确保边界框在图像范围内
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox]
|
||||
height, width = img_array.shape[:2]
|
||||
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(width - 1, x2)
|
||||
y2 = min(height - 1, y2)
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
return 50.0 # 默认值
|
||||
|
||||
# 提取条形区域
|
||||
bar_region = img_array[y1:y2, x1:x2]
|
||||
|
||||
# 根据条形类型设置颜色阈值
|
||||
if bar_type == "hp":
|
||||
# 红色血条,查找红色像素
|
||||
lower_threshold = np.array([150, 0, 0])
|
||||
upper_threshold = np.array([255, 70, 70])
|
||||
elif bar_type == "mp":
|
||||
# 蓝色蓝条,查找蓝色像素
|
||||
lower_threshold = np.array([0, 0, 150])
|
||||
upper_threshold = np.array([70, 70, 255])
|
||||
else:
|
||||
return 50.0 # 默认值
|
||||
|
||||
# 创建掩码
|
||||
mask = np.all((bar_region >= lower_threshold) & (bar_region <= upper_threshold), axis=2)
|
||||
|
||||
# 计算比例(假设条形从左到右填充)
|
||||
bar_width = x2 - x1
|
||||
|
||||
# 找出最右边的填充像素
|
||||
filled_width = 0
|
||||
for col in range(bar_width):
|
||||
if np.any(mask[:, col]):
|
||||
filled_width = col + 1
|
||||
|
||||
# 计算百分比
|
||||
percent = (filled_width / bar_width) * 100
|
||||
|
||||
return max(0, min(100, percent)) # 确保在0-100范围内
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"估计条形百分比时出错: {e}")
|
||||
return 50.0 # 出错时返回默认值
|
||||
|
||||
def estimate_item_rarity(img_array, bbox):
|
||||
"""
|
||||
估计物品稀有度(通过颜色)
|
||||
|
||||
参数:
|
||||
img_array (numpy.ndarray): 图像数组
|
||||
bbox (list): 边界框 [x1, y1, x2, y2]
|
||||
|
||||
返回:
|
||||
str: 稀有度描述
|
||||
"""
|
||||
try:
|
||||
# 提取物品区域
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox]
|
||||
height, width = img_array.shape[:2]
|
||||
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(width - 1, x2)
|
||||
y2 = min(height - 1, y2)
|
||||
|
||||
# 提取边框区域(只取几个像素宽的边框)
|
||||
border_size = max(2, min(5, int((x2 - x1) * 0.1)))
|
||||
|
||||
top_border = img_array[y1:y1+border_size, x1:x2]
|
||||
bottom_border = img_array[y2-border_size:y2, x1:x2]
|
||||
left_border = img_array[y1:y2, x1:x1+border_size]
|
||||
right_border = img_array[y1:y2, x2-border_size:x2]
|
||||
|
||||
# 合并所有边框像素
|
||||
borders = np.vstack([
|
||||
top_border.reshape(-1, 3),
|
||||
bottom_border.reshape(-1, 3),
|
||||
left_border.reshape(-1, 3),
|
||||
right_border.reshape(-1, 3)
|
||||
])
|
||||
|
||||
# 计算平均颜色
|
||||
avg_color = np.mean(borders, axis=0)
|
||||
|
||||
# 根据颜色判断稀有度
|
||||
r, g, b = avg_color
|
||||
|
||||
# 简单启发式规则
|
||||
if r > 200 and g < 100 and b < 100:
|
||||
return "legendary" # 红色:传说
|
||||
elif r > 200 and g > 150 and b < 100:
|
||||
return "epic" # 橙色:史诗
|
||||
elif r < 100 and g < 100 and b > 180:
|
||||
return "rare" # 蓝色:稀有
|
||||
elif r < 100 and g > 180 and b < 100:
|
||||
return "uncommon" # 绿色:优秀
|
||||
else:
|
||||
return "common" # 白色/灰色:普通
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"估计物品稀有度时出错: {e}")
|
||||
return "unknown"
|
||||
256
server/main.py
256
server/main.py
@@ -3,6 +3,7 @@
|
||||
|
||||
"""
|
||||
服务器主模块,负责初始化和启动WebSocket服务
|
||||
优化版 - 增加更好的资源管理和错误恢复
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -10,64 +11,241 @@ import asyncio
|
||||
import ssl
|
||||
import sys
|
||||
import os
|
||||
import signal
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到系统路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from server.websocket_server import WebsocketServer
|
||||
from config.settings import SECURITY, SERVER
|
||||
from config.settings import SECURITY, SERVER, MODEL
|
||||
from models.yolo_model import YOLOModel
|
||||
from utils.security import generate_ssl_cert
|
||||
from utils.logging_utils import setup_logging
|
||||
|
||||
logger = logging.getLogger("DNFAutoCloud")
|
||||
|
||||
def start_server(host="0.0.0.0", port=8080, debug=False):
|
||||
"""启动WebSocket服务器"""
|
||||
try:
|
||||
# 初始化YOLO模型
|
||||
logger.info("正在加载YOLO模型...")
|
||||
yolo_model = YOLOModel()
|
||||
|
||||
# 创建SSL上下文(如果启用)
|
||||
ssl_context = None
|
||||
if SECURITY["ssl_enabled"]:
|
||||
if not os.path.exists(SECURITY["ssl_cert"]) or not os.path.exists(SECURITY["ssl_key"]):
|
||||
logger.info("未找到SSL证书,正在生成自签名证书...")
|
||||
# 全局变量,用于优雅关闭
|
||||
websocket_server = None
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def setup_signals():
|
||||
"""设置信号处理器"""
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"收到信号 {sig},准备关闭服务器...")
|
||||
shutdown_event.set()
|
||||
|
||||
# 注册SIGINT和SIGTERM信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
logger.info("信号处理器已设置")
|
||||
|
||||
async def cleanup():
|
||||
"""清理资源"""
|
||||
logger.info("开始清理资源...")
|
||||
|
||||
# 关闭WebSocket服务器
|
||||
global websocket_server
|
||||
if websocket_server:
|
||||
await websocket_server.stop()
|
||||
|
||||
# 强制执行垃圾回收
|
||||
gc.collect()
|
||||
|
||||
logger.info("资源清理完成")
|
||||
|
||||
async def startup_checks():
|
||||
"""启动前检查"""
|
||||
logger.info("执行启动前检查...")
|
||||
|
||||
# 检查模型权重
|
||||
if not os.path.exists(MODEL.get("weights", "")):
|
||||
logger.error(f"模型权重文件不存在: {MODEL.get('weights', '')}")
|
||||
return False
|
||||
|
||||
# 检查GPU可用性
|
||||
if MODEL.get("device", "").startswith("cuda"):
|
||||
try:
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning("已配置使用CUDA,但无法找到可用的GPU。将使用CPU模式。")
|
||||
MODEL["device"] = "cpu"
|
||||
else:
|
||||
logger.info(f"已检测到GPU: {torch.cuda.get_device_name(0)}")
|
||||
# 打印CUDA版本等详细信息
|
||||
logger.info(f"CUDA版本: {torch.version.cuda}")
|
||||
logger.info(f"可用GPU数量: {torch.cuda.device_count()}")
|
||||
except ImportError:
|
||||
logger.warning("无法导入torch模块,将使用CPU模式。")
|
||||
MODEL["device"] = "cpu"
|
||||
|
||||
# 检查SSL证书
|
||||
if SECURITY.get("ssl_enabled", False):
|
||||
if not os.path.exists(SECURITY.get("ssl_cert", "")) or not os.path.exists(SECURITY.get("ssl_key", "")):
|
||||
logger.info("未找到SSL证书,正在生成自签名证书...")
|
||||
try:
|
||||
generate_ssl_cert()
|
||||
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(SECURITY["ssl_cert"], SECURITY["ssl_key"])
|
||||
logger.info("SSL证书加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"生成SSL证书失败: {e}")
|
||||
return False
|
||||
|
||||
# 检查数据目录
|
||||
data_dirs = ["data/logs", "data/cache", "data/sessions"]
|
||||
for dir_path in data_dirs:
|
||||
try:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {dir_path}: {e}")
|
||||
return False
|
||||
|
||||
logger.info("启动前检查完成")
|
||||
return True
|
||||
|
||||
async def load_model():
|
||||
"""加载YOLO模型"""
|
||||
logger.info("正在加载YOLO模型...")
|
||||
try:
|
||||
# 使用上下文管理器捕获并正确处理CUDA内存错误
|
||||
yolo_model = YOLOModel()
|
||||
logger.info("YOLO模型加载成功")
|
||||
|
||||
# 创建事件循环
|
||||
# 打印模型性能基准
|
||||
performance = yolo_model.get_performance_stats()
|
||||
logger.info(f"模型性能基准: 平均推理时间 {performance['average_time']:.4f}秒 ({performance['fps']:.2f} FPS)")
|
||||
|
||||
return yolo_model
|
||||
except Exception as e:
|
||||
logger.error(f"加载YOLO模型失败: {e}")
|
||||
return None
|
||||
|
||||
async def save_server_info(host, port, ssl_enabled):
|
||||
"""保存服务器信息到文件,便于客户端连接"""
|
||||
info = {
|
||||
"server": {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"url": f"{'wss' if ssl_enabled else 'ws'}://{host}:{port}/ws",
|
||||
"ssl_enabled": ssl_enabled
|
||||
},
|
||||
"started_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"version": "1.0.1"
|
||||
}
|
||||
|
||||
# 保存为JSON
|
||||
try:
|
||||
with open("server_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
logger.info(f"服务器信息已保存到 server_info.json")
|
||||
except Exception as e:
|
||||
logger.error(f"保存服务器信息失败: {e}")
|
||||
|
||||
async def run_server(host, port, debug):
|
||||
"""运行服务器"""
|
||||
global websocket_server
|
||||
|
||||
# 执行启动前检查
|
||||
if not await startup_checks():
|
||||
logger.error("启动前检查失败,服务器未启动")
|
||||
return False
|
||||
|
||||
# 加载YOLO模型
|
||||
yolo_model = await load_model()
|
||||
if not yolo_model:
|
||||
logger.error("加载模型失败,服务器未启动")
|
||||
return False
|
||||
|
||||
# 创建SSL上下文(如果启用)
|
||||
ssl_context = None
|
||||
if SECURITY.get("ssl_enabled", False):
|
||||
try:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(SECURITY.get("ssl_cert", ""), SECURITY.get("ssl_key", ""))
|
||||
logger.info("SSL证书加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载SSL证书失败: {e}")
|
||||
logger.warning("将以非SSL模式启动服务器")
|
||||
|
||||
try:
|
||||
# 创建WebSocket服务器
|
||||
websocket_server = WebsocketServer(yolo_model, host, port, ssl_context)
|
||||
|
||||
# 启动服务器
|
||||
await websocket_server.start()
|
||||
|
||||
# 保存服务器信息
|
||||
await save_server_info(host, port, SECURITY.get("ssl_enabled", False))
|
||||
|
||||
# 等待关闭信号
|
||||
await shutdown_event.wait()
|
||||
|
||||
logger.info("收到关闭信号,正在停止服务器...")
|
||||
|
||||
# 清理资源
|
||||
await cleanup()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"运行服务器时出错: {e}")
|
||||
return False
|
||||
|
||||
async def start_server_async(host="0.0.0.0", port=8080, debug=False):
|
||||
"""异步启动服务器"""
|
||||
# 设置信号处理器
|
||||
setup_signals()
|
||||
|
||||
# 运行服务器
|
||||
success = await run_server(host, port, debug)
|
||||
|
||||
return success
|
||||
|
||||
def start_server(host="0.0.0.0", port=8080, debug=False):
|
||||
"""同步启动服务器"""
|
||||
# 设置日志
|
||||
setup_logging(debug=debug)
|
||||
|
||||
# 打印启动信息
|
||||
logger.info("=" * 50)
|
||||
logger.info("正在启动DNF自动化云服务...")
|
||||
logger.info(f"主机: {host}, 端口: {port}, 调试模式: {'开启' if debug else '关闭'}")
|
||||
logger.info(f"服务器版本: 1.0.1")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
# 获取事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
if debug:
|
||||
loop.set_debug(True)
|
||||
logger.info("调试模式已启用")
|
||||
|
||||
# 创建WebSocket服务器
|
||||
server = WebsocketServer(yolo_model, host, port, ssl_context)
|
||||
# 运行服务器
|
||||
success = loop.run_until_complete(start_server_async(host, port, debug))
|
||||
|
||||
# 启动服务器
|
||||
logger.info(f"正在启动WebSocket服务器: {host}:{port}")
|
||||
loop.run_until_complete(server.start())
|
||||
# 如果服务器启动成功,运行事件循环直到关闭
|
||||
if success:
|
||||
loop.run_until_complete(shutdown_event.wait())
|
||||
|
||||
# 运行服务器直到被中断
|
||||
logger.info("服务器启动成功,等待连接...")
|
||||
loop.run_forever()
|
||||
# 清理资源
|
||||
loop.run_until_complete(cleanup())
|
||||
|
||||
# 关闭事件循环
|
||||
loop.close()
|
||||
|
||||
logger.info("服务器已关闭")
|
||||
return success
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("服务器收到终止信号,正在关闭...")
|
||||
logger.info("接收到用户中断,正在关闭...")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"服务器启动失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 清理资源
|
||||
tasks = asyncio.all_tasks(loop=loop)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
logger.info("正在关闭事件循环...")
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
logger.info("服务器已关闭")
|
||||
logger.error(f"启动服务器时出错: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 直接运行此脚本时的默认行为
|
||||
start_server(host="0.0.0.0", port=8080, debug=True)
|
||||
@@ -2,7 +2,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
WebSocket服务器实现,处理客户端连接和通信 - Linux兼容版
|
||||
WebSocket服务器实现,处理客户端连接和通信 - 优化版本
|
||||
增强稳定性、连接管理和会话持久化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -11,11 +12,14 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import threading
|
||||
|
||||
import websockets
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from config.settings import SERVER, SECURITY
|
||||
from server.image_processor import process_image
|
||||
@@ -26,7 +30,7 @@ from utils.human_behavior import generate_human_delay
|
||||
logger = logging.getLogger("DNFAutoCloud")
|
||||
|
||||
class WebsocketServer:
|
||||
"""WebSocket服务器类 - Linux兼容版"""
|
||||
"""WebSocket服务器类 - 优化版本"""
|
||||
|
||||
def __init__(self, yolo_model, host="0.0.0.0", port=8080, ssl_context=None):
|
||||
"""初始化WebSocket服务器"""
|
||||
@@ -35,38 +39,136 @@ class WebsocketServer:
|
||||
self.port = port
|
||||
self.ssl_context = ssl_context
|
||||
self.clients = {} # 客户端连接管理
|
||||
self.client_sessions = {} # 客户端会话数据
|
||||
self.start_time = datetime.now()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.stats = {
|
||||
"total_connections": 0,
|
||||
"active_connections": 0,
|
||||
"images_processed": 0,
|
||||
"actions_generated": 0,
|
||||
"errors": 0
|
||||
}
|
||||
|
||||
# 客户端监控线程
|
||||
self.monitoring_task = None
|
||||
|
||||
# 保存最近处理的图像(用于调试)
|
||||
self.debug_mode = os.environ.get("DNF_DEBUG", "0") == "1"
|
||||
if self.debug_mode:
|
||||
os.makedirs("debug/images", exist_ok=True)
|
||||
|
||||
async def start(self):
|
||||
"""启动WebSocket服务器"""
|
||||
return await websockets.serve(
|
||||
ws_server = await websockets.serve(
|
||||
self.handle_connection,
|
||||
self.host,
|
||||
self.port,
|
||||
ssl=self.ssl_context,
|
||||
max_size=10 * 1024 * 1024, # 10MB最大消息大小
|
||||
ping_interval=SERVER["heartbeat_interval"],
|
||||
ping_timeout=SERVER["timeout"]
|
||||
ping_interval=SERVER.get("heartbeat_interval", 5),
|
||||
ping_timeout=SERVER.get("timeout", 60),
|
||||
close_timeout=10
|
||||
)
|
||||
|
||||
# 启动客户端监控
|
||||
self.monitoring_task = asyncio.create_task(self.monitor_clients())
|
||||
|
||||
logger.info(f"WebSocket服务器已启动: {self.host}:{self.port}")
|
||||
return ws_server
|
||||
|
||||
async def handle_connection(self, websocket, path):
|
||||
async def stop(self):
|
||||
"""停止WebSocket服务器"""
|
||||
logger.info("正在停止WebSocket服务器...")
|
||||
|
||||
# 设置关闭事件
|
||||
self.shutdown_event.set()
|
||||
|
||||
# 关闭所有客户端连接
|
||||
close_tasks = []
|
||||
for client_id, info in list(self.clients.items()):
|
||||
if "websocket" in info:
|
||||
try:
|
||||
close_tasks.append(info["websocket"].close(1001, "服务器关闭"))
|
||||
except:
|
||||
pass
|
||||
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
|
||||
# 停止监控任务
|
||||
if self.monitoring_task:
|
||||
self.monitoring_task.cancel()
|
||||
try:
|
||||
await self.monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("WebSocket服务器已停止")
|
||||
|
||||
async def monitor_clients(self):
|
||||
"""监控客户端连接状态"""
|
||||
inactive_timeout = SERVER.get("inactive_timeout", 300) # 默认5分钟
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# 检查不活跃的客户端
|
||||
for client_id, info in list(self.clients.items()):
|
||||
if "last_activity" in info:
|
||||
inactive_time = current_time - info["last_activity"]
|
||||
|
||||
# 清理超时客户端
|
||||
if inactive_time > inactive_timeout:
|
||||
logger.info(f"客户端 {client_id} 不活跃超过 {inactive_timeout} 秒,断开连接")
|
||||
try:
|
||||
if "websocket" in info:
|
||||
await info["websocket"].close(1000, "不活跃超时")
|
||||
except:
|
||||
pass
|
||||
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
self.stats["active_connections"] -= 1
|
||||
|
||||
# 更新服务器状态
|
||||
self.stats["active_connections"] = len(self.clients)
|
||||
|
||||
# 等待下一个检查周期
|
||||
await asyncio.sleep(60) # 每分钟检查一次
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"监控客户端出错: {e}")
|
||||
await asyncio.sleep(60) # 出错后等待一分钟再重试
|
||||
|
||||
async def handle_connection(self, websocket, path=None):
|
||||
"""处理新的WebSocket连接"""
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 创建客户端信息
|
||||
client_info = {
|
||||
"id": client_id,
|
||||
"connected_at": datetime.now(),
|
||||
"remote": websocket.remote_address,
|
||||
"last_activity": time.time()
|
||||
"last_activity": time.time(),
|
||||
"websocket": websocket,
|
||||
"authenticated": False
|
||||
}
|
||||
|
||||
# 限制最大连接数
|
||||
if len(self.clients) >= SERVER["max_connections"]:
|
||||
if len(self.clients) >= SERVER.get("max_connections", 10):
|
||||
logger.warning(f"达到最大连接数限制,拒绝新连接: {client_info['remote']}")
|
||||
await websocket.close(1013, "服务器连接数已满")
|
||||
return
|
||||
|
||||
# 添加到客户端列表
|
||||
self.clients[client_id] = client_info
|
||||
self.stats["total_connections"] += 1
|
||||
self.stats["active_connections"] += 1
|
||||
|
||||
logger.info(f"新客户端连接: {client_id} 来自 {client_info['remote']}")
|
||||
|
||||
try:
|
||||
@@ -83,10 +185,12 @@ class WebsocketServer:
|
||||
logger.info(f"客户端连接关闭: {client_id}, 代码: {e.code}, 原因: {e.reason}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端连接时出错: {client_id}, 错误: {e}")
|
||||
self.stats["errors"] += 1
|
||||
finally:
|
||||
# 清理客户端连接
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
self.stats["active_connections"] -= 1
|
||||
logger.info(f"客户端连接已关闭: {client_id}")
|
||||
|
||||
async def _authenticate(self, websocket, client_id):
|
||||
@@ -96,13 +200,14 @@ class WebsocketServer:
|
||||
challenge = str(uuid.uuid4())
|
||||
await websocket.send(json.dumps({
|
||||
"type": "auth_challenge",
|
||||
"challenge": challenge
|
||||
"challenge": challenge,
|
||||
"timestamp": time.time()
|
||||
}))
|
||||
|
||||
# 等待认证响应
|
||||
response_raw = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=SERVER["timeout"]
|
||||
timeout=SERVER.get("timeout", 60)
|
||||
)
|
||||
|
||||
# 检查认证响应
|
||||
@@ -112,23 +217,37 @@ class WebsocketServer:
|
||||
await websocket.close(1008, "认证格式错误")
|
||||
return False
|
||||
|
||||
# 这里可以实现更复杂的认证逻辑
|
||||
# 简单示例仅检查是否有响应
|
||||
if "response" in response:
|
||||
# 认证成功
|
||||
self.clients[client_id]["authenticated"] = True
|
||||
await websocket.send(json.dumps({
|
||||
"type": "auth_result",
|
||||
"status": "success",
|
||||
"client_id": client_id
|
||||
}))
|
||||
logger.info(f"客户端认证成功: {client_id}")
|
||||
return True
|
||||
else:
|
||||
# 认证失败
|
||||
await websocket.close(1008, "认证失败")
|
||||
logger.warning(f"客户端认证失败: {client_id}")
|
||||
return False
|
||||
# 获取客户端信息
|
||||
client_info = response.get("client_info", {})
|
||||
|
||||
# 保存会话信息
|
||||
self.client_sessions[client_id] = {
|
||||
"client_info": client_info,
|
||||
"game_state": {},
|
||||
"last_actions": [],
|
||||
"statistics": {
|
||||
"images_processed": 0,
|
||||
"actions_sent": 0,
|
||||
"errors": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 认证成功
|
||||
self.clients[client_id]["authenticated"] = True
|
||||
self.clients[client_id]["client_info"] = client_info
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "auth_result",
|
||||
"status": "success",
|
||||
"client_id": client_id,
|
||||
"server_info": {
|
||||
"version": "1.0.1",
|
||||
"uptime": (datetime.now() - self.start_time).total_seconds()
|
||||
}
|
||||
}))
|
||||
|
||||
logger.info(f"客户端认证成功: {client_id}, 版本: {client_info.get('version', 'unknown')}")
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"客户端认证超时: {client_id}")
|
||||
@@ -137,11 +256,12 @@ class WebsocketServer:
|
||||
except Exception as e:
|
||||
logger.error(f"客户端认证出错: {client_id}, 错误: {e}")
|
||||
await websocket.close(1011, "认证处理错误")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
async def _handle_messages(self, websocket, client_id):
|
||||
"""处理客户端消息"""
|
||||
while True:
|
||||
while not self.shutdown_event.is_set():
|
||||
# 接收消息
|
||||
message_raw = await websocket.recv()
|
||||
|
||||
@@ -150,7 +270,7 @@ class WebsocketServer:
|
||||
|
||||
try:
|
||||
# 解析消息(可能需要解密)
|
||||
if SECURITY["encryption_enabled"]:
|
||||
if SECURITY.get("encryption_enabled", False):
|
||||
message_data = decrypt_message(message_raw)
|
||||
else:
|
||||
message_data = json.loads(message_raw)
|
||||
@@ -163,7 +283,7 @@ class WebsocketServer:
|
||||
await self._handle_image_request(websocket, client_id, message_data)
|
||||
elif msg_type == "heartbeat":
|
||||
# 处理心跳消息
|
||||
await self._handle_heartbeat(websocket, client_id)
|
||||
await self._handle_heartbeat(websocket, client_id, message_data)
|
||||
else:
|
||||
# 未知消息类型
|
||||
logger.warning(f"收到未知消息类型: {msg_type} 来自客户端: {client_id}")
|
||||
@@ -180,6 +300,7 @@ class WebsocketServer:
|
||||
"error": "invalid_json",
|
||||
"message": "无效的JSON格式"
|
||||
}))
|
||||
self.stats["errors"] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错,来自客户端: {client_id}, 错误: {e}")
|
||||
await websocket.send(json.dumps({
|
||||
@@ -187,44 +308,76 @@ class WebsocketServer:
|
||||
"error": "processing_error",
|
||||
"message": f"处理消息时出错: {str(e)}"
|
||||
}))
|
||||
self.stats["errors"] += 1
|
||||
|
||||
async def _handle_image_request(self, websocket, client_id, message_data):
|
||||
"""处理图像识别请求"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 提取图像数据
|
||||
image_base64 = message_data.get("data", "")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
# 记录请求
|
||||
logger.debug(f"收到图像识别请求,来自客户端: {client_id}, 图像尺寸: {image.size}")
|
||||
# 提取游戏状态
|
||||
game_state = message_data.get("game_state", {})
|
||||
window_rect = message_data.get("window_rect", [0, 0, 0, 0])
|
||||
|
||||
# 更新会话中的游戏状态
|
||||
if client_id in self.client_sessions:
|
||||
self.client_sessions[client_id]["game_state"] = game_state
|
||||
|
||||
# 保存调试图像
|
||||
if self.debug_mode:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
debug_path = f"debug/images/{client_id}_{timestamp}.jpg"
|
||||
try:
|
||||
image.save(debug_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 处理图像并进行检测
|
||||
detections = process_image(self.yolo_model, image)
|
||||
|
||||
# 根据检测结果生成动作
|
||||
game_state = message_data.get("game_state", {})
|
||||
actions = generate_actions(detections, game_state)
|
||||
|
||||
# 生成人类化延迟
|
||||
delay = generate_human_delay()
|
||||
|
||||
# 追踪处理时间
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 构建响应
|
||||
response = {
|
||||
"type": "action_response",
|
||||
"request_id": message_data.get("request_id"),
|
||||
"timestamp": time.time(),
|
||||
"actions": actions,
|
||||
"delay": delay
|
||||
"detections": detections, # 返回检测结果供客户端分析
|
||||
"delay": delay,
|
||||
"processing_time": processing_time
|
||||
}
|
||||
|
||||
# 发送响应(可能需要加密)
|
||||
if SECURITY["encryption_enabled"]:
|
||||
if SECURITY.get("encryption_enabled", False):
|
||||
await websocket.send(encrypt_message(response))
|
||||
else:
|
||||
await websocket.send(json.dumps(response))
|
||||
|
||||
logger.debug(f"已发送动作响应,客户端: {client_id}, 动作数: {len(actions)}")
|
||||
# 更新统计信息
|
||||
self.stats["images_processed"] += 1
|
||||
self.stats["actions_generated"] += len(actions)
|
||||
|
||||
if client_id in self.client_sessions:
|
||||
self.client_sessions[client_id]["statistics"]["images_processed"] += 1
|
||||
self.client_sessions[client_id]["statistics"]["actions_sent"] += len(actions)
|
||||
self.client_sessions[client_id]["last_actions"] = actions
|
||||
|
||||
# 记录处理信息
|
||||
if len(actions) > 0:
|
||||
logger.debug(f"已发送 {len(actions)} 个动作给客户端: {client_id}, 处理时间: {processing_time:.3f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图像请求时出错,客户端: {client_id}, 错误: {e}")
|
||||
@@ -234,20 +387,58 @@ class WebsocketServer:
|
||||
"request_id": message_data.get("request_id"),
|
||||
"message": f"处理图像时出错: {str(e)}"
|
||||
}))
|
||||
self.stats["errors"] += 1
|
||||
if client_id in self.client_sessions:
|
||||
self.client_sessions[client_id]["statistics"]["errors"] += 1
|
||||
|
||||
async def _handle_heartbeat(self, websocket, client_id):
|
||||
async def _handle_heartbeat(self, websocket, client_id, message_data):
|
||||
"""处理心跳消息"""
|
||||
try:
|
||||
# 提取客户端游戏状态
|
||||
if "game_state" in message_data and client_id in self.client_sessions:
|
||||
self.client_sessions[client_id]["game_state"].update(message_data["game_state"])
|
||||
|
||||
# 构建心跳响应
|
||||
response = {
|
||||
"type": "heartbeat_response",
|
||||
"timestamp": time.time(),
|
||||
"server_uptime": (datetime.now() - self.start_time).total_seconds()
|
||||
"server_uptime": (datetime.now() - self.start_time).total_seconds(),
|
||||
"server_stats": {
|
||||
"connections": len(self.clients),
|
||||
"images_processed": self.stats["images_processed"],
|
||||
"actions_generated": self.stats["actions_generated"]
|
||||
}
|
||||
}
|
||||
|
||||
# 发送响应
|
||||
await websocket.send(json.dumps(response))
|
||||
logger.debug(f"已发送心跳响应,客户端: {client_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理心跳消息时出错,客户端: {client_id}, 错误: {e}")
|
||||
logger.error(f"处理心跳消息时出错,客户端: {client_id}, 错误: {e}")
|
||||
self.stats["errors"] += 1
|
||||
|
||||
def get_server_stats(self):
|
||||
"""获取服务器统计信息"""
|
||||
return {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"uptime": (datetime.now() - self.start_time).total_seconds(),
|
||||
"connections": {
|
||||
"total": self.stats["total_connections"],
|
||||
"active": self.stats["active_connections"]
|
||||
},
|
||||
"performance": {
|
||||
"images_processed": self.stats["images_processed"],
|
||||
"actions_generated": self.stats["actions_generated"],
|
||||
"errors": self.stats["errors"]
|
||||
},
|
||||
"clients": [
|
||||
{
|
||||
"id": client_id,
|
||||
"connected_at": info["connected_at"].isoformat(),
|
||||
"remote": info["remote"],
|
||||
"authenticated": info["authenticated"],
|
||||
"version": info.get("client_info", {}).get("version", "unknown")
|
||||
}
|
||||
for client_id, info in self.clients.items()
|
||||
]
|
||||
}
|
||||
10
server_info.json
Normal file
10
server_info.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"server": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 8080,
|
||||
"url": "ws://0.0.0.0:8080/ws",
|
||||
"ssl_enabled": false
|
||||
},
|
||||
"started_at": "2025-03-26 17:52:39",
|
||||
"version": "1.0.1"
|
||||
}
|
||||
@@ -3,33 +3,161 @@
|
||||
|
||||
"""
|
||||
人类行为模拟工具,提供模拟人类操作的功能
|
||||
优化版 - 更自然的行为模式和个性化特征
|
||||
"""
|
||||
|
||||
import random
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from config.settings import BEHAVIOR
|
||||
|
||||
# 不同用户的行为模式
|
||||
USER_PROFILES = {
|
||||
"fast": {
|
||||
"delay_mean": 0.15,
|
||||
"delay_std": 0.05,
|
||||
"movement_smoothness": 0.7,
|
||||
"click_accuracy": 0.9,
|
||||
"double_click_chance": 0.05,
|
||||
"description": "快速玩家 - 反应迅速,操作精准"
|
||||
},
|
||||
"normal": {
|
||||
"delay_mean": 0.25,
|
||||
"delay_std": 0.1,
|
||||
"movement_smoothness": 0.8,
|
||||
"click_accuracy": 0.85,
|
||||
"double_click_chance": 0.1,
|
||||
"description": "普通玩家 - 中等速度,操作稳定"
|
||||
},
|
||||
"casual": {
|
||||
"delay_mean": 0.4,
|
||||
"delay_std": 0.15,
|
||||
"movement_smoothness": 0.6,
|
||||
"click_accuracy": 0.7,
|
||||
"double_click_chance": 0.15,
|
||||
"description": "休闲玩家 - 操作较慢,精确度一般"
|
||||
}
|
||||
}
|
||||
|
||||
# 当前用户配置文件
|
||||
current_profile = "normal"
|
||||
|
||||
# 用户注意力随时间变化的模拟
|
||||
attention_level = 1.0
|
||||
last_attention_update = time.time()
|
||||
|
||||
# 活动模式(早上、下午、晚上)
|
||||
activity_mode = "normal"
|
||||
|
||||
def set_user_profile(profile_name):
|
||||
"""
|
||||
设置用户行为配置文件
|
||||
|
||||
参数:
|
||||
profile_name (str): 配置文件名称 ("fast", "normal", "casual")
|
||||
"""
|
||||
global current_profile
|
||||
if profile_name in USER_PROFILES:
|
||||
current_profile = profile_name
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_attention_level():
|
||||
"""更新用户注意力水平,模拟人类随时间的疲劳"""
|
||||
global attention_level, last_attention_update
|
||||
|
||||
current_time = time.time()
|
||||
elapsed = current_time - last_attention_update
|
||||
|
||||
# 随着时间推移,注意力逐渐下降
|
||||
if elapsed > 60: # 每分钟更新一次
|
||||
# 注意力随时间缓慢下降(0.5%/分钟)
|
||||
attention_decay = 0.005 * (elapsed / 60)
|
||||
|
||||
# 随机波动(+/- 2%)
|
||||
attention_fluctuation = random.uniform(-0.02, 0.02)
|
||||
|
||||
# 更新注意力水平
|
||||
attention_level = max(0.7, min(1.0, attention_level - attention_decay + attention_fluctuation))
|
||||
|
||||
# 重置时间戳
|
||||
last_attention_update = current_time
|
||||
|
||||
return attention_level
|
||||
|
||||
def set_activity_mode():
|
||||
"""根据一天中的时间设置活动模式"""
|
||||
global activity_mode
|
||||
|
||||
# 获取当前小时
|
||||
current_hour = datetime.now().hour
|
||||
|
||||
# 根据时间段设置模式
|
||||
if 5 <= current_hour < 9: # 早晨
|
||||
activity_mode = "morning"
|
||||
elif 9 <= current_hour < 17: # 工作时间
|
||||
activity_mode = "normal"
|
||||
elif 17 <= current_hour < 22: # 晚上
|
||||
activity_mode = "evening"
|
||||
else: # 深夜
|
||||
activity_mode = "night"
|
||||
|
||||
return activity_mode
|
||||
|
||||
def generate_human_delay():
|
||||
"""
|
||||
生成人类化的延迟时间
|
||||
生成人类化的延迟时间 - 优化版
|
||||
|
||||
返回:
|
||||
float: 延迟时间(秒)
|
||||
"""
|
||||
# 获取当前用户配置
|
||||
profile = USER_PROFILES[current_profile]
|
||||
|
||||
# 更新注意力水平
|
||||
attention = update_attention_level()
|
||||
|
||||
# 检查活动模式
|
||||
mode = set_activity_mode()
|
||||
|
||||
# 基础延迟参数
|
||||
delay_mean = profile["delay_mean"]
|
||||
delay_std = profile["delay_std"]
|
||||
|
||||
# 根据注意力调整
|
||||
# 注意力下降,延迟增加
|
||||
delay_mean = delay_mean * (2 - attention)
|
||||
|
||||
# 根据活动模式调整
|
||||
if mode == "morning":
|
||||
# 早晨略慢
|
||||
delay_mean *= 1.1
|
||||
elif mode == "evening":
|
||||
# 晚上正常
|
||||
pass
|
||||
elif mode == "night":
|
||||
# 深夜反应较慢
|
||||
delay_mean *= 1.2
|
||||
delay_std *= 1.2
|
||||
|
||||
# 偶尔的停顿(思考时间)
|
||||
if random.random() < 0.05: # 5%几率
|
||||
thinking_time = random.uniform(0.5, 1.5)
|
||||
return delay_mean + thinking_time
|
||||
|
||||
# 使用正态分布生成随机延迟
|
||||
mean_delay = (BEHAVIOR["min_delay"] + BEHAVIOR["max_delay"]) / 2
|
||||
std_dev = (BEHAVIOR["max_delay"] - BEHAVIOR["min_delay"]) / 6 # 使99.7%的值在范围内
|
||||
delay = random.normalvariate(delay_mean, delay_std)
|
||||
|
||||
delay = random.normalvariate(mean_delay, std_dev)
|
||||
|
||||
# 确保在设定范围内
|
||||
return max(BEHAVIOR["min_delay"], min(BEHAVIOR["max_delay"], delay))
|
||||
# 确保在合理范围内
|
||||
min_delay = BEHAVIOR.get("min_delay", 0.1)
|
||||
max_delay = BEHAVIOR.get("max_delay", 0.8)
|
||||
return max(min_delay, min(max_delay, delay))
|
||||
|
||||
def generate_human_movement_path(start_pos, end_pos, steps=None):
|
||||
"""
|
||||
生成模拟人类的鼠标移动路径
|
||||
生成模拟人类的鼠标移动路径 - 优化版
|
||||
|
||||
参数:
|
||||
start_pos (list): 起始位置 [x, y]
|
||||
@@ -39,41 +167,99 @@ def generate_human_movement_path(start_pos, end_pos, steps=None):
|
||||
返回:
|
||||
list: 路径点列表 [[x1, y1], [x2, y2], ...]
|
||||
"""
|
||||
# 获取当前用户配置
|
||||
profile = USER_PROFILES[current_profile]
|
||||
smoothness = profile["movement_smoothness"]
|
||||
|
||||
# 计算距离
|
||||
distance = math.sqrt((end_pos[0] - start_pos[0])**2 + (end_pos[1] - start_pos[1])**2)
|
||||
|
||||
# 如果未指定步数,根据距离计算
|
||||
# 如果距离太短,使用直线路径
|
||||
if distance < 10:
|
||||
return [start_pos, end_pos]
|
||||
|
||||
# 如果未指定步数,根据距离和速度计算
|
||||
if steps is None:
|
||||
steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点
|
||||
# 根据距离、平滑度和注意力计算步数
|
||||
attention = update_attention_level()
|
||||
base_steps = max(int(distance / 20), 5) # 每20像素一个点,至少5个点
|
||||
|
||||
# 平滑度越高,步数越多;注意力越高,步数越多
|
||||
steps = int(base_steps * smoothness * attention)
|
||||
steps = max(5, min(30, steps)) # 限制在合理范围内
|
||||
|
||||
# 生成基础路径
|
||||
t = np.linspace(0, 1, steps)
|
||||
path = []
|
||||
|
||||
# 添加曲率 - 模拟手腕运动
|
||||
# 贝塞尔曲线的控制点
|
||||
control_x = start_pos[0] + (end_pos[0] - start_pos[0]) / 2
|
||||
control_y = start_pos[1] + (end_pos[1] - start_pos[1]) / 2
|
||||
|
||||
# 添加随机偏移到控制点
|
||||
offset_factor = (1.0 - smoothness) * 0.5 # 平滑度越低,偏移越大
|
||||
max_offset = distance * offset_factor
|
||||
control_x += random.normalvariate(0, max_offset / 2)
|
||||
control_y += random.normalvariate(0, max_offset / 2)
|
||||
|
||||
for i in range(steps):
|
||||
# 基础线性插值
|
||||
x = start_pos[0] + (end_pos[0] - start_pos[0]) * t[i]
|
||||
y = start_pos[1] + (end_pos[1] - start_pos[1]) * t[i]
|
||||
# 二次贝塞尔曲线
|
||||
t_i = t[i]
|
||||
x = (1 - t_i)**2 * start_pos[0] + 2 * (1 - t_i) * t_i * control_x + t_i**2 * end_pos[0]
|
||||
y = (1 - t_i)**2 * start_pos[1] + 2 * (1 - t_i) * t_i * control_y + t_i**2 * end_pos[1]
|
||||
|
||||
# 添加随机偏移(越靠近中间偏移越大)
|
||||
mid_factor = 4 * t[i] * (1 - t[i]) # 在中间最大
|
||||
max_offset = distance * 0.05 * mid_factor # 最大偏移为距离的5%
|
||||
# 添加细微的随机抖动(手部微小颤抖)
|
||||
jitter_factor = (1.0 - smoothness) * 0.02 * distance
|
||||
jitter_x = random.normalvariate(0, jitter_factor)
|
||||
jitter_y = random.normalvariate(0, jitter_factor)
|
||||
|
||||
offset_x = random.normalvariate(0, max_offset / 3)
|
||||
offset_y = random.normalvariate(0, max_offset / 3)
|
||||
# 注意力越低,抖动越大
|
||||
attention = update_attention_level()
|
||||
jitter_x *= (2 - attention)
|
||||
jitter_y *= (2 - attention)
|
||||
|
||||
# 添加到路径
|
||||
path.append([x + offset_x, y + offset_y])
|
||||
path.append([x + jitter_x, y + jitter_y])
|
||||
|
||||
# 确保起点和终点准确
|
||||
path[0] = start_pos.copy()
|
||||
path[-1] = end_pos.copy()
|
||||
|
||||
# 模拟加速和减速
|
||||
# 路径采样 - 开始慢,中间快,结束慢
|
||||
if steps > 10:
|
||||
resampled_path = []
|
||||
resampled_path.append(path[0]) # 起点
|
||||
|
||||
# 前20%缓慢加速
|
||||
accel_end = int(steps * 0.2)
|
||||
for i in range(1, accel_end):
|
||||
resampled_path.append(path[i])
|
||||
|
||||
# 中间60%快速移动(跳过一些点)
|
||||
mid_start = accel_end
|
||||
mid_end = int(steps * 0.8)
|
||||
|
||||
# 根据平滑度决定中间部分的采样率
|
||||
skip_factor = int(2 + (1 - smoothness) * 3)
|
||||
for i in range(mid_start, mid_end, skip_factor):
|
||||
resampled_path.append(path[min(i, steps - 1)])
|
||||
|
||||
# 最后20%缓慢减速
|
||||
for i in range(mid_end, steps):
|
||||
resampled_path.append(path[i])
|
||||
|
||||
if resampled_path[-1] != end_pos:
|
||||
resampled_path.append(end_pos) # 确保终点
|
||||
|
||||
return resampled_path
|
||||
|
||||
return path
|
||||
|
||||
def generate_human_click_offset(target_pos, target_size=(50, 50)):
|
||||
"""
|
||||
生成人类化的点击位置偏移
|
||||
生成人类化的点击位置偏移 - 优化版
|
||||
|
||||
参数:
|
||||
target_pos (list): 目标中心位置 [x, y]
|
||||
@@ -82,13 +268,23 @@ def generate_human_click_offset(target_pos, target_size=(50, 50)):
|
||||
返回:
|
||||
list: 带偏移的点击位置 [x, y]
|
||||
"""
|
||||
# 获取当前用户配置
|
||||
profile = USER_PROFILES[current_profile]
|
||||
accuracy = profile["click_accuracy"]
|
||||
|
||||
# 更新注意力水平
|
||||
attention = update_attention_level()
|
||||
|
||||
# 综合准确度和注意力
|
||||
effective_accuracy = accuracy * attention
|
||||
|
||||
# 计算最大偏移(不超过目标大小的一半)
|
||||
max_offset_x = min(target_size[0] / 2 * 0.8, 10) # 最大偏移不超过10像素
|
||||
max_offset_y = min(target_size[1] / 2 * 0.8, 10)
|
||||
max_offset_x = min(target_size[0] / 2 * (1 - effective_accuracy) * 1.5, 15)
|
||||
max_offset_y = min(target_size[1] / 2 * (1 - effective_accuracy) * 1.5, 15)
|
||||
|
||||
# 生成偏移(中心位置概率更高)
|
||||
offset_x = random.normalvariate(0, max_offset_x / 3)
|
||||
offset_y = random.normalvariate(0, max_offset_y / 3)
|
||||
offset_x = random.normalvariate(0, max_offset_x / 2)
|
||||
offset_y = random.normalvariate(0, max_offset_y / 2)
|
||||
|
||||
# 应用偏移
|
||||
click_pos = [
|
||||
@@ -98,22 +294,137 @@ def generate_human_click_offset(target_pos, target_size=(50, 50)):
|
||||
|
||||
return click_pos
|
||||
|
||||
def generate_typing_speed(base_cps=5.0, variance=0.2):
|
||||
def should_double_click():
|
||||
"""
|
||||
生成人类化的打字速度
|
||||
确定是否应该双击
|
||||
|
||||
返回:
|
||||
bool: 是否应该双击
|
||||
"""
|
||||
profile = USER_PROFILES[current_profile]
|
||||
return random.random() < profile["double_click_chance"]
|
||||
|
||||
def generate_typing_speed(text_length, base_cps=5.0):
|
||||
"""
|
||||
生成人类化的打字速度序列
|
||||
|
||||
参数:
|
||||
text_length (int): 文本长度
|
||||
base_cps (float): 基础字符每秒速度
|
||||
variance (float): 速度方差
|
||||
|
||||
返回:
|
||||
float: 字符间延迟时间(秒)
|
||||
list: 每个字符的延迟时间列表
|
||||
"""
|
||||
# 获取当前用户配置和注意力
|
||||
profile = USER_PROFILES[current_profile]
|
||||
attention = update_attention_level()
|
||||
|
||||
# 基础打字速度 - 根据配置文件调整
|
||||
if current_profile == "fast":
|
||||
base_cps = 8.0
|
||||
elif current_profile == "casual":
|
||||
base_cps = 3.5
|
||||
|
||||
# 根据注意力调整速度
|
||||
base_cps = base_cps * attention
|
||||
|
||||
# 生成每个字符的延迟
|
||||
delays = []
|
||||
current_speed = base_cps
|
||||
|
||||
for i in range(text_length):
|
||||
# 随机波动
|
||||
speed_variance = 0.2 * base_cps
|
||||
current_speed = random.normalvariate(base_cps, speed_variance)
|
||||
current_speed = max(base_cps * 0.5, current_speed) # 确保速度不会太慢
|
||||
|
||||
# 将速度转换为延迟
|
||||
delay = 1.0 / current_speed
|
||||
|
||||
# 某些特殊情况
|
||||
if random.random() < 0.05: # 偶尔停顿
|
||||
delay += random.uniform(0.2, 0.7)
|
||||
|
||||
if i > 0 and i % 10 == 0 and random.random() < 0.2: # 句子结束停顿
|
||||
delay += random.uniform(0.3, 0.8)
|
||||
|
||||
delays.append(delay)
|
||||
|
||||
return delays
|
||||
|
||||
def generate_action_sequence(action_count):
|
||||
"""
|
||||
生成一系列人类化的动作延迟
|
||||
|
||||
参数:
|
||||
action_count (int): 动作数量
|
||||
|
||||
返回:
|
||||
list: 延迟时间列表
|
||||
"""
|
||||
delays = []
|
||||
|
||||
for i in range(action_count):
|
||||
# 基础延迟
|
||||
delay = generate_human_delay()
|
||||
|
||||
# 连续动作的相关性
|
||||
if i > 0:
|
||||
# 连续动作通常更快
|
||||
delay *= 0.8
|
||||
|
||||
# 序列中的变化
|
||||
if i == 0:
|
||||
# 第一个动作前可能有更长的准备时间
|
||||
delay *= 1.5
|
||||
elif i == action_count - 1:
|
||||
# 最后一个动作后可能有更长的思考时间
|
||||
delay *= 1.2
|
||||
|
||||
delays.append(delay)
|
||||
|
||||
return delays
|
||||
|
||||
def simulate_fatigue(session_duration):
|
||||
"""
|
||||
模拟随时间产生的疲劳效应
|
||||
|
||||
参数:
|
||||
session_duration (float): 会话持续时间(秒)
|
||||
|
||||
返回:
|
||||
float: 疲劳系数 (1.0=正常, >1.0=疲劳)
|
||||
"""
|
||||
# 基础疲劳曲线
|
||||
hours = session_duration / 3600.0
|
||||
|
||||
# 前1小时基本无疲劳
|
||||
if hours < 1:
|
||||
base_fatigue = 1.0
|
||||
# 1-2小时逐渐增加疲劳
|
||||
elif hours < 2:
|
||||
base_fatigue = 1.0 + (hours - 1) * 0.1
|
||||
# 2-3小时疲劳加重
|
||||
elif hours < 3:
|
||||
base_fatigue = 1.1 + (hours - 2) * 0.2
|
||||
# 3小时以上疲劳显著
|
||||
else:
|
||||
base_fatigue = 1.3 + min(0.4, (hours - 3) * 0.1)
|
||||
|
||||
# 添加随机波动
|
||||
speed = random.normalvariate(base_cps, base_cps * variance)
|
||||
speed = max(speed, base_cps * 0.5) # 确保速度不会太慢
|
||||
fatigue = base_fatigue * random.uniform(0.95, 1.05)
|
||||
|
||||
# 将速度转换为延迟
|
||||
delay = 1.0 / speed
|
||||
return fatigue
|
||||
|
||||
def get_behavior_profile():
|
||||
"""
|
||||
获取当前行为配置文件信息
|
||||
|
||||
return delay
|
||||
返回:
|
||||
dict: 行为配置信息
|
||||
"""
|
||||
profile = USER_PROFILES[current_profile].copy()
|
||||
profile["attention"] = update_attention_level()
|
||||
profile["activity_mode"] = set_activity_mode()
|
||||
|
||||
return profile
|
||||
Reference in New Issue
Block a user