result_set.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_RESULTSET_H
  31. #define OPENCV_FLANN_RESULTSET_H
  32. //! @cond IGNORED
  33. #include <algorithm>
  34. #include <cstring>
  35. #include <iostream>
  36. #include <limits>
  37. #include <set>
  38. #include <vector>
  39. namespace cvflann
  40. {
  41. /* This record represents a branch point when finding neighbors in
  42. the tree. It contains a record of the minimum distance to the query
  43. point, as well as the node at which the search resumes.
  44. */
  45. template <typename T, typename DistanceType>
  46. struct BranchStruct
  47. {
  48. T node; /* Tree node at which search resumes */
  49. DistanceType mindist; /* Minimum distance to query for all nodes below. */
  50. BranchStruct() {}
  51. BranchStruct(const T& aNode, DistanceType dist) : node(aNode), mindist(dist) {}
  52. bool operator<(const BranchStruct<T, DistanceType>& rhs) const
  53. {
  54. return mindist<rhs.mindist;
  55. }
  56. };
  57. template <typename DistanceType>
  58. class ResultSet
  59. {
  60. public:
  61. virtual ~ResultSet() {}
  62. virtual bool full() const = 0;
  63. virtual void addPoint(DistanceType dist, int index) = 0;
  64. virtual DistanceType worstDist() const = 0;
  65. };
  66. /**
  67. * KNNSimpleResultSet does not ensure that the element it holds are unique.
  68. * Is used in those cases where the nearest neighbour algorithm used does not
  69. * attempt to insert the same element multiple times.
  70. */
  71. template <typename DistanceType>
  72. class KNNSimpleResultSet : public ResultSet<DistanceType>
  73. {
  74. int* indices;
  75. DistanceType* dists;
  76. int capacity;
  77. int count;
  78. DistanceType worst_distance_;
  79. public:
  80. KNNSimpleResultSet(int capacity_) : capacity(capacity_), count(0)
  81. {
  82. }
  83. void init(int* indices_, DistanceType* dists_)
  84. {
  85. indices = indices_;
  86. dists = dists_;
  87. count = 0;
  88. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  89. dists[capacity-1] = worst_distance_;
  90. }
  91. size_t size() const
  92. {
  93. return count;
  94. }
  95. bool full() const CV_OVERRIDE
  96. {
  97. return count == capacity;
  98. }
  99. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  100. {
  101. if (dist >= worst_distance_) return;
  102. int i;
  103. for (i=count; i>0; --i) {
  104. #ifdef FLANN_FIRST_MATCH
  105. if ( (dists[i-1]>dist) || ((dist==dists[i-1])&&(indices[i-1]>index)) )
  106. #else
  107. if (dists[i-1]>dist)
  108. #endif
  109. {
  110. if (i<capacity) {
  111. dists[i] = dists[i-1];
  112. indices[i] = indices[i-1];
  113. }
  114. }
  115. else break;
  116. }
  117. if (count < capacity) ++count;
  118. dists[i] = dist;
  119. indices[i] = index;
  120. worst_distance_ = dists[capacity-1];
  121. }
  122. DistanceType worstDist() const CV_OVERRIDE
  123. {
  124. return worst_distance_;
  125. }
  126. };
  127. /**
  128. * K-Nearest neighbour result set. Ensures that the elements inserted are unique
  129. */
  130. template <typename DistanceType>
  131. class KNNResultSet : public ResultSet<DistanceType>
  132. {
  133. int* indices;
  134. DistanceType* dists;
  135. int capacity;
  136. int count;
  137. DistanceType worst_distance_;
  138. public:
  139. KNNResultSet(int capacity_)
  140. : indices(NULL), dists(NULL), capacity(capacity_), count(0), worst_distance_(0)
  141. {
  142. }
  143. void init(int* indices_, DistanceType* dists_)
  144. {
  145. indices = indices_;
  146. dists = dists_;
  147. count = 0;
  148. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  149. dists[capacity-1] = worst_distance_;
  150. }
  151. size_t size() const
  152. {
  153. return count;
  154. }
  155. bool full() const CV_OVERRIDE
  156. {
  157. return count == capacity;
  158. }
  159. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  160. {
  161. CV_DbgAssert(indices);
  162. CV_DbgAssert(dists);
  163. if (dist >= worst_distance_) return;
  164. int i;
  165. for (i = count; i > 0; --i) {
  166. #ifdef FLANN_FIRST_MATCH
  167. if ( (dists[i-1]<=dist) && ((dist!=dists[i-1])||(indices[i-1]<=index)) )
  168. #else
  169. if (dists[i-1]<=dist)
  170. #endif
  171. {
  172. // Check for duplicate indices
  173. for (int j = i; dists[j] == dist && j--;) {
  174. if (indices[j] == index) {
  175. return;
  176. }
  177. }
  178. break;
  179. }
  180. }
  181. if (count < capacity) ++count;
  182. for (int j = count-1; j > i; --j) {
  183. dists[j] = dists[j-1];
  184. indices[j] = indices[j-1];
  185. }
  186. dists[i] = dist;
  187. indices[i] = index;
  188. worst_distance_ = dists[capacity-1];
  189. }
  190. DistanceType worstDist() const CV_OVERRIDE
  191. {
  192. return worst_distance_;
  193. }
  194. };
  195. /**
  196. * A result-set class used when performing a radius based search.
  197. */
  198. template <typename DistanceType>
  199. class RadiusResultSet : public ResultSet<DistanceType>
  200. {
  201. DistanceType radius;
  202. int* indices;
  203. DistanceType* dists;
  204. size_t capacity;
  205. size_t count;
  206. public:
  207. RadiusResultSet(DistanceType radius_, int* indices_, DistanceType* dists_, int capacity_) :
  208. radius(radius_), indices(indices_), dists(dists_), capacity(capacity_)
  209. {
  210. init();
  211. }
  212. ~RadiusResultSet()
  213. {
  214. }
  215. void init()
  216. {
  217. count = 0;
  218. }
  219. size_t size() const
  220. {
  221. return count;
  222. }
  223. bool full() const
  224. {
  225. return true;
  226. }
  227. void addPoint(DistanceType dist, int index)
  228. {
  229. if (dist<radius) {
  230. if ((capacity>0)&&(count < capacity)) {
  231. dists[count] = dist;
  232. indices[count] = index;
  233. }
  234. count++;
  235. }
  236. }
  237. DistanceType worstDist() const
  238. {
  239. return radius;
  240. }
  241. };
  242. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  243. /** Class that holds the k NN neighbors
  244. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  245. */
  246. template<typename DistanceType>
  247. class UniqueResultSet : public ResultSet<DistanceType>
  248. {
  249. public:
  250. struct DistIndex
  251. {
  252. DistIndex(DistanceType dist, unsigned int index) :
  253. dist_(dist), index_(index)
  254. {
  255. }
  256. bool operator<(const DistIndex dist_index) const
  257. {
  258. return (dist_ < dist_index.dist_) || ((dist_ == dist_index.dist_) && index_ < dist_index.index_);
  259. }
  260. DistanceType dist_;
  261. unsigned int index_;
  262. };
  263. /** Default constructor */
  264. UniqueResultSet() :
  265. is_full_(false), worst_distance_(std::numeric_limits<DistanceType>::max())
  266. {
  267. }
  268. /** Check the status of the set
  269. * @return true if we have k NN
  270. */
  271. inline bool full() const CV_OVERRIDE
  272. {
  273. return is_full_;
  274. }
  275. /** Remove all elements in the set
  276. */
  277. virtual void clear() = 0;
  278. /** Copy the set to two C arrays
  279. * @param indices pointer to a C array of indices
  280. * @param dist pointer to a C array of distances
  281. * @param n_neighbors the number of neighbors to copy
  282. */
  283. virtual void copy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  284. {
  285. if (n_neighbors < 0) {
  286. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  287. dist_indices_.end(); dist_index != dist_index_end; ++dist_index, ++indices, ++dist) {
  288. *indices = dist_index->index_;
  289. *dist = dist_index->dist_;
  290. }
  291. }
  292. else {
  293. int i = 0;
  294. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  295. dist_indices_.end(); (dist_index != dist_index_end) && (i < n_neighbors); ++dist_index, ++indices, ++dist, ++i) {
  296. *indices = dist_index->index_;
  297. *dist = dist_index->dist_;
  298. }
  299. }
  300. }
  301. /** Copy the set to two C arrays but sort it according to the distance first
  302. * @param indices pointer to a C array of indices
  303. * @param dist pointer to a C array of distances
  304. * @param n_neighbors the number of neighbors to copy
  305. */
  306. virtual void sortAndCopy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  307. {
  308. copy(indices, dist, n_neighbors);
  309. }
  310. /** The number of neighbors in the set
  311. * @return
  312. */
  313. size_t size() const
  314. {
  315. return dist_indices_.size();
  316. }
  317. /** The distance of the furthest neighbor
  318. * If we don't have enough neighbors, it returns the max possible value
  319. * @return
  320. */
  321. inline DistanceType worstDist() const CV_OVERRIDE
  322. {
  323. return worst_distance_;
  324. }
  325. protected:
  326. /** Flag to say if the set is full */
  327. bool is_full_;
  328. /** The worst distance found so far */
  329. DistanceType worst_distance_;
  330. /** The best candidates so far */
  331. std::set<DistIndex> dist_indices_;
  332. };
  333. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  334. /** Class that holds the k NN neighbors
  335. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  336. */
  337. template<typename DistanceType>
  338. class KNNUniqueResultSet : public UniqueResultSet<DistanceType>
  339. {
  340. public:
  341. /** Constructor
  342. * @param capacity the number of neighbors to store at max
  343. */
  344. KNNUniqueResultSet(unsigned int capacity) : capacity_(capacity)
  345. {
  346. this->is_full_ = false;
  347. this->clear();
  348. }
  349. /** Add a possible candidate to the best neighbors
  350. * @param dist distance for that neighbor
  351. * @param index index of that neighbor
  352. */
  353. inline void addPoint(DistanceType dist, int index) CV_OVERRIDE
  354. {
  355. // Don't do anything if we are worse than the worst
  356. if (dist >= worst_distance_) return;
  357. dist_indices_.insert(DistIndex(dist, index));
  358. if (is_full_) {
  359. if (dist_indices_.size() > capacity_) {
  360. dist_indices_.erase(*dist_indices_.rbegin());
  361. worst_distance_ = dist_indices_.rbegin()->dist_;
  362. }
  363. }
  364. else if (dist_indices_.size() == capacity_) {
  365. is_full_ = true;
  366. worst_distance_ = dist_indices_.rbegin()->dist_;
  367. }
  368. }
  369. /** Remove all elements in the set
  370. */
  371. void clear() CV_OVERRIDE
  372. {
  373. dist_indices_.clear();
  374. worst_distance_ = std::numeric_limits<DistanceType>::max();
  375. is_full_ = false;
  376. }
  377. protected:
  378. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  379. using UniqueResultSet<DistanceType>::is_full_;
  380. using UniqueResultSet<DistanceType>::worst_distance_;
  381. using UniqueResultSet<DistanceType>::dist_indices_;
  382. /** The number of neighbors to keep */
  383. unsigned int capacity_;
  384. };
  385. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  386. /** Class that holds the radius nearest neighbors
  387. * It is more accurate than RadiusResult as it is not limited in the number of neighbors
  388. */
  389. template<typename DistanceType>
  390. class RadiusUniqueResultSet : public UniqueResultSet<DistanceType>
  391. {
  392. public:
  393. /** Constructor
  394. * @param radius the maximum distance of a neighbor
  395. */
  396. RadiusUniqueResultSet(DistanceType radius) :
  397. radius_(radius)
  398. {
  399. is_full_ = true;
  400. }
  401. /** Add a possible candidate to the best neighbors
  402. * @param dist distance for that neighbor
  403. * @param index index of that neighbor
  404. */
  405. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  406. {
  407. if (dist <= radius_) dist_indices_.insert(DistIndex(dist, index));
  408. }
  409. /** Remove all elements in the set
  410. */
  411. inline void clear() CV_OVERRIDE
  412. {
  413. dist_indices_.clear();
  414. }
  415. /** Check the status of the set
  416. * @return alwys false
  417. */
  418. inline bool full() const CV_OVERRIDE
  419. {
  420. return true;
  421. }
  422. /** The distance of the furthest neighbor
  423. * If we don't have enough neighbors, it returns the max possible value
  424. * @return
  425. */
  426. inline DistanceType worstDist() const CV_OVERRIDE
  427. {
  428. return radius_;
  429. }
  430. private:
  431. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  432. using UniqueResultSet<DistanceType>::dist_indices_;
  433. using UniqueResultSet<DistanceType>::is_full_;
  434. /** The furthest distance a neighbor can be */
  435. DistanceType radius_;
  436. };
  437. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  438. /** Class that holds the k NN neighbors within a radius distance
  439. */
  440. template<typename DistanceType>
  441. class KNNRadiusUniqueResultSet : public KNNUniqueResultSet<DistanceType>
  442. {
  443. public:
  444. /** Constructor
  445. * @param capacity the number of neighbors to store at max
  446. * @param radius the maximum distance of a neighbor
  447. */
  448. KNNRadiusUniqueResultSet(unsigned int capacity, DistanceType radius)
  449. {
  450. this->capacity_ = capacity;
  451. this->radius_ = radius;
  452. this->dist_indices_.reserve(capacity_);
  453. this->clear();
  454. }
  455. /** Remove all elements in the set
  456. */
  457. void clear()
  458. {
  459. dist_indices_.clear();
  460. worst_distance_ = radius_;
  461. is_full_ = false;
  462. }
  463. private:
  464. using KNNUniqueResultSet<DistanceType>::dist_indices_;
  465. using KNNUniqueResultSet<DistanceType>::is_full_;
  466. using KNNUniqueResultSet<DistanceType>::worst_distance_;
  467. /** The maximum number of neighbors to consider */
  468. unsigned int capacity_;
  469. /** The maximum distance of a neighbor */
  470. DistanceType radius_;
  471. };
  472. }
  473. //! @endcond
  474. #endif //OPENCV_FLANN_RESULTSET_H