MeterDectect.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #include "MeterDectect.h"
  2. #include "DoublePointerCountALGO.h"
  3. #include "AtmosphericPressureALGO.h"
  4. #include "CircularArresterCurrentALGO.h"
  5. #include "ArresterCircularZeroThreeALGO.h"
  6. #include "CircularOilSinglePointALGO.h"
  7. #include "DoubleTemperatureGuageALGO.h"
  8. #include "AmpereMeterALGO.h"
  9. #include "VoltageMeterALGO.h"
  10. using namespace std;
  11. static AtmosphericPressureALGO atmosphericPressureALGO;
  12. static CircularArresterCurrentALGO circularArresterCurrentALGO;
  13. static CircularOilSinglePointALGO circularOilSinglePointALGO;
  14. static DoubleTemperatureGuageALGO doubleTemperatureGuageALGO;
  15. static AmpereMeterALGO ampereMeterALGO;
  16. static VoltageMeterALGO voltageMeterALGO;
  17. static DoublePointerCountALGO doublePointerCountALGO;
  18. bool MeterDectect::Init(bool isCuda)
  19. {
  20. string model_path = "models/meter-sim.onnx";
  21. try {
  22. net = cv::dnn::readNet(model_path);
  23. atmosphericPressureALGO.Init();
  24. circularArresterCurrentALGO.Init();
  25. circularOilSinglePointALGO.Init(isCuda);
  26. doubleTemperatureGuageALGO.Init(isCuda);
  27. ampereMeterALGO.Init(isCuda);
  28. voltageMeterALGO.Init(isCuda);
  29. doublePointerCountALGO.Init(isCuda);
  30. }
  31. catch (const std::exception& ex)
  32. {
  33. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  34. return false;
  35. }
  36. //cuda
  37. if (isCuda) {
  38. net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
  39. net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA_FP16);
  40. }
  41. //cpu
  42. else {
  43. net.setPreferableBackend(cv::dnn::DNN_BACKEND_DEFAULT);
  44. net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
  45. }
  46. return true;
  47. return false;
  48. }
  49. IDetection::DectectResult MeterDectect::GetStateResult(cv::Mat img, cv::Rect rec)
  50. {
  51. return resultValue;
  52. }
  53. IDetection::DectectResult MeterDectect::GetDigitResult(cv::Mat img, cv::Rect rec)
  54. {
  55. //resultValue.clear();
  56. //std::cout << "test" << std::endl;
  57. //try
  58. //{
  59. // cv::Mat ROI = img(rec);
  60. // /*imwrite("test.png", ROI);
  61. // YunDaISASImageRecognitionService::SetImage(QString::fromStdString("test.png"));*/
  62. // Detect(ROI);
  63. //}
  64. //catch (const std::exception& ex)
  65. //{
  66. // YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  67. //}
  68. //if (resultValue.m_confidence < 0.1)
  69. //{
  70. // resultValue = DectectResult(0.45, 0, "");
  71. //}
  72. return resultValue;
  73. }
  74. vector<IDetection::DectectResult> MeterDectect::GetDigitResults(cv::Mat img, cv::Rect rec)
  75. {
  76. resultValues.clear();
  77. try
  78. {
  79. cv::Mat ROI = img(rec);
  80. /*imwrite("test.png", ROI);
  81. YunDaISASImageRecognitionService::SetImage(QString::fromStdString("test.png"));*/
  82. Detect(ROI);
  83. if (YunDaISASImageRecognitionService::GetIsShowDectect())
  84. {
  85. cv::Mat drawROI;
  86. ROI.copyTo(drawROI);
  87. DrawPred(drawROI, output,className);
  88. imwrite("test.png", drawROI);
  89. YunDaISASImageRecognitionService::SetImage(QString::fromStdString("test.png"));
  90. }
  91. }
  92. catch (const std::exception& ex)
  93. {
  94. YunDaISASImageRecognitionService::ConsoleLog(ex.what());
  95. }
  96. /*if (resultValue.m_confidence < 0.1)
  97. {
  98. resultValue = DectectResult(0.45, 0, "");
  99. }*/
  100. return resultValues;
  101. }
  102. bool MeterDectect::Detect(cv::Mat& SrcImg)
  103. {
  104. cv::Mat blob;
  105. int col = SrcImg.cols;
  106. int row = SrcImg.rows;
  107. int maxLen = MAX(col, row);
  108. cv::Mat netInputImg = SrcImg.clone();
  109. if (maxLen > 1.2 * col || maxLen > 1.2 * row) {
  110. cv::Mat resizeImg = cv::Mat::zeros(maxLen, maxLen, CV_8UC3);
  111. SrcImg.copyTo(resizeImg(cv::Rect(0, 0, col, row)));
  112. netInputImg = resizeImg;
  113. }
  114. cv::dnn::blobFromImage(netInputImg, blob, 1 / 255.0, cv::Size(netWidth, netHeight), cv::Scalar(0, 0, 0), true, false);
  115. net.setInput(blob);
  116. std::vector<cv::Mat> netOutputImg;
  117. net.forward(netOutputImg, net.getUnconnectedOutLayersNames());
  118. std::vector<int> classIds;//结果id数组
  119. std::vector<float> confidences;//结果每个id对应置信度数组
  120. std::vector<cv::Rect> boxes;//每个id矩形框
  121. float ratio_h = (float)netInputImg.rows / netHeight;
  122. float ratio_w = (float)netInputImg.cols / netWidth;
  123. int net_width = className.size() + 5; //输出的网络宽度是类别数+5
  124. float* pdata = (float*)netOutputImg[0].data;
  125. for (int stride = 0; stride < strideSize; stride++) { //stride
  126. int grid_x = (int)(netWidth / netStride[stride]);
  127. int grid_y = (int)(netHeight / netStride[stride]);
  128. for (int anchor = 0; anchor < 3; anchor++) { //anchors
  129. const float anchor_w = netAnchors[stride][anchor * 2];
  130. const float anchor_h = netAnchors[stride][anchor * 2 + 1];
  131. for (int i = 0; i < grid_y; i++) {
  132. for (int j = 0; j < grid_x; j++) {
  133. float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
  134. if (box_score >= boxThreshold) {
  135. cv::Mat scores(1, className.size(), CV_32FC1, pdata + 5);
  136. cv::Point classIdPoint;
  137. double max_class_socre;
  138. minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
  139. max_class_socre = (float)max_class_socre;
  140. if (max_class_socre >= classThreshold)
  141. {
  142. //rect [x,y,w,h]
  143. float x = pdata[0]; //x
  144. float y = pdata[1]; //y
  145. float w = pdata[2]; //w
  146. float h = pdata[3]; //h
  147. int left = (x - 0.5 * w) * ratio_w;
  148. int top = (y - 0.5 * h) * ratio_h;
  149. left = left < 0 ? 0 : left;
  150. top = top < 0 ? 0 : top;
  151. int widthBox = int(w * ratio_w);
  152. int heightBox = int(h * ratio_h);
  153. widthBox = widthBox > col ? col : widthBox;
  154. heightBox = heightBox > row ? row : heightBox;
  155. if (left < 0|| left>col|| top < 0|| top>row|| widthBox > col|| heightBox > row
  156. || left+ widthBox> col|| top+ heightBox> row
  157. )
  158. {
  159. continue;
  160. }
  161. classIds.push_back(classIdPoint.x);
  162. confidences.push_back(max_class_socre * box_score);
  163. boxes.push_back(cv::Rect(left, top, widthBox, heightBox));
  164. }
  165. }
  166. pdata += net_width;//下一行
  167. }
  168. }
  169. }
  170. }
  171. //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
  172. vector<int> nms_result;
  173. cv::dnn::NMSBoxes(boxes, confidences, nmsScoreThreshold, nmsThreshold, nms_result);
  174. float confidenceMax = -1;
  175. int confidenceMaxId = 0;
  176. output.clear();
  177. if (nms_result.size() > 0)
  178. {
  179. for (int i = 0; i < nms_result.size()&&i<3; i++)
  180. {
  181. int idx = nms_result[i];
  182. Output result(classIds[idx], confidences[idx], boxes[idx]);
  183. /*int figure = 0;
  184. result.id = classIds[idx];
  185. result.confidence = confidences[idx];
  186. result.box = boxes[idx];*/
  187. output.push_back(result);
  188. if (className[classIds[idx]] == "rect_arrester_current")
  189. {
  190. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  191. auto resValue = circularArresterCurrentALGO.detect(SrcImg);
  192. tempResultValue.m_dValue = resValue;
  193. resultValues.push_back(tempResultValue);
  194. }
  195. else if (className[classIds[idx]] == "atmospheric_pressure")
  196. {
  197. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  198. //auto resState = atmosphericPressureALGO.detect(SrcImg(boxes[idx]));
  199. auto resState = atmosphericPressureALGO.detect(SrcImg);
  200. tempResultValue.m_dValue = atmosphericPressureALGO.resultValue;
  201. tempResultValue.m_sValue = "atmospheric_pressure";
  202. resultValues.push_back(tempResultValue);
  203. }
  204. else if (className[classIds[idx]] == "digital_gear")
  205. {
  206. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  207. resultValues.push_back(tempResultValue);
  208. }
  209. else if (className[classIds[idx]] == "circular_arrester_current") //0-3mA避雷器电流
  210. {
  211. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  212. ArresterCircularZeroThreeALGO* algo = new ArresterCircularZeroThreeALGO;
  213. auto resNum = algo->GetResult(SrcImg(boxes[idx]), false);
  214. tempResultValue.m_dValue = resNum;
  215. delete algo;
  216. resultValues.push_back(tempResultValue);
  217. }
  218. else if (className[classIds[idx]] == "ampere_meter")
  219. {
  220. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  221. tempResultValue.m_dValue = ampereMeterALGO.detect(SrcImg);
  222. resultValues.push_back(tempResultValue);
  223. }
  224. else if (className[classIds[idx]] == "voltage_meter")
  225. {
  226. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  227. tempResultValue.m_dValue = voltageMeterALGO.detect(SrcImg);
  228. resultValues.push_back(tempResultValue);
  229. }
  230. /*else if (className[classIds[idx]] == "double_pointer_count1")
  231. {
  232. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  233. DoublePointerCountALGO* doublePointerCountALGO = new DoublePointerCountALGO();
  234. int resCount = doublePointerCountALGO->GetResult(SrcImg(boxes[idx]), false);
  235. tempResultValue.m_dValue = resCount;
  236. delete doublePointerCountALGO;
  237. resultValues.push_back(tempResultValue);
  238. }*/
  239. else if (className[classIds[idx]] == "double_pointer_count")
  240. {
  241. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  242. tempResultValue.m_dValue = doublePointerCountALGO.detect(SrcImg);
  243. resultValues.push_back(tempResultValue);
  244. }
  245. else if (className[classIds[idx]] == "transformers_oil_surface_thermometer")
  246. {
  247. auto tempResultValueMin = DectectResult(confidences[idx], 60, className[classIds[idx]] + "_Min");
  248. auto tempResultValueMax = DectectResult(confidences[idx], 80, className[classIds[idx]] + "_Max");
  249. auto min_max = doubleTemperatureGuageALGO.detect(SrcImg, 150);//150表示量程
  250. tempResultValueMin.m_dValue = min_max.first;
  251. tempResultValueMax.m_dValue = min_max.second;
  252. resultValues.push_back(tempResultValueMin);
  253. resultValues.push_back(tempResultValueMax);
  254. }
  255. else if (className[classIds[idx]] == "close")
  256. {
  257. resultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  258. }
  259. else if (className[classIds[idx]] == "open")
  260. {
  261. resultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  262. }
  263. else if (className[classIds[idx]] == "round_single_point_oil_level")
  264. {
  265. auto tempResultValue = DectectResult(confidences[idx], 0, className[classIds[idx]]);
  266. tempResultValue.m_dValue = circularOilSinglePointALGO.detect(SrcImg);
  267. resultValues.push_back(tempResultValue);
  268. }
  269. else if (className[classIds[idx]] == "oil_thermometer")//油位温控
  270. {
  271. auto tempResultValueMin = DectectResult(confidences[idx], 60, className[classIds[idx]] + "_Min");
  272. auto tempResultValueMax = DectectResult(confidences[idx], 80, className[classIds[idx]] + "_Max");
  273. auto min_max = doubleTemperatureGuageALGO.detect(SrcImg,120);//120表示量程
  274. tempResultValueMin.m_dValue = min_max.first;
  275. tempResultValueMax.m_dValue = min_max.second;
  276. resultValues.push_back(tempResultValueMin);
  277. resultValues.push_back(tempResultValueMax);
  278. }
  279. YunDaISASImageRecognitionService::ConsoleLog(QString::fromStdString(className[classIds[idx]]));
  280. }
  281. }
  282. else {
  283. resultValue = DectectResult(confidenceMax, 0.01, "");
  284. }
  285. return false;
  286. }