#!/usr/bin/env python # -*- coding: utf-8 -*- """ 工作线程管理模块 负责工作线程的创建、监控和调整 """ import asyncio import logging import time from typing import Dict, List, Any, Optional, Set, Callable, Coroutine from datetime import datetime, timedelta import psutil from utils.logger import get_logger # 获取日志记录器 logger = get_logger("services.enhanced_scheduler.worker_manager") class WorkerManager: """ 工作线程管理器 负责管理和监控工作线程池 """ def __init__(self, min_workers: int = 5, max_workers: int = 20, worker_heartbeat_interval: int = 30, auto_scale_interval: int = 60, cpu_threshold: float = 80.0, memory_threshold: float = 80.0, queue_manager = None): """ 初始化工作线程管理器 Args: min_workers: 最小工作线程数 max_workers: 最大工作线程数 worker_heartbeat_interval: 工作线程心跳间隔(秒) auto_scale_interval: 自动扩缩容间隔(秒) cpu_threshold: CPU使用率阈值(百分比) memory_threshold: 内存使用率阈值(百分比) """ self.min_workers = min_workers self.max_workers = max_workers self.worker_heartbeat_interval = worker_heartbeat_interval self.auto_scale_interval = auto_scale_interval self.cpu_threshold = cpu_threshold self.memory_threshold = memory_threshold self.workers = {} # 工作线程字典 {worker_id: worker_task} self.worker_status = {} # 工作线程状态 {worker_id: status_dict} self.worker_heartbeats = {} # 工作线程心跳 {worker_id: last_heartbeat_time} self.worker_factory = None # 工作线程工厂函数 self.is_running = False self.monitor_task = None # 监控任务 self.last_auto_scale_time = datetime.now() self.queue_manager = queue_manager logger.info(f"初始化工作线程管理器: min={min_workers}, max={max_workers}, " f"心跳间隔={worker_heartbeat_interval}秒, 自动扩缩容间隔={auto_scale_interval}秒") def set_worker_factory(self, factory: Callable[[int], Coroutine]): """ 设置工作线程工厂函数 Args: factory: 工作线程工厂函数,接受worker_id参数,返回协程对象 """ self.worker_factory = factory async def start(self, initial_workers: int = None) -> None: """ 启动工作线程管理器 Args: initial_workers: 初始工作线程数,默认为min_workers """ if self.is_running: logger.warning("工作线程管理器已经在运行中") return if not self.worker_factory: raise ValueError("未设置工作线程工厂函数") self.is_running = True # 启动初始工作线程 initial_count = initial_workers if initial_workers is not None else self.min_workers initial_count = max(self.min_workers, min(initial_count, self.max_workers)) for i in range(initial_count): await self.add_worker() # 启动监控任务 self.monitor_task = asyncio.create_task(self._monitor_workers()) logger.info(f"工作线程管理器启动成功,初始工作线程数: {initial_count}") async def stop(self) -> None: """ 停止工作线程管理器 """ if not self.is_running: logger.warning("工作线程管理器未在运行") return self.is_running = False # 取消监控任务 if self.monitor_task: self.monitor_task.cancel() try: await self.monitor_task except asyncio.CancelledError: pass self.monitor_task = None # 取消所有工作线程 for worker_id, worker in list(self.workers.items()): await self.remove_worker(worker_id) logger.info("工作线程管理器已停止") async def add_worker(self) -> int: """ 添加工作线程 Returns: int: 工作线程ID """ if len(self.workers) >= self.max_workers: logger.warning(f"已达到最大工作线程数 {self.max_workers}") return -1 # 生成新的工作线程ID worker_id = 0 while worker_id in self.workers: worker_id += 1 # 创建工作线程 worker = asyncio.create_task(self.worker_factory(worker_id)) # 记录工作线程 self.workers[worker_id] = worker self.worker_status[worker_id] = { "state": "running", "start_time": datetime.now(), "last_activity": datetime.now(), "task_count": 0 } self.worker_heartbeats[worker_id] = datetime.now() logger.info(f"添加工作线程 {worker_id}, 当前工作线程数: {len(self.workers)}") return worker_id async def remove_worker(self, worker_id: int) -> bool: """ 移除工作线程 Args: worker_id: 工作线程ID Returns: bool: 是否成功移除 """ if worker_id not in self.workers: logger.warning(f"工作线程 {worker_id} 不存在") return False # 取消工作线程 worker = self.workers.pop(worker_id) worker.cancel() try: await worker except asyncio.CancelledError: pass # 移除状态记录 self.worker_status.pop(worker_id, None) self.worker_heartbeats.pop(worker_id, None) logger.info(f"移除工作线程 {worker_id}, 当前工作线程数: {len(self.workers)}") return True def update_worker_heartbeat(self, worker_id: int) -> None: """ 更新工作线程心跳 Args: worker_id: 工作线程ID """ if worker_id in self.worker_heartbeats: self.worker_heartbeats[worker_id] = datetime.now() self.worker_status[worker_id]["last_activity"] = datetime.now() def update_worker_status(self, worker_id: int, status_update: Dict[str, Any]) -> None: """ 更新工作线程状态 Args: worker_id: 工作线程ID status_update: 状态更新字典 """ if worker_id in self.worker_status: self.worker_status[worker_id].update(status_update) self.update_worker_heartbeat(worker_id) async def _monitor_workers(self) -> None: """ 监控工作线程 检查心跳、资源使用情况,自动扩缩容 """ logger.info("工作线程监控任务启动") while self.is_running: try: # 检查工作线程心跳 now = datetime.now() for worker_id, last_heartbeat in list(self.worker_heartbeats.items()): if (now - last_heartbeat).total_seconds() > self.worker_heartbeat_interval * 2: if self.worker_status[worker_id].get("current_task",None) is None: logger.info(f"工作线程 {worker_id} 心跳超时,重启中...") # 重启工作线程 await self.remove_worker(worker_id) await self.add_worker() else: logger.info(f"工作线程 {worker_id} 心跳超时,但当前有任务,不重启") # 自动扩缩容 if (now - self.last_auto_scale_time).total_seconds() > self.auto_scale_interval: await self._auto_scale() self.last_auto_scale_time = now # 休眠一段时间 await asyncio.sleep(self.worker_heartbeat_interval) except asyncio.CancelledError: # 取消异常,退出循环 logger.info("工作线程监控任务被取消") break except Exception as e: logger.error(f"工作线程监控任务异常: {str(e)}") # 出现异常时短暂休眠,避免频繁错误 await asyncio.sleep(5.0) logger.info("工作线程监控任务结束") async def _check_workers(self, flag: bool = False) -> List[int]: """ 检查指定的工作线程 """ worker_ids = [] unused_worker_ids = [] for worker_id, _ in list(self.worker_heartbeats.items()): if self.worker_status[worker_id].get("current_task",None) is None: unused_worker_ids.append(worker_id) else: worker_ids.append(worker_id) if flag: return unused_worker_ids else: return worker_ids async def _delete_unused_workers(self) -> bool: """ 删除空闲工作线程 """ import random worker_ids = await self._check_workers(flag=True) if len(worker_ids) > 0: worker_id = random.choice(worker_ids) await self.remove_worker(worker_id) logger.info(f"移除未使用的空闲工作线程 {worker_id}") return True else: logger.info("没有未使用的空闲工作线程") return False def _calculate_threads_to_add(self, queue_size: int, current_workers: int) -> int: """ 根据队列大小和当前工作线程数计算需要增加的线程数 Args: queue_size: 队列中的任务数量 current_workers: 当前正在运行任务的线程数 Returns: int: 需要增加的线程数 """ total_tasks = queue_size + current_workers current_pool_size = len(self.workers) remaining_capacity = self.max_workers - current_pool_size if remaining_capacity <= 0: return 0 # 计算任务与线程的比率 task_thread_ratio = total_tasks / max(1, current_pool_size) # 根据任务与线程比率动态计算增加的线程数 if task_thread_ratio > 3.0: # 负载非常高,尝试增加更多线程 threads_to_add = min(remaining_capacity, max(3, total_tasks // 3)) elif task_thread_ratio > 2.0: # 中等负载,适量增加线程 threads_to_add = min(remaining_capacity, max(2, total_tasks // 4)) else: # 轻微负载,小幅增加线程 threads_to_add = min(remaining_capacity, 1) return threads_to_add async def _auto_scale(self) -> None: """ 自动扩缩容 根据系统资源使用情况和任务队列长度自动调整工作线程数量 """ try: current_workers = len(await self._check_workers(flag=False)) cpu_percent = psutil.cpu_percent() # 获取系统资源使用情况 memory_percent = psutil.virtual_memory().percent # 获取系统的虚拟内存(包括物理内存和交换内存)的使用率 # 获取任务队列长度(需要外部传入或通过其他方式获取) queue_size = self.get_queue_size() # 根据资源使用情况和队列长度决定是否调整工作线程数量 if cpu_percent > self.cpu_threshold or memory_percent > self.memory_threshold: # 资源使用率过高,减少工作线程 # 移除最后添加的工作线程 if len(self.workers) > self.min_workers: if await self._delete_unused_workers(): logger.info(f"资源使用率过高(CPU:{cpu_percent}%, MEM:{memory_percent}%), 减少工作线程至 {current_workers-1}") elif queue_size+current_workers > len(self.workers): # 队列任务较多,动态增加工作线程 threads_to_add = self._calculate_threads_to_add(queue_size, current_workers) if threads_to_add > 0: for _ in range(threads_to_add): await self.add_worker() logger.info(f"队列任务较多(size:{queue_size}, running:{current_workers}), " f"动态增加{threads_to_add}个工作线程至 {len(self.workers)}") elif queue_size == 0 and len(self.workers) > self.min_workers: # 队列为空,减少工作线程 if await self._delete_unused_workers(): logger.info(f"队列为空,减少工作线程至 {current_workers-1}") except Exception as e: logger.error(f"自动扩缩容异常: {str(e)}") def get_queue_size(self) -> int: """ 获取任务队列长度 默认实现,需要在子类中重写或设置外部获取方法 Returns: int: 队列长度 """ # 此方法应由子类重写或设置外部方法 return sum(self.queue_manager.get_queue_sizes()) def set_queue_size_getter(self, getter: Callable[[], int]) -> None: """ 设置队列长度获取方法 Args: getter: 队列长度获取函数 """ self.get_queue_size = getter def get_worker_count(self) -> int: """ 获取当前工作线程数量 Returns: int: 工作线程数量 """ return len(self.workers) def get_worker_status(self) -> Dict[str, Any]: """ 获取工作线程状态信息 Returns: Dict[str, Any]: 工作线程状态 """ return { "worker_count": len(self.workers), "min_workers": self.min_workers, "max_workers": self.max_workers, "workers": self.worker_status, "cpu_usage": psutil.cpu_percent(), "memory_usage": psutil.virtual_memory().percent }