37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
50 #include "metaprogramming.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
87 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
146 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
153 typedef detail::DecisionTree DecisionTree_t;
160 typedef LabelType LabelT;
227 template<
class TopologyIterator,
class ParameterIterator>
229 TopologyIterator topology_begin,
230 ParameterIterator parameter_begin,
234 trees_(treeCount, DecisionTree_t(problem_spec)),
235 ext_param_(problem_spec),
241 for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243 trees_[k].topology_ = *topology_begin;
244 trees_[k].parameters_ = *parameter_begin;
262 vigra_precondition(ext_param_.used() ==
true,
263 "RandomForest::ext_param(): "
264 "Random forest has not been trained yet.");
281 vigra_precondition(ext_param_.used() ==
false,
282 "RandomForest::set_ext_param():"
283 "Random forest has been trained! Call reset()"
284 "before specifying new extrinsic parameters.");
308 DecisionTree_t
const &
tree(
int index)
const
310 return trees_[index];
315 DecisionTree_t &
tree(
int index)
317 return trees_[index];
325 return ext_param_.column_count_;
336 return ext_param_.column_count_;
344 return ext_param_.class_count_;
351 return options_.tree_count_;
392 template <
class U,
class C1,
403 Random_t
const & random);
405 template <
class U,
class C1,
426 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
427 void learn( MultiArrayView<2, U, C1>
const & features,
428 MultiArrayView<2, U2,C2>
const & labels,
438 template <
class U,
class C1,
class U2,
class C2,
439 class Visitor_t,
class Split_t>
440 void learn( MultiArrayView<2, U, C1>
const & features,
441 MultiArrayView<2, U2,C2>
const & labels,
470 template <
class U,
class C1,
class U2,
class C2>
482 template<
class U,
class C1,
495 bool adjust_thresholds=
false);
497 template <
class U,
class C1,
class U2,
class C2>
502 onlineLearn(features,
512 template<
class U,
class C1,
518 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
519 MultiArrayView<2,U2,C2>
const & response,
526 template<
class U,
class C1,
class U2,
class C2>
527 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
528 MultiArrayView<2, U2, C2>
const & labels,
531 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
561 template <
class U,
class C,
class Stop>
562 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features, Stop & stop)
const;
564 template <
class U,
class C>
565 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features)
575 template <
class U,
class C>
576 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
577 ArrayVectorView<double> prior)
const;
589 template <
class U,
class C1,
class T,
class C2>
593 vigra_precondition(features.
shape(0) == labels.
shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(
int k=0; k<features.
shape(0); ++k)
597 vigra_precondition(!detail::contains_nan(rowVector(features, k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
613 template <
class U,
class C1,
class T,
class C2>
616 LabelType nanLabel)
const
618 vigra_precondition(features.
shape(0) == labels.
shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(
int k=0; k<features.
shape(0); ++k)
622 if(detail::contains_nan(rowVector(features, k)))
623 labels(k,0) = nanLabel;
638 template <
class U,
class C1,
class T,
class C2,
class Stop>
643 vigra_precondition(features.
shape(0) == labels.
shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(
int k=0; k<features.
shape(0); ++k)
646 labels(k,0) = detail::RequiresExplicitCast<T>::cast(
predictLabel(rowVector(features, k), stop));
660 template <
class U,
class C1,
class T,
class C2,
class Stop>
664 template <
class T1,
class T2,
class C>
674 template <
class U,
class C1,
class T,
class C2>
681 template <
class U,
class C1,
class T,
class C2>
691 template <
class LabelType,
class PreprocessorTag>
692 template<
class U,
class C1,
698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
699 MultiArrayView<2,U2,C2>
const & response,
705 bool adjust_thresholds)
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
712 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713 typedef UniformIntRandomFunctor<Random_t>
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721 Default_Stop_t default_stop(options_);
722 typename RF_CHOOSER(Stop_t)::type stop
723 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724 Default_Split_t default_split;
725 typename RF_CHOOSER(Split_t)::type split
726 = RF_CHOOSER(Split_t)::choose(split_, default_split);
727 rf::visitors::StopVisiting stopvisiting;
728 typedef rf::visitors::detail::VisitorNode
729 <rf::visitors::OnlineLearnVisitor,
730 typename RF_CHOOSER(Visitor_t)::type>
733 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
735 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
741 ext_param_.class_count_=0;
742 Preprocessor_t preprocessor( features, response,
743 options_, ext_param_);
746 RandFunctor_t randint ( random);
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
754 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
760 for(
int ii = 0; ii < static_cast<int>(trees_.
size()); ++ii)
762 online_visitor_.tree_id=ii;
763 poisson_sampler.sample();
764 std::map<int,int> leaf_parents;
765 leaf_parents.clear();
767 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
769 int sample=poisson_sampler[s];
770 online_visitor_.current_label=preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
776 online_visitor_.add_to_index_list(ii,leaf,sample);
779 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
781 leaf_parents[leaf]=online_visitor_.last_node_id;
786 std::map<int,int>::iterator leaf_iterator;
787 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
789 int leaf=leaf_iterator->first;
790 int parent=leaf_iterator->second;
791 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
792 ArrayVector<Int32> indeces;
794 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795 StackEntry_t stack_entry(indeces.begin(),
797 ext_param_.class_count_);
802 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
808 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
809 stack_entry.rightParent=parent;
813 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
815 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
828 online_visitor_.deactivate();
831 template<
class LabelType,
class PreprocessorTag>
832 template<
class U,
class C1,
853 ext_param_.class_count_=0;
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
863 typename RF_CHOOSER(Stop_t)::type stop
864 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
866 typename RF_CHOOSER(Split_t)::type split
867 = RF_CHOOSER(Split_t)::choose(split_, default_split);
871 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
873 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
875 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
879 RandFunctor_t randint ( random);
885 Preprocessor_t preprocessor( features, response,
886 options_, ext_param_);
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
897 preprocessor.strata().end(),
898 detail::make_sampler_opt(options_)
899 .sampleSize(ext_param().actual_msample_),
906 first_stack_entry( sampler.sampledIndices().begin(),
907 sampler.sampledIndices().end(),
908 ext_param_.class_count_);
910 .set_oob_range( sampler.oobIndices().begin(),
911 sampler.oobIndices().end());
913 online_visitor_.tree_id=treeId;
914 trees_[treeId].reset();
916 .learn( preprocessor.features(),
917 preprocessor.response(),
924 .visit_after_tree( *
this,
930 online_visitor_.deactivate();
933 template <
class LabelType,
class PreprocessorTag>
934 template <
class U,
class C1,
946 Random_t
const & random)
957 vigra_precondition(features.
shape(0) == response.
shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
967 typename RF_CHOOSER(Stop_t)::type stop
968 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
970 typename RF_CHOOSER(Split_t)::type split
971 = RF_CHOOSER(Split_t)::choose(split_, default_split);
975 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
977 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
982 online_visitor_.deactivate();
986 RandFunctor_t randint ( random);
993 Preprocessor_t preprocessor( features, response,
994 options_, ext_param_);
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1005 preprocessor.strata().end(),
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(ext_param().actual_msample_),
1010 visitor.visit_at_beginning(*
this, preprocessor);
1013 for(
int ii = 0; ii < static_cast<int>(trees_.
size()); ++ii)
1019 first_stack_entry( sampler.sampledIndices().begin(),
1020 sampler.sampledIndices().end(),
1021 ext_param_.class_count_);
1023 .set_oob_range( sampler.oobIndices().begin(),
1024 sampler.oobIndices().end());
1026 .learn( preprocessor.features(),
1027 preprocessor.response(),
1034 .visit_after_tree( *
this,
1041 visitor.visit_at_end(*
this, preprocessor);
1043 online_visitor_.deactivate();
1049 template <
class LabelType,
class Tag>
1050 template <
class U,
class C,
class Stop>
1054 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():"
1056 " Too few columns in feature matrix.");
1057 vigra_precondition(rowCount(features) == 1,
1058 "RandomForestn::predictLabel():"
1059 " Feature matrix must have a singlerow.");
1062 predictProbabilities(features, probabilities, stop);
1063 ext_param_.to_classlabel(
argMax(probabilities), d);
1069 template <
class LabelType,
class PreprocessorTag>
1070 template <
class U,
class C>
1075 using namespace functor;
1076 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(rowCount(features) == 1,
1079 "RandomForestn::predictLabel():"
1080 " Feature matrix must have a single row.");
1082 predictProbabilities(features, prob);
1083 std::transform( prob.begin(), prob.end(),
1084 priors.
begin(), prob.begin(),
1087 ext_param_.to_classlabel(
argMax(prob), d);
1091 template<
class LabelType,
class PreprocessorTag>
1092 template <
class T1,
class T2,
class C>
1100 vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1101 "RandomFroest::predictProbabilities():"
1102 " Feature matrix and probability matrix size mismatch.");
1105 vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():"
1107 " Too few columns in feature matrix.");
1108 vigra_precondition( columnCount(prob)
1109 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110 "RandomForestn::predictProbabilities():"
1111 " Probability matrix must have as many columns as there are classes.");
1114 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1117 for(
int k=0; k<options_.tree_count_; ++k)
1119 set_id=(set_id+1) % predictionSet.indices[0].size();
1120 typedef std::set<SampleRange<T1> > my_set;
1121 typedef typename my_set::iterator set_it;
1124 std::vector<std::pair<int,set_it> > stack;
1126 for(set_it i=predictionSet.ranges[set_id].begin();
1127 i!=predictionSet.ranges[set_id].end();++i)
1128 stack.push_back(std::pair<int,set_it>(2,i));
1130 int num_decisions=0;
1131 while(!stack.empty())
1133 set_it range=stack.back().second;
1134 int index=stack.back().first;
1138 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1141 trees_[k].parameters_,
1142 index).prob_begin();
1143 for(
int i=range->start;i!=range->end;++i)
1146 for(
int l=0; l<ext_param_.class_count_; ++l)
1148 prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1150 totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1157 if(trees_[k].topology_[index]!=i_ThresholdNode)
1159 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1161 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1168 if(range->max_boundaries[node.column()]<node.threshold())
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1175 SampleRange<T1> new_range=*range;
1176 new_range.min_boundaries[node.column()]=FLT_MAX;
1177 range->max_boundaries[node.column()]=-FLT_MAX;
1178 new_range.start=new_range.end=range->end;
1180 while(i!=range->end)
1183 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1185 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1189 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1200 if(range->start==range->end)
1202 predictionSet.ranges[set_id].erase(range);
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1209 if(new_range.start!=new_range.end)
1211 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1216 predictionSet.cumulativePredTime[k]=num_decisions;
1218 for(
unsigned int i=0;i<totalWeights.size();++i)
1222 for(
int l=0; l<ext_param_.class_count_; ++l)
1225 prob(i, l) /= totalWeights[i];
1227 assert(test==totalWeights[i]);
1228 assert(totalWeights[i]>0.0);
1232 template <
class LabelType,
class PreprocessorTag>
1233 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1236 MultiArrayView<2, T, C2> & prob,
1237 Stop_t & stop_)
const
1243 "RandomForestn::predictProbabilities():"
1244 " Feature matrix and probability matrix size mismatch.");
1248 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():"
1250 " Too few columns in feature matrix.");
1252 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253 "RandomForestn::predictProbabilities():"
1254 " Probability matrix must have as many columns as there are classes.");
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257 Default_Stop_t default_stop(options_);
1258 typename RF_CHOOSER(Stop_t)::type & stop
1259 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1261 stop.set_external_parameters(ext_param_, tree_count());
1262 prob.init(NumericTraits<T>::zero());
1272 for(
int row=0; row <
rowCount(features); ++row)
1274 MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
1278 if(detail::contains_nan(currentRow))
1284 ArrayVector<double>::const_iterator weights;
1287 double totalWeight = 0.0;
1290 for(
int k=0; k<options_.tree_count_; ++k)
1293 weights = trees_[k ].predict(currentRow);
1296 int weighted = options_.predict_weighted_;
1297 for(
int l=0; l<ext_param_.class_count_; ++l)
1299 double cur_w = weights[l] * (weighted * (*(weights-1))
1301 prob(row, l) += static_cast<T>(cur_w);
1303 totalWeight += cur_w;
1305 if(stop.after_prediction(weights,
1315 for(
int l=0; l< ext_param_.class_count_; ++l)
1317 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1323 template <
class LabelType,
class PreprocessorTag>
1324 template <
class U,
class C1,
class T,
class C2>
1325 void RandomForest<LabelType, PreprocessorTag>
1326 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1327 MultiArrayView<2, T, C2> & prob)
const
1333 "RandomForestn::predictProbabilities():"
1334 " Feature matrix and probability matrix size mismatch.");
1338 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():"
1340 " Too few columns in feature matrix.");
1342 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343 "RandomForestn::predictProbabilities():"
1344 " Probability matrix must have as many columns as there are classes.");
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347 prob.init(NumericTraits<T>::zero());
1357 for(
int row=0; row <
rowCount(features); ++row)
1359 ArrayVector<double>::const_iterator weights;
1362 double totalWeight = 0.0;
1365 for(
int k=0; k<options_.tree_count_; ++k)
1368 weights = trees_[k ].predict(
rowVector(features, row));
1371 int weighted = options_.predict_weighted_;
1372 for(
int l=0; l<ext_param_.class_count_; ++l)
1374 double cur_w = weights[l] * (weighted * (*(weights-1))
1376 prob(row, l) += static_cast<T>(cur_w);
1378 totalWeight += cur_w;
1382 prob/= options_.tree_count_;
1388 #include "random_forest/rf_algorithm.hxx"
1389 #endif // VIGRA_RANDOM_FOREST_HXX