[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef VIGRA_RF3_RANDOM_FOREST_HXX
36 #define VIGRA_RF3_RANDOM_FOREST_HXX
37 
38 #include <type_traits>
39 #include <thread>
40 
41 #include "../multi_shape.hxx"
42 #include "../binary_forest.hxx"
43 #include "../threadpool.hxx"
44 #include "random_forest_common.hxx"
45 
46 
47 
48 namespace vigra
49 {
50 
51 namespace rf3
52 {
53 
54 /********************************************************/
55 /* */
56 /* rf3::RandomForest */
57 /* */
58 /********************************************************/
59 
60 /** \brief Random forest version 3.
61 
62  vigra::rf3::RandomForest is typicall constructed via the factory function \ref vigra::rf3::random_forest().
63 */
64 template <typename FEATURES,
65  typename LABELS,
66  typename SPLITTESTS = LessEqualSplitTest<typename FEATURES::value_type>,
67  typename ACCTYPE = ArgMaxVectorAcc<double>>
69 {
70 public:
71 
72  typedef FEATURES Features;
73  typedef typename Features::value_type FeatureType;
74  typedef LABELS Labels;
75  typedef typename Labels::value_type LabelType;
76  typedef SPLITTESTS SplitTests;
77  typedef ACCTYPE ACC;
78  typedef typename ACC::input_type AccInputType;
79  typedef BinaryForest Graph;
80  typedef Graph::Node Node;
81 
82  static ContainerTag const container_tag = VectorTag;
83 
84  // FIXME:
85  // Once the support for Visual Studio 2012 is dropped, replace this struct with
86  // template <typename T>
87  // using NodeMap = PropertyMap<Node, T, container_tag>;
88  // Then the verbose typename NodeMap<T>::type, which typically shows up on NodeMap usages,
89  // can be replace with NodeMap<T>.
90  template <typename T>
91  struct NodeMap
92  {
94  };
95 
96  // Default (empty) constructor.
97  RandomForest();
98 
99  // Default constructor (copy all of the given stuff).
100  RandomForest(
101  Graph const & graph,
102  typename NodeMap<SplitTests>::type const & split_tests,
103  typename NodeMap<AccInputType>::type const & node_responses,
104  ProblemSpec<LabelType> const & problem_spec
105  );
106 
107  /// \brief Grow this forest by incorporating the other.
108  void merge(
109  RandomForest const & other
110  );
111 
112  /// \brief Predict the given data and return the average number of split comparisons.
113  /// \note labels must be a 1-D array with size <tt>features.shape(0)</tt>.
114  void predict(
115  FEATURES const & features,
116  LABELS & labels,
117  int n_threads = -1,
118  const std::vector<size_t> & tree_indices = std::vector<size_t>()
119  ) const;
120 
121  /// \brief Predict the probabilities of the given data and return the average number of split comparisons.
122  /// \note probs should have the shape (features.shape()[0], num_classes).
123  template <typename PROBS>
125  FEATURES const & features,
126  PROBS & probs,
127  int n_threads = -1,
128  const std::vector<size_t> & tree_indices = std::vector<size_t>()
129  ) const;
130 
131  /// \brief For each data point in features, compute the corresponding leaf ids and return the average number of split comparisons.
132  /// \note ids should have the shape (features.shape()[0], num_trees).
133  template <typename IDS>
134  double leaf_ids(
135  FEATURES const & features,
136  IDS & ids,
137  int n_threads = -1,
138  const std::vector<size_t> tree_indices = std::vector<size_t>()
139  ) const;
140 
141  /// \brief Return the number of nodes.
142  size_t num_nodes() const
143  {
144  return graph_.numNodes();
145  }
146 
147  /// \brief Return the number of trees.
148  size_t num_trees() const
149  {
150  return graph_.numRoots();
151  }
152 
153  /// \brief Return the number of classes.
154  size_t num_classes() const
155  {
156  return problem_spec_.num_classes_;
157  }
158 
159  /// \brief Return the number of classes.
160  size_t num_features() const
161  {
162  return problem_spec_.num_features_;
163  }
164 
165  /// \brief The graph structure.
167 
168  /// \brief Contains a test for each internal node, that is used to determine whether given data goes to the left or the right child.
169  typename NodeMap<SplitTests>::type split_tests_;
170 
171  /// \brief Contains the responses of each node (for example the most frequent label).
172  typename NodeMap<AccInputType>::type node_responses_;
173 
174  /// \brief The specifications.
175  ProblemSpec<LabelType> problem_spec_;
176 
177  /// \brief The options that were used for training.
179 
180 private:
181 
182  /// \brief Compute the leaf ids of the instances in [from, to).
183  template <typename IDS, typename INDICES>
184  double leaf_ids_impl(
185  FEATURES const & features,
186  IDS & ids,
187  size_t from,
188  size_t to,
189  INDICES const & tree_indices
190  ) const;
191 
192  template<typename PROBS>
193  void predict_probabilities_impl(
194  FEATURES const & features,
195  PROBS & probs,
196  const size_t i,
197  const std::vector<size_t> & tree_indices) const;
198 
199 };
200 
201 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
203  :
204  graph_(),
205  split_tests_(),
206  node_responses_(),
207  problem_spec_()
208 {}
209 
210 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
212  Graph const & graph,
213  typename NodeMap<SplitTests>::type const & split_tests,
214  typename NodeMap<AccInputType>::type const & node_responses,
215  ProblemSpec<LabelType> const & problem_spec
216 ) :
217  graph_(graph),
218  split_tests_(split_tests),
219  node_responses_(node_responses),
220  problem_spec_(problem_spec)
221 {}
222 
223 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
225  RandomForest const & other
226 ){
227  vigra_precondition(problem_spec_ == other.problem_spec_,
228  "RandomForest::merge(): You cannot merge with different problem specs.");
229 
230  // FIXME: Eventually compare the options and only fix if the forests are compatible.
231 
232  size_t const offset = num_nodes();
233  graph_.merge(other.graph_);
234  for (auto const & p : other.split_tests_)
235  {
236  split_tests_.insert(Node(p.first.id()+offset), p.second);
237  }
238  for (auto const & p : other.node_responses_)
239  {
240  node_responses_.insert(Node(p.first.id()+offset), p.second);
241  }
242 }
243 
244 // FIXME TODO we don't support the selection of tree indices any more in predict_probabilities, might be a good idea
245 // to re-enable this.
246 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
248  FEATURES const & features,
249  LABELS & labels,
250  int n_threads,
251  const std::vector<size_t> & tree_indices
252 ) const {
253  vigra_precondition(features.shape()[0] == labels.shape()[0],
254  "RandomForest::predict(): Shape mismatch between features and labels.");
255  vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
256  "RandomForest::predict(): Number of features in prediction differs from training.");
257 
258  MultiArray<2, double> probs(Shape2(features.shape()[0], problem_spec_.num_classes_));
259  predict_probabilities(features, probs, n_threads, tree_indices);
260  for (size_t i = 0; i < (size_t)features.shape()[0]; ++i)
261  {
262  auto const sub_probs = probs.template bind<0>(i);
263  auto it = std::max_element(sub_probs.begin(), sub_probs.end());
264  size_t const label = std::distance(sub_probs.begin(), it);
265  labels(i) = problem_spec_.distinct_classes_[label];
266  }
267 }
268 
269 
270 // FIXME TODO we don't support the selection of tree indices any more in predict_probabilities, might be a good idea
271 // to re-enable this.
272 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
273 template <typename PROBS>
275  FEATURES const & features,
276  PROBS & probs,
277  int n_threads,
278  const std::vector<size_t> & tree_indices
279 ) const {
280  vigra_precondition(features.shape()[0] == probs.shape()[0],
281  "RandomForest::predict_probabilities(): Shape mismatch between features and probabilities.");
282  vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
283  "RandomForest::predict_probabilities(): Number of features in prediction differs from training.");
284  vigra_precondition((size_t)probs.shape()[1] == problem_spec_.num_classes_,
285  "RandomForest::predict_probabilities(): Number of labels in probabilities differs from training.");
286 
287  // By default, actual_tree_indices is empty. In that case we want to use all trees.
288  // We need to make a copy. I really don't know how the old code did compile...
289  std::vector<size_t> tree_indices_cpy(tree_indices);
290  if (tree_indices_cpy.size() == 0)
291  {
292  tree_indices_cpy.resize(graph_.numRoots());
293  std::iota(tree_indices_cpy.begin(), tree_indices_cpy.end(), 0);
294  }
295  else {
296  // Check the tree indices.
297  std::sort(tree_indices_cpy.begin(), tree_indices_cpy.end());
298  tree_indices_cpy.erase(std::unique(tree_indices_cpy.begin(), tree_indices_cpy.end()), tree_indices_cpy.end());
299  for (auto i : tree_indices_cpy)
300  vigra_precondition(i < graph_.numRoots(), "RandomForest::leaf_ids(): Tree index out of range.");
301  }
302 
303  size_t const num_instances = features.shape()[0];
304 
305  if (n_threads == -1)
306  n_threads = std::thread::hardware_concurrency();
307  if (n_threads < 1)
308  n_threads = 1;
309 
311  n_threads,
312  num_instances,
313  [&features,&probs,&tree_indices_cpy,this](size_t, size_t i) {
314  this->predict_probabilities_impl(features, probs, i, tree_indices_cpy);
315  }
316  );
317 }
318 
319 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
320 template <typename PROBS>
322  FEATURES const & features,
323  PROBS & probs,
324  const size_t i,
325  const std::vector<size_t> & tree_indices
326 ) const {
327 
328  // instantiate the accumulation function and the vector to store the tree node results
329  ACC acc;
330  std::vector<AccInputType> tree_results;
331  tree_results.reserve(tree_indices.size());
332  auto const sub_features = features.template bind<0>(i);
333 
334  // loop over the trees
335  for (auto k : tree_indices)
336  {
337  Node node = graph_.getRoot(k);
338  while (graph_.outDegree(node) > 0)
339  {
340  size_t const child_index = split_tests_.at(node)(sub_features);
341  node = graph_.getChild(node, child_index);
342  }
343  tree_results.emplace_back(node_responses_.at(node));
344  }
345 
346  // write the tree results into the probabilities
347  auto sub_probs = probs.template bind<0>(i);
348  acc(tree_results.begin(), tree_results.end(), sub_probs.begin());
349 }
350 
351 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
352 template <typename IDS>
354  FEATURES const & features,
355  IDS & ids,
356  int n_threads,
357  std::vector<size_t> tree_indices
358 ) const {
359  vigra_precondition(features.shape()[0] == ids.shape()[0],
360  "RandomForest::leaf_ids(): Shape mismatch between features and probabilities.");
361  vigra_precondition((size_t)features.shape()[1] == problem_spec_.num_features_,
362  "RandomForest::leaf_ids(): Number of features in prediction differs from training.");
363  vigra_precondition(ids.shape()[1] == graph_.numRoots(),
364  "RandomForest::leaf_ids(): Leaf array has wrong shape.");
365 
366  // Check the tree indices.
367  std::sort(tree_indices.begin(), tree_indices.end());
368  tree_indices.erase(std::unique(tree_indices.begin(), tree_indices.end()), tree_indices.end());
369  for (auto i : tree_indices)
370  vigra_precondition(i < graph_.numRoots(), "RandomForest::leaf_ids(): Tree index out of range.");
371 
372  // By default, actual_tree_indices is empty. In that case we want to use all trees.
373  if (tree_indices.size() == 0)
374  {
375  tree_indices.resize(graph_.numRoots());
376  std::iota(tree_indices.begin(), tree_indices.end(), 0);
377  }
378 
379  size_t const num_instances = features.shape()[0];
380  if (n_threads == -1)
381  n_threads = std::thread::hardware_concurrency();
382  if (n_threads < 1)
383  n_threads = 1;
384  std::vector<double> split_comparisons(n_threads, 0.0);
385  std::vector<size_t> indices(num_instances);
386  std::iota(indices.begin(), indices.end(), 0);
387  std::fill(ids.begin(), ids.end(), -1);
389  n_threads,
390  indices.begin(),
391  indices.end(),
392  [this, &features, &ids, &split_comparisons, &tree_indices](size_t thread_id, size_t i) {
393  split_comparisons[thread_id] += this->leaf_ids_impl(features, ids, i, i+1, tree_indices);
394  }
395  );
396 
397  double const sum_split_comparisons = std::accumulate(split_comparisons.begin(), split_comparisons.end(), 0.0);
398  return sum_split_comparisons / features.shape()[0];
399 }
400 
401 template <typename FEATURES, typename LABELS, typename SPLITTESTS, typename ACC>
402 template <typename IDS, typename INDICES>
404  FEATURES const & features,
405  IDS & ids,
406  size_t from,
407  size_t to,
408  INDICES const & tree_indices
409 ) const {
410  vigra_precondition(features.shape()[0] == ids.shape()[0],
411  "RandomForest::leaf_ids_impl(): Shape mismatch between features and labels.");
412  vigra_precondition(features.shape()[1] == problem_spec_.num_features_,
413  "RandomForest::leaf_ids_impl(): Number of Features in prediction differs from training.");
414  vigra_precondition(from >= 0 && from <= to && to <= (size_t)features.shape()[0],
415  "RandomForest::leaf_ids_impl(): Indices out of range.");
416  vigra_precondition(ids.shape()[1] == graph_.numRoots(),
417  "RandomForest::leaf_ids_impl(): Leaf array has wrong shape.");
418 
419  double split_comparisons = 0.0;
420  for (size_t i = from; i < to; ++i)
421  {
422  auto const sub_features = features.template bind<0>(i);
423  for (auto k : tree_indices)
424  {
425  Node node = graph_.getRoot(k);
426  while (graph_.outDegree(node) > 0)
427  {
428  size_t const child_index = split_tests_.at(node)(sub_features);
429  node = graph_.getChild(node, child_index);
430  split_comparisons += 1.0;
431  }
432  ids(i, k) = node.id();
433  }
434  }
435  return split_comparisons;
436 }
437 
438 
439 
440 } // namespace rf3
441 } // namespace vigra
442 
443 #endif
void predict_probabilities(FEATURES const &features, PROBS &probs, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the probabilities of the given data and return the average number of split comparisons...
Definition: random_forest.hxx:274
void predict(FEATURES const &features, LABELS &labels, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const
Predict the given data and return the average number of split comparisons.
Definition: random_forest.hxx:247
The PropertyMap is used to store Node or Arc information of graphs.
Definition: graphs.hxx:410
void merge(RandomForest const &other)
Grow this forest by incorporating the other.
Definition: random_forest.hxx:224
detail::NodeDescriptor< index_type > Node
Node descriptor type of the present graph.
Definition: binary_forest.hxx:70
size_t numNodes() const
Return the number of nodes (equivalent to maxNodeId()+1).
Definition: binary_forest.hxx:289
NodeMap< SplitTests >::type split_tests_
Contains a test for each internal node, that is used to determine whether given data goes to the left...
Definition: random_forest.hxx:169
problem specification class for the random forest.
Definition: rf_common.hxx:538
size_t num_features() const
Return the number of classes.
Definition: random_forest.hxx:160
Graph graph_
The graph structure.
Definition: random_forest.hxx:166
Random forest version 3.
Definition: random_forest.hxx:68
RandomForestOptions options_
The options that were used for training.
Definition: random_forest.hxx:178
double leaf_ids(FEATURES const &features, IDS &ids, int n_threads=-1, const std::vector< size_t > tree_indices=std::vector< size_t >()) const
For each data point in features, compute the corresponding leaf ids and return the average number of ...
Definition: random_forest.hxx:353
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
ProblemSpec< LabelType > problem_spec_
The specifications.
Definition: random_forest.hxx:175
void parallel_foreach(...)
Apply a functor to all items in a range in parallel.
BinaryForest stores a collection of rooted binary trees.
Definition: binary_forest.hxx:64
NodeMap< AccInputType >::type node_responses_
Contains the responses of each node (for example the most frequent label).
Definition: random_forest.hxx:172
size_t numRoots() const
Return the number of trees in the forest.
Definition: binary_forest.hxx:332
size_t num_nodes() const
Return the number of nodes.
Definition: random_forest.hxx:142
Options class for vigra::rf3::RandomForest version 3.
Definition: random_forest_common.hxx:582
size_t num_classes() const
Return the number of classes.
Definition: random_forest.hxx:154
size_t num_trees() const
Return the number of trees.
Definition: random_forest.hxx:148

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.0 (Fri May 19 2017)