from ultralytics import YOLO import cv2 import torch import numpy as np from PIL import ImageFont, ImageDraw, Image class YOLODetector: def __init__(self, model_path='best_weights/best.pt'): try: # 初始化YOLO模型 self.model = YOLO(model_path) print(f"成功加载模型:{model_path}") self.font = ImageFont.truetype("msyh.ttc", 21) # 检测配置 self.predict_config = { 'conf_thres': 0.25, 'iou_thres': 0.30, 'imgsz': 640, 'line_width': 2 } # 设置设备 self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if self.device == 'cuda:0': print(f"使用GPU: {torch.cuda.get_device_name(0)}") torch.backends.cudnn.benchmark = True else: print("警告:检测,将回退到CPU模式") self.model.to(self.device) except Exception as e: print(f"模型初始化失败: {str(e)}") raise def process_frame(self, frame, roi=None): """处理单帧图像""" try: if frame is None or frame.size == 0: print("收到空帧") return None, None # 如果有ROI,处理ROI区域 if roi and roi != (0, 0, 0, 0): x, y, w, h = roi if x >= 0 and y >= 0 and w > 0 and h > 0 and \ x + w <= frame.shape[1] and y + h <= frame.shape[0]: frame_roi = frame[y:y + h, x:x + w] else: # print("ROI 超出图像范围") frame_roi = frame else: frame_roi = frame # 运行YOLO检测 results = self.model( source=frame_roi, conf=self.predict_config['conf_thres'], iou=self.predict_config['iou_thres'], imgsz=self.predict_config['imgsz'], device=self.device, verbose=False ) # 在图像上绘制检测结果 annotated_frame = frame.copy() if len(results) > 0 and results[0].boxes is not None and len(results[0].boxes) > 0: for box, conf, cls in zip(results[0].boxes.xyxy, results[0].boxes.conf, results[0].boxes.cls): class_name = results[0].names[int(cls)] x1, y1, x2, y2 = map(int, box) # 如果使用ROI,调整坐标 # if roi and roi != (0, 0, 0, 0): if roi and len(roi) == 4: # x1, y1 = x1 + roi[0], y1 + roi[1] # x2, y2 = x2 + roi[0], y2 + roi[1] x1 += roi[0] y1 += roi[1] x2 += roi[0] y2 += roi[1] # 绘制边界框 cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), self.predict_config['line_width']) # 添加中文标签 # pil_img = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) # draw = ImageDraw.Draw(pil_img) # 根据类别显示不同的中文标签 class_name = results[0].names[int(cls)].lower() label_map = { "helmet": "安全帽", "person": "人员", "safevest": "工服", "smoke": "吸烟", "animal": "异物入侵", "cellphone": "玩手机", "fire": "起火" } label = label_map.get(class_name, class_name) # draw.text((x1, y1 - 30), f"{label} {conf:.2f}", # font=self.font, fill=(0, 255, 0)) # annotated_frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) # 绘制ROI区域 if roi and roi != (0, 0, 0, 0): x, y, w, h = roi cv2.rectangle(annotated_frame, (x, y), (x + w, y + h), (0, 0, 255), 2) return annotated_frame, results[0] except Exception as e: print(f"处理帧时出错: {str(e)}") return frame, None