VWED_server/routes/websocket_api.py
2025-07-14 10:29:37 +08:00

330 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
WebSocket API模块
提供WebSocket相关的API接口支持实时数据推送
"""
import json
import asyncio
from typing import Dict, List, Any, Optional, Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
from datetime import datetime, timedelta
from services.task_record_service import TaskRecordService
from utils.logger import get_logger
# 创建路由
router = APIRouter(
prefix="/ws",
tags=["WebSocket"]
)
# 设置日志
logger = get_logger("app.websocket_api")
def json_serializer(obj):
"""自定义JSON序列化器处理datetime对象"""
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
def safe_json_dumps(data, **kwargs):
"""安全的JSON序列化函数"""
return json.dumps(data, default=json_serializer, **kwargs)
# 存储活跃的WebSocket连接
class ConnectionManager:
def __init__(self):
# 存储WebSocket连接按task_record_id分组
self.active_connections: Dict[str, Set[WebSocket]] = {}
# 存储连接的最后推送时间
self.last_push_time: Dict[str, datetime] = {}
async def connect(self, websocket: WebSocket, task_record_id: str):
"""连接WebSocket"""
await websocket.accept()
if task_record_id not in self.active_connections:
self.active_connections[task_record_id] = set()
self.active_connections[task_record_id].add(websocket)
logger.info(f"WebSocket连接已建立任务记录ID: {task_record_id}. 当前连接数: {len(self.active_connections[task_record_id])}")
def disconnect(self, websocket: WebSocket, task_record_id: str):
"""断开WebSocket连接"""
if task_record_id in self.active_connections:
self.active_connections[task_record_id].discard(websocket)
if not self.active_connections[task_record_id]:
# 如果没有连接了,清理数据
del self.active_connections[task_record_id]
self.last_push_time.pop(task_record_id, None)
logger.info(f"WebSocket连接已断开任务记录ID: {task_record_id}")
async def send_personal_message(self, message: str, websocket: WebSocket):
"""发送个人消息"""
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送个人消息失败: {str(e)}")
async def broadcast_to_task(self, message: str, task_record_id: str):
"""向特定任务的所有连接广播消息"""
if task_record_id not in self.active_connections:
return
disconnected_websockets = []
for websocket in self.active_connections[task_record_id].copy():
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"广播消息失败: {str(e)}")
disconnected_websockets.append(websocket)
# 清理断开的连接
for websocket in disconnected_websockets:
self.disconnect(websocket, task_record_id)
# 连接管理器实例
manager = ConnectionManager()
@router.websocket("/task-execution/{task_record_id}")
async def websocket_task_execution(
websocket: WebSocket,
task_record_id: str = Path(..., description="任务记录ID"),
interval: int = Query(default=2, description="推送间隔(秒)", ge=1, le=30)
):
"""
任务执行结果WebSocket连接
Args:
websocket: WebSocket连接对象
task_record_id: 任务记录ID
interval: 推送间隔默认2秒范围1-30秒
"""
await manager.connect(websocket, task_record_id)
try:
# 立即发送一次当前状态
await send_task_execution_status(task_record_id, websocket)
# 启动定时推送任务
push_task = asyncio.create_task(
periodic_push_task_status(websocket, task_record_id, interval)
)
try:
# 监听客户端消息
while True:
# 接收客户端消息
data = await websocket.receive_text()
try:
message = json.loads(data)
await handle_websocket_message(websocket, task_record_id, message)
except json.JSONDecodeError:
await websocket.send_text(safe_json_dumps({
"type": "error",
"message": "无效的JSON格式"
}, ensure_ascii=False))
except Exception as e:
logger.error(f"处理WebSocket消息失败: {str(e)}")
await websocket.send_text(safe_json_dumps({
"type": "error",
"message": f"处理消息失败: {str(e)}"
}, ensure_ascii=False))
finally:
# 取消定时推送任务
push_task.cancel()
try:
await push_task
except asyncio.CancelledError:
pass
except WebSocketDisconnect:
logger.info(f"WebSocket客户端断开连接任务记录ID: {task_record_id}")
except Exception as e:
logger.error(f"WebSocket连接异常: {str(e)}")
finally:
manager.disconnect(websocket, task_record_id)
async def handle_websocket_message(websocket: WebSocket, task_record_id: str, message: Dict[str, Any]):
"""
处理WebSocket客户端消息
Args:
websocket: WebSocket连接对象
task_record_id: 任务记录ID
message: 客户端消息
"""
message_type = message.get("type", "")
if message_type == "get_status":
# 获取当前状态
await send_task_execution_status(task_record_id, websocket)
elif message_type == "ping":
# 心跳检测
await websocket.send_text(safe_json_dumps({
"type": "pong",
"timestamp": datetime.now().isoformat()
}, ensure_ascii=False))
else:
await websocket.send_text(safe_json_dumps({
"type": "error",
"message": f"不支持的消息类型: {message_type}"
}, ensure_ascii=False))
async def send_task_execution_status(task_record_id: str, websocket: WebSocket):
"""
发送任务执行状态
Args:
task_record_id: 任务记录ID
websocket: WebSocket连接对象
"""
try:
# 获取任务执行结果
result = await TaskRecordService.get_block_results(task_record_id)
if result["success"]:
response_data = {
"type": "task_execution_update",
"task_record_id": task_record_id,
"timestamp": datetime.now().isoformat(),
"data": result["data"],
"message": result["message"]
}
else:
response_data = {
"type": "error",
"task_record_id": task_record_id,
"timestamp": datetime.now().isoformat(),
"message": result["message"]
}
await websocket.send_text(safe_json_dumps(response_data, ensure_ascii=False))
except Exception as e:
logger.error(f"发送任务执行状态失败: {str(e)}")
try:
await websocket.send_text(safe_json_dumps({
"type": "error",
"message": f"获取任务执行状态失败: {str(e)}",
"timestamp": datetime.now().isoformat()
}, ensure_ascii=False))
except:
# 如果连接已断开,忽略错误
pass
async def periodic_push_task_status(websocket: WebSocket, task_record_id: str, interval: int):
"""
定期推送任务状态
Args:
websocket: WebSocket连接对象
task_record_id: 任务记录ID
interval: 推送间隔(秒)
"""
logger.info(f"开始定期推送任务状态任务记录ID: {task_record_id}, 间隔: {interval}")
last_data_hash = None # 用于检测数据是否发生变化
try:
while True:
await asyncio.sleep(interval)
# 获取当前数据
try:
result = await TaskRecordService.get_block_results(task_record_id)
if result["success"]:
# 计算数据哈希,只有数据变化时才推送
import hashlib
current_data = safe_json_dumps(result["data"], sort_keys=True, ensure_ascii=False)
current_hash = hashlib.md5(current_data.encode()).hexdigest()
if current_hash != last_data_hash:
await send_task_execution_status(task_record_id, websocket)
last_data_hash = current_hash
logger.debug(f"任务状态已更新并推送任务记录ID: {task_record_id}")
else:
logger.debug(f"任务状态无变化跳过推送任务记录ID: {task_record_id}")
else:
# 如果获取失败,仍然推送错误信息
await send_task_execution_status(task_record_id, websocket)
except Exception as e:
logger.error(f"获取任务状态失败: {str(e)}")
# 发送错误状态
try:
await websocket.send_text(safe_json_dumps({
"type": "error",
"message": f"获取任务状态失败: {str(e)}",
"timestamp": datetime.now().isoformat()
}, ensure_ascii=False))
except:
# 连接可能已断开
break
except asyncio.CancelledError:
logger.info(f"定期推送任务已取消任务记录ID: {task_record_id}")
raise
except Exception as e:
logger.error(f"定期推送任务状态失败: {str(e)}")
@router.websocket("/task-execution-broadcast/{task_record_id}")
async def websocket_task_execution_broadcast(
websocket: WebSocket,
task_record_id: str = Path(..., description="任务记录ID")
):
"""
任务执行结果广播WebSocket连接只接收广播不主动推送
Args:
websocket: WebSocket连接对象
task_record_id: 任务记录ID
"""
await manager.connect(websocket, task_record_id)
try:
# 发送初始状态
await send_task_execution_status(task_record_id, websocket)
# 等待连接断开或消息
while True:
try:
data = await websocket.receive_text()
# 可以处理客户端的心跳或其他控制消息
try:
message = json.loads(data)
if message.get("type") == "ping":
await websocket.send_text(safe_json_dumps({
"type": "pong",
"timestamp": datetime.now().isoformat()
}, ensure_ascii=False))
except:
pass
except WebSocketDisconnect:
break
except Exception as e:
logger.error(f"广播WebSocket连接异常: {str(e)}")
finally:
manager.disconnect(websocket, task_record_id)
# 提供给其他模块调用的广播接口
async def broadcast_task_update(task_record_id: str, data: Dict[str, Any]):
"""
广播任务更新消息给所有相关连接
Args:
task_record_id: 任务记录ID
data: 要广播的数据
"""
if task_record_id not in manager.active_connections:
return
message = safe_json_dumps({
"type": "task_execution_update",
"task_record_id": task_record_id,
"timestamp": datetime.now().isoformat(),
"data": data
}, ensure_ascii=False)
await manager.broadcast_to_task(message, task_record_id)
logger.info(f"已广播任务更新消息任务记录ID: {task_record_id}")