35 #ifndef VIGRA_RF3_RANDOM_FOREST_HXX
36 #define VIGRA_RF3_RANDOM_FOREST_HXX
38 #include <type_traits>
41 #include "../multi_shape.hxx"
42 #include "../binary_forest.hxx"
43 #include "../threadpool.hxx"
44 #include "random_forest_common.hxx"
64 template <
typename FEATURES,
66 typename SPLITTESTS = LessEqualSplitTest<typename FEATURES::value_type>,
67 typename ACCTYPE = ArgMaxVectorAcc<double>>
72 typedef FEATURES Features;
73 typedef typename Features::value_type FeatureType;
74 typedef LABELS Labels;
75 typedef typename Labels::value_type LabelType;
76 typedef SPLITTESTS SplitTests;
78 typedef typename ACC::input_type AccInputType;
82 static ContainerTag
const container_tag = VectorTag;
102 typename NodeMap<SplitTests>::type
const & split_tests,
103 typename NodeMap<AccInputType>::type
const & node_responses,
104 ProblemSpec<LabelType>
const & problem_spec
115 FEATURES
const & features,
118 const std::vector<size_t> & tree_indices = std::vector<size_t>()
123 template <
typename PROBS>
125 FEATURES
const & features,
128 const std::vector<size_t> & tree_indices = std::vector<size_t>()
133 template <
typename IDS>
135 FEATURES
const & features,
138 const std::vector<size_t> tree_indices = std::vector<size_t>()
183 template <
typename IDS,
typename INDICES>
184 double leaf_ids_impl(
185 FEATURES
const & features,
189 INDICES
const & tree_indices
192 template<
typename PROBS>
193 void predict_probabilities_impl(
194 FEATURES
const & features,
197 const std::vector<size_t> & tree_indices)
const;
201 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
210 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
213 typename NodeMap<SplitTests>::type
const & split_tests,
214 typename NodeMap<AccInputType>::type
const & node_responses,
218 split_tests_(split_tests),
219 node_responses_(node_responses),
220 problem_spec_(problem_spec)
223 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
228 "RandomForest::merge(): You cannot merge with different problem specs.");
232 size_t const offset = num_nodes();
233 graph_.merge(other.
graph_);
236 split_tests_.insert(Node(p.first.id()+offset), p.second);
240 node_responses_.insert(Node(p.first.id()+offset), p.second);
246 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
248 FEATURES
const & features,
251 const std::vector<size_t> & tree_indices
253 vigra_precondition(features.shape()[0] == labels.shape()[0],
254 "RandomForest::predict(): Shape mismatch between features and labels.");
255 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
256 "RandomForest::predict(): Number of features in prediction differs from training.");
259 predict_probabilities(features, probs, n_threads, tree_indices);
260 for (
size_t i = 0; i < (size_t)features.shape()[0]; ++i)
262 auto const sub_probs = probs.template bind<0>(i);
263 auto it = std::max_element(sub_probs.begin(), sub_probs.end());
264 size_t const label = std::distance(sub_probs.begin(), it);
265 labels(i) = problem_spec_.distinct_classes_[label];
272 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
273 template <
typename PROBS>
275 FEATURES
const & features,
278 const std::vector<size_t> & tree_indices
280 vigra_precondition(features.shape()[0] == probs.shape()[0],
281 "RandomForest::predict_probabilities(): Shape mismatch between features and probabilities.");
282 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
283 "RandomForest::predict_probabilities(): Number of features in prediction differs from training.");
284 vigra_precondition((
size_t)probs.shape()[1] == problem_spec_.num_classes_,
285 "RandomForest::predict_probabilities(): Number of labels in probabilities differs from training.");
289 std::vector<size_t> tree_indices_cpy(tree_indices);
290 if (tree_indices_cpy.size() == 0)
292 tree_indices_cpy.resize(graph_.numRoots());
293 std::iota(tree_indices_cpy.begin(), tree_indices_cpy.end(), 0);
297 std::sort(tree_indices_cpy.begin(), tree_indices_cpy.end());
298 tree_indices_cpy.erase(std::unique(tree_indices_cpy.begin(), tree_indices_cpy.end()), tree_indices_cpy.end());
299 for (
auto i : tree_indices_cpy)
300 vigra_precondition(i < graph_.numRoots(),
"RandomForest::leaf_ids(): Tree index out of range.");
303 size_t const num_instances = features.shape()[0];
306 n_threads = std::thread::hardware_concurrency();
313 [&features,&probs,&tree_indices_cpy,
this](
size_t,
size_t i) {
314 this->predict_probabilities_impl(features, probs, i, tree_indices_cpy);
319 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
320 template <
typename PROBS>
322 FEATURES
const & features,
325 const std::vector<size_t> & tree_indices
330 std::vector<AccInputType> tree_results;
331 tree_results.reserve(tree_indices.size());
332 auto const sub_features = features.template bind<0>(i);
335 for (
auto k : tree_indices)
337 Node node = graph_.getRoot(k);
338 while (graph_.outDegree(node) > 0)
340 size_t const child_index = split_tests_.at(node)(sub_features);
341 node = graph_.getChild(node, child_index);
343 tree_results.emplace_back(node_responses_.at(node));
347 auto sub_probs = probs.template bind<0>(i);
348 acc(tree_results.begin(), tree_results.end(), sub_probs.begin());
351 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
352 template <
typename IDS>
354 FEATURES
const & features,
357 std::vector<size_t> tree_indices
359 vigra_precondition(features.shape()[0] == ids.shape()[0],
360 "RandomForest::leaf_ids(): Shape mismatch between features and probabilities.");
361 vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
362 "RandomForest::leaf_ids(): Number of features in prediction differs from training.");
363 vigra_precondition(ids.shape()[1] == graph_.numRoots(),
364 "RandomForest::leaf_ids(): Leaf array has wrong shape.");
367 std::sort(tree_indices.begin(), tree_indices.end());
368 tree_indices.erase(std::unique(tree_indices.begin(), tree_indices.end()), tree_indices.end());
369 for (
auto i : tree_indices)
370 vigra_precondition(i < graph_.numRoots(),
"RandomForest::leaf_ids(): Tree index out of range.");
373 if (tree_indices.size() == 0)
375 tree_indices.resize(graph_.numRoots());
376 std::iota(tree_indices.begin(), tree_indices.end(), 0);
379 size_t const num_instances = features.shape()[0];
381 n_threads = std::thread::hardware_concurrency();
384 std::vector<double> split_comparisons(n_threads, 0.0);
385 std::vector<size_t> indices(num_instances);
386 std::iota(indices.begin(), indices.end(), 0);
387 std::fill(ids.begin(), ids.end(), -1);
392 [
this, &features, &ids, &split_comparisons, &tree_indices](
size_t thread_id,
size_t i) {
393 split_comparisons[thread_id] += this->leaf_ids_impl(features, ids, i, i+1, tree_indices);
397 double const sum_split_comparisons = std::accumulate(split_comparisons.begin(), split_comparisons.end(), 0.0);
398 return sum_split_comparisons / features.shape()[0];
401 template <
typename FEATURES,
typename LABELS,
typename SPLITTESTS,
typename ACC>
402 template <
typename IDS,
typename INDICES>
404 FEATURES
const & features,
408 INDICES
const & tree_indices
410 vigra_precondition(features.shape()[0] == ids.shape()[0],
411 "RandomForest::leaf_ids_impl(): Shape mismatch between features and labels.");
412 vigra_precondition(features.shape()[1] == problem_spec_.num_features_,
413 "RandomForest::leaf_ids_impl(): Number of Features in prediction differs from training.");
414 vigra_precondition(from >= 0 && from <= to && to <= (
size_t)features.shape()[0],
415 "RandomForest::leaf_ids_impl(): Indices out of range.");
416 vigra_precondition(ids.shape()[1] == graph_.numRoots(),
417 "RandomForest::leaf_ids_impl(): Leaf array has wrong shape.");
419 double split_comparisons = 0.0;
420 for (
size_t i = from; i < to; ++i)
422 auto const sub_features = features.template bind<0>(i);
423 for (
auto k : tree_indices)
425 Node node = graph_.getRoot(k);
426 while (graph_.outDegree(node) > 0)
428 size_t const child_index = split_tests_.at(node)(sub_features);
429 node = graph_.getChild(node, child_index);
430 split_comparisons += 1.0;
432 ids(i, k) = node.id();
435 return split_comparisons;
void predict_probabilities(FEATURES const &features, PROBS &probs, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the probabilities of the given data and return the average number of split comparisons...
Definition: random_forest.hxx:274
void predict(FEATURES const &features, LABELS &labels, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the given data and return the average number of split comparisons.
Definition: random_forest.hxx:247
The PropertyMap is used to store Node or Arc information of graphs.
Definition: graphs.hxx:410
void merge(RandomForest const &other)
Grow this forest by incorporating the other.
Definition: random_forest.hxx:224
detail::NodeDescriptor< index_type > Node
Node descriptor type of the present graph.
Definition: binary_forest.hxx:70
size_t numNodes() const
Return the number of nodes (equivalent to maxNodeId()+1).
Definition: binary_forest.hxx:289
NodeMap< SplitTests >::type split_tests_
Contains a test for each internal node, that is used to determine whether given data goes to the left...
Definition: random_forest.hxx:169
problem specification class for the random forest.
Definition: rf_common.hxx:538
size_t num_features() const
Return the number of classes.
Definition: random_forest.hxx:160
Graph graph_
The graph structure.
Definition: random_forest.hxx:166
Random forest version 3.
Definition: random_forest.hxx:68
RandomForestOptions options_
The options that were used for training.
Definition: random_forest.hxx:178
double leaf_ids(FEATURES const &features, IDS &ids, int n_threads=-1, const std::vector< size_t > tree_indices=std::vector< size_t >()) const
For each data point in features, compute the corresponding leaf ids and return the average number of ...
Definition: random_forest.hxx:353
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
ProblemSpec< LabelType > problem_spec_
The specifications.
Definition: random_forest.hxx:175
void parallel_foreach(...)
Apply a functor to all items in a range in parallel.
BinaryForest stores a collection of rooted binary trees.
Definition: binary_forest.hxx:64
NodeMap< AccInputType >::type node_responses_
Contains the responses of each node (for example the most frequent label).
Definition: random_forest.hxx:172
size_t numRoots() const
Return the number of trees in the forest.
Definition: binary_forest.hxx:332
size_t num_nodes() const
Return the number of nodes.
Definition: random_forest.hxx:142
Options class for vigra::rf3::RandomForest version 3.
Definition: random_forest_common.hxx:582
size_t num_classes() const
Return the number of classes.
Definition: random_forest.hxx:154
size_t num_trees() const
Return the number of trees.
Definition: random_forest.hxx:148