[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #else
41 # include "vigra/impex.hxx"
42 # include "vigra/multi_array.hxx"
43 # include "vigra/multi_impex.hxx"
44 # include "vigra/inspectimage.hxx"
45 #endif // HasHDF5
46 #include <vigra/windows.h>
47 #include <iostream>
48 #include <iomanip>
49 #include <vigra/timing.hxx>
50 
51 namespace vigra
52 {
53 namespace rf
54 {
55 /** \addtogroup MachineLearning Machine Learning
56 **/
57 //@{
58 
59 /**
60  This namespace contains all classes and methods related to extracting information during
61  learning of the random forest. All Visitors share the same interface defined in
62  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
63  the order they were supplied.
64 
65  For the Random Forest the Visitor concept is implemented as a statically linked list
66  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
67  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
68 
69  To simplify usage create_visitor() factory methods are supplied.
70  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
71  It is possible to supply more than one visitor. They will then be invoked in serial order.
72 
73  The calculated information are stored as public data members of the class. - see documentation
74  of the individual visitors
75 
76  While creating a new visitor the new class should therefore publicly inherit from this class
77  (i.e.: see visitors::OOB_Error).
78 
79  \code
80 
81  typedef xxx feature_t \\ replace xxx with whichever type
82  typedef yyy label_t \\ meme chose.
83  MultiArrayView<2, feature_t> f = get_some_features();
84  MultiArrayView<2, label_t> l = get_some_labels();
85  RandomForest<> rf()
86 
87  //calculate OOB Error
88  visitors::OOB_Error oob_v;
89  //calculate Variable Importance
90  visitors::VariableImportanceVisitor varimp_v;
91 
92  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
93  //the data can be found in the attributes of oob_v and varimp_v now
94 
95  \endcode
96 */
97 namespace visitors
98 {
99 
100 
101 /** Base Class from which all Visitors derive. Can be used as a template to create new
102  * Visitors.
103  */
105 {
106  public:
107  bool active_;
108  bool is_active()
109  {
110  return active_;
111  }
112 
113  bool has_value()
114  {
115  return false;
116  }
117 
118  VisitorBase()
119  : active_(true)
120  {}
121 
122  void deactivate()
123  {
124  active_ = false;
125  }
126  void activate()
127  {
128  active_ = true;
129  }
130 
131  /** do something after the the Split has decided how to process the Region
132  * (Stack entry)
133  *
134  * \param tree reference to the tree that is currently being learned
135  * \param split reference to the split object
136  * \param parent current stack entry which was used to decide the split
137  * \param leftChild left stack entry that will be pushed
138  * \param rightChild
139  * right stack entry that will be pushed.
140  * \param features features matrix
141  * \param labels label matrix
142  * \sa RF_Traits::StackEntry_t
143  */
144  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
145  void visit_after_split( Tree & tree,
146  Split & split,
147  Region & parent,
148  Region & leftChild,
149  Region & rightChild,
150  Feature_t & features,
151  Label_t & labels)
152  {}
153 
154  /** do something after each tree has been learned
155  *
156  * \param rf reference to the random forest object that called this
157  * visitor
158  * \param pr reference to the preprocessor that processed the input
159  * \param sm reference to the sampler object
160  * \param st reference to the first stack entry
161  * \param index index of current tree
162  */
163  template<class RF, class PR, class SM, class ST>
164  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
165  {}
166 
167  /** do something after all trees have been learned
168  *
169  * \param rf reference to the random forest object that called this
170  * visitor
171  * \param pr reference to the preprocessor that processed the input
172  */
173  template<class RF, class PR>
174  void visit_at_end(RF const & rf, PR const & pr)
175  {}
176 
177  /** do something before learning starts
178  *
179  * \param rf reference to the random forest object that called this
180  * visitor
181  * \param pr reference to the Processor class used.
182  */
183  template<class RF, class PR>
184  void visit_at_beginning(RF const & rf, PR const & pr)
185  {}
186  /** do some thing while traversing tree after it has been learned
187  * (external nodes)
188  *
189  * \param tr reference to the tree object that called this visitor
190  * \param index index in the topology_ array we currently are at
191  * \param node_t type of node we have (will be e_.... - )
192  * \param weight Node weight of current node.
193  * \sa NodeTags;
194  *
195  * you can create the node by using a switch on node_tag and using the
196  * corresponding Node objects. Or - if you do not care about the type
197  * use the Nodebase class.
198  */
199  template<class TR, class IntT, class TopT,class Feat>
200  void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
201  {}
202 
203  /** do something when visiting a internal node after it has been learned
204  *
205  * \sa visit_external_node
206  */
207  template<class TR, class IntT, class TopT,class Feat>
208  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
209  {}
210 
211  /** return a double value. The value of the first
212  * visitor encountered that has a return value is returned with the
213  * RandomForest::learn() method - or -1.0 if no return value visitor
214  * existed. This functionality basically only exists so that the
215  * OOB - visitor can return the oob error rate like in the old version
216  * of the random forest.
217  */
218  double return_val()
219  {
220  return -1.0;
221  }
222 };
223 
224 
225 /** Last Visitor that should be called to stop the recursion.
226  */
228 {
229  public:
230  bool has_value()
231  {
232  return true;
233  }
234  double return_val()
235  {
236  return -1.0;
237  }
238 };
239 namespace detail
240 {
241 /** Container elements of the statically linked Visitor list.
242  *
243  * use the create_visitor() factory functions to create visitors up to size 10;
244  *
245  */
246 template <class Visitor, class Next = StopVisiting>
248 {
249  public:
250 
251  StopVisiting stop_;
252  Next next_;
253  Visitor & visitor_;
254  VisitorNode(Visitor & visitor, Next & next)
255  :
256  next_(next), visitor_(visitor)
257  {}
258 
259  VisitorNode(Visitor & visitor)
260  :
261  next_(stop_), visitor_(visitor)
262  {}
263 
264  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
265  void visit_after_split( Tree & tree,
266  Split & split,
267  Region & parent,
268  Region & leftChild,
269  Region & rightChild,
270  Feature_t & features,
271  Label_t & labels)
272  {
273  if(visitor_.is_active())
274  visitor_.visit_after_split(tree, split,
275  parent, leftChild, rightChild,
276  features, labels);
277  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
278  features, labels);
279  }
280 
281  template<class RF, class PR, class SM, class ST>
282  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
283  {
284  if(visitor_.is_active())
285  visitor_.visit_after_tree(rf, pr, sm, st, index);
286  next_.visit_after_tree(rf, pr, sm, st, index);
287  }
288 
289  template<class RF, class PR>
290  void visit_at_beginning(RF & rf, PR & pr)
291  {
292  if(visitor_.is_active())
293  visitor_.visit_at_beginning(rf, pr);
294  next_.visit_at_beginning(rf, pr);
295  }
296  template<class RF, class PR>
297  void visit_at_end(RF & rf, PR & pr)
298  {
299  if(visitor_.is_active())
300  visitor_.visit_at_end(rf, pr);
301  next_.visit_at_end(rf, pr);
302  }
303 
304  template<class TR, class IntT, class TopT,class Feat>
305  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
306  {
307  if(visitor_.is_active())
308  visitor_.visit_external_node(tr, index, node_t,features);
309  next_.visit_external_node(tr, index, node_t,features);
310  }
311  template<class TR, class IntT, class TopT,class Feat>
312  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
313  {
314  if(visitor_.is_active())
315  visitor_.visit_internal_node(tr, index, node_t,features);
316  next_.visit_internal_node(tr, index, node_t,features);
317  }
318 
319  double return_val()
320  {
321  if(visitor_.is_active() && visitor_.has_value())
322  return visitor_.return_val();
323  return next_.return_val();
324  }
325 };
326 
327 } //namespace detail
328 
329 //////////////////////////////////////////////////////////////////////////////
330 // Visitor Factory function up to 10 visitors //
331 //////////////////////////////////////////////////////////////////////////////
332 
333 /** factory method to to be used with RandomForest::learn()
334  */
335 template<class A>
338 {
339  typedef detail::VisitorNode<A> _0_t;
340  _0_t _0(a);
341  return _0;
342 }
343 
344 
345 /** factory method to to be used with RandomForest::learn()
346  */
347 template<class A, class B>
348 detail::VisitorNode<A, detail::VisitorNode<B> >
349 create_visitor(A & a, B & b)
350 {
351  typedef detail::VisitorNode<B> _1_t;
352  _1_t _1(b);
353  typedef detail::VisitorNode<A, _1_t> _0_t;
354  _0_t _0(a, _1);
355  return _0;
356 }
357 
358 
359 /** factory method to to be used with RandomForest::learn()
360  */
361 template<class A, class B, class C>
362 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
363 create_visitor(A & a, B & b, C & c)
364 {
365  typedef detail::VisitorNode<C> _2_t;
366  _2_t _2(c);
367  typedef detail::VisitorNode<B, _2_t> _1_t;
368  _1_t _1(b, _2);
369  typedef detail::VisitorNode<A, _1_t> _0_t;
370  _0_t _0(a, _1);
371  return _0;
372 }
373 
374 
375 /** factory method to to be used with RandomForest::learn()
376  */
377 template<class A, class B, class C, class D>
378 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
379  detail::VisitorNode<D> > > >
380 create_visitor(A & a, B & b, C & c, D & d)
381 {
382  typedef detail::VisitorNode<D> _3_t;
383  _3_t _3(d);
384  typedef detail::VisitorNode<C, _3_t> _2_t;
385  _2_t _2(c, _3);
386  typedef detail::VisitorNode<B, _2_t> _1_t;
387  _1_t _1(b, _2);
388  typedef detail::VisitorNode<A, _1_t> _0_t;
389  _0_t _0(a, _1);
390  return _0;
391 }
392 
393 
394 /** factory method to to be used with RandomForest::learn()
395  */
396 template<class A, class B, class C, class D, class E>
397 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
398  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
399 create_visitor(A & a, B & b, C & c,
400  D & d, E & e)
401 {
402  typedef detail::VisitorNode<E> _4_t;
403  _4_t _4(e);
404  typedef detail::VisitorNode<D, _4_t> _3_t;
405  _3_t _3(d, _4);
406  typedef detail::VisitorNode<C, _3_t> _2_t;
407  _2_t _2(c, _3);
408  typedef detail::VisitorNode<B, _2_t> _1_t;
409  _1_t _1(b, _2);
410  typedef detail::VisitorNode<A, _1_t> _0_t;
411  _0_t _0(a, _1);
412  return _0;
413 }
414 
415 
416 /** factory method to to be used with RandomForest::learn()
417  */
418 template<class A, class B, class C, class D, class E,
419  class F>
420 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
421  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
422 create_visitor(A & a, B & b, C & c,
423  D & d, E & e, F & f)
424 {
425  typedef detail::VisitorNode<F> _5_t;
426  _5_t _5(f);
427  typedef detail::VisitorNode<E, _5_t> _4_t;
428  _4_t _4(e, _5);
429  typedef detail::VisitorNode<D, _4_t> _3_t;
430  _3_t _3(d, _4);
431  typedef detail::VisitorNode<C, _3_t> _2_t;
432  _2_t _2(c, _3);
433  typedef detail::VisitorNode<B, _2_t> _1_t;
434  _1_t _1(b, _2);
435  typedef detail::VisitorNode<A, _1_t> _0_t;
436  _0_t _0(a, _1);
437  return _0;
438 }
439 
440 
441 /** factory method to to be used with RandomForest::learn()
442  */
443 template<class A, class B, class C, class D, class E,
444  class F, class G>
445 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
446  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
447  detail::VisitorNode<G> > > > > > >
448 create_visitor(A & a, B & b, C & c,
449  D & d, E & e, F & f, G & g)
450 {
451  typedef detail::VisitorNode<G> _6_t;
452  _6_t _6(g);
453  typedef detail::VisitorNode<F, _6_t> _5_t;
454  _5_t _5(f, _6);
455  typedef detail::VisitorNode<E, _5_t> _4_t;
456  _4_t _4(e, _5);
457  typedef detail::VisitorNode<D, _4_t> _3_t;
458  _3_t _3(d, _4);
459  typedef detail::VisitorNode<C, _3_t> _2_t;
460  _2_t _2(c, _3);
461  typedef detail::VisitorNode<B, _2_t> _1_t;
462  _1_t _1(b, _2);
463  typedef detail::VisitorNode<A, _1_t> _0_t;
464  _0_t _0(a, _1);
465  return _0;
466 }
467 
468 
469 /** factory method to to be used with RandomForest::learn()
470  */
471 template<class A, class B, class C, class D, class E,
472  class F, class G, class H>
473 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
474  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
475  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
476 create_visitor(A & a, B & b, C & c,
477  D & d, E & e, F & f,
478  G & g, H & h)
479 {
480  typedef detail::VisitorNode<H> _7_t;
481  _7_t _7(h);
482  typedef detail::VisitorNode<G, _7_t> _6_t;
483  _6_t _6(g, _7);
484  typedef detail::VisitorNode<F, _6_t> _5_t;
485  _5_t _5(f, _6);
486  typedef detail::VisitorNode<E, _5_t> _4_t;
487  _4_t _4(e, _5);
488  typedef detail::VisitorNode<D, _4_t> _3_t;
489  _3_t _3(d, _4);
490  typedef detail::VisitorNode<C, _3_t> _2_t;
491  _2_t _2(c, _3);
492  typedef detail::VisitorNode<B, _2_t> _1_t;
493  _1_t _1(b, _2);
494  typedef detail::VisitorNode<A, _1_t> _0_t;
495  _0_t _0(a, _1);
496  return _0;
497 }
498 
499 
500 /** factory method to to be used with RandomForest::learn()
501  */
502 template<class A, class B, class C, class D, class E,
503  class F, class G, class H, class I>
504 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
505  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
506  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
507 create_visitor(A & a, B & b, C & c,
508  D & d, E & e, F & f,
509  G & g, H & h, I & i)
510 {
511  typedef detail::VisitorNode<I> _8_t;
512  _8_t _8(i);
513  typedef detail::VisitorNode<H, _8_t> _7_t;
514  _7_t _7(h, _8);
515  typedef detail::VisitorNode<G, _7_t> _6_t;
516  _6_t _6(g, _7);
517  typedef detail::VisitorNode<F, _6_t> _5_t;
518  _5_t _5(f, _6);
519  typedef detail::VisitorNode<E, _5_t> _4_t;
520  _4_t _4(e, _5);
521  typedef detail::VisitorNode<D, _4_t> _3_t;
522  _3_t _3(d, _4);
523  typedef detail::VisitorNode<C, _3_t> _2_t;
524  _2_t _2(c, _3);
525  typedef detail::VisitorNode<B, _2_t> _1_t;
526  _1_t _1(b, _2);
527  typedef detail::VisitorNode<A, _1_t> _0_t;
528  _0_t _0(a, _1);
529  return _0;
530 }
531 
532 /** factory method to to be used with RandomForest::learn()
533  */
534 template<class A, class B, class C, class D, class E,
535  class F, class G, class H, class I, class J>
536 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
537  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
538  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
539  detail::VisitorNode<J> > > > > > > > > >
540 create_visitor(A & a, B & b, C & c,
541  D & d, E & e, F & f,
542  G & g, H & h, I & i,
543  J & j)
544 {
545  typedef detail::VisitorNode<J> _9_t;
546  _9_t _9(j);
547  typedef detail::VisitorNode<I, _9_t> _8_t;
548  _8_t _8(i, _9);
549  typedef detail::VisitorNode<H, _8_t> _7_t;
550  _7_t _7(h, _8);
551  typedef detail::VisitorNode<G, _7_t> _6_t;
552  _6_t _6(g, _7);
553  typedef detail::VisitorNode<F, _6_t> _5_t;
554  _5_t _5(f, _6);
555  typedef detail::VisitorNode<E, _5_t> _4_t;
556  _4_t _4(e, _5);
557  typedef detail::VisitorNode<D, _4_t> _3_t;
558  _3_t _3(d, _4);
559  typedef detail::VisitorNode<C, _3_t> _2_t;
560  _2_t _2(c, _3);
561  typedef detail::VisitorNode<B, _2_t> _1_t;
562  _1_t _1(b, _2);
563  typedef detail::VisitorNode<A, _1_t> _0_t;
564  _0_t _0(a, _1);
565  return _0;
566 }
567 
568 //////////////////////////////////////////////////////////////////////////////
569 // Visitors of communal interest. //
570 //////////////////////////////////////////////////////////////////////////////
571 
572 
573 /** Visitor to gain information, later needed for online learning.
574  */
575 
577 {
578 public:
579  //Set if we adjust thresholds
580  bool adjust_thresholds;
581  //Current tree id
582  int tree_id;
583  //Last node id for finding parent
584  int last_node_id;
585  //Need to now the label for interior node visiting
586  vigra::Int32 current_label;
587  //marginal distribution for interior nodes
588  struct MarginalDistribution
589  {
590  ArrayVector<Int32> leftCounts;
591  Int32 leftTotalCounts;
592  ArrayVector<Int32> rightCounts;
593  Int32 rightTotalCounts;
594  double gap_left;
595  double gap_right;
596  };
598 
599  //All information for one tree
600  struct TreeOnlineInformation
601  {
602  std::vector<MarginalDistribution> mag_distributions;
603  std::vector<IndexList> index_lists;
604  //map for linear index of mag_distiributions
605  std::map<int,int> interior_to_index;
606  //map for linear index of index_lists
607  std::map<int,int> exterior_to_index;
608  };
609 
610  //All trees
611  std::vector<TreeOnlineInformation> trees_online_information;
612 
613  /** Initilize, set the number of trees
614  */
615  template<class RF,class PR>
616  void visit_at_beginning(RF & rf,const PR & pr)
617  {
618  tree_id=0;
619  trees_online_information.resize(rf.options_.tree_count_);
620  }
621 
622  /** Reset a tree
623  */
624  void reset_tree(int tree_id)
625  {
626  trees_online_information[tree_id].mag_distributions.clear();
627  trees_online_information[tree_id].index_lists.clear();
628  trees_online_information[tree_id].interior_to_index.clear();
629  trees_online_information[tree_id].exterior_to_index.clear();
630  }
631 
632  /** simply increase the tree count
633  */
634  template<class RF, class PR, class SM, class ST>
635  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
636  {
637  tree_id++;
638  }
639 
640  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
641  void visit_after_split( Tree & tree,
642  Split & split,
643  Region & parent,
644  Region & leftChild,
645  Region & rightChild,
646  Feature_t & features,
647  Label_t & labels)
648  {
649  int linear_index;
650  int addr=tree.topology_.size();
651  if(split.createNode().typeID() == i_ThresholdNode)
652  {
653  if(adjust_thresholds)
654  {
655  //Store marginal distribution
656  linear_index=trees_online_information[tree_id].mag_distributions.size();
657  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
658  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
659 
660  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
661  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
662 
663  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
664  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
665  //Store the gap
666  double gap_left,gap_right;
667  int i;
668  gap_left=features(leftChild[0],split.bestSplitColumn());
669  for(i=1;i<leftChild.size();++i)
670  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
671  gap_left=features(leftChild[i],split.bestSplitColumn());
672  gap_right=features(rightChild[0],split.bestSplitColumn());
673  for(i=1;i<rightChild.size();++i)
674  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
675  gap_right=features(rightChild[i],split.bestSplitColumn());
676  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
677  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
678  }
679  }
680  else
681  {
682  //Store index list
683  linear_index=trees_online_information[tree_id].index_lists.size();
684  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
685 
686  trees_online_information[tree_id].index_lists.push_back(IndexList());
687 
688  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
689  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
690  }
691  }
692  void add_to_index_list(int tree,int node,int index)
693  {
694  if(!this->active_)
695  return;
696  TreeOnlineInformation &ti=trees_online_information[tree];
697  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
698  }
699  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
700  {
701  if(!this->active_)
702  return;
703  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
704  trees_online_information[src_tree].exterior_to_index.erase(src_index);
705  }
706  /** do something when visiting a internal node during getToLeaf
707  *
708  * remember as last node id, for finding the parent of the last external node
709  * also: adjust class counts and borders
710  */
711  template<class TR, class IntT, class TopT,class Feat>
712  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
713  {
714  last_node_id=index;
715  if(adjust_thresholds)
716  {
717  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
718  //Check if we are in the gap
719  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
720  TreeOnlineInformation &ti=trees_online_information[tree_id];
721  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
722  if(value>m.gap_left && value<m.gap_right)
723  {
724  //Check which site we want to go
725  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
726  {
727  //We want to go left
728  m.gap_left=value;
729  }
730  else
731  {
732  //We want to go right
733  m.gap_right=value;
734  }
735  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
736  }
737  //Adjust class counts
738  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
739  {
740  ++m.rightTotalCounts;
741  ++m.rightCounts[current_label];
742  }
743  else
744  {
745  ++m.leftTotalCounts;
746  ++m.rightCounts[current_label];
747  }
748  }
749  }
750  /** do something when visiting a extern node during getToLeaf
751  *
752  * Store the new index!
753  */
754 };
755 
756 //////////////////////////////////////////////////////////////////////////////
757 // Out of Bag Error estimates //
758 //////////////////////////////////////////////////////////////////////////////
759 
760 
761 /** Visitor that calculates the oob error of each individual randomized
762  * decision tree.
763  *
764  * After training a tree, all those samples that are OOB for this particular tree
765  * are put down the tree and the error estimated.
766  * the per tree oob error is the average of the individual error estimates.
767  * (oobError = average error of one randomized tree)
768  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
769  * visitor)
770  */
772 {
773 public:
774  /** Average error of one randomized decision tree
775  */
776  double oobError;
777 
778  int totalOobCount;
779  ArrayVector<int> oobCount,oobErrorCount;
780 
782  : oobError(0.0),
783  totalOobCount(0)
784  {}
785 
786 
787  bool has_value()
788  {
789  return true;
790  }
791 
792 
793  /** does the basic calculation per tree*/
794  template<class RF, class PR, class SM, class ST>
795  void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index)
796  {
797  //do the first time called.
798  if(int(oobCount.size()) != rf.ext_param_.row_count_)
799  {
800  oobCount.resize(rf.ext_param_.row_count_, 0);
801  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
802  }
803  // go through the samples
804  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
805  {
806  // if the lth sample is oob...
807  if(!sm.is_used()[l])
808  {
809  ++oobCount[l];
810  if( rf.tree(index)
811  .predictLabel(rowVector(pr.features(), l))
812  != pr.response()(l,0))
813  {
814  ++oobErrorCount[l];
815  }
816  }
817 
818  }
819  }
820 
821  /** Does the normalisation
822  */
823  template<class RF, class PR>
824  void visit_at_end(RF & rf, PR & pr)
825  {
826  // do some normalisation
827  for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
828  {
829  if(oobCount[l])
830  {
831  oobError += double(oobErrorCount[l]) / oobCount[l];
832  ++totalOobCount;
833  }
834  }
835  oobError/=totalOobCount;
836  }
837 
838 };
839 
840 /** Visitor that calculates the oob error of the ensemble
841  * This rate should be used to estimate the crossvalidation
842  * error rate.
843  * Here each sample is put down those trees, for which this sample
844  * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate
845  * the output using the ensemble consisting only of trees 1 3 and 5.
846  *
847  * Using normal bagged sampling each sample is OOB for approx. 33% of trees
848  * The error rate obtained as such therefore corresponds to crossvalidation
849  * rate obtained using a ensemble containing 33% of the trees.
850  */
851 class OOB_Error : public VisitorBase
852 {
854  int class_count;
855  bool is_weighted;
856  MultiArray<2,double> tmp_prob;
857  public:
858 
859  MultiArray<2, double> prob_oob;
860  /** Ensemble oob error rate
861  */
862  double oob_breiman;
863 
864  MultiArray<2, double> oobCount;
865  ArrayVector< int> indices;
866  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
867 
868 #ifdef HasHDF5
869  void save(std::string filen, std::string pathn)
870  {
871  if(*(pathn.end()-1) != '/')
872  pathn += "/";
873  const char* filename = filen.c_str();
874  MultiArray<2, double> temp(Shp(1,1), 0.0);
875  temp[0] = oob_breiman;
876  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
877  }
878 #endif
879  // negative value if sample was ib, number indicates how often.
880  // value >=0 if sample was oob, 0 means fail 1, corrrect
881 
882  template<class RF, class PR>
883  void visit_at_beginning(RF & rf, PR & pr)
884  {
885  class_count = rf.class_count();
886  tmp_prob.reshape(Shp(1, class_count), 0);
887  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
888  is_weighted = rf.options().predict_weighted_;
889  indices.resize(rf.ext_param().row_count_);
890  if(int(oobCount.size()) != rf.ext_param_.row_count_)
891  {
892  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
893  }
894  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
895  {
896  indices[ii] = ii;
897  }
898  }
899 
900  template<class RF, class PR, class SM, class ST>
901  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
902  {
903  // go through the samples
904  int total_oob =0;
905  int wrong_oob =0;
906  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
907  // (i.e. the OOB sample ist very large)
908  // 40000: use at most 40000 OOB samples per class for OOB error estimate
909  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
910  {
911  ArrayVector<int> oob_indices;
912  ArrayVector<int> cts(class_count, 0);
913  std::random_shuffle(indices.begin(), indices.end());
914  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
915  {
916  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
917  {
918  oob_indices.push_back(indices[ii]);
919  ++cts[pr.response()(indices[ii], 0)];
920  }
921  }
922  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
923  {
924  // update number of trees in which current sample is oob
925  ++oobCount[oob_indices[ll]];
926 
927  // update number of oob samples in this tree.
928  ++total_oob;
929  // get the predicted votes ---> tmp_prob;
930  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
931  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
932  rf.tree(index).parameters_,
933  pos);
934  tmp_prob.init(0);
935  for(int ii = 0; ii < class_count; ++ii)
936  {
937  tmp_prob[ii] = node.prob_begin()[ii];
938  }
939  if(is_weighted)
940  {
941  for(int ii = 0; ii < class_count; ++ii)
942  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
943  }
944  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
945  int label = argMax(tmp_prob);
946 
947  }
948  }else
949  {
950  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
951  {
952  // if the lth sample is oob...
953  if(!sm.is_used()[ll])
954  {
955  // update number of trees in which current sample is oob
956  ++oobCount[ll];
957 
958  // update number of oob samples in this tree.
959  ++total_oob;
960  // get the predicted votes ---> tmp_prob;
961  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
962  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
963  rf.tree(index).parameters_,
964  pos);
965  tmp_prob.init(0);
966  for(int ii = 0; ii < class_count; ++ii)
967  {
968  tmp_prob[ii] = node.prob_begin()[ii];
969  }
970  if(is_weighted)
971  {
972  for(int ii = 0; ii < class_count; ++ii)
973  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
974  }
975  rowVector(prob_oob, ll) += tmp_prob;
976  int label = argMax(tmp_prob);
977 
978  }
979  }
980  }
981  // go through the ib samples;
982  }
983 
984  /** Normalise variable importance after the number of trees is known.
985  */
986  template<class RF, class PR>
987  void visit_at_end(RF & rf, PR & pr)
988  {
989  // ullis original metric and breiman style stuff
990  int totalOobCount =0;
991  int breimanstyle = 0;
992  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
993  {
994  if(oobCount[ll])
995  {
996  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
997  ++breimanstyle;
998  ++totalOobCount;
999  }
1000  }
1001  oob_breiman = double(breimanstyle)/totalOobCount;
1002  }
1003 };
1004 
1005 
1006 /** Visitor that calculates different OOB error statistics
1007  */
1009 {
1010  typedef MultiArrayShape<2>::type Shp;
1011  int class_count;
1012  bool is_weighted;
1013  MultiArray<2,double> tmp_prob;
1014  public:
1015 
1016  /** OOB Error rate of each individual tree
1017  */
1019  /** Mean of oob_per_tree
1020  */
1021  double oob_mean;
1022  /**Standard deviation of oob_per_tree
1023  */
1024  double oob_std;
1025 
1026  MultiArray<2, double> prob_oob;
1027  /** Ensemble OOB error
1028  *
1029  * \sa OOB_Error
1030  */
1031  double oob_breiman;
1032 
1033  MultiArray<2, double> oobCount;
1034  MultiArray<2, double> oobErrorCount;
1035  /** Per Tree OOB error calculated as in OOB_PerTreeError
1036  * (Ulli's version)
1037  */
1039 
1040  /**Column containing the development of the Ensemble
1041  * error rate with increasing number of trees
1042  */
1044  /** 4 dimensional array containing the development of confusion matrices
1045  * with number of trees - can be used to estimate ROC curves etc.
1046  *
1047  * oobroc_per_tree(ii,jj,kk,ll)
1048  * corresponds true label = ii
1049  * predicted label = jj
1050  * confusion matrix after ll trees
1051  *
1052  * explaination of third index:
1053  *
1054  * Two class case:
1055  * kk = 0 - (treeCount-1)
1056  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1057  * More classes:
1058  * kk = 0. Threshold on probability set by argMax of the probability array.
1059  */
1061 
1063 
1064 #ifdef HasHDF5
1065  /** save to HDF5 file
1066  */
1067  void save(std::string filen, std::string pathn)
1068  {
1069  if(*(pathn.end()-1) != '/')
1070  pathn += "/";
1071  const char* filename = filen.c_str();
1072  MultiArray<2, double> temp(Shp(1,1), 0.0);
1073  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1074  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1075  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1076  temp[0] = oob_mean;
1077  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1078  temp[0] = oob_std;
1079  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1080  temp[0] = oob_breiman;
1081  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1082  temp[0] = oob_per_tree2;
1083  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1084  }
1085 #endif
1086  // negative value if sample was ib, number indicates how often.
1087  // value >=0 if sample was oob, 0 means fail 1, corrrect
1088 
1089  template<class RF, class PR>
1090  void visit_at_beginning(RF & rf, PR & pr)
1091  {
1092  class_count = rf.class_count();
1093  if(class_count == 2)
1094  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1095  else
1096  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1097  tmp_prob.reshape(Shp(1, class_count), 0);
1098  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1099  is_weighted = rf.options().predict_weighted_;
1100  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1101  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1102  //do the first time called.
1103  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1104  {
1105  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1106  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1107  }
1108  }
1109 
1110  template<class RF, class PR, class SM, class ST>
1111  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1112  {
1113  // go through the samples
1114  int total_oob =0;
1115  int wrong_oob =0;
1116  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1117  {
1118  // if the lth sample is oob...
1119  if(!sm.is_used()[ll])
1120  {
1121  // update number of trees in which current sample is oob
1122  ++oobCount[ll];
1123 
1124  // update number of oob samples in this tree.
1125  ++total_oob;
1126  // get the predicted votes ---> tmp_prob;
1127  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1128  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1129  rf.tree(index).parameters_,
1130  pos);
1131  tmp_prob.init(0);
1132  for(int ii = 0; ii < class_count; ++ii)
1133  {
1134  tmp_prob[ii] = node.prob_begin()[ii];
1135  }
1136  if(is_weighted)
1137  {
1138  for(int ii = 0; ii < class_count; ++ii)
1139  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1140  }
1141  rowVector(prob_oob, ll) += tmp_prob;
1142  int label = argMax(tmp_prob);
1143 
1144  if(label != pr.response()(ll, 0))
1145  {
1146  // update number of wrong oob samples in this tree.
1147  ++wrong_oob;
1148  // update number of trees in which current sample is wrong oob
1149  ++oobErrorCount[ll];
1150  }
1151  }
1152  }
1153  int breimanstyle = 0;
1154  int totalOobCount = 0;
1155  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1156  {
1157  if(oobCount[ll])
1158  {
1159  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1160  ++breimanstyle;
1161  ++totalOobCount;
1162  if(oobroc_per_tree.shape(2) == 1)
1163  {
1164  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1165  }
1166  }
1167  }
1168  if(oobroc_per_tree.shape(2) == 1)
1169  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1170  if(oobroc_per_tree.shape(2) > 1)
1171  {
1172  MultiArrayView<3, double> current_roc
1173  = oobroc_per_tree.bindOuter(index);
1174  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1175  {
1176  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1177  {
1178  if(oobCount[ll])
1179  {
1180  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1181  1 : 0;
1182  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1183  }
1184  }
1185  current_roc.bindOuter(gg)/= totalOobCount;
1186  }
1187  }
1188  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1189  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1190  // go through the ib samples;
1191  }
1192 
1193  /** Normalise variable importance after the number of trees is known.
1194  */
1195  template<class RF, class PR>
1196  void visit_at_end(RF & rf, PR & pr)
1197  {
1198  // ullis original metric and breiman style stuff
1199  oob_per_tree2 = 0;
1200  int totalOobCount =0;
1201  int breimanstyle = 0;
1202  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1203  {
1204  if(oobCount[ll])
1205  {
1206  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1207  ++breimanstyle;
1208  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1209  ++totalOobCount;
1210  }
1211  }
1212  oob_per_tree2 /= totalOobCount;
1213  oob_breiman = double(breimanstyle)/totalOobCount;
1214  // mean error of each tree
1215  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1216  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1217  rowStatistics(oob_per_tree, mean, stdDev);
1218  }
1219 };
1220 
1221 /** calculate variable importance while learning.
1222  */
1224 {
1225  public:
1226 
1227  /** This Array has the same entries as the R - random forest variable
1228  * importance.
1229  * Matrix is featureCount by (classCount +2)
1230  * variable_importance_(ii,jj) is the variable importance measure of
1231  * the ii-th variable according to:
1232  * jj = 0 - (classCount-1)
1233  * classwise permutation importance
1234  * jj = rowCount(variable_importance_) -2
1235  * permutation importance
1236  * jj = rowCount(variable_importance_) -1
1237  * gini decrease importance.
1238  *
1239  * permutation importance:
1240  * The difference between the fraction of OOB samples classified correctly
1241  * before and after permuting (randomizing) the ii-th column is calculated.
1242  * The ii-th column is permuted rep_cnt times.
1243  *
1244  * class wise permutation importance:
1245  * same as permutation importance. We only look at those OOB samples whose
1246  * response corresponds to class jj.
1247  *
1248  * gini decrease importance:
1249  * row ii corresponds to the sum of all gini decreases induced by variable ii
1250  * in each node of the random forest.
1251  */
1253  int repetition_count_;
1254  bool in_place_;
1255 
1256 #ifdef HasHDF5
1257  void save(std::string filename, std::string prefix)
1258  {
1259  prefix = "variable_importance_" + prefix;
1260  writeHDF5(filename.c_str(),
1261  prefix.c_str(),
1263  }
1264 #endif
1265  /** Constructor
1266  * \param rep_cnt (defautl: 10) how often should
1267  * the permutation take place. Set to 1 to make calculation faster (but
1268  * possibly more instable)
1269  */
1270  VariableImportanceVisitor(int rep_cnt = 10)
1271  : repetition_count_(rep_cnt)
1272 
1273  {}
1274 
1275  /** calculates impurity decrease based variable importance after every
1276  * split.
1277  */
1278  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1279  void visit_after_split( Tree & tree,
1280  Split & split,
1281  Region & parent,
1282  Region & leftChild,
1283  Region & rightChild,
1284  Feature_t & features,
1285  Label_t & labels)
1286  {
1287  //resize to right size when called the first time
1288 
1289  Int32 const class_count = tree.ext_param_.class_count_;
1290  Int32 const column_count = tree.ext_param_.column_count_;
1291  if(variable_importance_.size() == 0)
1292  {
1293 
1295  .reshape(MultiArrayShape<2>::type(column_count,
1296  class_count+2));
1297  }
1298 
1299  if(split.createNode().typeID() == i_ThresholdNode)
1300  {
1301  Node<i_ThresholdNode> node(split.createNode());
1302  variable_importance_(node.column(),class_count+1)
1303  += split.region_gini_ - split.minGini();
1304  }
1305  }
1306 
1307  /**compute permutation based var imp.
1308  * (Only an Array of size oob_sample_count x 1 is created.
1309  * - apposed to oob_sample_count x feature_count in the other method.
1310  *
1311  * \sa FieldProxy
1312  */
1313  template<class RF, class PR, class SM, class ST>
1314  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index)
1315  {
1316  typedef MultiArrayShape<2>::type Shp_t;
1317  Int32 column_count = rf.ext_param_.column_count_;
1318  Int32 class_count = rf.ext_param_.class_count_;
1319 
1320  /* This solution saves memory uptake but not multithreading
1321  * compatible
1322  */
1323  // remove the const cast on the features (yep , I know what I am
1324  // doing here.) data is not destroyed.
1325  //typename PR::Feature_t & features
1326  // = const_cast<typename PR::Feature_t &>(pr.features());
1327 
1328  typename PR::FeatureWithMemory_t features = pr.features();
1329 
1330  //find the oob indices of current tree.
1331  ArrayVector<Int32> oob_indices;
1333  iter;
1334  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1335  if(!sm.is_used()[ii])
1336  oob_indices.push_back(ii);
1337 
1338  //create space to back up a column
1339  std::vector<double> backup_column;
1340 
1341  // Random foo
1342 #ifdef CLASSIFIER_TEST
1343  RandomMT19937 random(1);
1344 #else
1345  RandomMT19937 random(RandomSeed);
1346 #endif
1348  randint(random);
1349 
1350 
1351  //make some space for the results
1353  oob_right(Shp_t(1, class_count + 1));
1355  perm_oob_right (Shp_t(1, class_count + 1));
1356 
1357 
1358  // get the oob success rate with the original samples
1359  for(iter = oob_indices.begin();
1360  iter != oob_indices.end();
1361  ++iter)
1362  {
1363  if(rf.tree(index)
1364  .predictLabel(rowVector(features, *iter))
1365  == pr.response()(*iter, 0))
1366  {
1367  //per class
1368  ++oob_right[pr.response()(*iter,0)];
1369  //total
1370  ++oob_right[class_count];
1371  }
1372  }
1373  //get the oob rate after permuting the ii'th dimension.
1374  for(int ii = 0; ii < column_count; ++ii)
1375  {
1376  perm_oob_right.init(0.0);
1377  //make backup of orinal column
1378  backup_column.clear();
1379  for(iter = oob_indices.begin();
1380  iter != oob_indices.end();
1381  ++iter)
1382  {
1383  backup_column.push_back(features(*iter,ii));
1384  }
1385 
1386  //get the oob rate after permuting the ii'th dimension.
1387  for(int rr = 0; rr < repetition_count_; ++rr)
1388  {
1389  //permute dimension.
1390  int n = oob_indices.size();
1391  for(int jj = 1; jj < n; ++jj)
1392  std::swap(features(oob_indices[jj], ii),
1393  features(oob_indices[randint(jj+1)], ii));
1394 
1395  //get the oob sucess rate after permuting
1396  for(iter = oob_indices.begin();
1397  iter != oob_indices.end();
1398  ++iter)
1399  {
1400  if(rf.tree(index)
1401  .predictLabel(rowVector(features, *iter))
1402  == pr.response()(*iter, 0))
1403  {
1404  //per class
1405  ++perm_oob_right[pr.response()(*iter, 0)];
1406  //total
1407  ++perm_oob_right[class_count];
1408  }
1409  }
1410  }
1411 
1412 
1413  //normalise and add to the variable_importance array.
1414  perm_oob_right /= repetition_count_;
1415  perm_oob_right -=oob_right;
1416  perm_oob_right *= -1;
1417  perm_oob_right /= oob_indices.size();
1419  .subarray(Shp_t(ii,0),
1420  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1421  //copy back permuted dimension
1422  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1423  features(oob_indices[jj], ii) = backup_column[jj];
1424  }
1425  }
1426 
1427  /** calculate permutation based impurity after every tree has been
1428  * learned default behaviour is that this happens out of place.
1429  * If you have very big data sets and want to avoid copying of data
1430  * set the in_place_ flag to true.
1431  */
1432  template<class RF, class PR, class SM, class ST>
1433  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1434  {
1435  after_tree_ip_impl(rf, pr, sm, st, index);
1436  }
1437 
1438  /** Normalise variable importance after the number of trees is known.
1439  */
1440  template<class RF, class PR>
1441  void visit_at_end(RF & rf, PR & pr)
1442  {
1443  variable_importance_ /= rf.trees_.size();
1444  }
1445 };
1446 
1447 /** Verbose output
1448  */
1450  public:
1452 
1453  template<class RF, class PR, class SM, class ST>
1454  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){
1455  if(index != rf.options().tree_count_-1) {
1456  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1457  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1458  }
1459  else {
1460  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1461  }
1462  }
1463 
1464  template<class RF, class PR>
1465  void visit_at_end(RF const & rf, PR const & pr) {
1466  std::string a = TOCS;
1467  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1468  }
1469 
1470  template<class RF, class PR>
1471  void visit_at_beginning(RF const & rf, PR const & pr) {
1472  TIC;
1473  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1474  }
1475 
1476  private:
1477  USETICTOC;
1478 };
1479 
1480 
1481 /** Computes Correlation/Similarity Matrix of features while learning
1482  * random forest.
1483  */
1485 {
1486  public:
1487  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1488  * created on variable ii(when variable ii was chosen)
1489  */
1491  MultiArray<2, int> tmp_labels;
1492  /** additional noise features.
1493  */
1495  MultiArray<2, double> noise_l;
1496  /** how well can a noise column describe a partition created on variable ii.
1497  */
1499  MultiArray<2, double> corr_l;
1500 
1501  /** Similarity Matrix
1502  *
1503  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1504  * gini_missc
1505  * - row normalized by the number of times the column was chosen
1506  * - mean of corr_noise subtracted
1507  * - and symmetrised.
1508  *
1509  */
1511  /** Distance Matrix 1-similarity
1512  */
1514  ArrayVector<int> tmp_cc;
1515 
1516  /** How often was variable ii chosen
1517  */
1521  void save(std::string file, std::string prefix)
1522  {
1523  /*
1524  std::string tmp;
1525 #define VAR_WRITE(NAME) \
1526  tmp = #NAME;\
1527  tmp += "_";\
1528  tmp += prefix;\
1529  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1530  VAR_WRITE(gini_missc);
1531  VAR_WRITE(corr_noise);
1532  VAR_WRITE(distance);
1533  VAR_WRITE(similarity);
1534  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1535 #undef VAR_WRITE
1536 */
1537  }
1538  template<class RF, class PR>
1539  void visit_at_beginning(RF const & rf, PR & pr)
1540  {
1541  typedef MultiArrayShape<2>::type Shp;
1542  int n = rf.ext_param_.column_count_;
1543  gini_missc.reshape(Shp(n +1,n+ 1));
1544  corr_noise.reshape(Shp(n + 1, 10));
1545  corr_l.reshape(Shp(n +1, 10));
1546 
1547  noise.reshape(Shp(pr.features().shape(0), 10));
1548  noise_l.reshape(Shp(pr.features().shape(0), 10));
1549  RandomMT19937 random(RandomSeed);
1550  for(int ii = 0; ii < noise.size(); ++ii)
1551  {
1552  noise[ii] = random.uniform53();
1553  noise_l[ii] = random.uniform53() > 0.5;
1554  }
1555  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1556  tmp_labels.reshape(pr.response().shape());
1557  tmp_cc.resize(2);
1558  numChoices.resize(n+1);
1559  // look at allaxes
1560  }
1561  template<class RF, class PR>
1562  void visit_at_end(RF const & rf, PR const & pr)
1563  {
1564  typedef MultiArrayShape<2>::type Shp;
1567  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1568  rowStatistics(corr_noise, mean_noise);
1569  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1570  int rC = similarity.shape(0);
1571  for(int jj = 0; jj < rC-1; ++jj)
1572  {
1573  rowVector(similarity, jj) /= numChoices[jj];
1574  rowVector(similarity, jj) -= mean_noise(jj, 0);
1575  }
1576  for(int jj = 0; jj < rC; ++jj)
1577  {
1578  similarity(rC -1, jj) /= numChoices[jj];
1579  }
1580  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1582  FindMinMax<double> minmax;
1583  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1584 
1585  for(int jj = 0; jj < rC; ++jj)
1586  similarity(jj, jj) = minmax.max;
1587 
1588  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1589  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1590  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1591  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1592  for(int jj = 0; jj < rC; ++jj)
1593  similarity(jj, jj) = 0;
1594 
1595  FindMinMax<double> minmax2;
1596  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1597  for(int jj = 0; jj < rC; ++jj)
1598  similarity(jj, jj) = minmax2.max;
1599  distance.reshape(gini_missc.shape(), minmax2.max);
1600  distance -= similarity;
1601  }
1602 
1603  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1604  void visit_after_split( Tree & tree,
1605  Split & split,
1606  Region & parent,
1607  Region & leftChild,
1608  Region & rightChild,
1609  Feature_t & features,
1610  Label_t & labels)
1611  {
1612  if(split.createNode().typeID() == i_ThresholdNode)
1613  {
1614  double wgini;
1615  tmp_cc.init(0);
1616  for(int ii = 0; ii < parent.size(); ++ii)
1617  {
1618  tmp_labels[parent[ii]]
1619  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1620  ++tmp_cc[tmp_labels[parent[ii]]];
1621  }
1622  double region_gini = bgfunc.loss_of_region(tmp_labels,
1623  parent.begin(),
1624  parent.end(),
1625  tmp_cc);
1626 
1627  int n = split.bestSplitColumn();
1628  ++numChoices[n];
1629  ++(*(numChoices.end()-1));
1630  //this functor does all the work
1631  for(int k = 0; k < features.shape(1); ++k)
1632  {
1633  bgfunc(columnVector(features, k),
1634  0,
1635  tmp_labels,
1636  parent.begin(), parent.end(),
1637  tmp_cc);
1638  wgini = (region_gini - bgfunc.min_gini_);
1639  gini_missc(n, k)
1640  += wgini;
1641  }
1642  for(int k = 0; k < 10; ++k)
1643  {
1644  bgfunc(columnVector(noise, k),
1645  0,
1646  tmp_labels,
1647  parent.begin(), parent.end(),
1648  tmp_cc);
1649  wgini = (region_gini - bgfunc.min_gini_);
1650  corr_noise(n, k)
1651  += wgini;
1652  }
1653 
1654  for(int k = 0; k < 10; ++k)
1655  {
1656  bgfunc(columnVector(noise_l, k),
1657  0,
1658  tmp_labels,
1659  parent.begin(), parent.end(),
1660  tmp_cc);
1661  wgini = (region_gini - bgfunc.min_gini_);
1662  corr_l(n, k)
1663  += wgini;
1664  }
1665  bgfunc(labels,0, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1666  wgini = (region_gini - bgfunc.min_gini_);
1668  += wgini;
1669 
1670  region_gini = split.region_gini_;
1671 #if 1
1672  Node<i_ThresholdNode> node(split.createNode());
1674  node.column())
1675  +=split.region_gini_ - split.minGini();
1676 #endif
1677  for(int k = 0; k < 10; ++k)
1678  {
1679  split.bgfunc(columnVector(noise, k),
1680  0,
1681  labels,
1682  parent.begin(), parent.end(),
1683  parent.classCounts());
1685  k)
1686  += wgini;
1687  }
1688 #if 0
1689  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1690  {
1691  wgini = region_gini - split.min_gini_[k];
1692 
1694  split.splitColumns[k])
1695  += wgini;
1696  }
1697 
1698  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1699  {
1700  split.bgfunc(columnVector(features, split.splitColumns[k]),
1701  labels,
1702  parent.begin(), parent.end(),
1703  parent.classCounts());
1704  wgini = region_gini - split.bgfunc.min_gini_;
1706  split.splitColumns[k]) += wgini;
1707  }
1708 #endif
1709  // remember to partition the data according to the best.
1711  columnCount(gini_missc)-1)
1712  += region_gini;
1714  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1715  std::partition(parent.begin(), parent.end(), sorter);
1716  }
1717  }
1718 };
1719 
1720 
1721 } // namespace visitors
1722 } // namespace rf
1723 } // namespace vigra
1724 
1725 //@}
1726 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Thu Jun 14 2012)