diff --git a/main/yolo_detector.py b/main/yolo_detector.py new file mode 100644 index 0000000..6703813 --- /dev/null +++ b/main/yolo_detector.py @@ -0,0 +1,127 @@ +from ultralytics import YOLO +import cv2 +import torch +import numpy as np +from PIL import ImageFont, ImageDraw, Image + + +class YOLODetector: + def __init__(self, model_path='best_weights/best.pt'): + try: + # 初始化YOLO模型 + self.model = YOLO(model_path) + print(f"成功加载模型:{model_path}") + + self.font = ImageFont.truetype("msyh.ttc", 21) + # 检测配置 + self.predict_config = { + 'conf_thres': 0.25, + 'iou_thres': 0.30, + 'imgsz': 640, + 'line_width': 2 + } + + # 设置设备 + self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + if self.device == 'cuda:0': + print(f"使用GPU: {torch.cuda.get_device_name(0)}") + torch.backends.cudnn.benchmark = True + else: + print("警告:检测,将回退到CPU模式") + + self.model.to(self.device) + + except Exception as e: + print(f"模型初始化失败: {str(e)}") + raise + + def process_frame(self, frame, roi=None): + """处理单帧图像""" + try: + if frame is None or frame.size == 0: + print("收到空帧") + return None, None + + # 如果有ROI,处理ROI区域 + if roi and roi != (0, 0, 0, 0): + x, y, w, h = roi + if x >= 0 and y >= 0 and w > 0 and h > 0 and \ + x + w <= frame.shape[1] and y + h <= frame.shape[0]: + frame_roi = frame[y:y + h, x:x + w] + else: + # print("ROI 超出图像范围") + frame_roi = frame + else: + frame_roi = frame + + # 运行YOLO检测 + results = self.model( + source=frame_roi, + conf=self.predict_config['conf_thres'], + iou=self.predict_config['iou_thres'], + imgsz=self.predict_config['imgsz'], + device=self.device, + verbose=False + ) + + # 在图像上绘制检测结果 + annotated_frame = frame.copy() + + if len(results) > 0 and results[0].boxes is not None and len(results[0].boxes) > 0: + for box, conf, cls in zip(results[0].boxes.xyxy, + results[0].boxes.conf, + results[0].boxes.cls): + class_name = results[0].names[int(cls)] + x1, y1, x2, y2 = map(int, box) + + # 如果使用ROI,调整坐标 + # if roi and roi != (0, 0, 0, 0): + if roi and len(roi) == 4: + # x1, y1 = x1 + roi[0], y1 + roi[1] + # x2, y2 = x2 + roi[0], y2 + roi[1] + x1 += roi[0] + y1 += roi[1] + x2 += roi[0] + y2 += roi[1] + + # 绘制边界框 + cv2.rectangle(annotated_frame, + (x1, y1), + (x2, y2), + (0, 255, 0), + self.predict_config['line_width']) + + # 添加中文标签 + # pil_img = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) + # draw = ImageDraw.Draw(pil_img) + + # 根据类别显示不同的中文标签 + class_name = results[0].names[int(cls)].lower() + label_map = { + "helmet": "安全帽", + "person": "人员", + "safevest": "工服", + "smoke": "吸烟", + "animal": "异物入侵", + "cellphone": "玩手机", + "fire": "起火" + } + + label = label_map.get(class_name, class_name) + + # draw.text((x1, y1 - 30), f"{label} {conf:.2f}", + # font=self.font, fill=(0, 255, 0)) + # annotated_frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + + # 绘制ROI区域 + if roi and roi != (0, 0, 0, 0): + x, y, w, h = roi + cv2.rectangle(annotated_frame, (x, y), (x + w, y + h), (0, 0, 255), 2) + + return annotated_frame, results[0] + + except Exception as e: + print(f"处理帧时出错: {str(e)}") + return frame, None + + diff --git a/main/yolo_processor.py b/main/yolo_processor.py new file mode 100644 index 0000000..20e3671 --- /dev/null +++ b/main/yolo_processor.py @@ -0,0 +1,143 @@ +import cv2 +import numpy as np +import msgpack +from datetime import datetime +from PIL import ImageFont, ImageDraw, Image + + +def process_frame_with_yolo(frame, channel_index, camera_configs, yolo_detector, redis_client): + roi = camera_configs[channel_index]['box'] + types = camera_configs[channel_index]['types'] + height, width = frame.shape[:2] + + # ROI坐标转换 + x_min = int(roi[0] * width) + y_min = int(roi[1] * height) + x_max = int(roi[2] * width) + y_max = int(roi[3] * height) + roi_converted = (x_min, y_min, x_max - x_min, y_max - y_min) + + # 执行YOLO检测 + start_time = datetime.now() + frame_with_boxes, results = yolo_detector.process_frame(frame, roi_converted) + process_time = (datetime.now() - start_time).total_seconds() * 1000 + print(f"\n通道 {channel_index + 1} 检测结果 (处理时间: {process_time: .2f}ms):") + + detections = [] + person_boxes = [] + helmet_boxes = [] + safevest_boxes = [] + smoke_boxes = [] + other_detections = [] + + if results is not None and hasattr(results, 'boxes') and len(results.boxes) > 0: + # 新增中文映射 + class_mapping = { + "animal": "异物入侵", + "cellphone": "玩手机", + "fire": "起火" + } + + # 第一步:分类存储所有检测结果(根据types过滤) + for box, conf, cls in zip(results.boxes.xyxy, results.boxes.conf, results.boxes.cls): + class_name = results.names[int(cls)] + if class_name not in types: # 关键过滤逻辑 + continue + + x1, y1, x2, y2 = map(int, box) + normalized_box = ( + x1 / width, + y1 / height, + (x2 - x1) / width, + (y2 - y1) / height + ) + + if class_name == "person": + person_boxes.append((box, conf, (x1, y1, x2, y2))) + elif class_name == "helmet": + helmet_boxes.append((x1, y1, x2, y2)) + elif class_name == "safevest": + safevest_boxes.append((x1, y1, x2, y2)) + elif class_name == "smoke": + smoke_boxes.append((x1, y1, x2, y2)) + elif class_name in ["animal", "cellphone", "fire"]: + class_name_cn = class_mapping.get(class_name, class_name) + other_detections.append({ + "class": class_name_cn, + "confidence": float(conf), + "bbox": { + "x_min": x1 / width, + "y_min": y1 / height, + "width": (x2 - x1) / width, + "height": (y2 - y1) / height + } + }) + print(f"- 独立检测: {class_name_cn}, 置信度: {conf: .2f}, 位置: {box.tolist()}") + + # 第二步:处理人员状态(仅在需要检测person时处理) + detections = [] + if "person" in types: + for (box, conf, (x1, y1, x2, y2)) in person_boxes: + def calculate_iou(boxA, boxB): + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + interArea = max(0, xB - xA) * max(0, yB - yA) + boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) + boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) + return interArea / float(boxAArea + boxBArea - interArea) + + # 根据配置动态判断关联项 + has_helmet = any(calculate_iou((x1, y1, x2, y2), h_box) > 0.1 for h_box in helmet_boxes) if "helmet" in types else False + has_safevest = any(calculate_iou((x1, y1, x2, y2), s_box) > 0.1 for s_box in safevest_boxes) if "safevest" in types else False + has_smoke = any(calculate_iou((x1, y1, x2, y2), sm_box) > 0.1 for sm_box in smoke_boxes) if "smoke" in types else False + + # 生成状态标签 + status_label = "人员" + violations = [] + if has_smoke: + status_label = "吸烟" + else: + if "helmet" in types and not has_helmet: + violations.append("未戴安全帽") + if "safevest" in types and not has_safevest: + violations.append("未穿工服") + if violations: + status_label = "违规:" + "、".join(violations) + else: + if "helmet" in types or "safevest" in types: + status_label = "着装规范" + + detections.append({ + "class": status_label, + "confidence": float(conf), + "bbox": { + "x_min": x1 / width, + "y_min": y1 / height, + "width": (x2 - x1) / width, + "height": (y2 - y1) / height + } + }) + print(f"- 状态: {status_label}, 置信度: {conf:.2f}, 位置: {box}") + + # 添加独立检测类别 + detections.extend(other_detections) + if not detections: + print("- 未检测到任何目标") + + # 序列化并发送结果 + data = { + "channel": str(camera_configs[channel_index]['channel']), + "detections": detections, + "image_size": [width, height] + } + + try: + serialized_data = msgpack.packb(data) + redis_client.publish('detection_result_channel', serialized_data) + print(f"[Redis] 通道 {camera_configs[channel_index]['channel']} 数据发送成功") + except Exception as e: + print(f"[Redis] 发送失败: {str(e)}") + + return frame_with_boxes