12 #ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
13 #define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
15 #include "../sampling.hxx"
16 #include "rf_split.hxx"
17 #include "rf_nodeproxy.hxx"
18 #include "../regression.hxx"
20 #define outm(v) std::cout << (#v) << ": " << (v) << std::endl;
21 #define outm2(v) std::cout << (#v) << ": " << (v) << ", ";
70 template<
class ColumnDecisionFunctor,
class Tag = ClassificationTag>
71 class RidgeSplit:
public SplitBase<Tag>
76 typedef SplitBase<Tag> SB;
78 ArrayVector<Int32> splitColumns;
79 ColumnDecisionFunctor bgfunc;
82 ArrayVector<double> min_gini_;
83 ArrayVector<ptrdiff_t> min_indices_;
84 ArrayVector<double> min_thresholds_;
89 bool m_bDoScalingInTraining;
90 bool m_bDoBestLambdaBasedOnGini;
93 :m_bDoScalingInTraining(true),
94 m_bDoBestLambdaBasedOnGini(true)
98 double minGini()
const
100 return min_gini_[bestSplitIndex];
103 int bestSplitColumn()
const
105 return splitColumns[bestSplitIndex];
108 bool& doScalingInTraining()
109 {
return m_bDoScalingInTraining; }
111 bool& doBestLambdaBasedOnGini()
112 {
return m_bDoBestLambdaBasedOnGini; }
115 void set_external_parameters(ProblemSpec<T>
const & in)
118 bgfunc.set_external_parameters(in);
119 int featureCount_ = in.column_count_;
120 splitColumns.resize(featureCount_);
121 for(
int k=0; k<featureCount_; ++k)
123 min_gini_.resize(featureCount_);
124 min_indices_.resize(featureCount_);
125 min_thresholds_.resize(featureCount_);
129 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
131 MultiArrayView<2, T2, C2> multiClassLabels,
133 ArrayVector<Region>& childRegions,
138 typedef typename Region::IndexIterator IndexIterator;
139 typedef typename MultiArrayView <2, T, C>::difference_type fShape;
145 if(std::accumulate(region.classCounts().begin(),
146 region.classCounts().end(), 0) != region.size())
148 RandomForestClassCounter< MultiArrayView<2,T2, C2>,
149 ArrayVector<double> >
150 counter(multiClassLabels, region.classCounts());
151 std::for_each( region.begin(), region.end(), counter);
152 region.classCountsIsValid =
true;
159 if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2)
163 for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
164 std::swap(splitColumns[ii],
165 splitColumns[ii+ randint(features.shape(1) - ii)]);
168 MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1));
171 for(
int n=0; n<(int)region.classCounts().size(); n++)
172 nNumClasses+=((region.classCounts()[n]>0) ? 1:0);
178 int nMaxClassCounts=0;
179 for(
int n=0; n<(int)region.classCounts().size(); n++)
183 if(region.classCounts()[n]>nMaxClassCounts)
185 nMaxClassCounts=region.classCounts()[n];
191 for(
int n=0; n<multiClassLabels.shape(0); n++)
192 labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0);
195 labels=multiClassLabels;
229 MultiArrayView<2, T, C> cVector;
230 MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_));
232 MultiArray<2, double> regrLabels(dShape(region.size(),1));
235 MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1));
236 MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1));
237 for(
int m=0; m<SB::ext_param_.actual_mtry_; m++)
242 double dCurrFeatureColumnMean=0.0;
243 double dCurrFeatureColumnStd=1.0;
246 for(
int n=0; n<region.size(); n++)
247 dCurrFeatureColumnMean+=cVector[region[n]];
248 dCurrFeatureColumnMean/=region.size();
250 if(m_bDoScalingInTraining)
252 for(
int n=0; n<region.size(); n++)
254 dCurrFeatureColumnStd+=
255 (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean);
258 dCurrFeatureColumnStd=
sqrt(dCurrFeatureColumnStd/(region.size()-1));
261 stdMatrix(m,0)=dCurrFeatureColumnStd;
263 meanMatrix(m,0)=dCurrFeatureColumnMean;
267 for(
int n=0; n<region.size(); n++)
268 xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd;
273 for(
int n=0; n<region.size(); n++)
278 regrLabels(n,0)=((labels[region[n]]==0) ? -1:1);
281 MultiArray<2, double> dLambdas(dShape(11,1));
283 for(
int nLambda=-5; nLambda<=5; nLambda++)
284 dLambdas[nCounter++]=
pow(10.0,nLambda);
286 MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11));
289 double dMaxRidgeSum=NumericTraits<double>::min();
290 double dCurrRidgeSum;
291 int nMaxRidgeSumAtLambdaInd=0;
293 for(
int nLambdaInd=0; nLambdaInd<11; nLambdaInd++)
301 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
303 for(
int n=0; n<region.oob_size(); n++)
305 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
306 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
308 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
309 features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd);
313 double dCurrIntercept=0.0;
314 if(m_bDoBestLambdaBasedOnGini)
317 bgfunc(dDistanceFromHyperplane,
320 region.oob_begin(), region.oob_end(),
321 region.classCounts());
322 dCurrIntercept=bgfunc.min_threshold_;
326 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
327 dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd);
330 for(
int n=0; n<region.oob_size(); n++)
333 int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0);
334 dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0);
336 if(dCurrRidgeSum>dMaxRidgeSum)
338 dMaxRidgeSum=dCurrRidgeSum;
339 nMaxRidgeSumAtLambdaInd=nLambdaInd;
345 Node<i_HyperplaneNode> node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data);
349 MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1));
350 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
351 dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0);
354 double dVnorm=
columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm();
356 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
357 node.weights()[n]=dCoeffVector(n,0)/dVnorm;
361 node.column_data()[0]=SB::ext_param_.actual_mtry_;
362 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
363 node.column_data()[n+1]=splitColumns[n];
369 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
371 for(
int n=0; n<region.size(); n++)
373 dDistanceFromHyperplane(region[n],0)=0.0;
374 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
376 dDistanceFromHyperplane(region[n],0)+=
377 features(region[n],m)*node.weights()[m];
380 for(
int n=0; n<region.oob_size(); n++)
382 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
383 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
385 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
386 features(region.oob_begin()[n],m)*node.weights()[m];
391 bgfunc(dDistanceFromHyperplane,
394 region.begin(), region.end(),
395 region.classCounts());
402 node.intercept() = bgfunc.min_threshold_;
405 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
406 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
407 childRegions[0].classCountsIsValid =
true;
408 childRegions[1].classCountsIsValid =
true;
411 childRegions[0].setRange( region.begin() , region.begin() + bgfunc.min_index_ );
412 childRegions[0].rule = region.rule;
413 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
414 childRegions[1].setRange( region.begin() + bgfunc.min_index_ , region.end() );
415 childRegions[1].rule = region.rule;
416 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
421 std::sort(region.oob_begin(), region.oob_end(),
422 SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0));
426 for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++)
428 if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept())
432 childRegions[0].set_oob_range( region.oob_begin() , region.oob_begin() + nOOBindx );
433 childRegions[1].set_oob_range( region.oob_begin() + nOOBindx , region.oob_end() );
439 return i_HyperplaneNode;
449 #endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H