#!/usr/bin/env python # -*- coding: utf-8 -*- """ WebSocket API模块 提供WebSocket相关的API接口,支持实时数据推送 """ import json import asyncio from typing import Dict, List, Any, Optional, Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query from datetime import datetime, timedelta from services.task_record_service import TaskRecordService from utils.logger import get_logger # 创建路由 router = APIRouter( prefix="/ws", tags=["WebSocket"] ) # 设置日志 logger = get_logger("app.websocket_api") def json_serializer(obj): """自定义JSON序列化器,处理datetime对象""" if isinstance(obj, datetime): return obj.isoformat() raise TypeError(f"Object of type {type(obj)} is not JSON serializable") def safe_json_dumps(data, **kwargs): """安全的JSON序列化函数""" return json.dumps(data, default=json_serializer, **kwargs) # 存储活跃的WebSocket连接 class ConnectionManager: def __init__(self): # 存储WebSocket连接,按task_record_id分组 self.active_connections: Dict[str, Set[WebSocket]] = {} # 存储连接的最后推送时间 self.last_push_time: Dict[str, datetime] = {} async def connect(self, websocket: WebSocket, task_record_id: str): """连接WebSocket""" await websocket.accept() if task_record_id not in self.active_connections: self.active_connections[task_record_id] = set() self.active_connections[task_record_id].add(websocket) logger.info(f"WebSocket连接已建立,任务记录ID: {task_record_id}. 当前连接数: {len(self.active_connections[task_record_id])}") def disconnect(self, websocket: WebSocket, task_record_id: str): """断开WebSocket连接""" if task_record_id in self.active_connections: self.active_connections[task_record_id].discard(websocket) if not self.active_connections[task_record_id]: # 如果没有连接了,清理数据 del self.active_connections[task_record_id] self.last_push_time.pop(task_record_id, None) logger.info(f"WebSocket连接已断开,任务记录ID: {task_record_id}") async def send_personal_message(self, message: str, websocket: WebSocket): """发送个人消息""" try: await websocket.send_text(message) except Exception as e: logger.error(f"发送个人消息失败: {str(e)}") async def broadcast_to_task(self, message: str, task_record_id: str): """向特定任务的所有连接广播消息""" if task_record_id not in self.active_connections: return disconnected_websockets = [] for websocket in self.active_connections[task_record_id].copy(): try: await websocket.send_text(message) except Exception as e: logger.error(f"广播消息失败: {str(e)}") disconnected_websockets.append(websocket) # 清理断开的连接 for websocket in disconnected_websockets: self.disconnect(websocket, task_record_id) # 连接管理器实例 manager = ConnectionManager() @router.websocket("/task-execution/{task_record_id}") async def websocket_task_execution( websocket: WebSocket, task_record_id: str = Path(..., description="任务记录ID"), interval: int = Query(default=2, description="推送间隔(秒)", ge=1, le=30) ): """ 任务执行结果WebSocket连接 Args: websocket: WebSocket连接对象 task_record_id: 任务记录ID interval: 推送间隔(秒),默认2秒,范围1-30秒 """ await manager.connect(websocket, task_record_id) try: # 立即发送一次当前状态 await send_task_execution_status(task_record_id, websocket) # 启动定时推送任务 push_task = asyncio.create_task( periodic_push_task_status(websocket, task_record_id, interval) ) try: # 监听客户端消息 while True: # 接收客户端消息 data = await websocket.receive_text() try: message = json.loads(data) await handle_websocket_message(websocket, task_record_id, message) except json.JSONDecodeError: await websocket.send_text(safe_json_dumps({ "type": "error", "message": "无效的JSON格式" }, ensure_ascii=False)) except Exception as e: logger.error(f"处理WebSocket消息失败: {str(e)}") await websocket.send_text(safe_json_dumps({ "type": "error", "message": f"处理消息失败: {str(e)}" }, ensure_ascii=False)) finally: # 取消定时推送任务 push_task.cancel() try: await push_task except asyncio.CancelledError: pass except WebSocketDisconnect: logger.info(f"WebSocket客户端断开连接,任务记录ID: {task_record_id}") except Exception as e: logger.error(f"WebSocket连接异常: {str(e)}") finally: manager.disconnect(websocket, task_record_id) async def handle_websocket_message(websocket: WebSocket, task_record_id: str, message: Dict[str, Any]): """ 处理WebSocket客户端消息 Args: websocket: WebSocket连接对象 task_record_id: 任务记录ID message: 客户端消息 """ message_type = message.get("type", "") if message_type == "get_status": # 获取当前状态 await send_task_execution_status(task_record_id, websocket) elif message_type == "ping": # 心跳检测 await websocket.send_text(safe_json_dumps({ "type": "pong", "timestamp": datetime.now().isoformat() }, ensure_ascii=False)) else: await websocket.send_text(safe_json_dumps({ "type": "error", "message": f"不支持的消息类型: {message_type}" }, ensure_ascii=False)) async def send_task_execution_status(task_record_id: str, websocket: WebSocket): """ 发送任务执行状态 Args: task_record_id: 任务记录ID websocket: WebSocket连接对象 """ try: # 获取任务执行结果 result = await TaskRecordService.get_block_results(task_record_id) if result["success"]: response_data = { "type": "task_execution_update", "task_record_id": task_record_id, "timestamp": datetime.now().isoformat(), "data": result["data"], "message": result["message"] } else: response_data = { "type": "error", "task_record_id": task_record_id, "timestamp": datetime.now().isoformat(), "message": result["message"] } await websocket.send_text(safe_json_dumps(response_data, ensure_ascii=False)) except Exception as e: logger.error(f"发送任务执行状态失败: {str(e)}") try: await websocket.send_text(safe_json_dumps({ "type": "error", "message": f"获取任务执行状态失败: {str(e)}", "timestamp": datetime.now().isoformat() }, ensure_ascii=False)) except: # 如果连接已断开,忽略错误 pass async def periodic_push_task_status(websocket: WebSocket, task_record_id: str, interval: int): """ 定期推送任务状态 Args: websocket: WebSocket连接对象 task_record_id: 任务记录ID interval: 推送间隔(秒) """ logger.info(f"开始定期推送任务状态,任务记录ID: {task_record_id}, 间隔: {interval}秒") last_data_hash = None # 用于检测数据是否发生变化 try: while True: await asyncio.sleep(interval) # 获取当前数据 try: result = await TaskRecordService.get_block_results(task_record_id) if result["success"]: # 计算数据哈希,只有数据变化时才推送 import hashlib current_data = safe_json_dumps(result["data"], sort_keys=True, ensure_ascii=False) current_hash = hashlib.md5(current_data.encode()).hexdigest() if current_hash != last_data_hash: await send_task_execution_status(task_record_id, websocket) last_data_hash = current_hash logger.debug(f"任务状态已更新并推送,任务记录ID: {task_record_id}") else: logger.debug(f"任务状态无变化,跳过推送,任务记录ID: {task_record_id}") else: # 如果获取失败,仍然推送错误信息 await send_task_execution_status(task_record_id, websocket) except Exception as e: logger.error(f"获取任务状态失败: {str(e)}") # 发送错误状态 try: await websocket.send_text(safe_json_dumps({ "type": "error", "message": f"获取任务状态失败: {str(e)}", "timestamp": datetime.now().isoformat() }, ensure_ascii=False)) except: # 连接可能已断开 break except asyncio.CancelledError: logger.info(f"定期推送任务已取消,任务记录ID: {task_record_id}") raise except Exception as e: logger.error(f"定期推送任务状态失败: {str(e)}") @router.websocket("/task-execution-broadcast/{task_record_id}") async def websocket_task_execution_broadcast( websocket: WebSocket, task_record_id: str = Path(..., description="任务记录ID") ): """ 任务执行结果广播WebSocket连接(只接收广播,不主动推送) Args: websocket: WebSocket连接对象 task_record_id: 任务记录ID """ await manager.connect(websocket, task_record_id) try: # 发送初始状态 await send_task_execution_status(task_record_id, websocket) # 等待连接断开或消息 while True: try: data = await websocket.receive_text() # 可以处理客户端的心跳或其他控制消息 try: message = json.loads(data) if message.get("type") == "ping": await websocket.send_text(safe_json_dumps({ "type": "pong", "timestamp": datetime.now().isoformat() }, ensure_ascii=False)) except: pass except WebSocketDisconnect: break except Exception as e: logger.error(f"广播WebSocket连接异常: {str(e)}") finally: manager.disconnect(websocket, task_record_id) # 提供给其他模块调用的广播接口 async def broadcast_task_update(task_record_id: str, data: Dict[str, Any]): """ 广播任务更新消息给所有相关连接 Args: task_record_id: 任务记录ID data: 要广播的数据 """ if task_record_id not in manager.active_connections: return message = safe_json_dumps({ "type": "task_execution_update", "task_record_id": task_record_id, "timestamp": datetime.now().isoformat(), "data": data }, ensure_ascii=False) await manager.broadcast_to_task(message, task_record_id) logger.info(f"已广播任务更新消息,任务记录ID: {task_record_id}")