[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_common.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
39 
40 namespace vigra
41 {
42 
43 
44 struct ClassificationTag
45 {};
46 
47 struct RegressionTag
48 {};
49 
50 namespace detail
51 {
52  class RF_DEFAULT;
53 }
54 inline detail::RF_DEFAULT& rf_default();
55 namespace detail
56 {
57 
58 /**\brief singleton default tag class -
59  *
60  * use the rf_default() factory function to use the tag.
61  * \sa RandomForest<>::learn();
62  */
64 {
65  private:
66  RF_DEFAULT()
67  {}
68  public:
70 
71  /** ok workaround for automatic choice of the decisiontree
72  * stackentry.
73  */
74 };
75 
76 /**\brief chooses between default type and type supplied
77  *
78  * This is an internal class and you shouldn't really care about it.
79  * Just pass on used in RandomForest.learn()
80  * Usage:
81  *\code
82  * // example: use container type supplied by user or ArrayVector if
83  * // rf_default() was specified as argument;
84  * template<class Container_t>
85  * void do_some_foo(Container_t in)
86  * {
87  * typedef ArrayVector<int> Default_Container_t;
88  * Default_Container_t default_value;
89  * Value_Chooser<Container_t, Default_Container_t>
90  * choose(in, default_value);
91  *
92  * // if the user didn't care and the in was of type
93  * // RF_DEFAULT then default_value is used.
94  * do_some_more_foo(choose.value());
95  * }
96  * Value_Chooser choose_val<Type, Default_Type>
97  *\endcode
98  */
99 template<class T, class C>
101 {
102 public:
103  typedef T type;
104  static T & choose(T & t, C &)
105  {
106  return t;
107  }
108 };
109 
110 template<class C>
111 class Value_Chooser<detail::RF_DEFAULT, C>
112 {
113 public:
114  typedef C type;
115 
116  static C & choose(detail::RF_DEFAULT &, C & c)
117  {
118  return c;
119  }
120 };
121 
122 
123 
124 
125 } //namespace detail
126 
127 
128 /**\brief factory function to return a RF_DEFAULT tag
129  * \sa RandomForest<>::learn()
130  */
132 {
133  static detail::RF_DEFAULT result;
134  return result;
135 }
136 
137 /** tags used with the RandomForestOptions class
138  * \sa RF_Traits::Option_t
139  */
140 enum RF_OptionTag { RF_EQUAL,
141  RF_PROPORTIONAL,
142  RF_EXTERNAL,
143  RF_NONE,
144  RF_FUNCTION,
145  RF_LOG,
146  RF_SQRT,
147  RF_CONST,
148  RF_ALL};
149 
150 
151 /** \addtogroup MachineLearning
152 **/
153 //@{
154 
155 /**\brief Options object for the random forest
156  *
157  * usage:
158  * RandomForestOptions a = RandomForestOptions()
159  * .param1(value1)
160  * .param2(value2)
161  * ...
162  *
163  * This class only contains options/parameters that are not problem
164  * dependent. The ProblemSpec class contains methods to set class weights
165  * if necessary.
166  *
167  * Note that the return value of all methods is *this which makes
168  * concatenating of options as above possible.
169  */
171 {
172  public:
173  /**\name sampling options*/
174  /*\{*/
175  // look at the member access functions for documentation
176  double training_set_proportion_;
177  int training_set_size_;
178  int (*training_set_func_)(int);
180  training_set_calc_switch_;
181 
182  bool sample_with_replacement_;
184  stratification_method_;
185 
186 
187  /**\name general random forest options
188  *
189  * these usually will be used by most split functors and
190  * stopping predicates
191  */
192  /*\{*/
193  RF_OptionTag mtry_switch_;
194  int mtry_;
195  int (*mtry_func_)(int) ;
196 
197  bool predict_weighted_;
198  int tree_count_;
199  int min_split_node_size_;
200  bool prepare_online_learning_;
201  /*\}*/
202 
203  int serialized_size() const
204  {
205  return 12;
206  }
207 
208 
209  bool operator==(RandomForestOptions & rhs) const
210  {
211  bool result = true;
212  #define COMPARE(field) result = result && (this->field == rhs.field);
213  COMPARE(training_set_proportion_);
214  COMPARE(training_set_size_);
215  COMPARE(training_set_calc_switch_);
216  COMPARE(sample_with_replacement_);
217  COMPARE(stratification_method_);
218  COMPARE(mtry_switch_);
219  COMPARE(mtry_);
220  COMPARE(tree_count_);
221  COMPARE(min_split_node_size_);
222  COMPARE(predict_weighted_);
223  #undef COMPARE
224 
225  return result;
226  }
227  bool operator!=(RandomForestOptions & rhs_) const
228  {
229  return !(*this == rhs_);
230  }
231  template<class Iter>
232  void unserialize(Iter const & begin, Iter const & end)
233  {
234  Iter iter = begin;
235  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
236  "RandomForestOptions::unserialize():"
237  "wrong number of parameters");
238  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
239  PULL(training_set_proportion_, double);
240  PULL(training_set_size_, int);
241  ++iter; //PULL(training_set_func_, double);
242  PULL(training_set_calc_switch_, (RF_OptionTag)int);
243  PULL(sample_with_replacement_, 0 != );
244  PULL(stratification_method_, (RF_OptionTag)int);
245  PULL(mtry_switch_, (RF_OptionTag)int);
246  PULL(mtry_, int);
247  ++iter; //PULL(mtry_func_, double);
248  PULL(tree_count_, int);
249  PULL(min_split_node_size_, int);
250  PULL(predict_weighted_, 0 !=);
251  #undef PULL
252  }
253  template<class Iter>
254  void serialize(Iter const & begin, Iter const & end) const
255  {
256  Iter iter = begin;
257  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
258  "RandomForestOptions::serialize():"
259  "wrong number of parameters");
260  #define PUSH(item_) *iter = double(item_); ++iter;
261  PUSH(training_set_proportion_);
262  PUSH(training_set_size_);
263  if(training_set_func_ != 0)
264  {
265  PUSH(1);
266  }
267  else
268  {
269  PUSH(0);
270  }
271  PUSH(training_set_calc_switch_);
272  PUSH(sample_with_replacement_);
273  PUSH(stratification_method_);
274  PUSH(mtry_switch_);
275  PUSH(mtry_);
276  if(mtry_func_ != 0)
277  {
278  PUSH(1);
279  }
280  else
281  {
282  PUSH(0);
283  }
284  PUSH(tree_count_);
285  PUSH(min_split_node_size_);
286  PUSH(predict_weighted_);
287  #undef PUSH
288  }
289 
290  void make_from_map(std::map<std::string, ArrayVector<double> > & in)
291  {
292  typedef MultiArrayShape<2>::type Shp;
293  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
294  #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
295  PULL(training_set_proportion_,double);
296  PULL(training_set_size_, int);
297  PULL(mtry_, int);
298  PULL(tree_count_, int);
299  PULL(min_split_node_size_, int);
300  PULLBOOL(sample_with_replacement_, bool);
301  PULLBOOL(prepare_online_learning_, bool);
302  PULLBOOL(predict_weighted_, bool);
303 
304  PULL(training_set_calc_switch_, (RF_OptionTag)int);
305  PULL(stratification_method_, (RF_OptionTag)int);
306  PULL(mtry_switch_, (RF_OptionTag)int);
307 
308  /*don't pull*/
309  //PULL(mtry_func_!=0, int);
310  //PULL(training_set_func,int);
311  #undef PULL
312  #undef PULLBOOL
313  }
314  void make_map(std::map<std::string, ArrayVector<double> > & in) const
315  {
316  typedef MultiArrayShape<2>::type Shp;
317  #define PUSH(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_));
318  #define PUSHFUNC(item_, type_) in[#item_] = ArrayVector<double>(1, double(item_!=0));
319  PUSH(training_set_proportion_,double);
320  PUSH(training_set_size_, int);
321  PUSH(mtry_, int);
322  PUSH(tree_count_, int);
323  PUSH(min_split_node_size_, int);
324  PUSH(sample_with_replacement_, bool);
325  PUSH(prepare_online_learning_, bool);
326  PUSH(predict_weighted_, bool);
327 
328  PUSH(training_set_calc_switch_, RF_OptionTag);
329  PUSH(stratification_method_, RF_OptionTag);
330  PUSH(mtry_switch_, RF_OptionTag);
331 
332  PUSHFUNC(mtry_func_, int);
333  PUSHFUNC(training_set_func_,int);
334  #undef PUSH
335  #undef PUSHFUNC
336  }
337 
338 
339  /**\brief create a RandomForestOptions object with default initialisation.
340  *
341  * look at the other member functions for more information on default
342  * values
343  */
345  :
346  training_set_proportion_(1.0),
347  training_set_size_(0),
348  training_set_func_(0),
349  training_set_calc_switch_(RF_PROPORTIONAL),
350  sample_with_replacement_(true),
351  stratification_method_(RF_NONE),
352  mtry_switch_(RF_SQRT),
353  mtry_(0),
354  mtry_func_(0),
355  predict_weighted_(false),
356  tree_count_(256),
357  min_split_node_size_(1),
358  prepare_online_learning_(false)
359  {}
360 
361  /**\brief specify stratification strategy
362  *
363  * default: RF_NONE
364  * possible values: RF_EQUAL, RF_PROPORTIONAL,
365  * RF_EXTERNAL, RF_NONE
366  * RF_EQUAL: get equal amount of samples per class.
367  * RF_PROPORTIONAL: sample proportional to fraction of class samples
368  * in population
369  * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object
370  * has been set externally. (defunct)
371  */
373  {
374  vigra_precondition(in == RF_EQUAL ||
375  in == RF_PROPORTIONAL ||
376  in == RF_EXTERNAL ||
377  in == RF_NONE,
378  "RandomForestOptions::use_stratification()"
379  "input must be RF_EQUAL, RF_PROPORTIONAL,"
380  "RF_EXTERNAL or RF_NONE");
381  stratification_method_ = in;
382  return *this;
383  }
384 
385  RandomForestOptions & prepare_online_learning(bool in)
386  {
387  prepare_online_learning_=in;
388  return *this;
389  }
390 
391  /**\brief sample from training population with or without replacement?
392  *
393  * <br> Default: true
394  */
396  {
397  sample_with_replacement_ = in;
398  return *this;
399  }
400 
401  /**\brief specify the fraction of the total number of samples
402  * used per tree for learning.
403  *
404  * This value should be in [0.0 1.0] if sampling without
405  * replacement has been specified.
406  *
407  * <br> default : 1.0
408  */
410  {
411  training_set_proportion_ = in;
412  training_set_calc_switch_ = RF_PROPORTIONAL;
413  return *this;
414  }
415 
416  /**\brief directly specify the number of samples per tree
417  */
419  {
420  training_set_size_ = in;
421  training_set_calc_switch_ = RF_CONST;
422  return *this;
423  }
424 
425  /**\brief use external function to calculate the number of samples each
426  * tree should be learnt with.
427  *
428  * \param in function pointer that takes the number of rows in the
429  * learning data and outputs the number samples per tree.
430  */
432  {
433  training_set_func_ = in;
434  training_set_calc_switch_ = RF_FUNCTION;
435  return *this;
436  }
437 
438  /**\brief weight each tree with number of samples in that node
439  */
441  {
442  predict_weighted_ = true;
443  return *this;
444  }
445 
446  /**\brief use built in mapping to calculate mtry
447  *
448  * Use one of the built in mappings to calculate mtry from the number
449  * of columns in the input feature data.
450  * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
451  * <br> default: RF_SQRT.
452  */
454  {
455  vigra_precondition(in == RF_LOG ||
456  in == RF_SQRT||
457  in == RF_ALL,
458  "RandomForestOptions()::features_per_node():"
459  "input must be of type RF_LOG or RF_SQRT");
460  mtry_switch_ = in;
461  return *this;
462  }
463 
464  /**\brief Set mtry to a constant value
465  *
466  * mtry is the number of columns/variates/variables randomly choosen
467  * to select the best split from.
468  *
469  */
471  {
472  mtry_ = in;
473  mtry_switch_ = RF_CONST;
474  return *this;
475  }
476 
477  /**\brief use a external function to calculate mtry
478  *
479  * \param in function pointer that takes int (number of columns
480  * of the and outputs int (mtry)
481  */
483  {
484  mtry_func_ = in;
485  mtry_switch_ = RF_FUNCTION;
486  return *this;
487  }
488 
489  /** How many trees to create?
490  *
491  * <br> Default: 255.
492  */
494  {
495  tree_count_ = in;
496  return *this;
497  }
498 
499  /**\brief Number of examples required for a node to be split.
500  *
501  * When the number of examples in a node is below this number,
502  * the node is not split even if class separation is not yet perfect.
503  * Instead, the node returns the proportion of each class
504  * (among the remaining examples) during the prediction phase.
505  * <br> Default: 1 (complete growing)
506  */
508  {
509  min_split_node_size_ = in;
510  return *this;
511  }
512 };
513 
514 
515 /** \brief problem types
516  */
517 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
518 
519 
520 /** \brief problem specification class for the random forest.
521  *
522  * This class contains all the problem specific parameters the random
523  * forest needs for learning. Specification of an instance of this class
524  * is optional as all necessary fields will be computed prior to learning
525  * if not specified.
526  *
527  * if needed usage is similar to that of RandomForestOptions
528  */
529 
530 template<class LabelType = double>
532 {
533 
534 
535 public:
536 
537  /** \brief problem class
538  */
539 
540  typedef LabelType Label_t;
541  ArrayVector<Label_t> classes;
542 
543  int column_count_; // number of features
544  int class_count_; // number of classes
545  int row_count_; // number of samples
546 
547  int actual_mtry_; // mtry used in training
548  int actual_msample_; // number if in-bag samples per tree
549 
550  Problem_t problem_type_; // classification or regression
551 
552  int used_; // this ProblemSpec is valid
553  ArrayVector<double> class_weights_; // if classes have different importance
554  int is_weighted_; // class_weights_ are used
555  double precision_; // termination criterion for regression loss
556 
557 
558  template<class T>
559  void to_classlabel(int index, T & out) const
560  {
561  out = T(classes[index]);
562  }
563  template<class T>
564  int to_classIndex(T index) const
565  {
566  return std::find(classes.begin(), classes.end(), index) - classes.begin();
567  }
568 
569  #define EQUALS(field) field(rhs.field)
570  ProblemSpec(ProblemSpec const & rhs)
571  :
572  EQUALS(column_count_),
573  EQUALS(class_count_),
574  EQUALS(row_count_),
575  EQUALS(actual_mtry_),
576  EQUALS(actual_msample_),
577  EQUALS(problem_type_),
578  EQUALS(used_),
579  EQUALS(class_weights_),
580  EQUALS(is_weighted_),
581  EQUALS(precision_)
582  {
583  std::back_insert_iterator<ArrayVector<Label_t> >
584  iter(classes);
585  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
586  }
587  #undef EQUALS
588  #define EQUALS(field) field(rhs.field)
589  template<class T>
590  ProblemSpec(ProblemSpec<T> const & rhs)
591  :
592  EQUALS(column_count_),
593  EQUALS(class_count_),
594  EQUALS(row_count_),
595  EQUALS(actual_mtry_),
596  EQUALS(actual_msample_),
597  EQUALS(problem_type_),
598  EQUALS(used_),
599  EQUALS(class_weights_),
600  EQUALS(is_weighted_),
601  EQUALS(precision_)
602  {
603  std::back_insert_iterator<ArrayVector<Label_t> >
604  iter(classes);
605  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
606  }
607  #undef EQUALS
608 
609  // for some reason the function below does not match
610  // the default copy constructor
611  #define EQUALS(field) (this->field = rhs.field);
612  ProblemSpec & operator=(ProblemSpec const & rhs)
613  {
614  EQUALS(column_count_);
615  EQUALS(class_count_);
616  EQUALS(row_count_);
617  EQUALS(actual_mtry_);
618  EQUALS(actual_msample_);
619  EQUALS(problem_type_);
620  EQUALS(used_);
621  EQUALS(is_weighted_);
622  EQUALS(precision_);
623  class_weights_.clear();
624  std::back_insert_iterator<ArrayVector<double> >
625  iter2(class_weights_);
626  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
627  classes.clear();
628  std::back_insert_iterator<ArrayVector<Label_t> >
629  iter(classes);
630  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
631  return *this;
632  }
633 
634  template<class T>
635  ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
636  {
637  EQUALS(column_count_);
638  EQUALS(class_count_);
639  EQUALS(row_count_);
640  EQUALS(actual_mtry_);
641  EQUALS(actual_msample_);
642  EQUALS(problem_type_);
643  EQUALS(used_);
644  EQUALS(is_weighted_);
645  EQUALS(precision_);
646  class_weights_.clear();
647  std::back_insert_iterator<ArrayVector<double> >
648  iter2(class_weights_);
649  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
650  classes.clear();
651  std::back_insert_iterator<ArrayVector<Label_t> >
652  iter(classes);
653  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
654  return *this;
655  }
656  #undef EQUALS
657 
658  template<class T>
659  bool operator==(ProblemSpec<T> const & rhs)
660  {
661  bool result = true;
662  #define COMPARE(field) result = result && (this->field == rhs.field);
663  COMPARE(column_count_);
664  COMPARE(class_count_);
665  COMPARE(row_count_);
666  COMPARE(actual_mtry_);
667  COMPARE(actual_msample_);
668  COMPARE(problem_type_);
669  COMPARE(is_weighted_);
670  COMPARE(precision_);
671  COMPARE(used_);
672  COMPARE(class_weights_);
673  COMPARE(classes);
674  #undef COMPARE
675  return result;
676  }
677 
678  bool operator!=(ProblemSpec & rhs)
679  {
680  return !(*this == rhs);
681  }
682 
683 
684  size_t serialized_size() const
685  {
686  return 9 + class_count_ *int(is_weighted_+1);
687  }
688 
689 
690  template<class Iter>
691  void unserialize(Iter const & begin, Iter const & end)
692  {
693  Iter iter = begin;
694  vigra_precondition(end - begin >= 9,
695  "ProblemSpec::unserialize():"
696  "wrong number of parameters");
697  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
698  PULL(column_count_,int);
699  PULL(class_count_, int);
700 
701  vigra_precondition(end - begin >= 9 + class_count_,
702  "ProblemSpec::unserialize(): 1");
703  PULL(row_count_, int);
704  PULL(actual_mtry_,int);
705  PULL(actual_msample_, int);
706  PULL(problem_type_, Problem_t);
707  PULL(is_weighted_, int);
708  PULL(used_, int);
709  PULL(precision_, double);
710  if(is_weighted_)
711  {
712  vigra_precondition(end - begin == 9 + 2*class_count_,
713  "ProblemSpec::unserialize(): 2");
714  class_weights_.insert(class_weights_.end(),
715  iter,
716  iter + class_count_);
717  iter += class_count_;
718  }
719  classes.insert(classes.end(), iter, end);
720  #undef PULL
721  }
722 
723 
724  template<class Iter>
725  void serialize(Iter const & begin, Iter const & end) const
726  {
727  Iter iter = begin;
728  vigra_precondition(end - begin == serialized_size(),
729  "RandomForestOptions::serialize():"
730  "wrong number of parameters");
731  #define PUSH(item_) *iter = double(item_); ++iter;
732  PUSH(column_count_);
733  PUSH(class_count_)
734  PUSH(row_count_);
735  PUSH(actual_mtry_);
736  PUSH(actual_msample_);
737  PUSH(problem_type_);
738  PUSH(is_weighted_);
739  PUSH(used_);
740  PUSH(precision_);
741  if(is_weighted_)
742  {
743  std::copy(class_weights_.begin(),
744  class_weights_.end(),
745  iter);
746  iter += class_count_;
747  }
748  std::copy(classes.begin(),
749  classes.end(),
750  iter);
751  #undef PUSH
752  }
753 
754  void make_from_map(std::map<std::string, ArrayVector<double> > & in)
755  {
756  typedef MultiArrayShape<2>::type Shp;
757  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
758  PULL(column_count_,int);
759  PULL(class_count_, int);
760  PULL(row_count_, int);
761  PULL(actual_mtry_,int);
762  PULL(actual_msample_, int);
763  PULL(problem_type_, (Problem_t)int);
764  PULL(is_weighted_, int);
765  PULL(used_, int);
766  PULL(precision_, double);
767  class_weights_ = in["class_weights_"];
768  #undef PUSH
769  }
770  void make_map(std::map<std::string, ArrayVector<double> > & in) const
771  {
772  typedef MultiArrayShape<2>::type Shp;
773  #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_));
774  PUSH(column_count_);
775  PUSH(class_count_)
776  PUSH(row_count_);
777  PUSH(actual_mtry_);
778  PUSH(actual_msample_);
779  PUSH(problem_type_);
780  PUSH(is_weighted_);
781  PUSH(used_);
782  PUSH(precision_);
783  in["class_weights_"] = class_weights_;
784  #undef PUSH
785  }
786 
787  /**\brief set default values (-> values not set)
788  */
790  : column_count_(0),
791  class_count_(0),
792  row_count_(0),
793  actual_mtry_(0),
794  actual_msample_(0),
795  problem_type_(CHECKLATER),
796  used_(false),
797  is_weighted_(false),
798  precision_(0.0)
799  {}
800 
801 
802  ProblemSpec & column_count(int in)
803  {
804  column_count_ = in;
805  return *this;
806  }
807 
808  /**\brief supply with class labels -
809  *
810  * the preprocessor will not calculate the labels needed in this case.
811  */
812  template<class C_Iter>
813  ProblemSpec & classes_(C_Iter begin, C_Iter end)
814  {
815  int size = end-begin;
816  for(int k=0; k<size; ++k, ++begin)
817  classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
818  class_count_ = size;
819  return *this;
820  }
821 
822  /** \brief supply with class weights -
823  *
824  * this is the only case where you would really have to
825  * create a ProblemSpec object.
826  */
827  template<class W_Iter>
828  ProblemSpec & class_weights(W_Iter begin, W_Iter end)
829  {
830  class_weights_.insert(class_weights_.end(), begin, end);
831  is_weighted_ = true;
832  return *this;
833  }
834 
835 
836 
837  void clear()
838  {
839  used_ = false;
840  classes.clear();
841  class_weights_.clear();
842  column_count_ = 0 ;
843  class_count_ = 0;
844  actual_mtry_ = 0;
845  actual_msample_ = 0;
846  problem_type_ = CHECKLATER;
847  is_weighted_ = false;
848  precision_ = 0.0;
849 
850  }
851 
852  bool used() const
853  {
854  return used_ != 0;
855  }
856 };
857 
858 
859 //@}
860 
861 
862 
863 /**\brief Standard early stopping criterion
864  *
865  * Stop if region.size() < min_split_node_size_;
866  */
868 {
869  public:
870  int min_split_node_size_;
871 
872  template<class Opt>
873  EarlyStoppStd(Opt opt)
874  : min_split_node_size_(opt.min_split_node_size_)
875  {}
876 
877  template<class T>
878  void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
879  {}
880 
881  template<class Region>
882  bool operator()(Region& region)
883  {
884  return region.size() < min_split_node_size_;
885  }
886 
887  template<class WeightIter, class T, class C>
888  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
889  {
890  return false;
891  }
892 };
893 
894 
895 } // namespace vigra
896 
897 #endif //VIGRA_RF_COMMON_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Tue Jul 10 2012)