330 lines
12 KiB
Python
330 lines
12 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
WebSocket API模块
|
|||
|
提供WebSocket相关的API接口,支持实时数据推送
|
|||
|
"""
|
|||
|
|
|||
|
import json
|
|||
|
import asyncio
|
|||
|
from typing import Dict, List, Any, Optional, Set
|
|||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
|||
|
from datetime import datetime, timedelta
|
|||
|
|
|||
|
from services.task_record_service import TaskRecordService
|
|||
|
from utils.logger import get_logger
|
|||
|
|
|||
|
# 创建路由
|
|||
|
router = APIRouter(
|
|||
|
prefix="/ws",
|
|||
|
tags=["WebSocket"]
|
|||
|
)
|
|||
|
|
|||
|
# 设置日志
|
|||
|
logger = get_logger("app.websocket_api")
|
|||
|
|
|||
|
def json_serializer(obj):
|
|||
|
"""自定义JSON序列化器,处理datetime对象"""
|
|||
|
if isinstance(obj, datetime):
|
|||
|
return obj.isoformat()
|
|||
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
|||
|
|
|||
|
def safe_json_dumps(data, **kwargs):
|
|||
|
"""安全的JSON序列化函数"""
|
|||
|
return json.dumps(data, default=json_serializer, **kwargs)
|
|||
|
|
|||
|
# 存储活跃的WebSocket连接
|
|||
|
class ConnectionManager:
|
|||
|
def __init__(self):
|
|||
|
# 存储WebSocket连接,按task_record_id分组
|
|||
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
|||
|
# 存储连接的最后推送时间
|
|||
|
self.last_push_time: Dict[str, datetime] = {}
|
|||
|
|
|||
|
async def connect(self, websocket: WebSocket, task_record_id: str):
|
|||
|
"""连接WebSocket"""
|
|||
|
await websocket.accept()
|
|||
|
if task_record_id not in self.active_connections:
|
|||
|
self.active_connections[task_record_id] = set()
|
|||
|
self.active_connections[task_record_id].add(websocket)
|
|||
|
logger.info(f"WebSocket连接已建立,任务记录ID: {task_record_id}. 当前连接数: {len(self.active_connections[task_record_id])}")
|
|||
|
|
|||
|
def disconnect(self, websocket: WebSocket, task_record_id: str):
|
|||
|
"""断开WebSocket连接"""
|
|||
|
if task_record_id in self.active_connections:
|
|||
|
self.active_connections[task_record_id].discard(websocket)
|
|||
|
if not self.active_connections[task_record_id]:
|
|||
|
# 如果没有连接了,清理数据
|
|||
|
del self.active_connections[task_record_id]
|
|||
|
self.last_push_time.pop(task_record_id, None)
|
|||
|
logger.info(f"WebSocket连接已断开,任务记录ID: {task_record_id}")
|
|||
|
|
|||
|
async def send_personal_message(self, message: str, websocket: WebSocket):
|
|||
|
"""发送个人消息"""
|
|||
|
try:
|
|||
|
await websocket.send_text(message)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"发送个人消息失败: {str(e)}")
|
|||
|
|
|||
|
async def broadcast_to_task(self, message: str, task_record_id: str):
|
|||
|
"""向特定任务的所有连接广播消息"""
|
|||
|
if task_record_id not in self.active_connections:
|
|||
|
return
|
|||
|
|
|||
|
disconnected_websockets = []
|
|||
|
for websocket in self.active_connections[task_record_id].copy():
|
|||
|
try:
|
|||
|
await websocket.send_text(message)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"广播消息失败: {str(e)}")
|
|||
|
disconnected_websockets.append(websocket)
|
|||
|
|
|||
|
# 清理断开的连接
|
|||
|
for websocket in disconnected_websockets:
|
|||
|
self.disconnect(websocket, task_record_id)
|
|||
|
|
|||
|
# 连接管理器实例
|
|||
|
manager = ConnectionManager()
|
|||
|
|
|||
|
@router.websocket("/task-execution/{task_record_id}")
|
|||
|
async def websocket_task_execution(
|
|||
|
websocket: WebSocket,
|
|||
|
task_record_id: str = Path(..., description="任务记录ID"),
|
|||
|
interval: int = Query(default=2, description="推送间隔(秒)", ge=1, le=30)
|
|||
|
):
|
|||
|
"""
|
|||
|
任务执行结果WebSocket连接
|
|||
|
|
|||
|
Args:
|
|||
|
websocket: WebSocket连接对象
|
|||
|
task_record_id: 任务记录ID
|
|||
|
interval: 推送间隔(秒),默认2秒,范围1-30秒
|
|||
|
"""
|
|||
|
await manager.connect(websocket, task_record_id)
|
|||
|
|
|||
|
try:
|
|||
|
# 立即发送一次当前状态
|
|||
|
await send_task_execution_status(task_record_id, websocket)
|
|||
|
|
|||
|
# 启动定时推送任务
|
|||
|
push_task = asyncio.create_task(
|
|||
|
periodic_push_task_status(websocket, task_record_id, interval)
|
|||
|
)
|
|||
|
|
|||
|
try:
|
|||
|
# 监听客户端消息
|
|||
|
while True:
|
|||
|
# 接收客户端消息
|
|||
|
data = await websocket.receive_text()
|
|||
|
try:
|
|||
|
message = json.loads(data)
|
|||
|
await handle_websocket_message(websocket, task_record_id, message)
|
|||
|
except json.JSONDecodeError:
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "error",
|
|||
|
"message": "无效的JSON格式"
|
|||
|
}, ensure_ascii=False))
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理WebSocket消息失败: {str(e)}")
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "error",
|
|||
|
"message": f"处理消息失败: {str(e)}"
|
|||
|
}, ensure_ascii=False))
|
|||
|
finally:
|
|||
|
# 取消定时推送任务
|
|||
|
push_task.cancel()
|
|||
|
try:
|
|||
|
await push_task
|
|||
|
except asyncio.CancelledError:
|
|||
|
pass
|
|||
|
|
|||
|
except WebSocketDisconnect:
|
|||
|
logger.info(f"WebSocket客户端断开连接,任务记录ID: {task_record_id}")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"WebSocket连接异常: {str(e)}")
|
|||
|
finally:
|
|||
|
manager.disconnect(websocket, task_record_id)
|
|||
|
|
|||
|
async def handle_websocket_message(websocket: WebSocket, task_record_id: str, message: Dict[str, Any]):
|
|||
|
"""
|
|||
|
处理WebSocket客户端消息
|
|||
|
|
|||
|
Args:
|
|||
|
websocket: WebSocket连接对象
|
|||
|
task_record_id: 任务记录ID
|
|||
|
message: 客户端消息
|
|||
|
"""
|
|||
|
message_type = message.get("type", "")
|
|||
|
|
|||
|
if message_type == "get_status":
|
|||
|
# 获取当前状态
|
|||
|
await send_task_execution_status(task_record_id, websocket)
|
|||
|
elif message_type == "ping":
|
|||
|
# 心跳检测
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "pong",
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}, ensure_ascii=False))
|
|||
|
else:
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "error",
|
|||
|
"message": f"不支持的消息类型: {message_type}"
|
|||
|
}, ensure_ascii=False))
|
|||
|
|
|||
|
async def send_task_execution_status(task_record_id: str, websocket: WebSocket):
|
|||
|
"""
|
|||
|
发送任务执行状态
|
|||
|
|
|||
|
Args:
|
|||
|
task_record_id: 任务记录ID
|
|||
|
websocket: WebSocket连接对象
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 获取任务执行结果
|
|||
|
result = await TaskRecordService.get_block_results(task_record_id)
|
|||
|
|
|||
|
if result["success"]:
|
|||
|
response_data = {
|
|||
|
"type": "task_execution_update",
|
|||
|
"task_record_id": task_record_id,
|
|||
|
"timestamp": datetime.now().isoformat(),
|
|||
|
"data": result["data"],
|
|||
|
"message": result["message"]
|
|||
|
}
|
|||
|
else:
|
|||
|
response_data = {
|
|||
|
"type": "error",
|
|||
|
"task_record_id": task_record_id,
|
|||
|
"timestamp": datetime.now().isoformat(),
|
|||
|
"message": result["message"]
|
|||
|
}
|
|||
|
|
|||
|
await websocket.send_text(safe_json_dumps(response_data, ensure_ascii=False))
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"发送任务执行状态失败: {str(e)}")
|
|||
|
try:
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "error",
|
|||
|
"message": f"获取任务执行状态失败: {str(e)}",
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}, ensure_ascii=False))
|
|||
|
except:
|
|||
|
# 如果连接已断开,忽略错误
|
|||
|
pass
|
|||
|
|
|||
|
async def periodic_push_task_status(websocket: WebSocket, task_record_id: str, interval: int):
|
|||
|
"""
|
|||
|
定期推送任务状态
|
|||
|
|
|||
|
Args:
|
|||
|
websocket: WebSocket连接对象
|
|||
|
task_record_id: 任务记录ID
|
|||
|
interval: 推送间隔(秒)
|
|||
|
"""
|
|||
|
logger.info(f"开始定期推送任务状态,任务记录ID: {task_record_id}, 间隔: {interval}秒")
|
|||
|
|
|||
|
last_data_hash = None # 用于检测数据是否发生变化
|
|||
|
|
|||
|
try:
|
|||
|
while True:
|
|||
|
await asyncio.sleep(interval)
|
|||
|
|
|||
|
# 获取当前数据
|
|||
|
try:
|
|||
|
result = await TaskRecordService.get_block_results(task_record_id)
|
|||
|
if result["success"]:
|
|||
|
# 计算数据哈希,只有数据变化时才推送
|
|||
|
import hashlib
|
|||
|
current_data = safe_json_dumps(result["data"], sort_keys=True, ensure_ascii=False)
|
|||
|
current_hash = hashlib.md5(current_data.encode()).hexdigest()
|
|||
|
|
|||
|
if current_hash != last_data_hash:
|
|||
|
await send_task_execution_status(task_record_id, websocket)
|
|||
|
last_data_hash = current_hash
|
|||
|
logger.debug(f"任务状态已更新并推送,任务记录ID: {task_record_id}")
|
|||
|
else:
|
|||
|
logger.debug(f"任务状态无变化,跳过推送,任务记录ID: {task_record_id}")
|
|||
|
else:
|
|||
|
# 如果获取失败,仍然推送错误信息
|
|||
|
await send_task_execution_status(task_record_id, websocket)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"获取任务状态失败: {str(e)}")
|
|||
|
# 发送错误状态
|
|||
|
try:
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "error",
|
|||
|
"message": f"获取任务状态失败: {str(e)}",
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}, ensure_ascii=False))
|
|||
|
except:
|
|||
|
# 连接可能已断开
|
|||
|
break
|
|||
|
|
|||
|
except asyncio.CancelledError:
|
|||
|
logger.info(f"定期推送任务已取消,任务记录ID: {task_record_id}")
|
|||
|
raise
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"定期推送任务状态失败: {str(e)}")
|
|||
|
|
|||
|
@router.websocket("/task-execution-broadcast/{task_record_id}")
|
|||
|
async def websocket_task_execution_broadcast(
|
|||
|
websocket: WebSocket,
|
|||
|
task_record_id: str = Path(..., description="任务记录ID")
|
|||
|
):
|
|||
|
"""
|
|||
|
任务执行结果广播WebSocket连接(只接收广播,不主动推送)
|
|||
|
|
|||
|
Args:
|
|||
|
websocket: WebSocket连接对象
|
|||
|
task_record_id: 任务记录ID
|
|||
|
"""
|
|||
|
await manager.connect(websocket, task_record_id)
|
|||
|
|
|||
|
try:
|
|||
|
# 发送初始状态
|
|||
|
await send_task_execution_status(task_record_id, websocket)
|
|||
|
|
|||
|
# 等待连接断开或消息
|
|||
|
while True:
|
|||
|
try:
|
|||
|
data = await websocket.receive_text()
|
|||
|
# 可以处理客户端的心跳或其他控制消息
|
|||
|
try:
|
|||
|
message = json.loads(data)
|
|||
|
if message.get("type") == "ping":
|
|||
|
await websocket.send_text(safe_json_dumps({
|
|||
|
"type": "pong",
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}, ensure_ascii=False))
|
|||
|
except:
|
|||
|
pass
|
|||
|
except WebSocketDisconnect:
|
|||
|
break
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"广播WebSocket连接异常: {str(e)}")
|
|||
|
finally:
|
|||
|
manager.disconnect(websocket, task_record_id)
|
|||
|
|
|||
|
# 提供给其他模块调用的广播接口
|
|||
|
async def broadcast_task_update(task_record_id: str, data: Dict[str, Any]):
|
|||
|
"""
|
|||
|
广播任务更新消息给所有相关连接
|
|||
|
|
|||
|
Args:
|
|||
|
task_record_id: 任务记录ID
|
|||
|
data: 要广播的数据
|
|||
|
"""
|
|||
|
if task_record_id not in manager.active_connections:
|
|||
|
return
|
|||
|
|
|||
|
message = safe_json_dumps({
|
|||
|
"type": "task_execution_update",
|
|||
|
"task_record_id": task_record_id,
|
|||
|
"timestamp": datetime.now().isoformat(),
|
|||
|
"data": data
|
|||
|
}, ensure_ascii=False)
|
|||
|
|
|||
|
await manager.broadcast_to_task(message, task_record_id)
|
|||
|
logger.info(f"已广播任务更新消息,任务记录ID: {task_record_id}")
|