lsh_index.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. /***********************************************************************
  31. * Author: Vincent Rabaud
  32. *************************************************************************/
  33. #ifndef OPENCV_FLANN_LSH_INDEX_H_
  34. #define OPENCV_FLANN_LSH_INDEX_H_
  35. //! @cond IGNORED
  36. #include <algorithm>
  37. #include <cstring>
  38. #include <map>
  39. #include <vector>
  40. #include "nn_index.h"
  41. #include "matrix.h"
  42. #include "result_set.h"
  43. #include "heap.h"
  44. #include "lsh_table.h"
  45. #include "allocator.h"
  46. #include "random.h"
  47. #include "saving.h"
  48. #ifdef _MSC_VER
  49. #pragma warning(push)
  50. #pragma warning(disable: 4702) //disable unreachable code
  51. #endif
  52. namespace cvflann
  53. {
  54. struct LshIndexParams : public IndexParams
  55. {
  56. LshIndexParams(int table_number = 12, int key_size = 20, int multi_probe_level = 2)
  57. {
  58. (*this)["algorithm"] = FLANN_INDEX_LSH;
  59. // The number of hash tables to use
  60. (*this)["table_number"] = table_number;
  61. // The length of the key in the hash tables
  62. (*this)["key_size"] = key_size;
  63. // Number of levels to use in multi-probe (0 for standard LSH)
  64. (*this)["multi_probe_level"] = multi_probe_level;
  65. }
  66. };
  67. /**
  68. * Locality-sensitive hashing index
  69. *
  70. * Contains the tables and other information for indexing a set of points
  71. * for nearest-neighbor matching.
  72. */
  73. template<typename Distance>
  74. class LshIndex : public NNIndex<Distance>
  75. {
  76. public:
  77. typedef typename Distance::ElementType ElementType;
  78. typedef typename Distance::ResultType DistanceType;
  79. /** Constructor
  80. * @param input_data dataset with the input features
  81. * @param params parameters passed to the LSH algorithm
  82. * @param d the distance used
  83. */
  84. LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(),
  85. Distance d = Distance()) :
  86. dataset_(input_data), index_params_(params), distance_(d)
  87. {
  88. // cv::flann::IndexParams sets integer params as 'int', so it is used with get_param
  89. // in place of 'unsigned int'
  90. table_number_ = get_param(index_params_,"table_number",12);
  91. key_size_ = get_param(index_params_,"key_size",20);
  92. multi_probe_level_ = get_param(index_params_,"multi_probe_level",2);
  93. feature_size_ = (unsigned)dataset_.cols;
  94. fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
  95. }
  96. LshIndex(const LshIndex&);
  97. LshIndex& operator=(const LshIndex&);
  98. /**
  99. * Builds the index
  100. */
  101. void buildIndex() CV_OVERRIDE
  102. {
  103. tables_.resize(table_number_);
  104. for (int i = 0; i < table_number_; ++i) {
  105. lsh::LshTable<ElementType>& table = tables_[i];
  106. table = lsh::LshTable<ElementType>(feature_size_, key_size_);
  107. // Add the features to the table
  108. table.add(dataset_);
  109. }
  110. }
  111. flann_algorithm_t getType() const CV_OVERRIDE
  112. {
  113. return FLANN_INDEX_LSH;
  114. }
  115. void saveIndex(FILE* stream) CV_OVERRIDE
  116. {
  117. save_value(stream,table_number_);
  118. save_value(stream,key_size_);
  119. save_value(stream,multi_probe_level_);
  120. save_value(stream, dataset_);
  121. }
  122. void loadIndex(FILE* stream) CV_OVERRIDE
  123. {
  124. load_value(stream, table_number_);
  125. load_value(stream, key_size_);
  126. load_value(stream, multi_probe_level_);
  127. load_value(stream, dataset_);
  128. // Building the index is so fast we can afford not storing it
  129. buildIndex();
  130. index_params_["algorithm"] = getType();
  131. index_params_["table_number"] = table_number_;
  132. index_params_["key_size"] = key_size_;
  133. index_params_["multi_probe_level"] = multi_probe_level_;
  134. }
  135. /**
  136. * Returns size of index.
  137. */
  138. size_t size() const CV_OVERRIDE
  139. {
  140. return dataset_.rows;
  141. }
  142. /**
  143. * Returns the length of an index feature.
  144. */
  145. size_t veclen() const CV_OVERRIDE
  146. {
  147. return feature_size_;
  148. }
  149. /**
  150. * Computes the index memory usage
  151. * Returns: memory used by the index
  152. */
  153. int usedMemory() const CV_OVERRIDE
  154. {
  155. return (int)(dataset_.rows * sizeof(int));
  156. }
  157. IndexParams getParameters() const CV_OVERRIDE
  158. {
  159. return index_params_;
  160. }
  161. /**
  162. * \brief Perform k-nearest neighbor search
  163. * \param[in] queries The query points for which to find the nearest neighbors
  164. * \param[out] indices The indices of the nearest neighbors found
  165. * \param[out] dists Distances to the nearest neighbors found
  166. * \param[in] knn Number of nearest neighbors to return
  167. * \param[in] params Search parameters
  168. */
  169. virtual void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params) CV_OVERRIDE
  170. {
  171. CV_Assert(queries.cols == veclen());
  172. CV_Assert(indices.rows >= queries.rows);
  173. CV_Assert(dists.rows >= queries.rows);
  174. CV_Assert(int(indices.cols) >= knn);
  175. CV_Assert(int(dists.cols) >= knn);
  176. KNNUniqueResultSet<DistanceType> resultSet(knn);
  177. for (size_t i = 0; i < queries.rows; i++) {
  178. resultSet.clear();
  179. std::fill_n(indices[i], knn, -1);
  180. std::fill_n(dists[i], knn, std::numeric_limits<DistanceType>::max());
  181. findNeighbors(resultSet, queries[i], params);
  182. if (get_param(params,"sorted",true)) resultSet.sortAndCopy(indices[i], dists[i], knn);
  183. else resultSet.copy(indices[i], dists[i], knn);
  184. }
  185. }
  186. /**
  187. * Find set of nearest neighbors to vec. Their indices are stored inside
  188. * the result object.
  189. *
  190. * Params:
  191. * result = the result object in which the indices of the nearest-neighbors are stored
  192. * vec = the vector for which to search the nearest neighbors
  193. * maxCheck = the maximum number of restarts (in a best-bin-first manner)
  194. */
  195. void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) CV_OVERRIDE
  196. {
  197. getNeighbors(vec, result);
  198. }
  199. private:
  200. /** Defines the comparator on score and index
  201. */
  202. typedef std::pair<float, unsigned int> ScoreIndexPair;
  203. struct SortScoreIndexPairOnSecond
  204. {
  205. bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const
  206. {
  207. return left.second < right.second;
  208. }
  209. };
  210. /** Fills the different xor masks to use when getting the neighbors in multi-probe LSH
  211. * @param key the key we build neighbors from
  212. * @param lowest_index the lowest index of the bit set
  213. * @param level the multi-probe level we are at
  214. * @param xor_masks all the xor mask
  215. */
  216. void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level,
  217. std::vector<lsh::BucketKey>& xor_masks)
  218. {
  219. xor_masks.push_back(key);
  220. if (level == 0) return;
  221. for (int index = lowest_index - 1; index >= 0; --index) {
  222. // Create a new key
  223. lsh::BucketKey new_key = key | (1 << index);
  224. fill_xor_mask(new_key, index, level - 1, xor_masks);
  225. }
  226. }
  227. /** Performs the approximate nearest-neighbor search.
  228. * @param vec the feature to analyze
  229. * @param do_radius flag indicating if we check the radius too
  230. * @param radius the radius if it is a radius search
  231. * @param do_k flag indicating if we limit the number of nn
  232. * @param k_nn the number of nearest neighbors
  233. * @param checked_average used for debugging
  234. */
  235. void getNeighbors(const ElementType* vec, bool /*do_radius*/, float radius, bool do_k, unsigned int k_nn,
  236. float& /*checked_average*/)
  237. {
  238. static std::vector<ScoreIndexPair> score_index_heap;
  239. if (do_k) {
  240. unsigned int worst_score = std::numeric_limits<unsigned int>::max();
  241. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  242. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  243. for (; table != table_end; ++table) {
  244. size_t key = table->getKey(vec);
  245. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  246. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  247. for (; xor_mask != xor_mask_end; ++xor_mask) {
  248. size_t sub_key = key ^ (*xor_mask);
  249. const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
  250. if (bucket == 0) continue;
  251. // Go over each descriptor index
  252. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  253. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  254. DistanceType hamming_distance;
  255. // Process the rest of the candidates
  256. for (; training_index < last_training_index; ++training_index) {
  257. hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
  258. if (hamming_distance < worst_score) {
  259. // Insert the new element
  260. score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
  261. std::push_heap(score_index_heap.begin(), score_index_heap.end());
  262. if (score_index_heap.size() > (unsigned int)k_nn) {
  263. // Remove the highest distance value as we have too many elements
  264. std::pop_heap(score_index_heap.begin(), score_index_heap.end());
  265. score_index_heap.pop_back();
  266. // Keep track of the worst score
  267. worst_score = score_index_heap.front().first;
  268. }
  269. }
  270. }
  271. }
  272. }
  273. }
  274. else {
  275. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  276. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  277. for (; table != table_end; ++table) {
  278. size_t key = table->getKey(vec);
  279. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  280. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  281. for (; xor_mask != xor_mask_end; ++xor_mask) {
  282. size_t sub_key = key ^ (*xor_mask);
  283. const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
  284. if (bucket == 0) continue;
  285. // Go over each descriptor index
  286. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  287. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  288. DistanceType hamming_distance;
  289. // Process the rest of the candidates
  290. for (; training_index < last_training_index; ++training_index) {
  291. // Compute the Hamming distance
  292. hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
  293. if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
  294. }
  295. }
  296. }
  297. }
  298. }
  299. /** Performs the approximate nearest-neighbor search.
  300. * This is a slower version than the above as it uses the ResultSet
  301. * @param vec the feature to analyze
  302. */
  303. void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result)
  304. {
  305. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  306. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  307. for (; table != table_end; ++table) {
  308. size_t key = table->getKey(vec);
  309. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  310. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  311. for (; xor_mask != xor_mask_end; ++xor_mask) {
  312. size_t sub_key = key ^ (*xor_mask);
  313. const lsh::Bucket* bucket = table->getBucketFromKey((lsh::BucketKey)sub_key);
  314. if (bucket == 0) continue;
  315. // Go over each descriptor index
  316. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  317. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  318. DistanceType hamming_distance;
  319. // Process the rest of the candidates
  320. for (; training_index < last_training_index; ++training_index) {
  321. // Compute the Hamming distance
  322. hamming_distance = distance_(vec, dataset_[*training_index], (int)dataset_.cols);
  323. result.addPoint(hamming_distance, *training_index);
  324. }
  325. }
  326. }
  327. }
  328. /** The different hash tables */
  329. std::vector<lsh::LshTable<ElementType> > tables_;
  330. /** The data the LSH tables where built from */
  331. Matrix<ElementType> dataset_;
  332. /** The size of the features (as ElementType[]) */
  333. unsigned int feature_size_;
  334. IndexParams index_params_;
  335. /** table number */
  336. int table_number_;
  337. /** key size */
  338. int key_size_;
  339. /** How far should we look for neighbors in multi-probe LSH */
  340. int multi_probe_level_;
  341. /** The XOR masks to apply to a key to get the neighboring buckets */
  342. std::vector<lsh::BucketKey> xor_masks_;
  343. Distance distance_;
  344. };
  345. }
  346. #ifdef _MSC_VER
  347. #pragma warning(pop)
  348. #endif
  349. //! @endcond
  350. #endif //OPENCV_FLANN_LSH_INDEX_H_