ContactDectect.cpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #include "ContactDectect.h"
  2. using namespace std;
  3. /// <summary>
  4. /// 连接头
  5. /// </summary>
  6. /// <param name="isCuda"></param>
  7. /// <returns></returns>
  8. bool ContactDectect::Init(bool isCuda)
  9. {
  10. string model_path = "models/contact-sim.onnx";
  11. try {
  12. net = cv::dnn::readNet(model_path);
  13. }
  14. catch (const std::exception& ex)
  15. {
  16. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  17. return false;
  18. }
  19. //cuda
  20. if (isCuda) {
  21. net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
  22. net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
  23. }
  24. //cpu
  25. else {
  26. net.setPreferableBackend(cv::dnn::DNN_BACKEND_DEFAULT);
  27. net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
  28. }
  29. return true;
  30. }
  31. IDetection::DectectResult ContactDectect::GetStateResult(cv::Mat img, cv::Rect rec)
  32. {
  33. //resultValue.clear();
  34. std::cout << "test" << std::endl;
  35. try
  36. {
  37. cv::Mat ROI = img(rec);
  38. Detect(ROI);
  39. }
  40. catch (const std::exception& ex)
  41. {
  42. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  43. }
  44. if (resultValue.m_confidence < 0.1)
  45. {
  46. resultValue = DectectResult(0.45, 0, className[1]);
  47. }
  48. return resultValue;
  49. }
  50. IDetection::DectectResult ContactDectect::GetDigitResult(cv::Mat img, cv::Rect rec)
  51. {
  52. return resultValue;
  53. }
  54. bool ContactDectect::Detect(cv::Mat& SrcImg) {
  55. cv::Mat blob;
  56. int col = SrcImg.cols;
  57. int row = SrcImg.rows;
  58. int maxLen = MAX(col, row);
  59. cv::Mat netInputImg = SrcImg.clone();
  60. if (maxLen > 1.2 * col || maxLen > 1.2 * row) {
  61. cv::Mat resizeImg = cv::Mat::zeros(maxLen, maxLen, CV_8UC3);
  62. SrcImg.copyTo(resizeImg(cv::Rect(0, 0, col, row)));
  63. netInputImg = resizeImg;
  64. }
  65. cv::dnn::blobFromImage(netInputImg, blob, 1 / 255.0, cv::Size(netWidth, netHeight), cv::Scalar(0, 0, 0), true, false);
  66. net.setInput(blob);
  67. std::vector<cv::Mat> netOutputImg;
  68. net.forward(netOutputImg, net.getUnconnectedOutLayersNames());
  69. std::vector<int> classIds;//结果id数组
  70. std::vector<float> confidences;//结果每个id对应置信度数组
  71. std::vector<cv::Rect> boxes;//每个id矩形框
  72. float ratio_h = (float)netInputImg.rows / netHeight;
  73. float ratio_w = (float)netInputImg.cols / netWidth;
  74. int net_width = className.size() + 5; //输出的网络宽度是类别数+5
  75. float* pdata = (float*)netOutputImg[0].data;
  76. for (int stride = 0; stride < strideSize; stride++) { //stride
  77. int grid_x = (int)(netWidth / netStride[stride]);
  78. int grid_y = (int)(netHeight / netStride[stride]);
  79. for (int anchor = 0; anchor < 3; anchor++) { //anchors
  80. const float anchor_w = netAnchors[stride][anchor * 2];
  81. const float anchor_h = netAnchors[stride][anchor * 2 + 1];
  82. for (int i = 0; i < grid_y; i++) {
  83. for (int j = 0; j < grid_x; j++) {
  84. float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
  85. if (box_score >= boxThreshold) {
  86. cv::Mat scores(1, className.size(), CV_32FC1, pdata + 5);
  87. cv::Point classIdPoint;
  88. double max_class_socre;
  89. minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
  90. max_class_socre = (float)max_class_socre;
  91. if (max_class_socre >= classThreshold)
  92. {
  93. //rect [x,y,w,h]
  94. float x = pdata[0]; //x
  95. float y = pdata[1]; //y
  96. float w = pdata[2]; //w
  97. float h = pdata[3]; //h
  98. int left = (x - 0.5 * w) * ratio_w;
  99. int top = (y - 0.5 * h) * ratio_h;
  100. left = left < 0 ? 0 : left;
  101. top = top < 0 ? 0 : top;
  102. int widthBox = int(w * ratio_w);
  103. int heightBox = int(h * ratio_h);
  104. widthBox = widthBox > col ? col : widthBox;
  105. heightBox = heightBox > row ? row : heightBox;
  106. if (left < 0 || left>col || top < 0 || top>row || widthBox > col || heightBox > row)
  107. {
  108. continue;
  109. }
  110. classIds.push_back(classIdPoint.x);
  111. confidences.push_back(max_class_socre * box_score);
  112. boxes.push_back(cv::Rect(left, top, widthBox, heightBox));
  113. }
  114. }
  115. pdata += net_width;//下一行
  116. }
  117. }
  118. }
  119. }
  120. //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
  121. vector<int> nms_result;
  122. cv::dnn::NMSBoxes(boxes, confidences, nmsScoreThreshold, nmsThreshold, nms_result);
  123. float confidenceMax = 0;
  124. int confidenceMaxId = 0;
  125. if (nms_result.size() > 0)
  126. {
  127. for (int i = 0; i < nms_result.size(); i++) {
  128. int idx = nms_result[i];
  129. if (confidences[idx] > confidenceMax)
  130. {
  131. confidenceMax = confidences[idx];
  132. confidenceMaxId = idx;
  133. }
  134. }
  135. if (confidenceMax>0)
  136. {
  137. resultValue = DectectResult(confidenceMax, 0, className[classIds[confidenceMaxId]]);
  138. YunDaISASImageRecognitionService::ConsoleLog(QString::fromStdString(className[classIds[confidenceMaxId]]));
  139. }
  140. }
  141. else
  142. {
  143. resultValue = DectectResult(confidenceMax, 0, className[1]);
  144. }
  145. return true;
  146. }