kdtree_index.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_KDTREE_INDEX_H_
  31. #define OPENCV_FLANN_KDTREE_INDEX_H_
  32. //! @cond IGNORED
  33. #include <algorithm>
  34. #include <map>
  35. #include <cstring>
  36. #include "nn_index.h"
  37. #include "dynamic_bitset.h"
  38. #include "matrix.h"
  39. #include "result_set.h"
  40. #include "heap.h"
  41. #include "allocator.h"
  42. #include "random.h"
  43. #include "saving.h"
  44. namespace cvflann
  45. {
  46. struct KDTreeIndexParams : public IndexParams
  47. {
  48. KDTreeIndexParams(int trees = 4)
  49. {
  50. (*this)["algorithm"] = FLANN_INDEX_KDTREE;
  51. (*this)["trees"] = trees;
  52. }
  53. };
  54. /**
  55. * Randomized kd-tree index
  56. *
  57. * Contains the k-d trees and other information for indexing a set of points
  58. * for nearest-neighbor matching.
  59. */
  60. template <typename Distance>
  61. class KDTreeIndex : public NNIndex<Distance>
  62. {
  63. public:
  64. typedef typename Distance::ElementType ElementType;
  65. typedef typename Distance::ResultType DistanceType;
  66. /**
  67. * KDTree constructor
  68. *
  69. * Params:
  70. * inputData = dataset with the input features
  71. * params = parameters passed to the kdtree algorithm
  72. */
  73. KDTreeIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeIndexParams(),
  74. Distance d = Distance() ) :
  75. dataset_(inputData), index_params_(params), distance_(d)
  76. {
  77. size_ = dataset_.rows;
  78. veclen_ = dataset_.cols;
  79. trees_ = get_param(index_params_,"trees",4);
  80. tree_roots_ = new NodePtr[trees_];
  81. // Create a permutable array of indices to the input vectors.
  82. vind_.resize(size_);
  83. for (size_t i = 0; i < size_; ++i) {
  84. vind_[i] = int(i);
  85. }
  86. mean_ = new DistanceType[veclen_];
  87. var_ = new DistanceType[veclen_];
  88. }
  89. KDTreeIndex(const KDTreeIndex&);
  90. KDTreeIndex& operator=(const KDTreeIndex&);
  91. /**
  92. * Standard destructor
  93. */
  94. ~KDTreeIndex()
  95. {
  96. if (tree_roots_!=NULL) {
  97. delete[] tree_roots_;
  98. }
  99. delete[] mean_;
  100. delete[] var_;
  101. }
  102. /**
  103. * Builds the index
  104. */
  105. void buildIndex() CV_OVERRIDE
  106. {
  107. /* Construct the randomized trees. */
  108. for (int i = 0; i < trees_; i++) {
  109. /* Randomize the order of vectors to allow for unbiased sampling. */
  110. #ifndef OPENCV_FLANN_USE_STD_RAND
  111. cv::randShuffle(vind_);
  112. #else
  113. std::random_shuffle(vind_.begin(), vind_.end());
  114. #endif
  115. tree_roots_[i] = divideTree(&vind_[0], int(size_) );
  116. }
  117. }
  118. flann_algorithm_t getType() const CV_OVERRIDE
  119. {
  120. return FLANN_INDEX_KDTREE;
  121. }
  122. void saveIndex(FILE* stream) CV_OVERRIDE
  123. {
  124. save_value(stream, trees_);
  125. for (int i=0; i<trees_; ++i) {
  126. save_tree(stream, tree_roots_[i]);
  127. }
  128. }
  129. void loadIndex(FILE* stream) CV_OVERRIDE
  130. {
  131. load_value(stream, trees_);
  132. if (tree_roots_!=NULL) {
  133. delete[] tree_roots_;
  134. }
  135. tree_roots_ = new NodePtr[trees_];
  136. for (int i=0; i<trees_; ++i) {
  137. load_tree(stream,tree_roots_[i]);
  138. }
  139. index_params_["algorithm"] = getType();
  140. index_params_["trees"] = tree_roots_;
  141. }
  142. /**
  143. * Returns size of index.
  144. */
  145. size_t size() const CV_OVERRIDE
  146. {
  147. return size_;
  148. }
  149. /**
  150. * Returns the length of an index feature.
  151. */
  152. size_t veclen() const CV_OVERRIDE
  153. {
  154. return veclen_;
  155. }
  156. /**
  157. * Computes the inde memory usage
  158. * Returns: memory used by the index
  159. */
  160. int usedMemory() const CV_OVERRIDE
  161. {
  162. return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int)); // pool memory and vind array memory
  163. }
  164. /**
  165. * Find set of nearest neighbors to vec. Their indices are stored inside
  166. * the result object.
  167. *
  168. * Params:
  169. * result = the result object in which the indices of the nearest-neighbors are stored
  170. * vec = the vector for which to search the nearest neighbors
  171. * maxCheck = the maximum number of restarts (in a best-bin-first manner)
  172. */
  173. void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) CV_OVERRIDE
  174. {
  175. const int maxChecks = get_param(searchParams,"checks", 32);
  176. const float epsError = 1+get_param(searchParams,"eps",0.0f);
  177. const bool explore_all_trees = get_param(searchParams,"explore_all_trees",false);
  178. if (maxChecks==FLANN_CHECKS_UNLIMITED) {
  179. getExactNeighbors(result, vec, epsError);
  180. }
  181. else {
  182. getNeighbors(result, vec, maxChecks, epsError, explore_all_trees);
  183. }
  184. }
  185. IndexParams getParameters() const CV_OVERRIDE
  186. {
  187. return index_params_;
  188. }
  189. private:
  190. /*--------------------- Internal Data Structures --------------------------*/
  191. struct Node
  192. {
  193. /**
  194. * Dimension used for subdivision.
  195. */
  196. int divfeat;
  197. /**
  198. * The values used for subdivision.
  199. */
  200. DistanceType divval;
  201. /**
  202. * The child nodes.
  203. */
  204. Node* child1, * child2;
  205. };
  206. typedef Node* NodePtr;
  207. typedef BranchStruct<NodePtr, DistanceType> BranchSt;
  208. typedef BranchSt* Branch;
  209. void save_tree(FILE* stream, NodePtr tree)
  210. {
  211. save_value(stream, *tree);
  212. if (tree->child1!=NULL) {
  213. save_tree(stream, tree->child1);
  214. }
  215. if (tree->child2!=NULL) {
  216. save_tree(stream, tree->child2);
  217. }
  218. }
  219. void load_tree(FILE* stream, NodePtr& tree)
  220. {
  221. tree = pool_.allocate<Node>();
  222. load_value(stream, *tree);
  223. if (tree->child1!=NULL) {
  224. load_tree(stream, tree->child1);
  225. }
  226. if (tree->child2!=NULL) {
  227. load_tree(stream, tree->child2);
  228. }
  229. }
  230. /**
  231. * Create a tree node that subdivides the list of vecs from vind[first]
  232. * to vind[last]. The routine is called recursively on each sublist.
  233. * Place a pointer to this new tree node in the location pTree.
  234. *
  235. * Params: pTree = the new node to create
  236. * first = index of the first vector
  237. * last = index of the last vector
  238. */
  239. NodePtr divideTree(int* ind, int count)
  240. {
  241. NodePtr node = pool_.allocate<Node>(); // allocate memory
  242. /* If too few exemplars remain, then make this a leaf node. */
  243. if ( count == 1) {
  244. node->child1 = node->child2 = NULL; /* Mark as leaf node. */
  245. node->divfeat = *ind; /* Store index of this vec. */
  246. }
  247. else {
  248. int idx;
  249. int cutfeat;
  250. DistanceType cutval;
  251. meanSplit(ind, count, idx, cutfeat, cutval);
  252. node->divfeat = cutfeat;
  253. node->divval = cutval;
  254. node->child1 = divideTree(ind, idx);
  255. node->child2 = divideTree(ind+idx, count-idx);
  256. }
  257. return node;
  258. }
  259. /**
  260. * Choose which feature to use in order to subdivide this set of vectors.
  261. * Make a random choice among those with the highest variance, and use
  262. * its variance as the threshold value.
  263. */
  264. void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
  265. {
  266. memset(mean_,0,veclen_*sizeof(DistanceType));
  267. memset(var_,0,veclen_*sizeof(DistanceType));
  268. /* Compute mean values. Only the first SAMPLE_MEAN values need to be
  269. sampled to get a good estimate.
  270. */
  271. int cnt = std::min((int)SAMPLE_MEAN+1, count);
  272. for (int j = 0; j < cnt; ++j) {
  273. ElementType* v = dataset_[ind[j]];
  274. for (size_t k=0; k<veclen_; ++k) {
  275. mean_[k] += v[k];
  276. }
  277. }
  278. for (size_t k=0; k<veclen_; ++k) {
  279. mean_[k] /= cnt;
  280. }
  281. /* Compute variances (no need to divide by count). */
  282. for (int j = 0; j < cnt; ++j) {
  283. ElementType* v = dataset_[ind[j]];
  284. for (size_t k=0; k<veclen_; ++k) {
  285. DistanceType dist = v[k] - mean_[k];
  286. var_[k] += dist * dist;
  287. }
  288. }
  289. /* Select one of the highest variance indices at random. */
  290. cutfeat = selectDivision(var_);
  291. cutval = mean_[cutfeat];
  292. int lim1, lim2;
  293. planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
  294. if (lim1>count/2) index = lim1;
  295. else if (lim2<count/2) index = lim2;
  296. else index = count/2;
  297. /* If either list is empty, it means that all remaining features
  298. * are identical. Split in the middle to maintain a balanced tree.
  299. */
  300. if ((lim1==count)||(lim2==0)) index = count/2;
  301. }
  302. /**
  303. * Select the top RAND_DIM largest values from v and return the index of
  304. * one of these selected at random.
  305. */
  306. int selectDivision(DistanceType* v)
  307. {
  308. int num = 0;
  309. size_t topind[RAND_DIM];
  310. /* Create a list of the indices of the top RAND_DIM values. */
  311. for (size_t i = 0; i < veclen_; ++i) {
  312. if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
  313. /* Put this element at end of topind. */
  314. if (num < RAND_DIM) {
  315. topind[num++] = i; /* Add to list. */
  316. }
  317. else {
  318. topind[num-1] = i; /* Replace last element. */
  319. }
  320. /* Bubble end value down to right location by repeated swapping. */
  321. int j = num - 1;
  322. while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
  323. std::swap(topind[j], topind[j-1]);
  324. --j;
  325. }
  326. }
  327. }
  328. /* Select a random integer in range [0,num-1], and return that index. */
  329. int rnd = rand_int(num);
  330. return (int)topind[rnd];
  331. }
  332. /**
  333. * Subdivide the list of points by a plane perpendicular on axe corresponding
  334. * to the 'cutfeat' dimension at 'cutval' position.
  335. *
  336. * On return:
  337. * dataset[ind[0..lim1-1]][cutfeat]<cutval
  338. * dataset[ind[lim1..lim2-1]][cutfeat]==cutval
  339. * dataset[ind[lim2..count]][cutfeat]>cutval
  340. */
  341. void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
  342. {
  343. /* Move vector indices for left subtree to front of list. */
  344. int left = 0;
  345. int right = count-1;
  346. for (;; ) {
  347. while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
  348. while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
  349. if (left>right) break;
  350. std::swap(ind[left], ind[right]); ++left; --right;
  351. }
  352. lim1 = left;
  353. right = count-1;
  354. for (;; ) {
  355. while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
  356. while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
  357. if (left>right) break;
  358. std::swap(ind[left], ind[right]); ++left; --right;
  359. }
  360. lim2 = left;
  361. }
  362. /**
  363. * Performs an exact nearest neighbor search. The exact search performs a full
  364. * traversal of the tree.
  365. */
  366. void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
  367. {
  368. // checkID -= 1; /* Set a different unique ID for each search. */
  369. if (trees_ > 1) {
  370. fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
  371. }
  372. if (trees_>0) {
  373. searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
  374. }
  375. CV_Assert(result.full());
  376. }
  377. /**
  378. * Performs the approximate nearest-neighbor search. The search is approximate
  379. * because the tree traversal is abandoned after a given number of descends in
  380. * the tree.
  381. */
  382. void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec,
  383. int maxCheck, float epsError, bool explore_all_trees = false)
  384. {
  385. int i;
  386. BranchSt branch;
  387. int checkCount = 0;
  388. DynamicBitset checked(size_);
  389. // Priority queue storing intermediate branches in the best-bin-first search
  390. const cv::Ptr<Heap<BranchSt>>& heap = Heap<BranchSt>::getPooledInstance(cv::utils::getThreadID(), (int)size_);
  391. /* Search once through each tree down to root. */
  392. for (i = 0; i < trees_; ++i) {
  393. searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck,
  394. epsError, heap, checked, explore_all_trees);
  395. if (!explore_all_trees && (checkCount >= maxCheck) && result.full())
  396. break;
  397. }
  398. /* Keep searching other branches from heap until finished. */
  399. while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
  400. searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck,
  401. epsError, heap, checked, false);
  402. }
  403. CV_Assert(result.full());
  404. }
  405. /**
  406. * Search starting from a given node of the tree. Based on any mismatches at
  407. * higher levels, all exemplars below this level must have a distance of
  408. * at least "mindistsq".
  409. */
  410. void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
  411. float epsError, const cv::Ptr<Heap<BranchSt>>& heap, DynamicBitset& checked, bool explore_all_trees = false)
  412. {
  413. if (result_set.worstDist()<mindist) {
  414. // printf("Ignoring branch, too far\n");
  415. return;
  416. }
  417. /* If this is a leaf node, then do check and return. */
  418. if ((node->child1 == NULL)&&(node->child2 == NULL)) {
  419. /* Do not check same node more than once when searching multiple trees.
  420. Once a vector is checked, we set its location in vind to the
  421. current checkID.
  422. */
  423. int index = node->divfeat;
  424. if ( checked.test(index) ||
  425. (!explore_all_trees && (checkCount>=maxCheck) && result_set.full()) ) {
  426. return;
  427. }
  428. checked.set(index);
  429. checkCount++;
  430. DistanceType dist = distance_(dataset_[index], vec, veclen_);
  431. result_set.addPoint(dist,index);
  432. return;
  433. }
  434. /* Which child branch should be taken first? */
  435. ElementType val = vec[node->divfeat];
  436. DistanceType diff = val - node->divval;
  437. NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
  438. NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
  439. /* Create a branch record for the branch not taken. Add distance
  440. of this feature boundary (we don't attempt to correct for any
  441. use of this feature in a parent node, which is unlikely to
  442. happen and would have only a small effect). Don't bother
  443. adding more branches to heap after halfway point, as cost of
  444. adding exceeds their value.
  445. */
  446. DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
  447. // if (2 * checkCount < maxCheck || !result.full()) {
  448. if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
  449. heap->insert( BranchSt(otherChild, new_distsq) );
  450. }
  451. /* Call recursively to search next level down. */
  452. searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
  453. }
  454. /**
  455. * Performs an exact search in the tree starting from a node.
  456. */
  457. void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
  458. {
  459. /* If this is a leaf node, then do check and return. */
  460. if ((node->child1 == NULL)&&(node->child2 == NULL)) {
  461. int index = node->divfeat;
  462. DistanceType dist = distance_(dataset_[index], vec, veclen_);
  463. result_set.addPoint(dist,index);
  464. return;
  465. }
  466. /* Which child branch should be taken first? */
  467. ElementType val = vec[node->divfeat];
  468. DistanceType diff = val - node->divval;
  469. NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
  470. NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
  471. /* Create a branch record for the branch not taken. Add distance
  472. of this feature boundary (we don't attempt to correct for any
  473. use of this feature in a parent node, which is unlikely to
  474. happen and would have only a small effect). Don't bother
  475. adding more branches to heap after halfway point, as cost of
  476. adding exceeds their value.
  477. */
  478. DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
  479. /* Call recursively to search next level down. */
  480. searchLevelExact(result_set, vec, bestChild, mindist, epsError);
  481. if (new_distsq*epsError<=result_set.worstDist()) {
  482. searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
  483. }
  484. }
  485. private:
  486. enum
  487. {
  488. /**
  489. * To improve efficiency, only SAMPLE_MEAN random values are used to
  490. * compute the mean and variance at each level when building a tree.
  491. * A value of 100 seems to perform as well as using all values.
  492. */
  493. SAMPLE_MEAN = 100,
  494. /**
  495. * Top random dimensions to consider
  496. *
  497. * When creating random trees, the dimension on which to subdivide is
  498. * selected at random from among the top RAND_DIM dimensions with the
  499. * highest variance. A value of 5 works well.
  500. */
  501. RAND_DIM=5
  502. };
  503. /**
  504. * Number of randomized trees that are used
  505. */
  506. int trees_;
  507. /**
  508. * Array of indices to vectors in the dataset.
  509. */
  510. std::vector<int> vind_;
  511. /**
  512. * The dataset used by this index
  513. */
  514. const Matrix<ElementType> dataset_;
  515. IndexParams index_params_;
  516. size_t size_;
  517. size_t veclen_;
  518. DistanceType* mean_;
  519. DistanceType* var_;
  520. /**
  521. * Array of k-d trees used to find neighbours.
  522. */
  523. NodePtr* tree_roots_;
  524. /**
  525. * Pooled memory allocator.
  526. *
  527. * Using a pooled memory allocator is more efficient
  528. * than allocating memory directly when there is a large
  529. * number small of memory allocations.
  530. */
  531. PooledAllocator pool_;
  532. Distance distance_;
  533. }; // class KDTreeForest
  534. }
  535. //! @endcond
  536. #endif //OPENCV_FLANN_KDTREE_INDEX_H_