All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
NearestNeighborsGNAT.h
1 /*********************************************************************
2 * Software License Agreement (BSD License)
3 *
4 * Copyright (c) 2011, Rice University
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 * * Redistributions in binary form must reproduce the above
14 * copyright notice, this list of conditions and the following
15 * disclaimer in the documentation and/or other materials provided
16 * with the distribution.
17 * * Neither the name of the Rice University nor the names of its
18 * contributors may be used to endorse or promote products derived
19 * from this software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 * POSSIBILITY OF SUCH DAMAGE.
33 *********************************************************************/
34 
35 /* Author: Mark Moll */
36 
37 #ifndef OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
38 #define OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
39 
40 #include "ompl/datastructures/NearestNeighbors.h"
41 #include "ompl/datastructures/GreedyKCenters.h"
42 #include "ompl/util/Exception.h"
43 #include <boost/unordered_set.hpp>
44 #include <queue>
45 #include <algorithm>
46 
47 namespace ompl
48 {
49 
58  template<typename _T>
60  {
61  protected:
63  // internally, we use a priority queue for nearest neighbors, paired
64  // with their distance to the query point
65  typedef std::pair<const _T*,double> DataDist;
66  struct DataDistCompare
67  {
68  bool operator()(const DataDist& d0, const DataDist& d1)
69  {
70  return d0.second < d1.second;
71  }
72  };
73  typedef std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare> NearQueue;
74 
75  // another internal data structure is a priority queue of nodes to
76  // check next for possible nearest neighbors
77  class Node;
78  typedef std::pair<Node*,double> NodeDist;
79  struct NodeDistCompare
80  {
81  bool operator()(const NodeDist& n0, const NodeDist& n1) const
82  {
83  return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
84  }
85  };
86  typedef std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare> NodeQueue;
88 
89  public:
90  NearestNeighborsGNAT(unsigned int degree = 4, unsigned int minDegree = 2,
91  unsigned int maxDegree = 6, unsigned int maxNumPtsPerLeaf = 50,
92  unsigned int removedCacheSize = 50, bool rebalancing = false)
93  : NearestNeighbors<_T>(), tree_(NULL), degree_(degree),
94  minDegree_(std::min(degree,minDegree)), maxDegree_(std::max(maxDegree,degree)),
95  maxNumPtsPerLeaf_(maxNumPtsPerLeaf), size_(0),
96  rebuildSize_(rebalancing ? maxNumPtsPerLeaf*degree : std::numeric_limits<std::size_t>::max()),
97  removedCacheSize_(removedCacheSize)
98  {
99  }
100 
101  virtual ~NearestNeighborsGNAT(void)
102  {
103  if (tree_)
104  delete tree_;
105  }
107  virtual void setDistanceFunction(const typename NearestNeighbors<_T>::DistanceFunction &distFun)
108  {
110  pivotSelector_.setDistanceFunction(distFun);
111  }
112  virtual void clear(void)
113  {
114  if (tree_)
115  {
116  delete tree_;
117  tree_ = NULL;
118  }
119  size_ = 0;
120  removed_.clear();
121  }
122 
123  virtual void add(const _T &data)
124  {
125  if (tree_)
126  tree_->add(*this, data);
127  else
128  {
129  tree_ = new Node(degree_, maxNumPtsPerLeaf_, data);
130  size_ = 1;
131  }
132  }
133  virtual void add(const std::vector<_T> &data)
134  {
135  if (tree_)
137  else if (data.size()>0)
138  {
139  tree_ = new Node(degree_, maxNumPtsPerLeaf_, data[0]);
140  for (unsigned int i=1; i<data.size(); ++i)
141  tree_->data_.push_back(data[i]);
142  if (tree_->needToSplit(*this))
143  tree_->split(*this);
144  }
145  size_ += data.size();
146  }
149  {
150  std::vector<_T> lst;
151  list(lst);
152  clear();
153  add(lst);
154  }
160  virtual bool remove(const _T &data)
161  {
162  if (!tree_) return false;
163  NearQueue nbhQueue;
164  // find data in tree
165  bool isPivot = nearestKInternal(data, 1, nbhQueue);
166  if (*nbhQueue.top().first != data)
167  return false;
168  removed_.insert(nbhQueue.top().first);
169  size_--;
170  // if we removed a pivot or if the capacity of removed elements
171  // has been reached, we rebuild the entire GNAT
172  if (isPivot || removed_.size()>=removedCacheSize_)
174  return true;
175  }
176 
177  virtual _T nearest(const _T &data) const
178  {
179  if (tree_)
180  {
181  std::vector<_T> nbh;
182  nearestK(data, 1, nbh);
183  if (!nbh.empty()) return nbh[0];
184  }
185  throw Exception("No elements found in nearest neighbors data structure");
186  }
187 
188  virtual void nearestK(const _T &data, std::size_t k, std::vector<_T> &nbh) const
189  {
190  nbh.clear();
191  if (k == 0) return;
192  if (tree_)
193  {
194  NearQueue nbhQueue;
195  nearestKInternal(data, k, nbhQueue);
196  postprocessNearest(nbhQueue, nbh);
197  }
198  }
199 
200  virtual void nearestR(const _T &data, double radius, std::vector<_T> &nbh) const
201  {
202  nbh.clear();
203  if (tree_)
204  {
205  NearQueue nbhQueue;
206  nearestRInternal(data, radius, nbhQueue);
207  postprocessNearest(nbhQueue, nbh);
208  }
209  }
210 
211  virtual std::size_t size(void) const
212  {
213  return size_;
214  }
215 
216  virtual void list(std::vector<_T> &data) const
217  {
218  data.clear();
219  data.reserve(size());
220  if (tree_)
221  tree_->list(*this, data);
222  }
223 
225  friend std::ostream& operator<<(std::ostream& out, const NearestNeighborsGNAT<_T>& gnat)
226  {
227  if (gnat.tree_)
228  {
229  out << *gnat.tree_;
230  if (!gnat.removed_.empty())
231  {
232  out << "Elements marked for removal:\n";
233  for (typename boost::unordered_set<const _T*>::const_iterator it = gnat.removed_.begin();
234  it != gnat.removed_.end(); it++)
235  out << **it << '\t';
236  out << std::endl;
237  }
238  }
239  return out;
240  }
241 
242  // for debugging purposes
243  void integrityCheck()
244  {
245  std::vector<_T> lst;
246  boost::unordered_set<const _T*> tmp;
247  // get all elements, including those marked for removal
248  removed_.swap(tmp);
249  list(lst);
250  // check if every element marked for removal is also in the tree
251  for (typename boost::unordered_set<const _T*>::iterator it=tmp.begin(); it!=tmp.end(); it++)
252  {
253  unsigned int i;
254  for (i=0; i<lst.size(); ++i)
255  if (lst[i]==**it)
256  break;
257  if (i == lst.size())
258  {
259  // an element marked for removal is not actually in the tree
260  std::cout << "***** FAIL!! ******\n" << *this << '\n';
261  for (unsigned int j=0; j<lst.size(); ++j) std::cout<<lst[j]<<'\t';
262  std::cout<<std::endl;
263  }
264  assert(i != lst.size());
265  }
266  // restore
267  removed_.swap(tmp);
268  // get elements in the tree with elements marked for removal purged from the list
269  list(lst);
270  if (lst.size() != size_)
271  std::cout << "#########################################\n" << *this << std::endl;
272  assert(lst.size() == size_);
273  }
274  protected:
275  typedef NearestNeighborsGNAT<_T> GNAT;
276 
278  bool isRemoved(const _T& data) const
279  {
280  return !removed_.empty() && removed_.find(&data) != removed_.end();
281  }
282 
287  bool nearestKInternal(const _T &data, std::size_t k, NearQueue& nbhQueue) const
288  {
289  bool isPivot;
290  double dist;
291  NodeDist nodeDist;
292  NodeQueue nodeQueue;
293 
294  isPivot = tree_->insertNeighborK(nbhQueue, k, tree_->pivot_, data,
296  tree_->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
297  while (nodeQueue.size() > 0)
298  {
299  dist = nbhQueue.top().second; // note the difference with nearestRInternal
300  nodeDist = nodeQueue.top();
301  nodeQueue.pop();
302  if (nbhQueue.size() == k &&
303  (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
304  nodeDist.second < nodeDist.first->minRadius_ - dist))
305  break;
306  nodeDist.first->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
307  }
308  return isPivot;
309  }
311  void nearestRInternal(const _T &data, double radius, NearQueue& nbhQueue) const
312  {
313  double dist = radius; // note the difference with nearestKInternal
314  NodeQueue nodeQueue;
315  NodeDist nodeDist;
316 
317  tree_->insertNeighborR(nbhQueue, radius, tree_->pivot_,
319  tree_->nearestR(*this, data, radius, nbhQueue, nodeQueue);
320  while (nodeQueue.size() > 0)
321  {
322  nodeDist = nodeQueue.top();
323  nodeQueue.pop();
324  if (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
325  nodeDist.second < nodeDist.first->minRadius_ - dist)
326  break;
327  nodeDist.first->nearestR(*this, data, radius, nbhQueue, nodeQueue);
328  }
329  }
332  void postprocessNearest(NearQueue& nbhQueue, std::vector<_T> &nbh) const
333  {
334  typename std::vector<_T>::reverse_iterator it;
335  nbh.resize(nbhQueue.size());
336  for (it=nbh.rbegin(); it!=nbh.rend(); it++, nbhQueue.pop())
337  *it = *nbhQueue.top().first;
338  }
339 
341  class Node
342  {
343  public:
346  Node(int degree, int capacity, const _T& pivot)
347  : degree_(degree), pivot_(pivot),
348  minRadius_(std::numeric_limits<double>::infinity()),
350  maxRange_(degree, maxRadius_)
351  {
352  // The "+1" is needed because we add an element before we check whether to split
353  data_.reserve(capacity+1);
354  }
355 
356  ~Node()
357  {
358  for (unsigned int i=0; i<children_.size(); ++i)
359  delete children_[i];
360  }
361 
364  void updateRadius(double dist)
365  {
366  if (minRadius_ > dist)
367  minRadius_ = dist;
368  if (maxRadius_ < dist)
369  maxRadius_ = dist;
370  }
374  void updateRange(unsigned int i, double dist)
375  {
376  if (minRange_[i] > dist)
377  minRange_[i] = dist;
378  if (maxRange_[i] < dist)
379  maxRange_[i] = dist;
380  }
382  void add(GNAT& gnat, const _T& data)
383  {
384  if (children_.size()==0)
385  {
386  data_.push_back(data);
387  gnat.size_++;
388  if (needToSplit(gnat))
389  {
390  if (gnat.removed_.size() > 0)
391  gnat.rebuildDataStructure();
392  else if (gnat.size_ >= gnat.rebuildSize_)
393  {
394  gnat.rebuildSize_ <<= 1;
395  gnat.rebuildDataStructure();
396  }
397  else
398  split(gnat);
399  }
400  }
401  else
402  {
403  std::vector<double> dist(children_.size());
404  double minDist = dist[0] = gnat.distFun_(data, children_[0]->pivot_);
405  int minInd = 0;
406 
407  for (unsigned int i=1; i<children_.size(); ++i)
408  if ((dist[i] = gnat.distFun_(data, children_[i]->pivot_)) < minDist)
409  {
410  minDist = dist[i];
411  minInd = i;
412  }
413  for (unsigned int i=0; i<children_.size(); ++i)
414  children_[i]->updateRange(minInd, dist[i]);
415  children_[minInd]->updateRadius(minDist);
416  children_[minInd]->add(gnat, data);
417  }
418  }
420  bool needToSplit(const GNAT& gnat) const
421  {
422  unsigned int sz = data_.size();
423  return sz > gnat.maxNumPtsPerLeaf_ && sz > degree_;
424  }
428  void split(GNAT& gnat)
429  {
430  std::vector<std::vector<double> > dists;
431  std::vector<unsigned int> pivots;
432 
433  children_.reserve(degree_);
434  gnat.pivotSelector_.kcenters(data_, degree_, pivots, dists);
435  for(unsigned int i=0; i<pivots.size(); i++)
436  children_.push_back(new Node(degree_, gnat.maxNumPtsPerLeaf_, data_[pivots[i]]));
437  degree_ = pivots.size(); // in case fewer than degree_ pivots were found
438  for (unsigned int j=0; j<data_.size(); ++j)
439  {
440  unsigned int k = 0;
441  for (unsigned int i=1; i<degree_; ++i)
442  if (dists[j][i] < dists[j][k])
443  k = i;
444  Node* child = children_[k];
445  if (j != pivots[k])
446  {
447  child->data_.push_back(data_[j]);
448  child->updateRadius(dists[j][k]);
449  }
450  for (unsigned int i=0; i<degree_; ++i)
451  children_[i]->updateRange(k, dists[j][i]);
452  }
453 
454  for (unsigned int i=0; i<degree_; ++i)
455  {
456  // make sure degree lies between minDegree_ and maxDegree_
457  children_[i]->degree_ = std::min(std::max(
458  degree_ * (unsigned int)(children_[i]->data_.size() / data_.size()),
459  gnat.minDegree_), gnat.maxDegree_);
460  // singleton
461  if (children_[i]->minRadius_ == std::numeric_limits<double>::infinity())
462  children_[i]->minRadius_ = children_[i]->maxRadius_ = 0.;
463  }
464  // this does more than clear(); it also sets capacity to 0 and frees the memory
465  std::vector<_T> tmp;
466  data_.swap(tmp);
467  // check if new leaves need to be split
468  for (unsigned int i=0; i<degree_; ++i)
469  if (children_[i]->needToSplit(gnat))
470  children_[i]->split(gnat);
471  }
472 
474  bool insertNeighborK(NearQueue& nbh, std::size_t k, const _T& data, const _T& key, double dist) const
475  {
476  if (nbh.size() < k)
477  {
478  nbh.push(std::make_pair(&data, dist));
479  return true;
480  }
481  else if (dist < nbh.top().second ||
482  (dist < std::numeric_limits<double>::epsilon() && data==key))
483  {
484  nbh.pop();
485  nbh.push(std::make_pair(&data, dist));
486  return true;
487  }
488  return false;
489  }
490 
496  void nearestK(const GNAT& gnat, const _T &data, std::size_t k,
497  NearQueue& nbh, NodeQueue& nodeQueue, bool& isPivot) const
498  {
499  for (unsigned int i=0; i<data_.size(); ++i)
500  if (!gnat.isRemoved(data_[i]))
501  {
502  if (insertNeighborK(nbh, k, data_[i], data, gnat.distFun_(data, data_[i])))
503  isPivot = false;
504  }
505  if (children_.size() > 0)
506  {
507  double dist;
508  Node* child;
509  std::vector<double> distToPivot(children_.size());
510  std::vector<int> permutation(children_.size());
511 
512  for (unsigned int i=0; i<permutation.size(); ++i)
513  permutation[i] = i;
514  std::random_shuffle(permutation.begin(), permutation.end());
515 
516  for (unsigned int i=0; i<children_.size(); ++i)
517  if (permutation[i] >= 0)
518  {
519  child = children_[permutation[i]];
520  distToPivot[permutation[i]] = gnat.distFun_(data, child->pivot_);
521  if (insertNeighborK(nbh, k, child->pivot_, data, distToPivot[permutation[i]]))
522  isPivot = true;
523  if (nbh.size()==k)
524  {
525  dist = nbh.top().second; // note difference with nearestR
526  for (unsigned int j=0; j<children_.size(); ++j)
527  if (permutation[j] >=0 && i != j &&
528  (distToPivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
529  distToPivot[permutation[i]] + dist < child->minRange_[permutation[j]]))
530  permutation[j] = -1;
531  }
532  }
533 
534  dist = nbh.top().second;
535  for (unsigned int i=0; i<children_.size(); ++i)
536  if (permutation[i] >= 0)
537  {
538  child = children_[permutation[i]];
539  if (nbh.size()<k ||
540  (distToPivot[permutation[i]] - dist <= child->maxRadius_ &&
541  distToPivot[permutation[i]] + dist >= child->minRadius_))
542  nodeQueue.push(std::make_pair(child, distToPivot[permutation[i]]));
543  }
544  }
545  }
547  void insertNeighborR(NearQueue& nbh, double r, const _T& data, double dist) const
548  {
549  if (dist <= r)
550  nbh.push(std::make_pair(&data, dist));
551  }
555  void nearestR(const GNAT& gnat, const _T &data, double r, NearQueue& nbh, NodeQueue& nodeQueue) const
556  {
557  double dist = r; //note difference with nearestK
558 
559  for (unsigned int i=0; i<data_.size(); ++i)
560  if (!gnat.isRemoved(data_[i]))
561  insertNeighborR(nbh, r, data_[i], gnat.distFun_(data, data_[i]));
562  if (children_.size() > 0)
563  {
564  Node* child;
565  std::vector<double> distToPivot(children_.size());
566  std::vector<int> permutation(children_.size());
567 
568  for (unsigned int i=0; i<permutation.size(); ++i)
569  permutation[i] = i;
570  std::random_shuffle(permutation.begin(), permutation.end());
571 
572  for (unsigned int i=0; i<children_.size(); ++i)
573  if (permutation[i] >= 0)
574  {
575  child = children_[permutation[i]];
576  distToPivot[i] = gnat.distFun_(data, child->pivot_);
577  insertNeighborR(nbh, r, child->pivot_, distToPivot[i]);
578  for (unsigned int j=0; j<children_.size(); ++j)
579  if (permutation[j] >=0 && i != j &&
580  (distToPivot[i] - dist > child->maxRange_[permutation[j]] ||
581  distToPivot[i] + dist < child->minRange_[permutation[j]]))
582  permutation[j] = -1;
583  }
584 
585  for (unsigned int i=0; i<children_.size(); ++i)
586  if (permutation[i] >= 0)
587  {
588  child = children_[permutation[i]];
589  if (distToPivot[i] - dist <= child->maxRadius_ &&
590  distToPivot[i] + dist >= child->minRadius_)
591  nodeQueue.push(std::make_pair(child, distToPivot[i]));
592  }
593  }
594  }
595 
596  void list(const GNAT& gnat, std::vector<_T> &data) const
597  {
598  if (!gnat.isRemoved(pivot_))
599  data.push_back(pivot_);
600  for (unsigned int i=0; i<data_.size(); ++i)
601  if(!gnat.isRemoved(data_[i]))
602  data.push_back(data_[i]);
603  for (unsigned int i=0; i<children_.size(); ++i)
604  children_[i]->list(gnat, data);
605  }
606 
607  friend std::ostream& operator<<(std::ostream& out, const Node& node)
608  {
609  out << "\ndegree:\t" << node.degree_;
610  out << "\nminRadius:\t" << node.minRadius_;
611  out << "\nmaxRadius:\t" << node.maxRadius_;
612  out << "\nminRange:\t";
613  for (unsigned int i=0; i<node.minRange_.size(); ++i)
614  out << node.minRange_[i] << '\t';
615  out << "\nmaxRange: ";
616  for (unsigned int i=0; i<node.maxRange_.size(); ++i)
617  out << node.maxRange_[i] << '\t';
618  out << "\npivot:\t" << node.pivot_;
619  out << "\ndata: ";
620  for (unsigned int i=0; i<node.data_.size(); ++i)
621  out << node.data_[i] << '\t';
622  out << "\nthis:\t" << &node;
623  out << "\nchildren:\n";
624  for (unsigned int i=0; i<node.children_.size(); ++i)
625  out << node.children_[i] << '\t';
626  out << '\n';
627  for (unsigned int i=0; i<node.children_.size(); ++i)
628  out << *node.children_[i] << '\n';
629  return out;
630  }
631 
633  unsigned int degree_;
635  const _T pivot_;
637  double minRadius_;
639  double maxRadius_;
642  std::vector<double> minRange_;
645  std::vector<double> maxRange_;
648  std::vector<_T> data_;
651  std::vector<Node*> children_;
652  };
653 
657  unsigned int degree_;
662  unsigned int minDegree_;
667  unsigned int maxDegree_;
670  unsigned int maxNumPtsPerLeaf_;
672  std::size_t size_;
675  std::size_t rebuildSize_;
679  std::size_t removedCacheSize_;
683  boost::unordered_set<const _T*> removed_;
684  };
685 
686 }
687 
688 #endif