36 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
37 #define VIGRA_RANDOM_FOREST_DT_HXX
42 #include "vigra/multi_array.hxx"
43 #include "vigra/mathutil.hxx"
44 #include "vigra/array_vector.hxx"
45 #include "vigra/sized_int.hxx"
46 #include "vigra/matrix.hxx"
47 #include "vigra/random.hxx"
48 #include "vigra/functorexpression.hxx"
51 #include "rf_common.hxx"
52 #include "rf_visitors.hxx"
53 #include "rf_nodeproxy.hxx"
91 unsigned int classCount_;
99 ext_param_(ext_param),
100 classCount_(ext_param.class_count_)
105 void reset(
unsigned int classCount = 0)
108 classCount_ = classCount;
121 template <
class U,
class C,
130 StackEntry_t
const & stack_entry,
135 template <
class U,
class C,
144 StackEntry_t
const & stack_entry,
150 int garbaged_child=-1);
155 return (in & LeafNodeTag) == LeafNodeTag;
163 template<
class U,
class C,
class Visitor_t>
165 Visitor_t & visitor)
const
170 visitor.visit_internal_node(*
this, index, topology_[index],features);
171 switch(topology_[index])
173 case i_ThresholdNode:
175 Node<i_ThresholdNode>
176 node(topology_, parameters_, index);
177 index = node.next(features);
180 case i_HyperplaneNode:
182 Node<i_HyperplaneNode>
183 node(topology_, parameters_, index);
184 index = node.next(features);
187 case i_HypersphereNode:
189 Node<i_HypersphereNode>
190 node(topology_, parameters_, index);
191 index = node.next(features);
199 node(topology_, parameters, index);
200 index = node.next(features);
204 vigra_fail(
"DecisionTree::getToLeaf():"
205 "encountered unknown internal Node Type");
208 visitor.visit_external_node(*
this, index, topology_[index],features);
216 template<
class Visitor_t>
221 while(index < topology_.
size())
226 .visit_external_node(*
this, index, topology_[index]);
231 ._internal_node(*
this, index, topology_[index]);
236 template<
class Visitor_t>
237 void traverse_post_order(Visitor_t visitor,
TreeInt start = 2)
const
240 std::vector<Entry > stack;
241 std::vector<double> result_stack;
242 stack.push_back(Entry(2, 0));
244 while(!stack.empty())
246 addr = stack.back()[0];
247 NodeBase node(topology_, parameters_, stack.back()[0]);
248 if(stack.back()[1] == 1)
251 double leftRes = result_stack.back();
252 double rightRes = result_stack.back();
253 result_stack.pop_back();
254 result_stack.pop_back();
255 result_stack.push_back(rightRes+ leftRes);
256 visitor.visit_internal_node(*
this,
265 visitor.visit_external_node(*
this,
270 result_stack.push_back(node.weights());
275 stack.push_back(Entry(node.child(0), 0));
276 stack.push_back(Entry(node.child(1), 0));
284 template<
class U,
class C>
292 template <
class U,
class C>
297 switch(topology_[nodeindex])
299 case e_ConstProbNode:
302 nodeindex).prob_begin();
306 case e_LogRegProbNode:
307 return Node<e_LogRegProbNode>(topology_,
309 nodeindex).prob_begin();
312 vigra_fail(
"DecisionTree::predict() :"
313 " encountered unknown external Node Type");
320 template <
class U,
class C>
321 Int32 predictLabel(MultiArrayView<2, U, C>
const & features)
const
323 ArrayVector<double>::const_iterator weights = predict(features);
324 return argMax(weights, weights+classCount_) - weights;
330 template <
class U,
class C,
339 StackEntry_t
const & stack_entry,
346 topology_.reserve(256);
347 parameters_.reserve(256);
348 topology_.push_back(features.
shape(1));
349 topology_.push_back(classCount_);
350 continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
353 template <
class U,
class C,
362 StackEntry_t
const & stack_entry,
370 std::vector<StackEntry_t> stack;
373 stack.push_back(stack_entry);
374 size_t last_node_pos = 0;
375 StackEntry_t top=stack.back();
377 while(!stack.empty())
385 child_stack_entry[0].reset();
386 child_stack_entry[1].reset();
396 NodeID = split.makeTerminalNode(features,
403 NodeID = split.findBestSplit(features,
413 visitor.visit_after_split(*
this, split, top,
414 child_stack_entry[0],
415 child_stack_entry[1],
423 last_node_pos = topology_.
size();
424 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
428 top.leftParent).
child(0) = last_node_pos;
430 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
434 top.rightParent).child(1) = last_node_pos;
443 child_stack_entry[0].leftParent = topology_.
size();
444 child_stack_entry[1].rightParent = topology_.
size();
445 child_stack_entry[0].rightParent = -1;
446 child_stack_entry[1].leftParent = -1;
447 stack.push_back(child_stack_entry[0]);
448 stack.push_back(child_stack_entry[1]);
453 NodeBase(split.createNode(), topology_, parameters_ );
455 if(garbaged_child!=-1)
457 Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
459 int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
460 topology_.resize(last_node_pos);
461 parameters_.resize(parameters_.size() - last_parameter_size);
463 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
466 top.leftParent).child(0) = garbaged_child;
467 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
470 top.rightParent).child(1) = garbaged_child;
478 #endif //VIGRA_RANDOM_FOREST_DT_HXX