#include "ContactDectect.h"
using namespace std;
///
/// 连接头
///
///
///
bool ContactDectect::Init(bool isCuda)
{
string model_path = "models/contact-sim.onnx";
try {
net = cv::dnn::readNet(model_path);
}
catch (const std::exception& ex)
{
YunDaISASImageRecognitionService::ConsoleLog(ex.what());
return false;
}
//cuda
if (isCuda) {
net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
}
//cpu
else {
net.setPreferableBackend(cv::dnn::DNN_BACKEND_DEFAULT);
net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
}
return true;
}
IDetection::DectectResult ContactDectect::GetStateResult(cv::Mat img, cv::Rect rec)
{
//resultValue.clear();
std::cout << "test" << std::endl;
try
{
cv::Mat ROI = img(rec);
Detect(ROI);
}
catch (const std::exception& ex)
{
YunDaISASImageRecognitionService::ConsoleLog(ex.what());
}
if (resultValue.m_confidence < 0.1)
{
resultValue = DectectResult(0.45, 0, className[1]);
}
return resultValue;
}
IDetection::DectectResult ContactDectect::GetDigitResult(cv::Mat img, cv::Rect rec)
{
return resultValue;
}
bool ContactDectect::Detect(cv::Mat& SrcImg) {
cv::Mat blob;
int col = SrcImg.cols;
int row = SrcImg.rows;
int maxLen = MAX(col, row);
cv::Mat netInputImg = SrcImg.clone();
if (maxLen > 1.2 * col || maxLen > 1.2 * row) {
cv::Mat resizeImg = cv::Mat::zeros(maxLen, maxLen, CV_8UC3);
SrcImg.copyTo(resizeImg(cv::Rect(0, 0, col, row)));
netInputImg = resizeImg;
}
cv::dnn::blobFromImage(netInputImg, blob, 1 / 255.0, cv::Size(netWidth, netHeight), cv::Scalar(0, 0, 0), true, false);
net.setInput(blob);
std::vector netOutputImg;
net.forward(netOutputImg, net.getUnconnectedOutLayersNames());
std::vector classIds;//结果id数组
std::vector confidences;//结果每个id对应置信度数组
std::vector boxes;//每个id矩形框
float ratio_h = (float)netInputImg.rows / netHeight;
float ratio_w = (float)netInputImg.cols / netWidth;
int net_width = className.size() + 5; //输出的网络宽度是类别数+5
float* pdata = (float*)netOutputImg[0].data;
for (int stride = 0; stride < strideSize; stride++) { //stride
int grid_x = (int)(netWidth / netStride[stride]);
int grid_y = (int)(netHeight / netStride[stride]);
for (int anchor = 0; anchor < 3; anchor++) { //anchors
const float anchor_w = netAnchors[stride][anchor * 2];
const float anchor_h = netAnchors[stride][anchor * 2 + 1];
for (int i = 0; i < grid_y; i++) {
for (int j = 0; j < grid_x; j++) {
float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
if (box_score >= boxThreshold) {
cv::Mat scores(1, className.size(), CV_32FC1, pdata + 5);
cv::Point classIdPoint;
double max_class_socre;
minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
max_class_socre = (float)max_class_socre;
if (max_class_socre >= classThreshold)
{
//rect [x,y,w,h]
float x = pdata[0]; //x
float y = pdata[1]; //y
float w = pdata[2]; //w
float h = pdata[3]; //h
int left = (x - 0.5 * w) * ratio_w;
int top = (y - 0.5 * h) * ratio_h;
left = left < 0 ? 0 : left;
top = top < 0 ? 0 : top;
int widthBox = int(w * ratio_w);
int heightBox = int(h * ratio_h);
widthBox = widthBox > col ? col : widthBox;
heightBox = heightBox > row ? row : heightBox;
if (left < 0 || left>col || top < 0 || top>row || widthBox > col || heightBox > row)
{
continue;
}
classIds.push_back(classIdPoint.x);
confidences.push_back(max_class_socre * box_score);
boxes.push_back(cv::Rect(left, top, widthBox, heightBox));
}
}
pdata += net_width;//下一行
}
}
}
}
//执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
vector nms_result;
cv::dnn::NMSBoxes(boxes, confidences, nmsScoreThreshold, nmsThreshold, nms_result);
float confidenceMax = 0;
int confidenceMaxId = 0;
if (nms_result.size() > 0)
{
for (int i = 0; i < nms_result.size(); i++) {
int idx = nms_result[i];
if (confidences[idx] > confidenceMax)
{
confidenceMax = confidences[idx];
confidenceMaxId = idx;
}
}
if (confidenceMax>0)
{
resultValue = DectectResult(confidenceMax, 0, className[classIds[confidenceMaxId]]);
YunDaISASImageRecognitionService::ConsoleLog(QString::fromStdString(className[classIds[confidenceMaxId]]));
}
}
else
{
resultValue = DectectResult(confidenceMax, 0, className[1]);
}
return true;
}