yolo11/main/yolo_detector.py
2025-04-16 16:20:42 +08:00

128 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import torch
import numpy as np
from PIL import ImageFont, ImageDraw, Image
class YOLODetector:
def __init__(self, model_path='best_weights/best.pt'):
try:
# 初始化YOLO模型
self.model = YOLO(model_path)
print(f"成功加载模型:{model_path}")
self.font = ImageFont.truetype("msyh.ttc", 21)
# 检测配置
self.predict_config = {
'conf_thres': 0.25,
'iou_thres': 0.30,
'imgsz': 640,
'line_width': 2
}
# 设置设备
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if self.device == 'cuda:0':
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
torch.backends.cudnn.benchmark = True
else:
print("警告检测将回退到CPU模式")
self.model.to(self.device)
except Exception as e:
print(f"模型初始化失败: {str(e)}")
raise
def process_frame(self, frame, roi=None):
"""处理单帧图像"""
try:
if frame is None or frame.size == 0:
print("收到空帧")
return None, None
# 如果有ROI处理ROI区域
if roi and roi != (0, 0, 0, 0):
x, y, w, h = roi
if x >= 0 and y >= 0 and w > 0 and h > 0 and \
x + w <= frame.shape[1] and y + h <= frame.shape[0]:
frame_roi = frame[y:y + h, x:x + w]
else:
# print("ROI 超出图像范围")
frame_roi = frame
else:
frame_roi = frame
# 运行YOLO检测
results = self.model(
source=frame_roi,
conf=self.predict_config['conf_thres'],
iou=self.predict_config['iou_thres'],
imgsz=self.predict_config['imgsz'],
device=self.device,
verbose=False
)
# 在图像上绘制检测结果
annotated_frame = frame.copy()
if len(results) > 0 and results[0].boxes is not None and len(results[0].boxes) > 0:
for box, conf, cls in zip(results[0].boxes.xyxy,
results[0].boxes.conf,
results[0].boxes.cls):
class_name = results[0].names[int(cls)]
x1, y1, x2, y2 = map(int, box)
# 如果使用ROI调整坐标
# if roi and roi != (0, 0, 0, 0):
if roi and len(roi) == 4:
# x1, y1 = x1 + roi[0], y1 + roi[1]
# x2, y2 = x2 + roi[0], y2 + roi[1]
x1 += roi[0]
y1 += roi[1]
x2 += roi[0]
y2 += roi[1]
# 绘制边界框
cv2.rectangle(annotated_frame,
(x1, y1),
(x2, y2),
(0, 255, 0),
self.predict_config['line_width'])
# 添加中文标签
# pil_img = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
# draw = ImageDraw.Draw(pil_img)
# 根据类别显示不同的中文标签
class_name = results[0].names[int(cls)].lower()
label_map = {
"helmet": "安全帽",
"person": "人员",
"safevest": "工服",
"smoke": "吸烟",
"animal": "异物入侵",
"cellphone": "玩手机",
"fire": "起火"
}
label = label_map.get(class_name, class_name)
# draw.text((x1, y1 - 30), f"{label} {conf:.2f}",
# font=self.font, fill=(0, 255, 0))
# annotated_frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
# 绘制ROI区域
if roi and roi != (0, 0, 0, 0):
x, y, w, h = roi
cv2.rectangle(annotated_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
return annotated_frame, results[0]
except Exception as e:
print(f"处理帧时出错: {str(e)}")
return frame, None