[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_3_hdf5_impex.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2009,2014, 2015 by Sven Peter, Philip Schill, */
4 /* Rahul Nair and Ullrich Koethe */
5 /* */
6 /* This file is part of the VIGRA computer vision library. */
7 /* The VIGRA Website is */
8 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
9 /* Please direct questions, bug reports, and contributions to */
10 /* ullrich.koethe@iwr.uni-heidelberg.de or */
11 /* vigra@informatik.uni-hamburg.de */
12 /* */
13 /* Permission is hereby granted, free of charge, to any person */
14 /* obtaining a copy of this software and associated documentation */
15 /* files (the "Software"), to deal in the Software without */
16 /* restriction, including without limitation the rights to use, */
17 /* copy, modify, merge, publish, distribute, sublicense, and/or */
18 /* sell copies of the Software, and to permit persons to whom the */
19 /* Software is furnished to do so, subject to the following */
20 /* conditions: */
21 /* */
22 /* The above copyright notice and this permission notice shall be */
23 /* included in all copies or substantial portions of the */
24 /* Software. */
25 /* */
26 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
27 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
28 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
29 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
30 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
31 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
32 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
33 /* OTHER DEALINGS IN THE SOFTWARE. */
34 /* */
35 /************************************************************************/
36 
37 #ifndef VIGRA_RF3_IMPEX_HDF5_HXX
38 #define VIGRA_RF3_IMPEX_HDF5_HXX
39 
40 #include <string>
41 #include <sstream>
42 #include <iomanip>
43 #include <stack>
44 
45 #include "config.hxx"
46 #include "random_forest_3/random_forest.hxx"
47 #include "random_forest_3/random_forest_common.hxx"
48 #include "random_forest_3/random_forest_visitors.hxx"
49 #include "hdf5impex.hxx"
50 
51 namespace vigra
52 {
53 namespace rf3
54 {
55 
56 // needs to be in sync with random_forest_hdf5_impex for backwards compatibility
57 static const char *const rf_hdf5_ext_param = "_ext_param";
58 static const char *const rf_hdf5_options = "_options";
59 static const char *const rf_hdf5_topology = "topology";
60 static const char *const rf_hdf5_parameters = "parameters";
61 static const char *const rf_hdf5_tree = "Tree_";
62 static const char *const rf_hdf5_version_group = ".";
63 static const char *const rf_hdf5_version_tag = "vigra_random_forest_version";
64 static const double rf_hdf5_version = 0.1;
65 
66 // keep in sync with include/vigra/random_forest/rf_nodeproxy.hxx
67 enum NodeTags
68 {
69  rf_UnFilledNode = 42,
70  rf_AllColumns = 0x00000000,
71  rf_ToBePrunedTag = 0x80000000,
72  rf_LeafNodeTag = 0x40000000,
73 
74  rf_i_ThresholdNode = 0,
75  rf_i_HyperplaneNode = 1,
76  rf_i_HypersphereNode = 2,
77  rf_e_ConstProbNode = 0 | rf_LeafNodeTag,
78  rf_e_LogRegProbNode = 1 | rf_LeafNodeTag
79 };
80 
81 static const unsigned int rf_tag_mask = 0xf0000000;
82 static const unsigned int rf_type_mask = 0x00000003;
83 static const unsigned int rf_zero_mask = 0xffffffff & ~rf_tag_mask & ~rf_type_mask;
84 
85 namespace detail
86 {
87  inline std::string get_cwd(HDF5File & h5context)
88  {
89  return h5context.get_absolute_path(h5context.pwd());
90  }
91 }
92 
93 template <typename FEATURES, typename LABELS>
94 typename DefaultRF<FEATURES, LABELS>::type
95 random_forest_import_HDF5(HDF5File & h5ctx, std::string const & pathname = "")
96 {
97  typedef typename DefaultRF<FEATURES, LABELS>::type RF;
98  typedef typename RF::Graph Graph;
99  typedef typename RF::Node Node;
100  typedef typename RF::SplitTests SplitTest;
101  typedef typename LABELS::value_type LabelType;
102  typedef typename RF::AccInputType AccInputType;
103  typedef typename AccInputType::value_type AccValueType;
104 
105  std::string cwd;
106 
107  if (pathname.size()) {
108  cwd = detail::get_cwd(h5ctx);
109  h5ctx.cd(pathname);
110  }
111 
112  if (h5ctx.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag)) {
113  double version;
114  h5ctx.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag, version);
115  vigra_precondition(version <= rf_hdf5_version, "random_forest_import_HDF5(): unexpected file format version.");
116  }
117 
118  // Read ext params.
119  size_t actual_mtry;
120  size_t num_instances;
121  size_t num_features;
122  size_t num_classes;
123  size_t msample;
124  int is_weighted_int;
125  MultiArray<1, LabelType> distinct_labels_marray;
126  MultiArray<1, double> class_weights_marray;
127 
128  h5ctx.cd(rf_hdf5_ext_param);
129  h5ctx.read("actual_msample_", msample);
130  h5ctx.read("actual_mtry_", actual_mtry);
131  h5ctx.read("class_count_", num_classes);
132  h5ctx.readAndResize("class_weights_", class_weights_marray);
133  h5ctx.read("column_count_", num_features);
134  h5ctx.read("is_weighted_", is_weighted_int);
135  h5ctx.readAndResize("labels", distinct_labels_marray);
136  h5ctx.read("row_count_", num_instances);
137  h5ctx.cd_up();
138 
139  bool is_weighted = is_weighted_int == 1 ? true : false;
140 
141  // Read options.
142  size_t min_num_instances;
143  int mtry;
144  int mtry_switch_int;
145  int bootstrap_sampling_int;
146  int tree_count;
147  h5ctx.cd(rf_hdf5_options);
148  h5ctx.read("min_split_node_size_", min_num_instances);
149  h5ctx.read("mtry_", mtry);
150  h5ctx.read("mtry_switch_", mtry_switch_int);
151  h5ctx.read("sample_with_replacement_", bootstrap_sampling_int);
152  h5ctx.read("tree_count_", tree_count);
153  h5ctx.cd_up();
154 
155  RandomForestOptionTags mtry_switch = (RandomForestOptionTags)mtry_switch_int;
156  bool bootstrap_sampling = bootstrap_sampling_int == 1 ? true : false;
157 
158  std::vector<LabelType> const distinct_labels(distinct_labels_marray.begin(), distinct_labels_marray.end());
159  std::vector<double> const class_weights(class_weights_marray.begin(), class_weights_marray.end());
160 
161  auto const pspec = ProblemSpec<LabelType>()
162  .num_features(num_features)
163  .num_instances(num_instances)
164  .num_classes(num_classes)
165  .distinct_classes(distinct_labels)
166  .actual_mtry(actual_mtry)
167  .actual_msample(msample);
168 
169  auto options = RandomForestOptions()
170  .min_num_instances(min_num_instances)
171  .bootstrap_sampling(bootstrap_sampling)
172  .tree_count(tree_count);
173  options.features_per_node_switch_ = mtry_switch;
174  options.features_per_node_ = mtry;
175  if (is_weighted)
176  options.class_weights(class_weights);
177 
178  Graph gr;
179  typename RF::template NodeMap<SplitTest>::type split_tests;
180  typename RF::template NodeMap<AccInputType>::type leaf_responses;
181 
182  auto const groups = h5ctx.ls();
183  for (auto const & groupname : groups) {
184  if (groupname.substr(0, std::char_traits<char>::length(rf_hdf5_tree)).compare(rf_hdf5_tree) != 0) {
185  continue;
186  }
187 
188  MultiArray<1, unsigned int> topology;
189  MultiArray<1, double> parameters;
190  h5ctx.cd(groupname);
191  h5ctx.readAndResize(rf_hdf5_topology, topology);
192  h5ctx.readAndResize(rf_hdf5_parameters, parameters);
193  h5ctx.cd_up();
194 
195  vigra_precondition(topology[0] == num_features, "random_forest_import_HDF5(): number of features mismatch.");
196  vigra_precondition(topology[1] == num_classes, "random_forest_import_HDF5(): number of classes mismatch.");
197 
198  Node const n = gr.addNode();
199 
200  std::queue<std::pair<unsigned int, Node> > q;
201  q.emplace(2, n);
202  while (!q.empty()) {
203  auto const el = q.front();
204 
205  unsigned int const index = el.first;
206  Node const parent = el.second;
207 
208  vigra_precondition((topology[index] & rf_zero_mask) == 0, "random_forest_import_HDF5(): unexpected node type: type & zero_mask > 0");
209 
210  if (topology[index] & rf_LeafNodeTag) {
211  unsigned int const probs_start = topology[index+1] + 1;
212 
213  vigra_precondition((topology[index] & rf_tag_mask) == rf_LeafNodeTag, "random_forest_import_HDF5(): unexpected node type: additional tags in leaf node");
214 
215  std::vector<AccValueType> node_response;
216 
217  for (unsigned int i = 0; i < num_classes; ++i) {
218  node_response.push_back(parameters[probs_start + i]);
219  }
220 
221  leaf_responses.insert(parent, node_response);
222 
223  } else {
224  vigra_precondition(topology[index] == rf_i_ThresholdNode, "random_forest_import_HDF5(): unexpected node type.");
225 
226  Node const left = gr.addNode();
227  Node const right = gr.addNode();
228 
229  gr.addArc(parent, left);
230  gr.addArc(parent, right);
231 
232  split_tests.insert(parent, SplitTest(topology[index+4], parameters[topology[index+1]+1]));
233 
234  q.push(std::make_pair(topology[index+2], left));
235  q.push(std::make_pair(topology[index+3], right));
236  }
237 
238  q.pop();
239  }
240  }
241 
242  if (cwd.size()) {
243  h5ctx.cd(cwd);
244  }
245 
246  RF rf(gr, split_tests, leaf_responses, pspec);
247  rf.options_ = options;
248  return rf;
249 }
250 
251 namespace detail
252 {
253  class PaddedNumberString
254  {
255  public:
256 
257  PaddedNumberString(int n)
258  {
259  ss_ << (n-1);
260  width_ = ss_.str().size();
261  }
262 
263  std::string operator()(int k) const
264  {
265  ss_.str("");
266  ss_ << std::setw(width_) << std::setfill('0') << k;
267  return ss_.str();
268  }
269 
270  private:
271 
272  mutable std::ostringstream ss_;
273  unsigned int width_;
274  };
275 }
276 
277 template <typename RF>
278 void random_forest_export_HDF5(
279  RF const & rf,
280  HDF5File & h5context,
281  std::string const & pathname = ""
282 ){
283  typedef typename RF::LabelType LabelType;
284  typedef typename RF::Node Node;
285 
286  std::string cwd;
287  if (pathname.size()) {
288  cwd = detail::get_cwd(h5context);
289  h5context.cd_mk(pathname);
290  }
291 
292  // version attribute
293  h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
294  rf_hdf5_version);
295 
296 
297  auto const & p = rf.problem_spec_;
298  auto const & opts = rf.options_;
299  MultiArray<1, LabelType> distinct_classes(Shape1(p.distinct_classes_.size()), p.distinct_classes_.data());
300  MultiArray<1, double> class_weights(Shape1(p.num_classes_), 1.0);
301  int is_weighted = 0;
302  if (opts.class_weights_.size() > 0)
303  {
304  is_weighted = 1;
305  for (size_t i = 0; i < opts.class_weights_.size(); ++i)
306  class_weights(i) = opts.class_weights_[i];
307  }
308 
309  // Save external parameters.
310  h5context.cd_mk(rf_hdf5_ext_param);
311  h5context.write("column_count_", p.num_features_);
312  h5context.write("row_count_", p.num_instances_);
313  h5context.write("class_count_", p.num_classes_);
314  h5context.write("actual_mtry_", p.actual_mtry_);
315  h5context.write("actual_msample_", p.actual_msample_);
316  h5context.write("labels", distinct_classes);
317  h5context.write("is_weighted_", is_weighted);
318  h5context.write("class_weights_", class_weights);
319  h5context.write("precision_", 0.0);
320  h5context.write("problem_type_", 1.0);
321  h5context.write("response_size_", 1.0);
322  h5context.write("used_", 1.0);
323  h5context.cd_up();
324 
325  // Save the options.
326  h5context.cd_mk(rf_hdf5_options);
327  h5context.write("min_split_node_size_", opts.min_num_instances_);
328  h5context.write("mtry_", opts.features_per_node_);
329  h5context.write("mtry_func_", 0.0);
330  h5context.write("mtry_switch_", opts.features_per_node_switch_);
331  h5context.write("predict_weighted_", 0.0);
332  h5context.write("prepare_online_learning_", 0.0);
333  h5context.write("sample_with_replacement_", opts.bootstrap_sampling_ ? 1 : 0);
334  h5context.write("stratification_method_", 3.0);
335  h5context.write("training_set_calc_switch_", 1.0);
336  h5context.write("training_set_func_", 0.0);
337  h5context.write("training_set_proportion_", 1.0);
338  h5context.write("training_set_size_", 0.0);
339  h5context.write("tree_count_", opts.tree_count_);
340  h5context.cd_up();
341 
342  // Save the trees.
343  detail::PaddedNumberString tree_number(rf.num_trees());
344  for (size_t i = 0; i < rf.num_trees(); ++i)
345  {
346  // Create the topology and parameters arrays.
347  std::vector<UInt32> topology;
348  std::vector<double> parameters;
349  topology.push_back(p.num_features_);
350  topology.push_back(p.num_classes_);
351 
352  auto const & probs = rf.node_responses_;
353  auto const & splits = rf.split_tests_;
354  auto const & gr = rf.graph_;
355  auto const root = gr.getRoot(i);
356 
357  // Write the tree nodes using a depth-first search.
358  // When a node is created, the indices of the child nodes are unknown.
359  // Therefore, they have to be updated once the child nodes are created.
360  // The stack holds the node and the topology-index that must be updated.
361  std::stack<std::pair<Node, std::ptrdiff_t> > stack;
362  stack.emplace(root, -1);
363  while (!stack.empty())
364  {
365  auto const n = stack.top().first; // the node descriptor
366  auto const i = stack.top().second; // index from the parent node that must be updated
367  stack.pop();
368 
369  // Update the index in the parent node.
370  if (i != -1)
371  topology[i] = topology.size();
372 
373  if (gr.numChildren(n) == 0)
374  {
375  // The node is a leaf.
376  // Topology: leaf node tag, index of weight in parameters array.
377  // Parameters: node weight, class probabilities.
378  topology.push_back(rf_LeafNodeTag);
379  topology.push_back(parameters.size());
380  auto const & prob = probs.at(n);
381  auto const weight = std::accumulate(prob.begin(), prob.end(), 0.0);
382  parameters.push_back(weight);
383  parameters.insert(parameters.end(), prob.begin(), prob.end());
384  }
385  else
386  {
387  // The node is an inner node.
388  // Topology: threshold tag, index of weight in parameters array, index of left child, index of right child, split dimension.
389  // Parameters: node weight, split value.
390  topology.push_back(rf_i_ThresholdNode);
391  topology.push_back(parameters.size());
392  topology.push_back(-1); // index of left children (currently unknown, will be updated when the child node is taken from the stack)
393  topology.push_back(-1); // index of right children (see above)
394  topology.push_back(splits.at(n).dim_);
395  parameters.push_back(1.0); // inner nodes have the weight 1.
396  parameters.push_back(splits.at(n).val_);
397 
398  // Place the children on the stack.
399  stack.emplace(gr.getChild(n, 0), topology.size()-3);
400  stack.emplace(gr.getChild(n, 1), topology.size()-2);
401  }
402  }
403 
404  // Convert the vectors to multi arrays.
405  MultiArray<1, UInt32> topo(Shape1(topology.size()), topology.data());
406  MultiArray<1, double> para(Shape1(parameters.size()), parameters.data());
407 
408  auto const name = rf_hdf5_tree + tree_number(i);
409  h5context.cd_mk(name);
410  h5context.write(rf_hdf5_topology, topo);
411  h5context.write(rf_hdf5_parameters, para);
412  h5context.cd_up();
413  }
414 
415  if (pathname.size())
416  h5context.cd(cwd);
417 }
418 
419 
420 
421 } // namespace rf3
422 } // namespace vigra
423 
424 #endif // VIGRA_NEW_RANDOM_FOREST_IMPEX_HDF5_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.0 (Fri May 19 2017)