35 #define VIGRA_RF_ALGORTIHM_HXX
38 #include "splices.hxx"
57 template<
class OrigMultiArray,
60 void choose(OrigMultiArray
const & in,
69 for(Iter iter = b; iter != e; ++iter, ++ii)
99 template<
class Feature_t,
class Response_t>
101 Response_t
const & response)
124 typedef std::vector<int> FeatureList_t;
125 typedef std::vector<double> ErrorList_t;
126 typedef FeatureList_t::iterator Pivot_t;
152 template<
class FeatureT,
155 class ErrorRateCallBack>
156 bool init(FeatureT
const & all_features,
157 ResponseT
const & response,
160 ErrorRateCallBack errorcallback)
162 bool ret_ = init(all_features, response, errorcallback);
165 vigra_precondition(std::distance(b, e) ==
selected.size(),
166 "Number of features in ranking != number of features matrix");
171 template<
class FeatureT,
174 bool init(FeatureT
const & all_features,
175 ResponseT
const & response,
180 return init(all_features, response, b, e, ecallback);
184 template<
class FeatureT,
186 bool init(FeatureT
const & all_features,
187 ResponseT
const & response)
189 return init(all_features, response, RFErrorCallback());
201 template<
class FeatureT,
203 class ErrorRateCallBack>
204 bool init(FeatureT
const & all_features,
205 ResponseT
const & response,
206 ErrorRateCallBack errorcallback)
213 selected.resize(all_features.shape(1), 0);
214 for(
unsigned int ii = 0; ii <
selected.size(); ++ii)
216 errors.resize(all_features.shape(1), -1);
217 errors.back() = errorcallback(all_features, response);
221 std::map<typename ResponseT::value_type, int> res_map;
222 std::vector<int> cts;
224 for(
int ii = 0; ii < response.shape(0); ++ii)
226 if(res_map.find(response(ii, 0)) == res_map.end())
228 res_map[response(ii, 0)] = counter;
232 cts[res_map[response(ii,0)]] +=1;
234 no_features = double(*(std::max_element(cts.begin(),
236 /
double(response.shape(0));
291 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
293 ResponseT
const & response,
295 ErrorRateCallBack errorcallback)
297 VariableSelectionResult::FeatureList_t & selected = result.
selected;
298 VariableSelectionResult::ErrorList_t & errors = result.
errors;
299 VariableSelectionResult::Pivot_t & pivot = result.pivot;
300 int featureCount = features.shape(1);
302 if(!result.init(features, response, errorcallback))
306 vigra_precondition(selected.size() == featureCount,
307 "forward_selection(): Number of features in Feature "
308 "matrix and number of features in previously used "
309 "result struct mismatch!");
313 int not_selected_size = std::distance(pivot, selected.end());
314 while(not_selected_size > 1)
316 std::vector<int> current_errors;
317 VariableSelectionResult::Pivot_t next = pivot;
318 for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
320 std::swap(*pivot, *next);
322 detail::choose( features,
326 double error = errorcallback(cur_feats, response);
327 current_errors.push_back(error);
328 std::swap(*pivot, *next);
330 int pos = std::distance(current_errors.begin(),
331 std::min_element(current_errors.begin(),
332 current_errors.end()));
334 std::advance(next, pos);
335 std::swap(*pivot, *next);
336 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
338 not_selected_size = std::distance(pivot, selected.end());
341 template<
class FeatureT,
class ResponseT>
343 ResponseT
const & response,
344 VariableSelectionResult & result)
389 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
391 ResponseT
const & response,
393 ErrorRateCallBack errorcallback)
395 int featureCount = features.shape(1);
396 VariableSelectionResult::FeatureList_t & selected = result.
selected;
397 VariableSelectionResult::ErrorList_t & errors = result.
errors;
398 VariableSelectionResult::Pivot_t & pivot = result.pivot;
401 if(!result.init(features, response, errorcallback))
405 vigra_precondition(selected.size() == featureCount,
406 "backward_elimination(): Number of features in Feature "
407 "matrix and number of features in previously used "
408 "result struct mismatch!");
410 pivot = selected.end() - 1;
412 int selected_size = std::distance(selected.begin(), pivot);
413 while(selected_size > 1)
415 VariableSelectionResult::Pivot_t next = selected.begin();
416 std::vector<int> current_errors;
417 for(
int ii = 0; ii < selected_size; ++ii, ++next)
419 std::swap(*pivot, *next);
421 detail::choose( features,
425 double error = errorcallback(cur_feats, response);
426 current_errors.push_back(error);
427 std::swap(*pivot, *next);
429 int pos = std::distance(current_errors.begin(),
430 std::max_element(current_errors.begin(),
431 current_errors.end()));
432 next = selected.begin();
433 std::advance(next, pos);
434 std::swap(*pivot, *next);
436 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
437 selected_size = std::distance(selected.begin(), pivot);
442 template<
class FeatureT,
class ResponseT>
444 ResponseT
const & response,
445 VariableSelectionResult & result)
482 template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
484 ResponseT
const & response,
486 ErrorRateCallBack errorcallback)
488 VariableSelectionResult::FeatureList_t & selected = result.
selected;
489 VariableSelectionResult::ErrorList_t & errors = result.
errors;
490 VariableSelectionResult::Pivot_t & iter = result.pivot;
491 int featureCount = features.shape(1);
493 if(!result.init(features, response, errorcallback))
497 vigra_precondition(selected.size() == featureCount,
498 "forward_selection(): Number of features in Feature "
499 "matrix and number of features in previously used "
500 "result struct mismatch!");
504 for(; iter != selected.end(); ++iter)
509 detail::choose( features,
513 double error = errorcallback(cur_feats, response);
514 errors[std::distance(selected.begin(), iter)] = error;
519 template<
class FeatureT,
class ResponseT>
521 ResponseT
const & response,
522 VariableSelectionResult & result)
529 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
544 ClusterNode():NodeBase(){}
545 ClusterNode(
int nCol,
546 BT::T_Container_type & topology,
547 BT::P_Container_type & split_param)
548 : BT(nCol + 5, 5,topology, split_param)
558 ClusterNode( BT::T_Container_type
const & topology,
559 BT::P_Container_type
const & split_param,
561 :
NodeBase(5 , 5,topology, split_param, n)
567 ClusterNode( BT & node_)
572 BT::parameter_size_ += 0;
578 void set_index(
int in)
604 HC_Entry(
int p,
int l,
int a,
bool in)
605 : parent(p), level(l), addr(a), infm(in)
634 double dist_func(
double a,
double b)
636 return std::min(a, b);
642 template<
class Functor>
646 std::vector<int> stack;
647 stack.push_back(begin_addr);
648 while(!stack.empty())
650 ClusterNode node(topology_, parameters_, stack.
back());
654 if(node.columns_size() != 1)
656 stack.push_back(node.child(0));
657 stack.push_back(node.child(1));
665 template<
class Functor>
669 std::queue<HC_Entry> queue;
674 queue.push(
HC_Entry(parent,level,begin_addr, infm));
675 while(!queue.empty())
677 level = queue.front().level;
678 parent = queue.front().parent;
679 addr = queue.front().addr;
680 infm = queue.front().infm;
681 ClusterNode node(topology_, parameters_, queue.
front().addr);
685 parnt = ClusterNode(topology_, parameters_, parent);
688 bool istrue = tester(node, level, parnt, infm);
689 if(node.columns_size() != 1)
691 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
692 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
699 void save(std::string file, std::string prefix)
702 vigra::writeHDF5(file.c_str(), (prefix +
"topology").c_str(),
704 Shp(topology_.
size(),1),
706 vigra::writeHDF5(file.c_str(), (prefix +
"parameters").c_str(),
708 Shp(parameters_.
size(), 1),
709 parameters_.
data()));
710 vigra::writeHDF5(file.c_str(), (prefix +
"begin_addr").c_str(),
719 template<
class T,
class C>
723 std::vector<std::pair<int, int> > addr;
724 typedef std::pair<int, int> Entry;
726 for(
int ii = 0; ii < distance.
shape(0); ++ii)
728 addr.push_back(std::make_pair(topology_.
size(), ii));
729 ClusterNode leaf(1, topology_, parameters_);
730 leaf.set_index(index);
732 leaf.columns_begin()[0] = ii;
735 while(addr.size() != 1)
740 double min_dist = dist((addr.begin()+ii_min)->second,
741 (addr.begin()+jj_min)->second);
742 for(
unsigned int ii = 0; ii < addr.size(); ++ii)
744 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
746 if( dist((addr.begin()+ii_min)->second,
747 (addr.begin()+jj_min)->second)
748 > dist((addr.begin()+ii)->second,
749 (addr.begin()+jj)->second))
751 min_dist = dist((addr.begin()+ii)->second,
752 (addr.begin()+jj)->second);
764 ClusterNode firstChild(topology_,
766 (addr.begin() +ii_min)->first);
767 ClusterNode secondChild(topology_,
769 (addr.begin() +jj_min)->first);
770 col_size = firstChild.columns_size() + secondChild.columns_size();
772 int cur_addr = topology_.
size();
773 begin_addr = cur_addr;
775 ClusterNode parent(col_size,
778 ClusterNode firstChild(topology_,
780 (addr.begin() +ii_min)->first);
781 ClusterNode secondChild(topology_,
783 (addr.begin() +jj_min)->first);
784 parent.parameters_begin()[0] = min_dist;
785 parent.set_index(index);
787 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
788 secondChild.columns_begin(),secondChild.columns_end(),
789 parent.columns_begin());
794 if(*parent.columns_begin() == *firstChild.columns_begin())
796 parent.child(0) = (addr.begin()+ii_min)->first;
797 parent.child(1) = (addr.begin()+jj_min)->first;
798 (addr.begin()+ii_min)->first = cur_addr;
800 to_keep = (addr.begin()+ii_min)->second;
801 to_desc = (addr.begin()+jj_min)->second;
802 addr.erase(addr.begin()+jj_min);
806 parent.child(1) = (addr.begin()+ii_min)->first;
807 parent.child(0) = (addr.begin()+jj_min)->first;
808 (addr.begin()+jj_min)->first = cur_addr;
810 to_keep = (addr.begin()+jj_min)->second;
811 to_desc = (addr.begin()+ii_min)->second;
812 addr.erase(addr.begin()+ii_min);
816 for(
unsigned int jj = 0 ; jj < addr.size(); ++jj)
820 double bla = dist_func(
821 dist(to_desc, (addr.begin()+jj)->second),
822 dist((addr.begin()+ii_keep)->second,
823 (addr.begin()+jj)->second));
825 dist((addr.begin()+ii_keep)->second,
826 (addr.begin()+jj)->second) = bla;
827 dist((addr.begin()+jj)->second,
828 (addr.begin()+ii_keep)->second) = bla;
849 bool operator()(Node& node)
862 template<
class Iter,
class DT>
867 Matrix<double> tmp_mem_;
870 Matrix<double> feats_;
877 template<
class Feat_T,
class Label_T>
880 Feat_T
const & feats,
881 Label_T
const & labls,
886 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
889 feats_(_spl(a,b).size(), feats.shape(1)),
890 labels_(_spl(a,b).size(),1),
896 copy_splice(_spl(a,b),
897 _spl(feats.shape(1)),
900 copy_splice(_spl(a,b),
901 _spl(labls.shape(1)),
907 bool operator()(Node& node)
911 int class_count = perm_imp.
shape(1) - 1;
913 for(
int kk = 0; kk < nPerm; ++kk)
916 for(
int ii = 0; ii <
rowCount(feats_); ++ii)
919 for(
int jj = 0; jj < node.columns_size(); ++jj)
921 if(node.columns_begin()[jj] != feats_.shape(1))
922 tmp_mem_(ii, node.columns_begin()[jj])
923 = tmp_mem_(index, node.columns_begin()[jj]);
927 for(
int ii = 0; ii <
rowCount(tmp_mem_); ++ii)
934 ++perm_imp(index,labels_(ii, 0));
936 ++perm_imp(index, class_count);
940 double node_status = perm_imp(index, class_count);
941 node_status /= nPerm;
942 node_status -= orig_imp(0, class_count);
944 node_status /= oob_size;
945 node.status() += node_status;
966 void save(std::string file, std::string prefix)
968 vigra::writeHDF5(file.c_str(), (prefix +
"_variables").c_str(),
974 bool operator()(Node& node)
976 for(
int ii = 0; ii < node.columns_size(); ++ii)
977 variables(index, ii) = node.columns_begin()[ii];
991 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
994 cur.status() = std::min(parent.status(), cur.status());
1021 std::ofstream graphviz;
1026 std::string
const gz)
1027 :features_(features), labels_(labels),
1028 graphviz(gz.c_str(), std::ios::out)
1030 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1034 graphviz <<
"\n}\n";
1039 bool operator()(Nde & cur,
int level, Nde parent,
bool infm)
1041 graphviz <<
"node" << cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() <<
"\\n";
1042 graphviz <<
" status: " << cur.status() <<
"\\n";
1043 for(
int kk = 0; kk < cur.columns_size(); ++kk)
1045 graphviz << cur.columns_begin()[kk] <<
" ";
1049 graphviz <<
"\"] [color = \"" <<cur.status() <<
" 1.000 1.000\"];\n";
1051 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" << cur.index() <<
"\";\n";
1071 int repetition_count_;
1077 void save(std::string filename, std::string prefix)
1079 std::string prefix1 =
"cluster_importance_" + prefix;
1080 writeHDF5(filename.c_str(),
1083 prefix1 =
"vars_" + prefix;
1084 writeHDF5(filename.c_str(),
1091 : repetition_count_(rep_cnt), clustering(clst)
1097 template<
class RF,
class PR>
1100 Int32 const class_count = rf.ext_param_.class_count_;
1101 Int32 const column_count = rf.ext_param_.column_count_+1;
1122 template<
class RF,
class PR,
class SM,
class ST>
1126 Int32 column_count = rf.ext_param_.column_count_ +1;
1127 Int32 class_count = rf.ext_param_.class_count_;
1131 typename PR::Feature_t & features
1132 =
const_cast<typename PR::Feature_t &
>(pr.features());
1139 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1143 for(
int ii = 0; ii < pr.features().shape(0); ++ii)
1144 indices.push_back(ii);
1145 std::random_shuffle(indices.begin(), indices.end());
1146 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1148 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1150 oob_indices.push_back(indices[ii]);
1151 ++cts[pr.response()(indices[ii], 0)];
1157 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1158 if(!sm.is_used()[ii])
1159 oob_indices.push_back(ii);
1169 oob_right(Shp_t(1, class_count + 1));
1172 for(iter = oob_indices.
begin();
1173 iter != oob_indices.
end();
1177 .predictLabel(
rowVector(features, *iter))
1178 == pr.response()(*iter, 0))
1181 ++oob_right[pr.response()(*iter,0)];
1183 ++oob_right[class_count];
1188 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1191 pc(oob_indices.
begin(), oob_indices.
end(),
1200 perm_oob_right /= repetition_count_;
1201 for(
int ii = 0; ii <
rowCount(perm_oob_right); ++ii)
1202 rowVector(perm_oob_right, ii) -= oob_right;
1204 perm_oob_right *= -1;
1205 perm_oob_right /= oob_indices.
size();
1214 template<
class RF,
class PR,
class SM,
class ST>
1222 template<
class RF,
class PR>
1262 template<
class FeatureT,
class ResponseT>
1264 ResponseT
const & response,
1271 if(features.shape(0) > 40000)
1278 RF.
learn(features, response,
1307 template<
class FeatureT,
class ResponseT>
1309 ResponseT
const & response,
1310 HClustering & linkage)