[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_split.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
37 #include <algorithm>
38 #include <cstddef>
39 #include <map>
40 #include <numeric>
41 #include <math.h>
42 #include "../mathutil.hxx"
43 #include "../array_vector.hxx"
44 #include "../sized_int.hxx"
45 #include "../matrix.hxx"
46 #include "../random.hxx"
47 #include "../functorexpression.hxx"
48 #include "rf_nodeproxy.hxx"
49 //#include "rf_sampling.hxx"
50 #include "rf_region.hxx"
51 //#include "../hokashyap.hxx"
52 //#include "vigra/rf_helpers.hxx"
53 
54 namespace vigra
55 {
56 
57 // Incomplete Class to ensure that findBestSplit is always implemented in
58 // the derived classes of SplitBase
59 class CompileTimeError;
60 
61 
62 namespace detail
63 {
64  template<class Tag>
65  class Normalise
66  {
67  public:
68  template<class Iter>
69  static void exec(Iter begin, Iter end)
70  {}
71  };
72 
73  template<>
74  class Normalise<ClassificationTag>
75  {
76  public:
77  template<class Iter>
78  static void exec (Iter begin, Iter end)
79  {
80  double bla = std::accumulate(begin, end, 0.0);
81  for(int ii = 0; ii < end - begin; ++ii)
82  begin[ii] = begin[ii]/bla ;
83  }
84  };
85 }
86 
87 
88 /** Base Class for all SplitFunctors used with the \ref RandomForest class
89  defines the interface used while learning a tree.
90 **/
91 template<class Tag>
92 class SplitBase
93 {
94  public:
95 
96  typedef Tag RF_Tag;
99 
100  ProblemSpec<> ext_param_;
101 
104 
105  NodeBase node_;
106 
107  /** returns the DecisionTree Node created by
108  \ref findBestSplit or \ref makeTerminalNode.
109  **/
110 
111  template<class T>
113  {
114  ext_param_ = in;
115  t_data.push_back(in.column_count_);
116  t_data.push_back(in.class_count_);
117  }
118 
119  NodeBase & createNode()
120  {
121  return node_;
122  }
123 
124  int classCount() const
125  {
126  return int(t_data[1]);
127  }
128 
129  int featureCount() const
130  {
131  return int(t_data[0]);
132  }
133 
134  /** resets internal data. Should always be called before
135  calling findBestSplit or makeTerminalNode
136  **/
137  void reset()
138  {
139  t_data.resize(2);
140  p_data.resize(0);
141  }
142 
143 
144  /** findBestSplit has to be implemented in derived split functor.
145  these functions only insures That a CompileTime error is issued
146  if no such method was defined.
147  **/
148 
149  template<class T, class C, class T2, class C2, class Region, class Random>
152  Region region,
153  ArrayVector<Region> childs,
154  Random randint)
155  {
156  CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
157  return 0;
158  }
159 
160  /** default action for creating a terminal Node.
161  sets the Class probability of the remaining region according to
162  the class histogram
163  **/
164  template<class T, class C, class T2,class C2, class Region, class Random>
167  Region & region,
168  Random randint)
169  {
170  Node<e_ConstProbNode> ret(t_data, p_data);
171  node_ = ret;
172  if(ext_param_.class_weights_.size() != region.classCounts().size())
173  {
174  std::copy( region.classCounts().begin(),
175  region.classCounts().end(),
176  ret.prob_begin());
177  }
178  else
179  {
180  std::transform( region.classCounts().begin(),
181  region.classCounts().end(),
182  ext_param_.class_weights_.begin(),
183  ret.prob_begin(), std::multiplies<double>());
184  }
185  detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
186  ret.weights() = region.size();
187  return e_ConstProbNode;
188  }
189 
190 
191 };
192 
193 /** Functor to sort the indices of a feature Matrix by a certain dimension
194 **/
195 template<class DataMatrix>
197 {
198  DataMatrix const & data_;
199  MultiArrayIndex sortColumn_;
200  double thresVal_;
201  public:
202 
203  SortSamplesByDimensions(DataMatrix const & data,
204  MultiArrayIndex sortColumn,
205  double thresVal = 0.0)
206  : data_(data),
207  sortColumn_(sortColumn),
208  thresVal_(thresVal)
209  {}
210 
211  void setColumn(MultiArrayIndex sortColumn)
212  {
213  sortColumn_ = sortColumn;
214  }
215  void setThreshold(double value)
216  {
217  thresVal_ = value;
218  }
219 
220  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
221  {
222  return data_(l, sortColumn_) < data_(r, sortColumn_);
223  }
224  bool operator()(MultiArrayIndex l) const
225  {
226  return data_(l, sortColumn_) < thresVal_;
227  }
228 };
229 
230 template<class DataMatrix>
231 class DimensionNotEqual
232 {
233  DataMatrix const & data_;
234  MultiArrayIndex sortColumn_;
235 
236  public:
237 
238  DimensionNotEqual(DataMatrix const & data,
239  MultiArrayIndex sortColumn)
240  : data_(data),
241  sortColumn_(sortColumn)
242  {}
243 
244  void setColumn(MultiArrayIndex sortColumn)
245  {
246  sortColumn_ = sortColumn;
247  }
248 
249  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
250  {
251  return data_(l, sortColumn_) != data_(r, sortColumn_);
252  }
253 };
254 
255 template<class DataMatrix>
256 class SortSamplesByHyperplane
257 {
258  DataMatrix const & data_;
259  Node<i_HyperplaneNode> const & node_;
260 
261  public:
262 
263  SortSamplesByHyperplane(DataMatrix const & data,
264  Node<i_HyperplaneNode> const & node)
265  :
266  data_(data),
267  node_()
268  {}
269 
270  /** calculate the distance of a sample point to a hyperplane
271  */
272  double operator[](MultiArrayIndex l) const
273  {
274  double result_l = -1 * node_.intercept();
275  for(int ii = 0; ii < node_.columns_size(); ++ii)
276  {
277  result_l += rowVector(data_, l)[node_.columns_begin()[ii]]
278  * node_.weights()[ii];
279  }
280  return result_l;
281  }
282 
283  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
284  {
285  return (*this)[l] < (*this)[r];
286  }
287 
288 };
289 
290 /** makes a Class Histogram given indices in a labels_ array
291  * usage:
292  * MultiArrayView<2, T2, C2> labels = makeSomeLabels()
293  * ArrayVector<int> hist(numberOfLabels(labels), 0);
294  * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
295  *
296  * Container<int> indices = getSomeIndices()
297  * std::for_each(indices, counter);
298  */
299 template <class DataSource, class CountArray>
301 {
302  DataSource const & labels_;
303  CountArray & counts_;
304 
305  public:
306 
307  RandomForestClassCounter(DataSource const & labels,
308  CountArray & counts)
309  : labels_(labels),
310  counts_(counts)
311  {
312  reset();
313  }
314 
315  void reset()
316  {
317  counts_.init(0);
318  }
319 
320  void operator()(MultiArrayIndex l) const
321  {
322  counts_[labels_[l]] +=1;
323  }
324 };
325 
326 
327 /** Functor To Calculate the Best possible Split Based on the Gini Index
328  given Labels and Features along a given Axis
329 */
330 
331 namespace detail
332 {
333  template<int N>
334  class ConstArr
335  {
336  public:
337  double operator[](size_t) const
338  {
339  return (double)N;
340  }
341  };
342 
343 
344 }
345 
346 
347 
348 
349 /** Functor to calculate the entropy based impurity
350  */
352 {
353 public:
354  /**caculate the weighted gini impurity based on class histogram
355  * and class weights
356  */
357  template<class Array, class Array2>
358  double operator() (Array const & hist,
359  Array2 const & weights,
360  double total = 1.0) const
361  {
362  return impurity(hist, weights, total);
363  }
364 
365  /** calculate the gini based impurity based on class histogram
366  */
367  template<class Array>
368  double operator()(Array const & hist, double total = 1.0) const
369  {
370  return impurity(hist, total);
371  }
372 
373  /** static version of operator(hist total)
374  */
375  template<class Array>
376  static double impurity(Array const & hist, double total)
377  {
378  return impurity(hist, detail::ConstArr<1>(), total);
379  }
380 
381  /** static version of operator(hist, weights, total)
382  */
383  template<class Array, class Array2>
384  static double impurity (Array const & hist,
385  Array2 const & weights,
386  double total)
387  {
388 
389  int class_count = hist.size();
390  double entropy = 0.0;
391  if(class_count == 2)
392  {
393  double p0 = (hist[0]/total);
394  double p1 = (hist[1]/total);
395  entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
396  }
397  else
398  {
399  for(int ii = 0; ii < class_count; ++ii)
400  {
401  double w = weights[ii];
402  double pii = hist[ii]/total;
403  entropy -= w*( pii*std::log(pii));
404  }
405  }
406  entropy = total * entropy;
407  return entropy;
408  }
409 };
410 
411 /** Functor to calculate the gini impurity
412  */
414 {
415 public:
416  /**caculate the weighted gini impurity based on class histogram
417  * and class weights
418  */
419  template<class Array, class Array2>
420  double operator() (Array const & hist,
421  Array2 const & weights,
422  double total = 1.0) const
423  {
424  return impurity(hist, weights, total);
425  }
426 
427  /** calculate the gini based impurity based on class histogram
428  */
429  template<class Array>
430  double operator()(Array const & hist, double total = 1.0) const
431  {
432  return impurity(hist, total);
433  }
434 
435  /** static version of operator(hist total)
436  */
437  template<class Array>
438  static double impurity(Array const & hist, double total)
439  {
440  return impurity(hist, detail::ConstArr<1>(), total);
441  }
442 
443  /** static version of operator(hist, weights, total)
444  */
445  template<class Array, class Array2>
446  static double impurity (Array const & hist,
447  Array2 const & weights,
448  double total)
449  {
450 
451  int class_count = hist.size();
452  double gini = 0.0;
453  if(class_count == 2)
454  {
455  double w = weights[0] * weights[1];
456  gini = w * (hist[0] * hist[1] / total);
457  }
458  else
459  {
460  for(int ii = 0; ii < class_count; ++ii)
461  {
462  double w = weights[ii];
463  gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
464  }
465  }
466  return gini;
467  }
468 };
469 
470 
471 template <class DataSource, class Impurity= GiniCriterion>
472 class ImpurityLoss
473 {
474 
475  DataSource const & labels_;
476  ArrayVector<double> counts_;
477  ArrayVector<double> const class_weights_;
478  double total_counts_;
479  Impurity impurity_;
480 
481  public:
482 
483  template<class T>
484  ImpurityLoss(DataSource const & labels,
485  ProblemSpec<T> const & ext_)
486  : labels_(labels),
487  counts_(ext_.class_count_, 0.0),
488  class_weights_(ext_.class_weights_),
489  total_counts_(0.0)
490  {}
491 
492  void reset()
493  {
494  counts_.init(0);
495  total_counts_ = 0.0;
496  }
497 
498  template<class Counts>
499  double increment_histogram(Counts const & counts)
500  {
501  std::transform(counts.begin(), counts.end(),
502  counts_.begin(), counts_.begin(),
503  std::plus<double>());
504  total_counts_ = std::accumulate( counts_.begin(),
505  counts_.end(),
506  0.0);
507  return impurity_(counts_, class_weights_, total_counts_);
508  }
509 
510  template<class Counts>
511  double decrement_histogram(Counts const & counts)
512  {
513  std::transform(counts.begin(), counts.end(),
514  counts_.begin(), counts_.begin(),
515  std::minus<double>());
516  total_counts_ = std::accumulate( counts_.begin(),
517  counts_.end(),
518  0.0);
519  return impurity_(counts_, class_weights_, total_counts_);
520  }
521 
522  template<class Iter>
523  double increment(Iter begin, Iter end)
524  {
525  for(Iter iter = begin; iter != end; ++iter)
526  {
527  counts_[labels_(*iter, 0)] +=1.0;
528  total_counts_ +=1.0;
529  }
530  return impurity_(counts_, class_weights_, total_counts_);
531  }
532 
533  template<class Iter>
534  double decrement(Iter const & begin, Iter const & end)
535  {
536  for(Iter iter = begin; iter != end; ++iter)
537  {
538  counts_[labels_(*iter,0)] -=1.0;
539  total_counts_ -=1.0;
540  }
541  return impurity_(counts_, class_weights_, total_counts_);
542  }
543 
544  template<class Iter, class Resp_t>
545  double init (Iter begin, Iter end, Resp_t resp)
546  {
547  reset();
548  std::copy(resp.begin(), resp.end(), counts_.begin());
549  total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
550  return impurity_(counts_,class_weights_, total_counts_);
551  }
552 
553  ArrayVector<double> const & response()
554  {
555  return counts_;
556  }
557 };
558 
559 template <class DataSource>
560 class RegressionForestCounter
561 {
562  typedef MultiArrayShape<2>::type Shp;
563  DataSource const & labels_;
564  ArrayVector <double> mean_;
565  ArrayVector <double> variance_;
566  ArrayVector <double> tmp_;
567  size_t count_;
568 
569  template<class T>
570  RegressionForestCounter(DataSource const & labels,
571  ProblemSpec<T> const & ext_)
572  :
573  labels_(labels),
574  mean_(ext_.response_size, 0.0),
575  variance_(ext_.response_size, 0.0),
576  tmp_(ext_.response_size),
577  count_(0)
578  {}
579 
580  // west's alorithm for incremental variance
581  // calculation
582  template<class Iter>
583  double increment (Iter begin, Iter end)
584  {
585  for(Iter iter = begin; iter != end; ++iter)
586  {
587  ++count_;
588  for(int ii = 0; ii < mean_.size(); ++ii)
589  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
590  double f = 1.0 / count_,
591  f1 = 1.0 - f;
592  for(int ii = 0; ii < mean_.size(); ++ii)
593  mean_[ii] += f*tmp_[ii];
594  for(int ii = 0; ii < mean_.size(); ++ii)
595  variance_[ii] += f1*sq(tmp_[ii]);
596  }
597  return std::accumulate(variance_.begin(),
598  variance_.end(),
599  0.0,
600  std::plus<double>())
601  /(count_ -1);
602  }
603 
604  template<class Iter>
605  double decrement (Iter begin, Iter end)
606  {
607  for(Iter iter = begin; iter != end; ++iter)
608  {
609  --count_;
610  for(int ii = 0; ii < mean_.size(); ++ii)
611  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
612  double f = 1.0 / count_,
613  f1 = 1.0 + f;
614  for(int ii = 0; ii < mean_.size(); ++ii)
615  mean_[ii] -= f*tmp_[ii];
616  for(int ii = 0; ii < mean_.size(); ++ii)
617  variance_[ii] -= f1*sq(tmp_[ii]);
618  }
619  return std::accumulate(variance_.begin(),
620  variance_.end(),
621  0.0,
622  std::plus<double>())
623  /(count_ -1);
624  }
625 
626  template<class Iter, class Resp_t>
627  double init (Iter begin, Iter end, Resp_t resp)
628  {
629  reset();
630  return increment(begin, end);
631  }
632 
633 
634  ArrayVector<double> const & response()
635  {
636  return mean_;
637  }
638 
639  void reset()
640  {
641  mean_.init(0.0);
642  variance_.init(0.0);
643  count_ = 0;
644  }
645 };
646 
647 template<class Tag, class Datatyp>
648 struct LossTraits;
649 
650 struct LSQLoss
651 {};
652 
653 template<class Datatype>
654 struct LossTraits<GiniCriterion, Datatype>
655 {
656  typedef ImpurityLoss<Datatype, GiniCriterion> type;
657 };
658 
659 template<class Datatype>
660 struct LossTraits<EntropyCriterion, Datatype>
661 {
662  typedef ImpurityLoss<Datatype, EntropyCriterion> type;
663 };
664 
665 template<class Datatype>
666 struct LossTraits<LSQLoss, Datatype>
667 {
668  typedef RegressionForestCounter<Datatype> type;
669 };
670 
671 /** Given a column, choose a split that minimizes some loss
672  */
673 template<class LineSearchLossTag>
675 {
676 public:
677  ArrayVector<double> class_weights_;
678  ArrayVector<double> bestCurrentCounts[2];
679  double min_gini_;
680  ptrdiff_t min_index_;
681  double min_threshold_;
682  ProblemSpec<> ext_param_;
683 
685  {}
686 
687  template<class T>
688  BestGiniOfColumn(ProblemSpec<T> const & ext)
689  :
690  class_weights_(ext.class_weights_),
691  ext_param_(ext)
692  {
693  bestCurrentCounts[0].resize(ext.class_count_);
694  bestCurrentCounts[1].resize(ext.class_count_);
695  }
696  template<class T>
697  void set_external_parameters(ProblemSpec<T> const & ext)
698  {
699  class_weights_ = ext.class_weights_;
700  ext_param_ = ext;
701  bestCurrentCounts[0].resize(ext.class_count_);
702  bestCurrentCounts[1].resize(ext.class_count_);
703  }
704  /** calculate the best gini split along a Feature Column
705  * \param column, the feature vector - has to support the [] operator
706  * \param labels, the label vector
707  * \param begin
708  * \param end (in and out)
709  * begin and end iterators to the indices of the
710  * samples in the current region.
711  * the range begin - end is sorted by the column supplied
712  * during function execution.
713  * \param class_counts
714  * class histogram of the range.
715  *
716  * precondition: begin, end valid range,
717  * class_counts positive integer valued array with the
718  * class counts in the current range.
719  * labels.size() >= max(begin, end);
720  * postcondition:
721  * begin, end sorted by column given.
722  * min_gini_ contains the minimum gini found or
723  * NumericTraits<double>::max if no split was found.
724  * min_index_ countains the splitting index in the range
725  * or invalid data if no split was found.
726  * BestCirremtcounts[0] and [1] contain the
727  * class histogram of the left and right region of
728  * the left and right regions.
729  */
730  template< class DataSourceF_t,
731  class DataSource_t,
732  class I_Iter,
733  class Array>
734  void operator()(DataSourceF_t const & column,
735  int g,
736  DataSource_t const & labels,
737  I_Iter & begin,
738  I_Iter & end,
739  Array const & region_response)
740  {
741  std::sort(begin, end,
743  typedef typename
744  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
745  LineSearchLoss left(labels, ext_param_);
746  LineSearchLoss right(labels, ext_param_);
747 
748 
749 
750  min_gini_ = right.init(begin, end, region_response);
751  min_threshold_ = *begin;
752  min_index_ = 0;
753  DimensionNotEqual<DataSourceF_t> comp(column, g);
754 
755  I_Iter iter = begin;
756  I_Iter next = std::adjacent_find(iter, end, comp);
757  while( next != end)
758  {
759 
760  double loss = right.decrement(iter, next + 1)
761  + left.increment(iter , next + 1);
762 #ifdef CLASSIFIER_TEST
763  if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
764 #else
765  if(loss < min_gini_ )
766 #endif
767  {
768  bestCurrentCounts[0] = left.response();
769  bestCurrentCounts[1] = right.response();
770 #ifdef CLASSIFIER_TEST
771  min_gini_ = loss < min_gini_? loss : min_gini_;
772 #else
773  min_gini_ = loss;
774 #endif
775  min_index_ = next - begin +1 ;
776  min_threshold_ = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0;
777  }
778  iter = next +1 ;
779  next = std::adjacent_find(iter, end, comp);
780  }
781  }
782 
783  template<class DataSource_t, class Iter, class Array>
784  double loss_of_region(DataSource_t const & labels,
785  Iter & begin,
786  Iter & end,
787  Array const & region_response) const
788  {
789  typedef typename
790  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
791  LineSearchLoss region_loss(labels, ext_param_);
792  return
793  region_loss.init(begin, end, region_response);
794  }
795 
796 };
797 
798 
799 /** Chooses mtry columns ad applys ColumnDecisionFunctor to each of the
800  * columns. Then Chooses the column that is best
801  */
802 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
803 class ThresholdSplit: public SplitBase<Tag>
804 {
805  public:
806 
807 
808  typedef SplitBase<Tag> SB;
809 
810  ArrayVector<Int32> splitColumns;
811  ColumnDecisionFunctor bgfunc;
812 
813  double region_gini_;
814  ArrayVector<double> min_gini_;
815  ArrayVector<ptrdiff_t> min_indices_;
816  ArrayVector<double> min_thresholds_;
817 
818  int bestSplitIndex;
819 
820  double minGini() const
821  {
822  return min_gini_[bestSplitIndex];
823  }
824  int bestSplitColumn() const
825  {
826  return splitColumns[bestSplitIndex];
827  }
828  double bestSplitThreshold() const
829  {
830  return min_thresholds_[bestSplitIndex];
831  }
832 
833  template<class T>
835  {
837  bgfunc.set_external_parameters( SB::ext_param_);
838  int featureCount_ = SB::ext_param_.column_count_;
839  splitColumns.resize(featureCount_);
840  for(int k=0; k<featureCount_; ++k)
841  splitColumns[k] = k;
842  min_gini_.resize(featureCount_);
843  min_indices_.resize(featureCount_);
844  min_thresholds_.resize(featureCount_);
845  }
846 
847 
848  template<class T, class C, class T2, class C2, class Region, class Random>
849  int findBestSplit(MultiArrayView<2, T, C> features,
851  Region & region,
852  ArrayVector<Region>& childRegions,
853  Random & randint)
854  {
855 
856  typedef typename Region::IndexIterator IndexIterator;
857  if(region.size() == 0)
858  {
859  std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
860  "continuing learning process....";
861  }
862  // calculate things that haven't been calculated yet.
863 
864  if(std::accumulate(region.classCounts().begin(),
865  region.classCounts().end(), 0) != region.size())
866  {
867  RandomForestClassCounter< MultiArrayView<2,T2, C2>,
868  ArrayVector<double> >
869  counter(labels, region.classCounts());
870  std::for_each( region.begin(), region.end(), counter);
871  region.classCountsIsValid = true;
872  }
873 
874  // Is the region pure already?
875  region_gini_ = bgfunc.loss_of_region(labels,
876  region.begin(),
877  region.end(),
878  region.classCounts());
879  if(region_gini_ <= SB::ext_param_.precision_)
880  return this->makeTerminalNode(features, labels, region, randint);
881 
882  // select columns to be tried.
883  for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
884  std::swap(splitColumns[ii],
885  splitColumns[ii+ randint(features.shape(1) - ii)]);
886 
887  // find the best gini index
888  bestSplitIndex = 0;
889  double current_min_gini = region_gini_;
890  int num2try = features.shape(1);
891  for(int k=0; k<num2try; ++k)
892  {
893  //this functor does all the work
894  bgfunc(features,
895  splitColumns[k],
896  labels,
897  region.begin(), region.end(),
898  region.classCounts());
899  min_gini_[k] = bgfunc.min_gini_;
900  min_indices_[k] = bgfunc.min_index_;
901  min_thresholds_[k] = bgfunc.min_threshold_;
902 #ifdef CLASSIFIER_TEST
903  if( bgfunc.min_gini_ < current_min_gini
904  && !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
905 #else
906  if(bgfunc.min_gini_ < current_min_gini)
907 #endif
908  {
909  current_min_gini = bgfunc.min_gini_;
910  childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
911  childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
912  childRegions[0].classCountsIsValid = true;
913  childRegions[1].classCountsIsValid = true;
914 
915  bestSplitIndex = k;
916  num2try = SB::ext_param_.actual_mtry_;
917  }
918  }
919 
920  // did not find any suitable split
921  if(closeAtTolerance(current_min_gini, region_gini_))
922  return this->makeTerminalNode(features, labels, region, randint);
923 
924  //create a Node for output
925  Node<i_ThresholdNode> node(SB::t_data, SB::p_data);
926  SB::node_ = node;
927  node.threshold() = min_thresholds_[bestSplitIndex];
928  node.column() = splitColumns[bestSplitIndex];
929 
930  // partition the range according to the best dimension
931  SortSamplesByDimensions<MultiArrayView<2, T, C> >
932  sorter(features, node.column(), node.threshold());
933  IndexIterator bestSplit =
934  std::partition(region.begin(), region.end(), sorter);
935  // Save the ranges of the child stack entries.
936  childRegions[0].setRange( region.begin() , bestSplit );
937  childRegions[0].rule = region.rule;
938  childRegions[0].rule.push_back(std::make_pair(1, 1.0));
939  childRegions[1].setRange( bestSplit , region.end() );
940  childRegions[1].rule = region.rule;
941  childRegions[1].rule.push_back(std::make_pair(1, 1.0));
942 
943  return i_ThresholdNode;
944  }
945 };
946 
947 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
948 typedef ThresholdSplit<BestGiniOfColumn<EntropyCriterion> > EntropySplit;
949 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit;
950 
951 namespace rf
952 {
953 
954 /** This namespace contains additional Splitfunctors.
955  *
956  * The Split functor classes are designed in a modular fashion because new split functors may
957  * share a lot of code with existing ones.
958  *
959  * ThresholdSplit implements the functionality needed for any split functor, that makes its
960  * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
961  * along one dimension is chosen.
962  *
963  * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
964  * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a
965  * kD tree fashion.
966  *
967  *
968  * Currently defined typedefs:
969  * \code
970  * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
971  * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit;
972  * typedef ThresholdSplit<Median> MedianSplit;
973  * \endcode
974  */
975 namespace split
976 {
977 
978 /** This Functor chooses the median value of a column
979  */
980 class Median
981 {
982 public:
983 
985  ArrayVector<double> class_weights_;
986  ArrayVector<double> bestCurrentCounts[2];
987  double min_gini_;
988  ptrdiff_t min_index_;
989  double min_threshold_;
990  ProblemSpec<> ext_param_;
991 
992  Median()
993  {}
994 
995  template<class T>
996  Median(ProblemSpec<T> const & ext)
997  :
998  class_weights_(ext.class_weights_),
999  ext_param_(ext)
1000  {
1001  bestCurrentCounts[0].resize(ext.class_count_);
1002  bestCurrentCounts[1].resize(ext.class_count_);
1003  }
1004 
1005  template<class T>
1006  void set_external_parameters(ProblemSpec<T> const & ext)
1007  {
1008  class_weights_ = ext.class_weights_;
1009  ext_param_ = ext;
1010  bestCurrentCounts[0].resize(ext.class_count_);
1011  bestCurrentCounts[1].resize(ext.class_count_);
1012  }
1013 
1014  template< class DataSourceF_t,
1015  class DataSource_t,
1016  class I_Iter,
1017  class Array>
1018  void operator()(DataSourceF_t const & column,
1019  DataSource_t const & labels,
1020  I_Iter & begin,
1021  I_Iter & end,
1022  Array const & region_response)
1023  {
1024  std::sort(begin, end,
1026  typedef typename
1027  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1028  LineSearchLoss left(labels, ext_param_);
1029  LineSearchLoss right(labels, ext_param_);
1030  right.init(begin, end, region_response);
1031 
1032  min_gini_ = NumericTraits<double>::max();
1033  min_index_ = floor(double(end - begin)/2.0);
1034  min_threshold_ = column[*(begin + min_index_)];
1036  sorter(column, 0, min_threshold_);
1037  I_Iter part = std::partition(begin, end, sorter);
1038  DimensionNotEqual<DataSourceF_t> comp(column, 0);
1039  if(part == begin)
1040  {
1041  part= std::adjacent_find(part, end, comp)+1;
1042 
1043  }
1044  if(part >= end)
1045  {
1046  return;
1047  }
1048  else
1049  {
1050  min_threshold_ = column[*part];
1051  }
1052  min_gini_ = right.decrement(begin, part)
1053  + left.increment(begin , part);
1054 
1055  bestCurrentCounts[0] = left.response();
1056  bestCurrentCounts[1] = right.response();
1057 
1058  min_index_ = part - begin;
1059  }
1060 
1061  template<class DataSource_t, class Iter, class Array>
1062  double loss_of_region(DataSource_t const & labels,
1063  Iter & begin,
1064  Iter & end,
1065  Array const & region_response) const
1066  {
1067  typedef typename
1068  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1069  LineSearchLoss region_loss(labels, ext_param_);
1070  return
1071  region_loss.init(begin, end, region_response);
1072  }
1073 
1074 };
1075 
1077 
1078 
1079 /** This Functor chooses a random value of a column
1080  */
1082 {
1083 public:
1084 
1086  ArrayVector<double> class_weights_;
1087  ArrayVector<double> bestCurrentCounts[2];
1088  double min_gini_;
1089  ptrdiff_t min_index_;
1090  double min_threshold_;
1091  ProblemSpec<> ext_param_;
1092  typedef RandomMT19937 Random_t;
1093  Random_t random;
1094 
1096  {}
1097 
1098  template<class T>
1099  RandomSplitOfColumn(ProblemSpec<T> const & ext)
1100  :
1101  class_weights_(ext.class_weights_),
1102  ext_param_(ext),
1103  random(RandomSeed)
1104  {
1105  bestCurrentCounts[0].resize(ext.class_count_);
1106  bestCurrentCounts[1].resize(ext.class_count_);
1107  }
1108 
1109  template<class T>
1110  RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_)
1111  :
1112  class_weights_(ext.class_weights_),
1113  ext_param_(ext),
1114  random(random_)
1115  {
1116  bestCurrentCounts[0].resize(ext.class_count_);
1117  bestCurrentCounts[1].resize(ext.class_count_);
1118  }
1119 
1120  template<class T>
1121  void set_external_parameters(ProblemSpec<T> const & ext)
1122  {
1123  class_weights_ = ext.class_weights_;
1124  ext_param_ = ext;
1125  bestCurrentCounts[0].resize(ext.class_count_);
1126  bestCurrentCounts[1].resize(ext.class_count_);
1127  }
1128 
1129  template< class DataSourceF_t,
1130  class DataSource_t,
1131  class I_Iter,
1132  class Array>
1133  void operator()(DataSourceF_t const & column,
1134  DataSource_t const & labels,
1135  I_Iter & begin,
1136  I_Iter & end,
1137  Array const & region_response)
1138  {
1139  std::sort(begin, end,
1141  typedef typename
1142  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1143  LineSearchLoss left(labels, ext_param_);
1144  LineSearchLoss right(labels, ext_param_);
1145  right.init(begin, end, region_response);
1146 
1147 
1148  min_gini_ = NumericTraits<double>::max();
1149 
1150  min_index_ = begin + random.uniformInt(end -begin);
1151  min_threshold_ = column[*(begin + min_index_)];
1153  sorter(column, 0, min_threshold_);
1154  I_Iter part = std::partition(begin, end, sorter);
1155  DimensionNotEqual<DataSourceF_t> comp(column, 0);
1156  if(part == begin)
1157  {
1158  part= std::adjacent_find(part, end, comp)+1;
1159 
1160  }
1161  if(part >= end)
1162  {
1163  return;
1164  }
1165  else
1166  {
1167  min_threshold_ = column[*part];
1168  }
1169  min_gini_ = right.decrement(begin, part)
1170  + left.increment(begin , part);
1171 
1172  bestCurrentCounts[0] = left.response();
1173  bestCurrentCounts[1] = right.response();
1174 
1175  min_index_ = part - begin;
1176  }
1177 
1178  template<class DataSource_t, class Iter, class Array>
1179  double loss_of_region(DataSource_t const & labels,
1180  Iter & begin,
1181  Iter & end,
1182  Array const & region_response) const
1183  {
1184  typedef typename
1185  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1186  LineSearchLoss region_loss(labels, ext_param_);
1187  return
1188  region_loss.init(begin, end, region_response);
1189  }
1190 
1191 };
1192 
1194 }
1195 }
1196 
1197 
1198 } //namespace vigra
1199 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Thu Jun 14 2012)