36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
40 #include "rf_common.hxx"
61 template<
class Tag,
class LabelType,
class T1,
class C1,
class T2,
class C2>
76 switch(options.mtry_switch_)
79 ext_param.actual_mtry_ =
81 std::sqrt(
double(ext_param.column_count_))
86 ext_param.actual_mtry_ =
87 int(1+(
std::log(
double(ext_param.column_count_))
91 ext_param.actual_mtry_ =
92 options.mtry_func_(ext_param.column_count_);
95 ext_param.actual_mtry_ = ext_param.column_count_;
98 ext_param.actual_mtry_ =
102 switch(options.training_set_calc_switch_)
105 ext_param.actual_msample_ =
106 options.training_set_size_;
108 case RF_PROPORTIONAL:
109 ext_param.actual_msample_ =
110 (int)
std::ceil( options.training_set_proportion_ *
111 ext_param.row_count_);
114 ext_param.actual_msample_ =
115 options.training_set_func_(ext_param.row_count_);
118 vigra_precondition(1!= 1,
"unexpected error");
126 template<
unsigned int N,
class T,
class C>
129 for(
int ii = 0; ii < in.
size(); ++ii)
137 template<
unsigned int N,
class T,
class C>
140 if(!std::numeric_limits<T>::has_infinity)
142 for(
int ii = 0; ii < in.
size(); ++ii)
143 if(in[ii] == std::numeric_limits<T>::infinity())
156 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
157 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
160 typedef Int32 LabelInt;
185 ext_param.column_count_ = features.
shape(1);
186 ext_param.row_count_ = features.
shape(0);
187 ext_param.problem_type_ = CLASSIFICATION;
188 ext_param.used_ =
true;
189 intLabels_.reshape(response.
shape());
192 if(ext_param.class_count_ == 0)
196 std::set<T2> labelToInt;
198 labelToInt.insert(response(k,0));
199 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
200 ext_param.
classes_(tmp_.begin(), tmp_.end());
204 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
206 throw std::runtime_error(
"unknown label type");
209 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
210 - ext_param.classes.begin();
213 if(ext_param.class_weights_.size() == 0)
216 tmp((std::size_t)ext_param.class_count_,
217 NumericTraits<T2>::one());
225 strata_ = intLabels_;
263 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
264 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
287 ext_param_(ext_param)
323 #endif //VIGRA_RF_PREPROCESSING_HXX