AppearanceClassifyDectect.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #include "AppearanceClassifyDectect.h"
  2. #include "YunDaISASImageRecognitionService.h"
  3. //using namespace cv::dnn;
  4. //int img;
  5. cv::Mat img2;
  6. cv::Mat imageWarp2;
  7. bool AppearanceClassifyDectect::Init(bool isCuda) {
  8. try {
  9. string model_path = "models/appearanceClassify.onnx";
  10. net = cv::dnn::readNet(model_path);
  11. }
  12. catch (const std::exception& ex)
  13. {
  14. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  15. return false;
  16. }
  17. //cuda
  18. if (isCuda) {
  19. net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
  20. net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
  21. }
  22. //cpu
  23. else {
  24. net.setPreferableBackend(cv::dnn::DNN_BACKEND_DEFAULT);
  25. net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
  26. }
  27. return true;
  28. }
  29. bool AppearanceClassifyDectect::Detect(cv::Mat& SrcImg, vector<Output2>& output ) {
  30. cv::Mat blob;
  31. int col = SrcImg.cols;
  32. int row = SrcImg.rows;
  33. int maxLen = MAX(col, row);
  34. cv::Mat netInputImg = SrcImg.clone();
  35. if (maxLen > 1.2 * col || maxLen > 1.2 * row) {
  36. cv::Mat resizeImg = cv::Mat::zeros(maxLen, maxLen, CV_8UC3);
  37. SrcImg.copyTo(resizeImg(cv::Rect(0, 0, col, row)));
  38. netInputImg = resizeImg;
  39. }
  40. cv::dnn::blobFromImage(netInputImg, blob, 1 / 255.0, cv::Size(netWidth, netHeight), cv::Scalar(0, 0, 0), true, false);
  41. net.setInput(blob);
  42. std::vector<cv::Mat> netOutputImg;
  43. net.forward(netOutputImg, net.getUnconnectedOutLayersNames());
  44. std::vector<int> classIds;//结果id数组
  45. std::vector<float> confidences;//结果每个id对应置信度数组
  46. std::vector<cv::Rect> boxes;//每个id矩形框
  47. float ratio_h = (float)netInputImg.rows / netHeight;
  48. float ratio_w = (float)netInputImg.cols / netWidth;
  49. int net_width = className.size() + 5; //输出的网络宽度是类别数+5
  50. float* pdata = (float*)netOutputImg[0].data;
  51. for (int stride = 0; stride < strideSize; stride++) { //stride
  52. int grid_x = (int)(netWidth / netStride[stride]);
  53. int grid_y = (int)(netHeight / netStride[stride]);
  54. for (int anchor = 0; anchor < 3; anchor++) { //anchors
  55. const float anchor_w = netAnchors[stride][anchor * 2];
  56. const float anchor_h = netAnchors[stride][anchor * 2 + 1];
  57. for (int i = 0; i < grid_y; i++) {
  58. for (int j = 0; j < grid_x; j++) {
  59. float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
  60. if (box_score >= boxThreshold) {
  61. cv::Mat scores(1, className.size(), CV_32FC1, pdata + 5);
  62. cv::Point classIdPoint;
  63. double max_class_socre;
  64. minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
  65. max_class_socre = (float)max_class_socre;
  66. if (max_class_socre >= classThreshold)
  67. {
  68. //rect [x,y,w,h]
  69. float x = pdata[0]; //x
  70. float y = pdata[1]; //y
  71. float w = pdata[2]; //w
  72. float h = pdata[3]; //h
  73. int left = (x - 0.5 * w) * ratio_w;
  74. int top = (y - 0.5 * h) * ratio_h;
  75. int widthBox = int(w * ratio_w);
  76. int heightBox = int(h * ratio_h);
  77. widthBox = widthBox > col ? col : widthBox;
  78. heightBox = heightBox > row ? row : heightBox;
  79. left = left < 0 ? 0 : left;
  80. top = top < 0 ? 0 : top;
  81. if (left < 0 || left>col || top < 0 || top>row || widthBox > col || heightBox > row)
  82. {
  83. continue;
  84. }
  85. classIds.push_back(classIdPoint.x);
  86. confidences.push_back(max_class_socre * box_score);
  87. boxes.push_back(cv::Rect(left, top, widthBox, heightBox));
  88. }
  89. }
  90. pdata += net_width;//下一行
  91. }
  92. }
  93. }
  94. }
  95. //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
  96. vector<int> nms_result;
  97. cv::dnn::NMSBoxes(boxes, confidences, nmsScoreThreshold, nmsThreshold, nms_result);
  98. for (int i = 0; i < nms_result.size(); i++) {
  99. int idx = nms_result[i];
  100. Output2 result2;
  101. int figure =0;
  102. result2.id = classIds[idx];
  103. result2.confidence = confidences[idx];
  104. result2.box = boxes[idx];
  105. output.push_back(result2);
  106. }
  107. if (output.size())
  108. return true;
  109. else
  110. return false;
  111. }
  112. void AppearanceClassifyDectect::drawPred(cv::Mat& img, vector<Output2> result2)
  113. {
  114. for (int i = 0; i < result2.size(); i++) {
  115. try
  116. {
  117. int left = result2[i].box.x;
  118. int top = result2[i].box.y;
  119. int width = result2[i].box.width;
  120. int height = result2[i].box.height;
  121. int baseLine;
  122. //2022.10.17
  123. cv::Rect box = result2[i].box;
  124. /*MeterRead meterRead;*/
  125. cv::Mat input_image_copy = img.clone();
  126. cv::Mat cutMat = input_image_copy(box);
  127. //string label = className[result[i].id] + ":" + to_string(result[i].confidence);
  128. string label = className[result2[i].id]+ ": " + to_string(result2[i].confidence);
  129. cv::Size labelSize = getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  130. top = max(top, labelSize.height);
  131. cv::rectangle(img, cv::Point(left, top - int(2 * labelSize.height)), cv::Point(left + int(2 * labelSize.width), top + baseLine), cv::Scalar(0, 0, 255), cv::FILLED);
  132. cv::rectangle(img, cv::Point(left, top), cv::Point(left + width, top + height), cv::Scalar(0, 0, 255), 3);
  133. cv::putText(img, label, cv::Point(left, top), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(255, 255, 255), 3);
  134. }
  135. catch (const std::exception& ex)
  136. {
  137. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  138. }
  139. }
  140. }
  141. void AppearanceClassifyDectect::modifyConfidenceParameter(float boxThresholdPara, float classThresholdPara, float nmsThresholdPara)
  142. {
  143. boxThreshold = boxThresholdPara;
  144. classThreshold = classThresholdPara;
  145. nmsThreshold = nmsThresholdPara;
  146. }