128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import torch
|
||
import numpy as np
|
||
from PIL import ImageFont, ImageDraw, Image
|
||
|
||
|
||
class YOLODetector:
|
||
def __init__(self, model_path='best_weights/best.pt'):
|
||
try:
|
||
# 初始化YOLO模型
|
||
self.model = YOLO(model_path)
|
||
print(f"成功加载模型:{model_path}")
|
||
|
||
self.font = ImageFont.truetype("msyh.ttc", 21)
|
||
# 检测配置
|
||
self.predict_config = {
|
||
'conf_thres': 0.25,
|
||
'iou_thres': 0.30,
|
||
'imgsz': 640,
|
||
'line_width': 2
|
||
}
|
||
|
||
# 设置设备
|
||
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||
if self.device == 'cuda:0':
|
||
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
|
||
torch.backends.cudnn.benchmark = True
|
||
else:
|
||
print("警告:检测,将回退到CPU模式")
|
||
|
||
self.model.to(self.device)
|
||
|
||
except Exception as e:
|
||
print(f"模型初始化失败: {str(e)}")
|
||
raise
|
||
|
||
def process_frame(self, frame, roi=None):
|
||
"""处理单帧图像"""
|
||
try:
|
||
if frame is None or frame.size == 0:
|
||
print("收到空帧")
|
||
return None, None
|
||
|
||
# 如果有ROI,处理ROI区域
|
||
if roi and roi != (0, 0, 0, 0):
|
||
x, y, w, h = roi
|
||
if x >= 0 and y >= 0 and w > 0 and h > 0 and \
|
||
x + w <= frame.shape[1] and y + h <= frame.shape[0]:
|
||
frame_roi = frame[y:y + h, x:x + w]
|
||
else:
|
||
# print("ROI 超出图像范围")
|
||
frame_roi = frame
|
||
else:
|
||
frame_roi = frame
|
||
|
||
# 运行YOLO检测
|
||
results = self.model(
|
||
source=frame_roi,
|
||
conf=self.predict_config['conf_thres'],
|
||
iou=self.predict_config['iou_thres'],
|
||
imgsz=self.predict_config['imgsz'],
|
||
device=self.device,
|
||
verbose=False
|
||
)
|
||
|
||
# 在图像上绘制检测结果
|
||
annotated_frame = frame.copy()
|
||
|
||
if len(results) > 0 and results[0].boxes is not None and len(results[0].boxes) > 0:
|
||
for box, conf, cls in zip(results[0].boxes.xyxy,
|
||
results[0].boxes.conf,
|
||
results[0].boxes.cls):
|
||
class_name = results[0].names[int(cls)]
|
||
x1, y1, x2, y2 = map(int, box)
|
||
|
||
# 如果使用ROI,调整坐标
|
||
# if roi and roi != (0, 0, 0, 0):
|
||
if roi and len(roi) == 4:
|
||
# x1, y1 = x1 + roi[0], y1 + roi[1]
|
||
# x2, y2 = x2 + roi[0], y2 + roi[1]
|
||
x1 += roi[0]
|
||
y1 += roi[1]
|
||
x2 += roi[0]
|
||
y2 += roi[1]
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(annotated_frame,
|
||
(x1, y1),
|
||
(x2, y2),
|
||
(0, 255, 0),
|
||
self.predict_config['line_width'])
|
||
|
||
# 添加中文标签
|
||
# pil_img = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
||
# draw = ImageDraw.Draw(pil_img)
|
||
|
||
# 根据类别显示不同的中文标签
|
||
class_name = results[0].names[int(cls)].lower()
|
||
label_map = {
|
||
"helmet": "安全帽",
|
||
"person": "人员",
|
||
"safevest": "工服",
|
||
"smoke": "吸烟",
|
||
"animal": "异物入侵",
|
||
"cellphone": "玩手机",
|
||
"fire": "起火"
|
||
}
|
||
|
||
label = label_map.get(class_name, class_name)
|
||
|
||
# draw.text((x1, y1 - 30), f"{label} {conf:.2f}",
|
||
# font=self.font, fill=(0, 255, 0))
|
||
# annotated_frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||
|
||
# 绘制ROI区域
|
||
if roi and roi != (0, 0, 0, 0):
|
||
x, y, w, h = roi
|
||
cv2.rectangle(annotated_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
|
||
|
||
return annotated_frame, results[0]
|
||
|
||
except Exception as e:
|
||
print(f"处理帧时出错: {str(e)}")
|
||
return frame, None
|
||
|
||
|