上传文件至 main
This commit is contained in:
parent
e864f3aab2
commit
0a605211ca
127
main/yolo_detector.py
Normal file
127
main/yolo_detector.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import ImageFont, ImageDraw, Image
|
||||||
|
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
def __init__(self, model_path='best_weights/best.pt'):
|
||||||
|
try:
|
||||||
|
# 初始化YOLO模型
|
||||||
|
self.model = YOLO(model_path)
|
||||||
|
print(f"成功加载模型:{model_path}")
|
||||||
|
|
||||||
|
self.font = ImageFont.truetype("msyh.ttc", 21)
|
||||||
|
# 检测配置
|
||||||
|
self.predict_config = {
|
||||||
|
'conf_thres': 0.25,
|
||||||
|
'iou_thres': 0.30,
|
||||||
|
'imgsz': 640,
|
||||||
|
'line_width': 2
|
||||||
|
}
|
||||||
|
|
||||||
|
# 设置设备
|
||||||
|
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
|
if self.device == 'cuda:0':
|
||||||
|
print(f"使用GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
else:
|
||||||
|
print("警告:检测,将回退到CPU模式")
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"模型初始化失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def process_frame(self, frame, roi=None):
|
||||||
|
"""处理单帧图像"""
|
||||||
|
try:
|
||||||
|
if frame is None or frame.size == 0:
|
||||||
|
print("收到空帧")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 如果有ROI,处理ROI区域
|
||||||
|
if roi and roi != (0, 0, 0, 0):
|
||||||
|
x, y, w, h = roi
|
||||||
|
if x >= 0 and y >= 0 and w > 0 and h > 0 and \
|
||||||
|
x + w <= frame.shape[1] and y + h <= frame.shape[0]:
|
||||||
|
frame_roi = frame[y:y + h, x:x + w]
|
||||||
|
else:
|
||||||
|
# print("ROI 超出图像范围")
|
||||||
|
frame_roi = frame
|
||||||
|
else:
|
||||||
|
frame_roi = frame
|
||||||
|
|
||||||
|
# 运行YOLO检测
|
||||||
|
results = self.model(
|
||||||
|
source=frame_roi,
|
||||||
|
conf=self.predict_config['conf_thres'],
|
||||||
|
iou=self.predict_config['iou_thres'],
|
||||||
|
imgsz=self.predict_config['imgsz'],
|
||||||
|
device=self.device,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在图像上绘制检测结果
|
||||||
|
annotated_frame = frame.copy()
|
||||||
|
|
||||||
|
if len(results) > 0 and results[0].boxes is not None and len(results[0].boxes) > 0:
|
||||||
|
for box, conf, cls in zip(results[0].boxes.xyxy,
|
||||||
|
results[0].boxes.conf,
|
||||||
|
results[0].boxes.cls):
|
||||||
|
class_name = results[0].names[int(cls)]
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
|
||||||
|
# 如果使用ROI,调整坐标
|
||||||
|
# if roi and roi != (0, 0, 0, 0):
|
||||||
|
if roi and len(roi) == 4:
|
||||||
|
# x1, y1 = x1 + roi[0], y1 + roi[1]
|
||||||
|
# x2, y2 = x2 + roi[0], y2 + roi[1]
|
||||||
|
x1 += roi[0]
|
||||||
|
y1 += roi[1]
|
||||||
|
x2 += roi[0]
|
||||||
|
y2 += roi[1]
|
||||||
|
|
||||||
|
# 绘制边界框
|
||||||
|
cv2.rectangle(annotated_frame,
|
||||||
|
(x1, y1),
|
||||||
|
(x2, y2),
|
||||||
|
(0, 255, 0),
|
||||||
|
self.predict_config['line_width'])
|
||||||
|
|
||||||
|
# 添加中文标签
|
||||||
|
# pil_img = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
||||||
|
# draw = ImageDraw.Draw(pil_img)
|
||||||
|
|
||||||
|
# 根据类别显示不同的中文标签
|
||||||
|
class_name = results[0].names[int(cls)].lower()
|
||||||
|
label_map = {
|
||||||
|
"helmet": "安全帽",
|
||||||
|
"person": "人员",
|
||||||
|
"safevest": "工服",
|
||||||
|
"smoke": "吸烟",
|
||||||
|
"animal": "异物入侵",
|
||||||
|
"cellphone": "玩手机",
|
||||||
|
"fire": "起火"
|
||||||
|
}
|
||||||
|
|
||||||
|
label = label_map.get(class_name, class_name)
|
||||||
|
|
||||||
|
# draw.text((x1, y1 - 30), f"{label} {conf:.2f}",
|
||||||
|
# font=self.font, fill=(0, 255, 0))
|
||||||
|
# annotated_frame = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# 绘制ROI区域
|
||||||
|
if roi and roi != (0, 0, 0, 0):
|
||||||
|
x, y, w, h = roi
|
||||||
|
cv2.rectangle(annotated_frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
|
||||||
|
|
||||||
|
return annotated_frame, results[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理帧时出错: {str(e)}")
|
||||||
|
return frame, None
|
||||||
|
|
||||||
|
|
143
main/yolo_processor.py
Normal file
143
main/yolo_processor.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import msgpack
|
||||||
|
from datetime import datetime
|
||||||
|
from PIL import ImageFont, ImageDraw, Image
|
||||||
|
|
||||||
|
|
||||||
|
def process_frame_with_yolo(frame, channel_index, camera_configs, yolo_detector, redis_client):
|
||||||
|
roi = camera_configs[channel_index]['box']
|
||||||
|
types = camera_configs[channel_index]['types']
|
||||||
|
height, width = frame.shape[:2]
|
||||||
|
|
||||||
|
# ROI坐标转换
|
||||||
|
x_min = int(roi[0] * width)
|
||||||
|
y_min = int(roi[1] * height)
|
||||||
|
x_max = int(roi[2] * width)
|
||||||
|
y_max = int(roi[3] * height)
|
||||||
|
roi_converted = (x_min, y_min, x_max - x_min, y_max - y_min)
|
||||||
|
|
||||||
|
# 执行YOLO检测
|
||||||
|
start_time = datetime.now()
|
||||||
|
frame_with_boxes, results = yolo_detector.process_frame(frame, roi_converted)
|
||||||
|
process_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||||
|
print(f"\n通道 {channel_index + 1} 检测结果 (处理时间: {process_time: .2f}ms):")
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
person_boxes = []
|
||||||
|
helmet_boxes = []
|
||||||
|
safevest_boxes = []
|
||||||
|
smoke_boxes = []
|
||||||
|
other_detections = []
|
||||||
|
|
||||||
|
if results is not None and hasattr(results, 'boxes') and len(results.boxes) > 0:
|
||||||
|
# 新增中文映射
|
||||||
|
class_mapping = {
|
||||||
|
"animal": "异物入侵",
|
||||||
|
"cellphone": "玩手机",
|
||||||
|
"fire": "起火"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 第一步:分类存储所有检测结果(根据types过滤)
|
||||||
|
for box, conf, cls in zip(results.boxes.xyxy, results.boxes.conf, results.boxes.cls):
|
||||||
|
class_name = results.names[int(cls)]
|
||||||
|
if class_name not in types: # 关键过滤逻辑
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
normalized_box = (
|
||||||
|
x1 / width,
|
||||||
|
y1 / height,
|
||||||
|
(x2 - x1) / width,
|
||||||
|
(y2 - y1) / height
|
||||||
|
)
|
||||||
|
|
||||||
|
if class_name == "person":
|
||||||
|
person_boxes.append((box, conf, (x1, y1, x2, y2)))
|
||||||
|
elif class_name == "helmet":
|
||||||
|
helmet_boxes.append((x1, y1, x2, y2))
|
||||||
|
elif class_name == "safevest":
|
||||||
|
safevest_boxes.append((x1, y1, x2, y2))
|
||||||
|
elif class_name == "smoke":
|
||||||
|
smoke_boxes.append((x1, y1, x2, y2))
|
||||||
|
elif class_name in ["animal", "cellphone", "fire"]:
|
||||||
|
class_name_cn = class_mapping.get(class_name, class_name)
|
||||||
|
other_detections.append({
|
||||||
|
"class": class_name_cn,
|
||||||
|
"confidence": float(conf),
|
||||||
|
"bbox": {
|
||||||
|
"x_min": x1 / width,
|
||||||
|
"y_min": y1 / height,
|
||||||
|
"width": (x2 - x1) / width,
|
||||||
|
"height": (y2 - y1) / height
|
||||||
|
}
|
||||||
|
})
|
||||||
|
print(f"- 独立检测: {class_name_cn}, 置信度: {conf: .2f}, 位置: {box.tolist()}")
|
||||||
|
|
||||||
|
# 第二步:处理人员状态(仅在需要检测person时处理)
|
||||||
|
detections = []
|
||||||
|
if "person" in types:
|
||||||
|
for (box, conf, (x1, y1, x2, y2)) in person_boxes:
|
||||||
|
def calculate_iou(boxA, boxB):
|
||||||
|
xA = max(boxA[0], boxB[0])
|
||||||
|
yA = max(boxA[1], boxB[1])
|
||||||
|
xB = min(boxA[2], boxB[2])
|
||||||
|
yB = min(boxA[3], boxB[3])
|
||||||
|
interArea = max(0, xB - xA) * max(0, yB - yA)
|
||||||
|
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||||||
|
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||||||
|
return interArea / float(boxAArea + boxBArea - interArea)
|
||||||
|
|
||||||
|
# 根据配置动态判断关联项
|
||||||
|
has_helmet = any(calculate_iou((x1, y1, x2, y2), h_box) > 0.1 for h_box in helmet_boxes) if "helmet" in types else False
|
||||||
|
has_safevest = any(calculate_iou((x1, y1, x2, y2), s_box) > 0.1 for s_box in safevest_boxes) if "safevest" in types else False
|
||||||
|
has_smoke = any(calculate_iou((x1, y1, x2, y2), sm_box) > 0.1 for sm_box in smoke_boxes) if "smoke" in types else False
|
||||||
|
|
||||||
|
# 生成状态标签
|
||||||
|
status_label = "人员"
|
||||||
|
violations = []
|
||||||
|
if has_smoke:
|
||||||
|
status_label = "吸烟"
|
||||||
|
else:
|
||||||
|
if "helmet" in types and not has_helmet:
|
||||||
|
violations.append("未戴安全帽")
|
||||||
|
if "safevest" in types and not has_safevest:
|
||||||
|
violations.append("未穿工服")
|
||||||
|
if violations:
|
||||||
|
status_label = "违规:" + "、".join(violations)
|
||||||
|
else:
|
||||||
|
if "helmet" in types or "safevest" in types:
|
||||||
|
status_label = "着装规范"
|
||||||
|
|
||||||
|
detections.append({
|
||||||
|
"class": status_label,
|
||||||
|
"confidence": float(conf),
|
||||||
|
"bbox": {
|
||||||
|
"x_min": x1 / width,
|
||||||
|
"y_min": y1 / height,
|
||||||
|
"width": (x2 - x1) / width,
|
||||||
|
"height": (y2 - y1) / height
|
||||||
|
}
|
||||||
|
})
|
||||||
|
print(f"- 状态: {status_label}, 置信度: {conf:.2f}, 位置: {box}")
|
||||||
|
|
||||||
|
# 添加独立检测类别
|
||||||
|
detections.extend(other_detections)
|
||||||
|
if not detections:
|
||||||
|
print("- 未检测到任何目标")
|
||||||
|
|
||||||
|
# 序列化并发送结果
|
||||||
|
data = {
|
||||||
|
"channel": str(camera_configs[channel_index]['channel']),
|
||||||
|
"detections": detections,
|
||||||
|
"image_size": [width, height]
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
serialized_data = msgpack.packb(data)
|
||||||
|
redis_client.publish('detection_result_channel', serialized_data)
|
||||||
|
print(f"[Redis] 通道 {camera_configs[channel_index]['channel']} 数据发送成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Redis] 发送失败: {str(e)}")
|
||||||
|
|
||||||
|
return frame_with_boxes
|
Loading…
x
Reference in New Issue
Block a user