[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "random.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
63 namespace vigra
64 {
65 
66 /** \addtogroup MachineLearning Machine Learning
67 
68  This module provides classification algorithms that map
69  features to labels or label probablities.
70  Look at the RandomForest class first for a overview of most of the
71  functionality provided as well as use cases.
72 **/
73 //@{
74 
75 namespace detail
76 {
77 
78 
79 
80 /* \brief sampling option factory function
81  */
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
83 {
84  SamplerOptions return_opt;
85  return_opt.withReplacement(RF_opt.sample_with_replacement_);
86  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
87  return return_opt;
88 }
89 }//namespace detail
90 
91 /** Random Forest class
92  *
93  * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess
94  * the input while learning and predicting. Currently Available:
95  * ClassificationTag and RegressionTag. It is recommended to use
96  * Splitfunctor::Preprocessor_t while using custom splitfunctors
97  * as they may need the data to be in a different format.
98  * \sa Preprocessor
99  *
100  * simple usage for classification (regression is not yet supported):
101  * look at RandomForest::learn() as well as RandomForestOptions() for additional
102  * options.
103  *
104  * \code
105  * typedef xxx feature_t \\ replace xxx with whichever type
106  * typedef yyy label_t \\ meme chose.
107  * MultiArrayView<2, feature_t> f = get_some_features();
108  * MultiArrayView<2, label_t> l = get_some_labels();
109  * RandomForest<> rf()
110  * double oob_error = rf.learn(f, l);
111  *
112  * MultiArrayView<2, feature_t> pf = get_some_unknown_features();
113  * MultiArrayView<2, label_t> prediction = allocate_space_for_response();
114  * MultiArrayView<2, double> prob = allocate_space_for_probability();
115  *
116  * rf.predict_labels(pf, prediction);
117  * rf.predict_probabilities(pf, prob);
118  *
119  * \endcode
120  *
121  * Additional information such as OOB Error and Variable Importance measures are accessed
122  * via Visitors defined in rf::visitors.
123  * Have a look at rf::split for other splitting methods.
124  *
125 */
126 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
128 {
129 
130  public:
131  //public typedefs
135  typedef GiniSplit Default_Split_t;
139  StackEntry_t;
140  typedef LabelType LabelT;
141  protected:
142 
143  /** optimisation for predictLabels
144  * */
146 
147  public:
148 
149  //problem independent data.
150  Options_t options_;
151  //problem dependent data members - is only set if
152  //a copy constructor, some sort of import
153  //function or the learn function is called
155  ProblemSpec_t ext_param_;
156  /*mutable ArrayVector<int> tree_indices_;*/
157  rf::visitors::OnlineLearnVisitor online_visitor_;
158 
159 
160  void reset()
161  {
162  ext_param_.clear();
163  trees_.clear();
164  }
165 
166  public:
167 
168  /** \name Contructors
169  * Note: No copy Constructor specified as no pointers are manipulated
170  * in this class
171  */
172  /*\{*/
173  /**\brief default constructor
174  *
175  * \param options general options to the Random Forest. Must be of Type
176  * Options_t
177  * \param ext_param problem specific values that can be supplied
178  * additionally. (class weights , labels etc)
179  * \sa RandomForestOptions, ProblemSpec
180  *
181  */
184  :
185  options_(options),
186  ext_param_(ext_param)/*,
187  tree_indices_(options.tree_count_,0)*/
188  {
189  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
190  tree_indices_[ii] = ii;*/
191  }
192 
193  /**\brief Create RF from external source
194  * \param treeCount Number of trees to add.
195  * \param topology_begin
196  * Iterator to a Container where the topology_ data
197  * of the trees are stored.
198  * Iterator should support at least treeCount forward
199  * iterations. (i.e. topology_end - topology_begin >= treeCount
200  * \param parameter_begin
201  * iterator to a Container where the parameters_ data
202  * of the trees are stored. Iterator should support at
203  * least treeCount forward iterations.
204  * \param problem_spec
205  * Extrinsic parameters that specify the problem e.g.
206  * ClassCount, featureCount etc.
207  * \param options (optional) specify options used to train the original
208  * Random forest. This parameter is not used anywhere
209  * during prediction and thus is optional.
210  *
211  */
212  /* TODO: This constructor may be replaced by a Constructor using
213  * NodeProxy iterators to encapsulate the underlying data type.
214  */
215  template<class TopologyIterator, class ParameterIterator>
216  RandomForest(int treeCount,
217  TopologyIterator topology_begin,
218  ParameterIterator parameter_begin,
219  ProblemSpec_t const & problem_spec,
220  Options_t const & options = Options_t())
221  :
222  trees_(treeCount, DecisionTree_t(problem_spec)),
223  ext_param_(problem_spec),
224  options_(options)
225  {
226  for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
227  {
228  trees_[k].topology_ = *topology_begin;
229  trees_[k].parameters_ = *parameter_begin;
230  }
231  }
232 
233  /*\}*/
234 
235 
236  /** \name Data Access
237  * data access interface - usage of member variables is deprecated
238  */
239 
240  /*\{*/
241 
242 
243  /**\brief return external parameters for viewing
244  * \return ProblemSpec_t
245  */
246  ProblemSpec_t const & ext_param() const
247  {
248  vigra_precondition(ext_param_.used() == true,
249  "RandomForest::ext_param(): "
250  "Random forest has not been trained yet.");
251  return ext_param_;
252  }
253 
254  /**\brief set external parameters
255  *
256  * \param in external parameters to be set
257  *
258  * set external parameters explicitly.
259  * If Random Forest has not been trained the preprocessor will
260  * either ignore filling values set this way or will throw an exception
261  * if values specified manually do not match the value calculated
262  & during the preparation step.
263  */
264  void set_ext_param(ProblemSpec_t const & in)
265  {
266  vigra_precondition(ext_param_.used() == false,
267  "RandomForest::set_ext_param():"
268  "Random forest has been trained! Call reset()"
269  "before specifying new extrinsic parameters.");
270  }
271 
272  /**\brief access random forest options
273  *
274  * \return random forest options
275  */
277  {
278  return options;
279  }
280 
281 
282  /**\brief access const random forest options
283  *
284  * \return const Option_t
285  */
286  Options_t const & options() const
287  {
288  return options_;
289  }
290 
291  /**\brief access const trees
292  */
293  DecisionTree_t const & tree(int index) const
294  {
295  return trees_[index];
296  }
297 
298  /**\brief access trees
299  */
300  DecisionTree_t & tree(int index)
301  {
302  return trees_[index];
303  }
304 
305  /*\}*/
306 
307  /**\brief return number of features used while
308  * training.
309  */
310  int feature_count() const
311  {
312  return ext_param_.column_count_;
313  }
314 
315 
316  /**\brief return number of features used while
317  * training.
318  *
319  * deprecated. Use feature_count() instead.
320  */
321  int column_count() const
322  {
323  return ext_param_.column_count_;
324  }
325 
326  /**\brief return number of classes used while
327  * training.
328  */
329  int class_count() const
330  {
331  return ext_param_.class_count_;
332  }
333 
334  /**\brief return number of trees
335  */
336  int tree_count() const
337  {
338  return options_.tree_count_;
339  }
340 
341 
342 
343  template<class U,class C1,
344  class U2, class C2,
345  class Split_t,
346  class Stop_t,
347  class Visitor_t,
348  class Random_t>
349  void onlineLearn( MultiArrayView<2,U,C1> const & features,
350  MultiArrayView<2,U2,C2> const & response,
351  int new_start_index,
352  Visitor_t visitor_,
353  Split_t split_,
354  Stop_t stop_,
355  Random_t & random,
356  bool adjust_thresholds=false);
357 
358  template <class U, class C1, class U2,class C2>
359  void onlineLearn( MultiArrayView<2, U, C1> const & features,
360  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
361  {
363  onlineLearn(features,
364  labels,
365  new_start_index,
366  rf_default(),
367  rf_default(),
368  rf_default(),
369  rnd,
370  adjust_thresholds);
371  }
372 
373  template<class U,class C1,
374  class U2, class C2,
375  class Split_t,
376  class Stop_t,
377  class Visitor_t,
378  class Random_t>
379  void reLearnTree(MultiArrayView<2,U,C1> const & features,
380  MultiArrayView<2,U2,C2> const & response,
381  int treeId,
382  Visitor_t visitor_,
383  Split_t split_,
384  Stop_t stop_,
385  Random_t & random);
386 
387  template<class U, class C1, class U2, class C2>
388  void reLearnTree(MultiArrayView<2, U, C1> const & features,
389  MultiArrayView<2, U2, C2> const & labels,
390  int treeId)
391  {
392  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
393  reLearnTree(features,
394  labels,
395  treeId,
396  rf_default(),
397  rf_default(),
398  rf_default(),
399  rnd);
400  }
401 
402 
403  /**\name Learning
404  * Following functions differ in the degree of customization
405  * allowed
406  */
407  /*\{*/
408  /**\brief learn on data with custom config and random number generator
409  *
410  * \param features a N x M matrix containing N samples with M
411  * features
412  * \param response a N x D matrix containing the corresponding
413  * response. Current split functors assume D to
414  * be 1 and ignore any additional columns.
415  * This is not enforced to allow future support
416  * for uncertain labels, label independent strata etc.
417  * The Preprocessor specified during construction
418  * should be able to handle features and labels
419  * features and the labels.
420  * see also: SplitFunctor, Preprocessing
421  *
422  * \param visitor visitor which is to be applied after each split,
423  * tree and at the end. Use rf_default for using
424  * default value. (No Visitors)
425  * see also: rf::visitors
426  * \param split split functor to be used to calculate each split
427  * use rf_default() for using default value. (GiniSplit)
428  * see also: rf::split
429  * \param stop
430  * predicate to be used to calculate each split
431  * use rf_default() for using default value. (EarlyStoppStd)
432  * \param random RandomNumberGenerator to be used. Use
433  * rf_default() to use default value.(RandomMT19337)
434  *
435  *
436  */
437  template <class U, class C1,
438  class U2,class C2,
439  class Split_t,
440  class Stop_t,
441  class Visitor_t,
442  class Random_t>
443  void learn( MultiArrayView<2, U, C1> const & features,
444  MultiArrayView<2, U2,C2> const & response,
445  Visitor_t visitor,
446  Split_t split,
447  Stop_t stop,
448  Random_t const & random);
449 
450  template <class U, class C1,
451  class U2,class C2,
452  class Split_t,
453  class Stop_t,
454  class Visitor_t>
455  void learn( MultiArrayView<2, U, C1> const & features,
456  MultiArrayView<2, U2,C2> const & response,
457  Visitor_t visitor,
458  Split_t split,
459  Stop_t stop)
460 
461  {
462  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
463  learn( features,
464  response,
465  visitor,
466  split,
467  stop,
468  rnd);
469  }
470 
471  template <class U, class C1, class U2,class C2, class Visitor_t>
472  void learn( MultiArrayView<2, U, C1> const & features,
473  MultiArrayView<2, U2,C2> const & labels,
474  Visitor_t visitor)
475  {
476  learn( features,
477  labels,
478  visitor,
479  rf_default(),
480  rf_default());
481  }
482 
483  template <class U, class C1, class U2,class C2,
484  class Visitor_t, class Split_t>
485  void learn( MultiArrayView<2, U, C1> const & features,
486  MultiArrayView<2, U2,C2> const & labels,
487  Visitor_t visitor,
488  Split_t split)
489  {
490  learn( features,
491  labels,
492  visitor,
493  split,
494  rf_default());
495  }
496 
497  /**\brief learn on data with default configuration
498  *
499  * \param features a N x M matrix containing N samples with M
500  * features
501  * \param labels a N x D matrix containing the corresponding
502  * N labels. Current split functors assume D to
503  * be 1 and ignore any additional columns.
504  * this is not enforced to allow future support
505  * for uncertain labels.
506  *
507  * learning is done with:
508  *
509  * \sa rf::split, EarlyStoppStd
510  *
511  * - Randomly seeded random number generator
512  * - default gini split functor as described by Breiman
513  * - default The standard early stopping criterion
514  */
515  template <class U, class C1, class U2,class C2>
516  void learn( MultiArrayView<2, U, C1> const & features,
517  MultiArrayView<2, U2,C2> const & labels)
518  {
519  learn( features,
520  labels,
521  rf_default(),
522  rf_default(),
523  rf_default());
524  }
525  /*\}*/
526 
527 
528 
529  /**\name prediction
530  */
531  /*\{*/
532  /** \brief predict a label given a feature.
533  *
534  * \param features: a 1 by featureCount matrix containing
535  * data point to be predicted (this only works in
536  * classification setting)
537  * \param stop: early stopping critierion
538  * \return double value representing class. You can use the
539  * predictLabels() function together with the
540  * rf.external_parameter().class_type_ attribute
541  * to get back the same type used during learning.
542  */
543  template <class U, class C, class Stop>
544  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
545 
546  template <class U, class C>
547  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
548  {
549  return predictLabel(features, rf_default());
550  }
551  /** \brief predict a label with features and class priors
552  *
553  * \param features: same as above.
554  * \param prior: iterator to prior weighting of classes
555  * \return sam as above.
556  */
557  template <class U, class C>
558  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
559  ArrayVectorView<double> prior) const;
560 
561  /** \brief predict multiple labels with given features
562  *
563  * \param features: a n by featureCount matrix containing
564  * data point to be predicted (this only works in
565  * classification setting)
566  * \param labels: a n by 1 matrix passed by reference to store
567  * output.
568  */
569  template <class U, class C1, class T, class C2>
571  MultiArrayView<2, T, C2> & labels) const
572  {
573  vigra_precondition(features.shape(0) == labels.shape(0),
574  "RandomForest::predictLabels(): Label array has wrong size.");
575  for(int k=0; k<features.shape(0); ++k)
576  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
577  }
578 
579  template <class U, class C1, class T, class C2, class Stop>
580  void predictLabels(MultiArrayView<2, U, C1>const & features,
581  MultiArrayView<2, T, C2> & labels,
582  Stop & stop) const
583  {
584  vigra_precondition(features.shape(0) == labels.shape(0),
585  "RandomForest::predictLabels(): Label array has wrong size.");
586  for(int k=0; k<features.shape(0); ++k)
587  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
588  }
589  /** \brief predict the class probabilities for multiple labels
590  *
591  * \param features same as above
592  * \param prob a n x class_count_ matrix. passed by reference to
593  * save class probabilities
594  * \param stop earlystopping criterion
595  * \sa EarlyStopping
596  */
597  template <class U, class C1, class T, class C2, class Stop>
598  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
599  MultiArrayView<2, T, C2> & prob,
600  Stop & stop) const;
601  template <class T1,class T2, class C>
602  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
603  MultiArrayView<2, T2, C> & prob);
604 
605  /** \brief predict the class probabilities for multiple labels
606  *
607  * \param features same as above
608  * \param prob a n x class_count_ matrix. passed by reference to
609  * save class probabilities
610  */
611  template <class U, class C1, class T, class C2>
613  MultiArrayView<2, T, C2> & prob) const
614  {
615  predictProbabilities(features, prob, rf_default());
616  }
617 
618 
619  /*\}*/
620 
621 };
622 
623 
624 template <class LabelType, class PreprocessorTag>
625 template<class U,class C1,
626  class U2, class C2,
627  class Split_t,
628  class Stop_t,
629  class Visitor_t,
630  class Random_t>
631 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
632  MultiArrayView<2,U2,C2> const & response,
633  int new_start_index,
634  Visitor_t visitor_,
635  Split_t split_,
636  Stop_t stop_,
637  Random_t & random,
638  bool adjust_thresholds)
639 {
640  online_visitor_.activate();
641  online_visitor_.adjust_thresholds=adjust_thresholds;
642 
643  using namespace rf;
644  //typedefs
645  typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
646  typedef UniformIntRandomFunctor<Random_t>
647  RandFunctor_t;
648  // default values and initialization
649  // Value Chooser chooses second argument as value if first argument
650  // is of type RF_DEFAULT. (thanks to template magic - don't care about
651  // it - just smile and wave.
652 
653  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
654  Default_Stop_t default_stop(options_);
655  typename RF_CHOOSER(Stop_t)::type stop
656  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
657  Default_Split_t default_split;
658  typename RF_CHOOSER(Split_t)::type split
659  = RF_CHOOSER(Split_t)::choose(split_, default_split);
660  rf::visitors::StopVisiting stopvisiting;
661  typedef rf::visitors::detail::VisitorNode
662  <rf::visitors::OnlineLearnVisitor,
663  typename RF_CHOOSER(Visitor_t)::type>
664  IntermedVis;
665  IntermedVis
666  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
667  #undef RF_CHOOSER
668 
669  // Preprocess the data to get something the split functor can work
670  // with. Also fill the ext_param structure by preprocessing
671  // option parameters that could only be completely evaluated
672  // when the training data is known.
673  ext_param_.class_count_=0;
674  Preprocessor_t preprocessor( features, response,
675  options_, ext_param_);
676 
677  // Make stl compatible random functor.
678  RandFunctor_t randint ( random);
679 
680  // Give the Split functor information about the data.
681  split.set_external_parameters(ext_param_);
682  stop.set_external_parameters(ext_param_);
683 
684 
685  //Create poisson samples
686  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
687 
688  //TODO: visitors for online learning
689  //visitor.visit_at_beginning(*this, preprocessor);
690 
691  // THE MAIN EFFING RF LOOP - YEAY DUDE!
692  for(int ii = 0; ii < (int)trees_.size(); ++ii)
693  {
694  online_visitor_.tree_id=ii;
695  poisson_sampler.sample();
696  std::map<int,int> leaf_parents;
697  leaf_parents.clear();
698  //Get all the leaf nodes for that sample
699  for(int s=0;s<poisson_sampler.numOfSamples();++s)
700  {
701  int sample=poisson_sampler[s];
702  online_visitor_.current_label=preprocessor.response()(sample,0);
703  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
704  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
705 
706 
707  //Add to the list for that leaf
708  online_visitor_.add_to_index_list(ii,leaf,sample);
709  //TODO: Class count?
710  //Store parent
711  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
712  {
713  leaf_parents[leaf]=online_visitor_.last_node_id;
714  }
715  }
716 
717 
718  std::map<int,int>::iterator leaf_iterator;
719  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
720  {
721  int leaf=leaf_iterator->first;
722  int parent=leaf_iterator->second;
723  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
724  ArrayVector<Int32> indeces;
725  indeces.clear();
726  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
727  StackEntry_t stack_entry(indeces.begin(),
728  indeces.end(),
729  ext_param_.class_count_);
730 
731 
732  if(parent!=-1)
733  {
734  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
735  {
736  stack_entry.leftParent=parent;
737  }
738  else
739  {
740  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
741  stack_entry.rightParent=parent;
742  }
743  }
744  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
745  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
746  //Now, the last one moved onto leaf
747  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
748  //Now it should be classified correctly!
749  }
750 
751  /*visitor
752  .visit_after_tree( *this,
753  preprocessor,
754  poisson_sampler,
755  stack_entry,
756  ii);*/
757  }
758 
759  //visitor.visit_at_end(*this, preprocessor);
760  online_visitor_.deactivate();
761 }
762 
763 template<class LabelType, class PreprocessorTag>
764 template<class U,class C1,
765  class U2, class C2,
766  class Split_t,
767  class Stop_t,
768  class Visitor_t,
769  class Random_t>
771  MultiArrayView<2,U2,C2> const & response,
772  int treeId,
773  Visitor_t visitor_,
774  Split_t split_,
775  Stop_t stop_,
776  Random_t & random)
777 {
778  using namespace rf;
779 
780 
782  RandFunctor_t;
783 
784  // See rf_preprocessing.hxx for more info on this
785  ext_param_.class_count_=0;
786  typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
787 
788  // default values and initialization
789  // Value Chooser chooses second argument as value if first argument
790  // is of type RF_DEFAULT. (thanks to template magic - don't care about
791  // it - just smile and wave.
792 
793  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
794  Default_Stop_t default_stop(options_);
795  typename RF_CHOOSER(Stop_t)::type stop
796  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
797  Default_Split_t default_split;
798  typename RF_CHOOSER(Split_t)::type split
799  = RF_CHOOSER(Split_t)::choose(split_, default_split);
800  rf::visitors::StopVisiting stopvisiting;
803  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
804  IntermedVis
805  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
806  #undef RF_CHOOSER
807  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
808  online_visitor_.activate();
809 
810  // Make stl compatible random functor.
811  RandFunctor_t randint ( random);
812 
813  // Preprocess the data to get something the split functor can work
814  // with. Also fill the ext_param structure by preprocessing
815  // option parameters that could only be completely evaluated
816  // when the training data is known.
817  Preprocessor_t preprocessor( features, response,
818  options_, ext_param_);
819 
820  // Give the Split functor information about the data.
821  split.set_external_parameters(ext_param_);
822  stop.set_external_parameters(ext_param_);
823 
824  /**\todo replace this crappy class out. It uses function pointers.
825  * and is making code slower according to me.
826  * Comment from Nathan: This is copied from Rahul, so me=Rahul
827  */
828  Sampler<Random_t > sampler(preprocessor.strata().begin(),
829  preprocessor.strata().end(),
830  detail::make_sampler_opt(options_)
831  .sampleSize(ext_param().actual_msample_),
832  random);
833  //initialize First region/node/stack entry
834  sampler
835  .sample();
836 
838  first_stack_entry( sampler.sampledIndices().begin(),
839  sampler.sampledIndices().end(),
840  ext_param_.class_count_);
841  first_stack_entry
842  .set_oob_range( sampler.oobIndices().begin(),
843  sampler.oobIndices().end());
844  online_visitor_.reset_tree(treeId);
845  online_visitor_.tree_id=treeId;
846  trees_[treeId].reset();
847  trees_[treeId]
848  .learn( preprocessor.features(),
849  preprocessor.response(),
850  first_stack_entry,
851  split,
852  stop,
853  visitor,
854  randint);
855  visitor
856  .visit_after_tree( *this,
857  preprocessor,
858  sampler,
859  first_stack_entry,
860  treeId);
861 
862  online_visitor_.deactivate();
863 }
864 
865 template <class LabelType, class PreprocessorTag>
866 template <class U, class C1,
867  class U2,class C2,
868  class Split_t,
869  class Stop_t,
870  class Visitor_t,
871  class Random_t>
874  MultiArrayView<2, U2,C2> const & response,
875  Visitor_t visitor_,
876  Split_t split_,
877  Stop_t stop_,
878  Random_t const & random)
879 {
880  using namespace rf;
881  //this->reset();
882  //typedefs
884  RandFunctor_t;
885 
886  // See rf_preprocessing.hxx for more info on this
887  typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
888 
889  // default values and initialization
890  // Value Chooser chooses second argument as value if first argument
891  // is of type RF_DEFAULT. (thanks to template magic - don't care about
892  // it - just smile and wave.
893 
894  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
895  Default_Stop_t default_stop(options_);
896  typename RF_CHOOSER(Stop_t)::type stop
897  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
898  Default_Split_t default_split;
899  typename RF_CHOOSER(Split_t)::type split
900  = RF_CHOOSER(Split_t)::choose(split_, default_split);
901  rf::visitors::StopVisiting stopvisiting;
904  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
905  IntermedVis
906  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
907  #undef RF_CHOOSER
908  if(options_.prepare_online_learning_)
909  online_visitor_.activate();
910  else
911  online_visitor_.deactivate();
912 
913 
914  // Make stl compatible random functor.
915  RandFunctor_t randint ( random);
916 
917 
918  // Preprocess the data to get something the split functor can work
919  // with. Also fill the ext_param structure by preprocessing
920  // option parameters that could only be completely evaluated
921  // when the training data is known.
922  Preprocessor_t preprocessor( features, response,
923  options_, ext_param_);
924 
925  // Give the Split functor information about the data.
926  split.set_external_parameters(ext_param_);
927  stop.set_external_parameters(ext_param_);
928 
929 
930  //initialize trees.
931  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
932 
933  Sampler<Random_t > sampler(preprocessor.strata().begin(),
934  preprocessor.strata().end(),
935  detail::make_sampler_opt(options_)
936  .sampleSize(ext_param().actual_msample_),
937  random);
938 
939  visitor.visit_at_beginning(*this, preprocessor);
940  // THE MAIN EFFING RF LOOP - YEAY DUDE!
941 
942  for(int ii = 0; ii < (int)trees_.size(); ++ii)
943  {
944  //initialize First region/node/stack entry
945  sampler
946  .sample();
948  first_stack_entry( sampler.sampledIndices().begin(),
949  sampler.sampledIndices().end(),
950  ext_param_.class_count_);
951  first_stack_entry
952  .set_oob_range( sampler.oobIndices().begin(),
953  sampler.oobIndices().end());
954  trees_[ii]
955  .learn( preprocessor.features(),
956  preprocessor.response(),
957  first_stack_entry,
958  split,
959  stop,
960  visitor,
961  randint);
962  visitor
963  .visit_after_tree( *this,
964  preprocessor,
965  sampler,
966  first_stack_entry,
967  ii);
968  }
969 
970  visitor.visit_at_end(*this, preprocessor);
971  // Only for online learning?
972  online_visitor_.deactivate();
973 }
974 
975 
976 
977 
978 template <class LabelType, class Tag>
979 template <class U, class C, class Stop>
981  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
982 {
983  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
984  "RandomForestn::predictLabel():"
985  " Too few columns in feature matrix.");
986  vigra_precondition(rowCount(features) == 1,
987  "RandomForestn::predictLabel():"
988  " Feature matrix must have a singlerow.");
989  typedef MultiArrayShape<2>::type Shp;
990  garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
991  LabelType d;
992  predictProbabilities(features, garbage_prediction_, stop);
993  ext_param_.to_classlabel(argMax(garbage_prediction_), d);
994  return d;
995 }
996 
997 
998 //Same thing as above with priors for each label !!!
999 template <class LabelType, class PreprocessorTag>
1000 template <class U, class C>
1003  ArrayVectorView<double> priors) const
1004 {
1005  using namespace functor;
1006  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1007  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1008  vigra_precondition(rowCount(features) == 1,
1009  "RandomForestn::predictLabel():"
1010  " Feature matrix must have a single row.");
1011  Matrix<double> prob(1,ext_param_.class_count_);
1012  predictProbabilities(features, prob);
1013  std::transform( prob.begin(), prob.end(),
1014  priors.begin(), prob.begin(),
1015  Arg1()*Arg2());
1016  LabelType d;
1017  ext_param_.to_classlabel(argMax(prob), d);
1018  return d;
1019 }
1020 
1021 template<class LabelType,class PreprocessorTag>
1022 template <class T1,class T2, class C>
1024  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1025  MultiArrayView<2, T2, C> & prob)
1026 {
1027  //Features are n xp
1028  //prob is n x NumOfLabel probaility for each feature in each class
1029 
1030  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1031  "RandomFroest::predictProbabilities():"
1032  " Feature matrix and probability matrix size misnmatch.");
1033  // num of features must be bigger than num of features in Random forest training
1034  // but why bigger?
1035  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1036  "RandomForestn::predictProbabilities():"
1037  " Too few columns in feature matrix.");
1038  vigra_precondition( columnCount(prob)
1039  == (MultiArrayIndex)ext_param_.class_count_,
1040  "RandomForestn::predictProbabilities():"
1041  " Probability matrix must have as many columns as there are classes.");
1042  prob.init(0.0);
1043  //store total weights
1044  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1045  //Go through all trees
1046  int set_id=-1;
1047  for(int k=0; k<options_.tree_count_; ++k)
1048  {
1049  set_id=(set_id+1) % predictionSet.indices[0].size();
1050  typedef std::set<SampleRange<T1> > my_set;
1051  typedef typename my_set::iterator set_it;
1052  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1053  //Build a stack with all the ranges we have
1054  std::vector<std::pair<int,set_it> > stack;
1055  stack.clear();
1056  set_it i;
1057  for(i=predictionSet.ranges[set_id].begin();i!=predictionSet.ranges[set_id].end();++i)
1058  stack.push_back(std::pair<int,set_it>(2,i));
1059  //get weights predicted by single tree
1060  int num_decisions=0;
1061  while(!stack.empty())
1062  {
1063  set_it range=stack.back().second;
1064  int index=stack.back().first;
1065  stack.pop_back();
1066  ++num_decisions;
1067 
1068  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1069  {
1070  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1071  trees_[k].parameters_,
1072  index).prob_begin();
1073  for(int i=range->start;i!=range->end;++i)
1074  {
1075  //update votecount.
1076  for(int l=0; l<ext_param_.class_count_; ++l)
1077  {
1078  prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1079  //every weight in totalWeight.
1080  totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1081  }
1082  }
1083  }
1084 
1085  else
1086  {
1087  if(trees_[k].topology_[index]!=i_ThresholdNode)
1088  {
1089  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1090  }
1091  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1092  if(range->min_boundaries[node.column()]>=node.threshold())
1093  {
1094  //Everything goes to right child
1095  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1096  continue;
1097  }
1098  if(range->max_boundaries[node.column()]<node.threshold())
1099  {
1100  //Everything goes to the left child
1101  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1102  continue;
1103  }
1104  //We have to split at this node
1105  SampleRange<T1> new_range=*range;
1106  new_range.min_boundaries[node.column()]=FLT_MAX;
1107  range->max_boundaries[node.column()]=-FLT_MAX;
1108  new_range.start=new_range.end=range->end;
1109  int i=range->start;
1110  while(i!=range->end)
1111  {
1112  //Decide for range->indices[i]
1113  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1114  {
1115  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1116  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1117  --range->end;
1118  --new_range.start;
1119  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1120 
1121  }
1122  else
1123  {
1124  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1125  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1126  ++i;
1127  }
1128  }
1129  //The old one ...
1130  if(range->start==range->end)
1131  {
1132  predictionSet.ranges[set_id].erase(range);
1133  }
1134  else
1135  {
1136  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1137  }
1138  //And the new one ...
1139  if(new_range.start!=new_range.end)
1140  {
1141  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1142  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1143  }
1144  }
1145  }
1146  predictionSet.cumulativePredTime[k]=num_decisions;
1147  }
1148  for(unsigned int i=0;i<totalWeights.size();++i)
1149  {
1150  double test=0.0;
1151  //Normalise votes in each row by total VoteCount (totalWeight
1152  for(int l=0; l<ext_param_.class_count_; ++l)
1153  {
1154  test+=prob(i,l);
1155  prob(i, l) /= totalWeights[i];
1156  }
1157  assert(test==totalWeights[i]);
1158  assert(totalWeights[i]>0.0);
1159  }
1160 }
1161 
1162 template <class LabelType, class PreprocessorTag>
1163 template <class U, class C1, class T, class C2, class Stop_t>
1165  ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1166  MultiArrayView<2, T, C2> & prob,
1167  Stop_t & stop_) const
1168 {
1169  //Features are n xp
1170  //prob is n x NumOfLabel probability for each feature in each class
1171 
1172  vigra_precondition(rowCount(features) == rowCount(prob),
1173  "RandomForestn::predictProbabilities():"
1174  " Feature matrix and probability matrix size mismatch.");
1175 
1176  // num of features must be bigger than num of features in Random forest training
1177  // but why bigger?
1178  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1179  "RandomForestn::predictProbabilities():"
1180  " Too few columns in feature matrix.");
1181  vigra_precondition( columnCount(prob)
1182  == (MultiArrayIndex)ext_param_.class_count_,
1183  "RandomForestn::predictProbabilities():"
1184  " Probability matrix must have as many columns as there are classes.");
1185 
1186  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1187  Default_Stop_t default_stop(options_);
1188  typename RF_CHOOSER(Stop_t)::type & stop
1189  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1190  #undef RF_CHOOSER
1191  stop.set_external_parameters(ext_param_, tree_count());
1192  prob.init(NumericTraits<T>::zero());
1193  /* This code was originally there for testing early stopping
1194  * - we wanted the order of the trees to be randomized
1195  if(tree_indices_.size() != 0)
1196  {
1197  std::random_shuffle(tree_indices_.begin(),
1198  tree_indices_.end());
1199  }
1200  */
1201  //Classify for each row.
1202  for(int row=0; row < rowCount(features); ++row)
1203  {
1204  ArrayVector<double>::const_iterator weights;
1205 
1206  //totalWeight == totalVoteCount!
1207  double totalWeight = 0.0;
1208 
1209  //Let each tree classify...
1210  for(int k=0; k<options_.tree_count_; ++k)
1211  {
1212  //get weights predicted by single tree
1213  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1214 
1215  //update votecount.
1216  int weighted = options_.predict_weighted_;
1217  for(int l=0; l<ext_param_.class_count_; ++l)
1218  {
1219  double cur_w = weights[l] * (weighted * (*(weights-1))
1220  + (1-weighted));
1221  prob(row, l) += (T)cur_w;
1222  //every weight in totalWeight.
1223  totalWeight += cur_w;
1224  }
1225  if(stop.after_prediction(weights,
1226  k,
1227  rowVector(prob, row),
1228  totalWeight))
1229  {
1230  break;
1231  }
1232  }
1233 
1234  //Normalise votes in each row by total VoteCount (totalWeight
1235  for(int l=0; l< ext_param_.class_count_; ++l)
1236  {
1237  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1238  }
1239  }
1240 
1241 }
1242 
1243 //@}
1244 
1245 } // namespace vigra
1246 
1247 #include "random_forest/rf_algorithm.hxx"
1248 #endif // VIGRA_RANDOM_FOREST_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Tue Jul 10 2012)