35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
39 # include "vigra/hdf5impex.hxx"
42 # include "vigra/multi_array.hxx"
43 # include "vigra/multi_impex.hxx"
44 # include "vigra/inspectimage.hxx"
46 #include <vigra/windows.h>
49 #include <vigra/timing.hxx>
144 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
150 Feature_t & features,
163 template<
class RF,
class PR,
class SM,
class ST>
173 template<
class RF,
class PR>
183 template<
class RF,
class PR>
199 template<
class TR,
class IntT,
class TopT,
class Feat>
207 template<
class TR,
class IntT,
class TopT,
class Feat>
246 template <
class Visitor,
class Next = StopVisiting>
256 next_(next), visitor_(visitor)
261 next_(stop_), visitor_(visitor)
264 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
265 void visit_after_split( Tree & tree,
270 Feature_t & features,
273 if(visitor_.is_active())
274 visitor_.visit_after_split(tree, split,
275 parent, leftChild, rightChild,
277 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
281 template<
class RF,
class PR,
class SM,
class ST>
282 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
284 if(visitor_.is_active())
285 visitor_.visit_after_tree(rf, pr, sm, st, index);
286 next_.visit_after_tree(rf, pr, sm, st, index);
289 template<
class RF,
class PR>
290 void visit_at_beginning(RF & rf, PR & pr)
292 if(visitor_.is_active())
293 visitor_.visit_at_beginning(rf, pr);
294 next_.visit_at_beginning(rf, pr);
296 template<
class RF,
class PR>
297 void visit_at_end(RF & rf, PR & pr)
299 if(visitor_.is_active())
300 visitor_.visit_at_end(rf, pr);
301 next_.visit_at_end(rf, pr);
304 template<
class TR,
class IntT,
class TopT,
class Feat>
305 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
307 if(visitor_.is_active())
308 visitor_.visit_external_node(tr, index, node_t,features);
309 next_.visit_external_node(tr, index, node_t,features);
311 template<
class TR,
class IntT,
class TopT,
class Feat>
312 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
314 if(visitor_.is_active())
315 visitor_.visit_internal_node(tr, index, node_t,features);
316 next_.visit_internal_node(tr, index, node_t,features);
321 if(visitor_.is_active() && visitor_.has_value())
322 return visitor_.return_val();
323 return next_.return_val();
347 template<
class A,
class B>
348 detail::VisitorNode<A, detail::VisitorNode<B> >
361 template<
class A,
class B,
class C>
362 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
377 template<
class A,
class B,
class C,
class D>
378 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
379 detail::VisitorNode<D> > > >
396 template<
class A,
class B,
class C,
class D,
class E>
397 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
398 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
418 template<
class A,
class B,
class C,
class D,
class E,
420 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
421 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
443 template<
class A,
class B,
class C,
class D,
class E,
445 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
446 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
447 detail::VisitorNode<G> > > > > > >
449 D & d, E & e, F & f, G & g)
471 template<
class A,
class B,
class C,
class D,
class E,
472 class F,
class G,
class H>
473 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
474 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
475 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
502 template<
class A,
class B,
class C,
class D,
class E,
503 class F,
class G,
class H,
class I>
504 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
505 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
506 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
534 template<
class A,
class B,
class C,
class D,
class E,
535 class F,
class G,
class H,
class I,
class J>
536 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
537 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
538 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
539 detail::VisitorNode<J> > > > > > > > > >
580 bool adjust_thresholds;
588 struct MarginalDistribution
591 Int32 leftTotalCounts;
593 Int32 rightTotalCounts;
600 struct TreeOnlineInformation
602 std::vector<MarginalDistribution> mag_distributions;
603 std::vector<IndexList> index_lists;
605 std::map<int,int> interior_to_index;
607 std::map<int,int> exterior_to_index;
611 std::vector<TreeOnlineInformation> trees_online_information;
615 template<
class RF,
class PR>
619 trees_online_information.resize(rf.options_.tree_count_);
626 trees_online_information[tree_id].mag_distributions.clear();
627 trees_online_information[tree_id].index_lists.clear();
628 trees_online_information[tree_id].interior_to_index.clear();
629 trees_online_information[tree_id].exterior_to_index.clear();
634 template<
class RF,
class PR,
class SM,
class ST>
640 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
646 Feature_t & features,
650 int addr=tree.topology_.size();
651 if(split.createNode().typeID() == i_ThresholdNode)
653 if(adjust_thresholds)
656 linear_index=trees_online_information[tree_id].mag_distributions.size();
657 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
658 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
660 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
661 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
663 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
664 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
666 double gap_left,gap_right;
668 gap_left=features(leftChild[0],split.bestSplitColumn());
669 for(i=1;i<leftChild.size();++i)
670 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
671 gap_left=features(leftChild[i],split.bestSplitColumn());
672 gap_right=features(rightChild[0],split.bestSplitColumn());
673 for(i=1;i<rightChild.size();++i)
674 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
675 gap_right=features(rightChild[i],split.bestSplitColumn());
676 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
677 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
683 linear_index=trees_online_information[tree_id].index_lists.size();
684 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
686 trees_online_information[tree_id].index_lists.push_back(
IndexList());
688 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
689 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
692 void add_to_index_list(
int tree,
int node,
int index)
696 TreeOnlineInformation &ti=trees_online_information[tree];
697 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
699 void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
703 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
704 trees_online_information[src_tree].exterior_to_index.erase(src_index);
711 template<
class TR,
class IntT,
class TopT,
class Feat>
715 if(adjust_thresholds)
717 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
719 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
720 TreeOnlineInformation &ti=trees_online_information[tree_id];
721 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
722 if(value>m.gap_left && value<m.gap_right)
725 if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
735 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
738 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
740 ++m.rightTotalCounts;
741 ++m.rightCounts[current_label];
746 ++m.rightCounts[current_label];
794 template<
class RF,
class PR,
class SM,
class ST>
798 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
800 oobCount.resize(rf.ext_param_.row_count_, 0);
801 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
804 for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
811 .predictLabel(
rowVector(pr.features(), l))
812 != pr.response()(l,0))
823 template<
class RF,
class PR>
827 for(
int l=0; l < (int)rf.ext_param_.row_count_; ++l)
831 oobError += double(oobErrorCount[l]) / oobCount[l];
869 void save(std::string filen, std::string pathn)
871 if(*(pathn.end()-1) !=
'/')
873 const char* filename = filen.c_str();
876 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
882 template<
class RF,
class PR>
883 void visit_at_beginning(RF & rf, PR & pr)
885 class_count = rf.class_count();
886 tmp_prob.
reshape(Shp(1, class_count), 0);
887 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
888 is_weighted = rf.options().predict_weighted_;
889 indices.resize(rf.ext_param().row_count_);
890 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
892 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
894 for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
900 template<
class RF,
class PR,
class SM,
class ST>
909 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
913 std::random_shuffle(indices.
begin(), indices.
end());
914 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
916 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
918 oob_indices.push_back(indices[ii]);
919 ++cts[pr.response()(indices[ii], 0)];
922 for(
unsigned int ll = 0; ll < oob_indices.
size(); ++ll)
925 ++oobCount[oob_indices[ll]];
930 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),oob_indices[ll]));
932 rf.tree(index).parameters_,
935 for(
int ii = 0; ii < class_count; ++ii)
937 tmp_prob[ii] = node.prob_begin()[ii];
941 for(
int ii = 0; ii < class_count; ++ii)
942 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
944 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
945 int label =
argMax(tmp_prob);
950 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
953 if(!sm.is_used()[ll])
961 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
963 rf.tree(index).parameters_,
966 for(
int ii = 0; ii < class_count; ++ii)
968 tmp_prob[ii] = node.prob_begin()[ii];
972 for(
int ii = 0; ii < class_count; ++ii)
973 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
976 int label =
argMax(tmp_prob);
986 template<
class RF,
class PR>
990 int totalOobCount =0;
991 int breimanstyle = 0;
992 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1067 void save(std::string filen, std::string pathn)
1069 if(*(pathn.end()-1) !=
'/')
1071 const char* filename = filen.c_str();
1073 writeHDF5(filename, (pathn +
"oob_per_tree").c_str(),
oob_per_tree);
1074 writeHDF5(filename, (pathn +
"oobroc_per_tree").c_str(),
oobroc_per_tree);
1075 writeHDF5(filename, (pathn +
"breiman_per_tree").c_str(),
breiman_per_tree);
1077 writeHDF5(filename, (pathn +
"per_tree_error").c_str(), temp);
1079 writeHDF5(filename, (pathn +
"per_tree_error_std").c_str(), temp);
1081 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
1083 writeHDF5(filename, (pathn +
"ulli_error").c_str(), temp);
1089 template<
class RF,
class PR>
1090 void visit_at_beginning(RF & rf, PR & pr)
1092 class_count = rf.class_count();
1093 if(class_count == 2)
1097 tmp_prob.
reshape(Shp(1, class_count), 0);
1098 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1099 is_weighted = rf.options().predict_weighted_;
1103 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
1105 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1106 oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
1110 template<
class RF,
class PR,
class SM,
class ST>
1116 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1119 if(!sm.is_used()[ll])
1127 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
1129 rf.tree(index).parameters_,
1132 for(
int ii = 0; ii < class_count; ++ii)
1134 tmp_prob[ii] = node.prob_begin()[ii];
1138 for(
int ii = 0; ii < class_count; ++ii)
1139 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1142 int label =
argMax(tmp_prob);
1144 if(label != pr.response()(ll, 0))
1149 ++oobErrorCount[ll];
1153 int breimanstyle = 0;
1154 int totalOobCount = 0;
1155 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1174 for(
int gg = 0; gg < current_roc.
shape(2); ++gg)
1176 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1180 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.
shape(2)))?
1182 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1185 current_roc.
bindOuter(gg)/= totalOobCount;
1189 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1195 template<
class RF,
class PR>
1200 int totalOobCount =0;
1201 int breimanstyle = 0;
1202 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1253 int repetition_count_;
1257 void save(std::string filename, std::string prefix)
1259 prefix =
"variable_importance_" + prefix;
1260 writeHDF5(filename.c_str(),
1271 : repetition_count_(rep_cnt)
1278 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1283 Region & rightChild,
1284 Feature_t & features,
1289 Int32 const class_count = tree.ext_param_.class_count_;
1290 Int32 const column_count = tree.ext_param_.column_count_;
1299 if(split.createNode().typeID() == i_ThresholdNode)
1301 Node<i_ThresholdNode> node(split.createNode());
1303 += split.region_gini_ - split.minGini();
1313 template<
class RF,
class PR,
class SM,
class ST>
1317 Int32 column_count = rf.ext_param_.column_count_;
1318 Int32 class_count = rf.ext_param_.class_count_;
1328 typename PR::FeatureWithMemory_t features = pr.features();
1334 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1335 if(!sm.is_used()[ii])
1336 oob_indices.push_back(ii);
1339 std::vector<double> backup_column;
1342 #ifdef CLASSIFIER_TEST
1353 oob_right(Shp_t(1, class_count + 1));
1355 perm_oob_right (Shp_t(1, class_count + 1));
1359 for(iter = oob_indices.
begin();
1360 iter != oob_indices.
end();
1364 .predictLabel(
rowVector(features, *iter))
1365 == pr.response()(*iter, 0))
1368 ++oob_right[pr.response()(*iter,0)];
1370 ++oob_right[class_count];
1374 for(
int ii = 0; ii < column_count; ++ii)
1376 perm_oob_right.
init(0.0);
1378 backup_column.clear();
1379 for(iter = oob_indices.
begin();
1380 iter != oob_indices.
end();
1383 backup_column.push_back(features(*iter,ii));
1387 for(
int rr = 0; rr < repetition_count_; ++rr)
1390 int n = oob_indices.
size();
1391 for(
int jj = 1; jj < n; ++jj)
1392 std::swap(features(oob_indices[jj], ii),
1393 features(oob_indices[randint(jj+1)], ii));
1396 for(iter = oob_indices.
begin();
1397 iter != oob_indices.
end();
1401 .predictLabel(
rowVector(features, *iter))
1402 == pr.response()(*iter, 0))
1405 ++perm_oob_right[pr.response()(*iter, 0)];
1407 ++perm_oob_right[class_count];
1414 perm_oob_right /= repetition_count_;
1415 perm_oob_right -=oob_right;
1416 perm_oob_right *= -1;
1417 perm_oob_right /= oob_indices.
size();
1420 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1422 for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
1423 features(oob_indices[jj], ii) = backup_column[jj];
1432 template<
class RF,
class PR,
class SM,
class ST>
1440 template<
class RF,
class PR>
1453 template<
class RF,
class PR,
class SM,
class ST>
1455 if(index != rf.options().tree_count_-1) {
1456 std::cout <<
"\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 <<
"%]"
1457 <<
" (" << index+1 <<
" of " << rf.options().tree_count_ <<
") done" << std::flush;
1460 std::cout <<
"\r[" << std::setw(10) << 100.0 <<
"%]" << std::endl;
1464 template<
class RF,
class PR>
1466 std::string a = TOCS;
1467 std::cout <<
"all " << rf.options().tree_count_ <<
" trees have been learned in " << a << std::endl;
1470 template<
class RF,
class PR>
1473 std::cout <<
"growing random forest, which will have " << rf.options().tree_count_ <<
" trees" << std::endl;
1521 void save(std::string file, std::string prefix)
1538 template<
class RF,
class PR>
1539 void visit_at_beginning(RF
const & rf, PR & pr)
1542 int n = rf.ext_param_.column_count_;
1545 corr_l.
reshape(Shp(n +1, 10));
1548 noise_l.
reshape(Shp(pr.features().shape(0), 10));
1550 for(
int ii = 0; ii <
noise.
size(); ++ii)
1552 noise[ii] = random.uniform53();
1553 noise_l[ii] = random.uniform53() > 0.5;
1555 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1556 tmp_labels.
reshape(pr.response().shape());
1561 template<
class RF,
class PR>
1571 for(
int jj = 0; jj < rC-1; ++jj)
1576 for(
int jj = 0; jj < rC; ++jj)
1585 for(
int jj = 0; jj < rC; ++jj)
1592 for(
int jj = 0; jj < rC; ++jj)
1597 for(
int jj = 0; jj < rC; ++jj)
1603 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1608 Region & rightChild,
1609 Feature_t & features,
1612 if(split.createNode().typeID() == i_ThresholdNode)
1616 for(
int ii = 0; ii < parent.size(); ++ii)
1618 tmp_labels[parent[ii]]
1619 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1620 ++tmp_cc[tmp_labels[parent[ii]]];
1622 double region_gini = bgfunc.loss_of_region(tmp_labels,
1627 int n = split.bestSplitColumn();
1631 for(
int k = 0; k < features.shape(1); ++k)
1636 parent.
begin(), parent.end(),
1638 wgini = (region_gini - bgfunc.min_gini_);
1642 for(
int k = 0; k < 10; ++k)
1647 parent.
begin(), parent.end(),
1649 wgini = (region_gini - bgfunc.min_gini_);
1654 for(
int k = 0; k < 10; ++k)
1659 parent.
begin(), parent.end(),
1661 wgini = (region_gini - bgfunc.min_gini_);
1665 bgfunc(labels,0, tmp_labels, parent.
begin(), parent.end(),tmp_cc);
1666 wgini = (region_gini - bgfunc.min_gini_);
1670 region_gini = split.region_gini_;
1672 Node<i_ThresholdNode> node(split.createNode());
1675 +=split.region_gini_ - split.minGini();
1677 for(
int k = 0; k < 10; ++k)
1682 parent.begin(), parent.end(),
1683 parent.classCounts());
1689 for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1691 wgini = region_gini - split.min_gini_[k];
1694 split.splitColumns[k])
1698 for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1700 split.bgfunc(
columnVector(features, split.splitColumns[k]),
1702 parent.begin(), parent.end(),
1703 parent.classCounts());
1704 wgini = region_gini - split.bgfunc.min_gini_;
1706 split.splitColumns[k]) += wgini;
1714 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1715 std::partition(parent.begin(), parent.end(), sorter);
1726 #endif // RF_VISITORS_HXX