[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_hdf5_impex.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2009 by Rahul Nair and Ullrich Koethe */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
38 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
39 
40 #include "config.hxx"
41 #include "random_forest.hxx"
42 #include "hdf5impex.hxx"
43 #include <cstdio>
44 #include <string>
45 
46 #ifdef HasHDF5
47 
48 namespace vigra
49 {
50 
51 namespace detail
52 {
53 
54 
55 /** shallow search the hdf5 group for containing elements
56  * returns negative value if unsuccessful
57  * \param grp_id hid_t containing path to group.
58  * \param cont reference to container that supports
59  * insert(). valuetype of cont must be
60  * std::string
61  */
62 template<class Container>
63 bool find_groups_hdf5(hid_t grp_id, Container &cont)
64 {
65 
66  //get group info
67 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
68  hsize_t size;
69  H5Gget_num_objs(grp_id, &size);
70 #else
71  hsize_t size;
72  H5G_info_t ginfo;
73  herr_t status;
74  status = H5Gget_info (grp_id , &ginfo);
75  if(status < 0)
76  std::runtime_error("find_groups_hdf5():"
77  "problem while getting group info");
78  size = ginfo.nlinks;
79 #endif
80  for(hsize_t ii = 0; ii < size; ++ii)
81  {
82 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
83  ssize_t buffer_size =
84  H5Gget_objname_by_idx(grp_id,
85  ii, NULL, 0 ) + 1;
86 #else
87  std::ptrdiff_t buffer_size =
88  H5Lget_name_by_idx(grp_id, ".",
89  H5_INDEX_NAME,
90  H5_ITER_INC,
91  ii, 0, 0, H5P_DEFAULT)+1;
92 #endif
93  ArrayVector<char> buffer(buffer_size);
94 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
95  buffer_size =
96  H5Gget_objname_by_idx(grp_id,
97  ii, buffer.data(),
98  (size_t)buffer_size );
99 #else
100  buffer_size =
101  H5Lget_name_by_idx(grp_id, ".",
102  H5_INDEX_NAME,
103  H5_ITER_INC,
104  ii, buffer.data(),
105  (size_t)buffer_size,
106  H5P_DEFAULT);
107 #endif
108  cont.insert(cont.end(), std::string(buffer.data()));
109  }
110  return true;
111 }
112 
113 
114 /** shallow search the hdf5 group for containing elements
115  * returns negative value if unsuccessful
116  * \param filename name of hdf5 file
117  * \param groupname path in hdf5 file
118  * \param cont reference to container that supports
119  * insert(). valuetype of cont must be
120  * std::string
121  */
122 template<class Container>
123 bool find_groups_hdf5(std::string filename,
124  std::string groupname,
125  Container &cont)
126 {
127  //check if file exists
128  FILE* pFile;
129  pFile = std::fopen ( filename.c_str(), "r" );
130  if ( pFile == NULL)
131  {
132  return 0;
133  }
134  std::fclose(pFile);
135  //open the file
136  HDF5Handle file_id(H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT),
137  &H5Fclose, "Unable to open HDF5 file");
138  HDF5Handle grp_id;
139  if(groupname == "")
140  {
141  grp_id = HDF5Handle(file_id, 0, "");
142  }
143  else
144  {
145  grp_id = HDF5Handle(H5Gopen(file_id, groupname.c_str(), H5P_DEFAULT),
146  &H5Gclose, "Unable to open group");
147 
148  }
149  bool res = find_groups_hdf5(grp_id, cont);
150  return res;
151 }
152 
153 VIGRA_EXPORT int get_number_of_digits(int in);
154 
155 VIGRA_EXPORT std::string make_padded_number(int number, int max_number);
156 
157 /** write a ArrayVector to a hdf5 dataset.
158  */
159 template<class U, class T>
160 void write_array_2_hdf5(hid_t & id,
161  ArrayVector<U> const & arr,
162  std::string const & name,
163  T type)
164 {
165  hsize_t size = arr.size();
166  vigra_postcondition(H5LTmake_dataset (id,
167  name.c_str(),
168  1,
169  &size,
170  type,
171  arr.begin())
172  >= 0,
173  "write_array_2_hdf5():"
174  "unable to write dataset");
175 }
176 
177 
178 template<class U, class T>
179 void write_hdf5_2_array(hid_t & id,
180  ArrayVector<U> & arr,
181  std::string const & name,
182  T type)
183 {
184  // The last three values of get_dataset_info can be NULL
185  // my EFFING FOOT! that is valid for HDF5 1.8 but not for
186  // 1.6 - but documented the other way around AAARRHGHGHH
187  hsize_t size;
188  H5T_class_t a;
189  size_t b;
190  vigra_postcondition(H5LTget_dataset_info(id,
191  name.c_str(),
192  &size,
193  &a,
194  &b) >= 0,
195  "write_hdf5_2_array(): "
196  "Unable to locate dataset");
197  arr.resize((typename ArrayVector<U>::size_type)size);
198  vigra_postcondition(H5LTread_dataset (id,
199  name.c_str(),
200  type,
201  arr.data()) >= 0,
202  "write_array_2_hdf5():"
203  "unable to read dataset");
204 }
205 
206 /*
207 inline void options_import_HDF5(hid_t & group_id,
208  RandomForestOptions & opt,
209  std::string name)
210 {
211  ArrayVector<double> serialized_options;
212  write_hdf5_2_array(group_id, serialized_options,
213  name, H5T_NATIVE_DOUBLE);
214  opt.unserialize(serialized_options.begin(),
215  serialized_options.end());
216 }
217 
218 inline void options_export_HDF5(hid_t & group_id,
219  RandomForestOptions const & opt,
220  std::string name)
221 {
222  ArrayVector<double> serialized_options(opt.serialized_size());
223  opt.serialize(serialized_options.begin(),
224  serialized_options.end());
225  write_array_2_hdf5(group_id, serialized_options,
226  name, H5T_NATIVE_DOUBLE);
227 }
228 */
229 
230 struct MyT
231 {
232  enum type { INT8 = 1, INT16 = 2, INT32 =3, INT64=4,
233  UINT8 = 5, UINT16 = 6, UINT32= 7, UINT64= 8,
234  FLOAT = 9, DOUBLE = 10, OTHER = 3294};
235 };
236 
237 
238 
239 #define create_type_of(TYPE, ENUM) \
240 inline MyT::type type_of(TYPE)\
241 {\
242  return MyT::ENUM; \
243 }
244 create_type_of(Int8, INT8)
245 create_type_of(Int16, INT16)
246 create_type_of(Int32, INT32)
247 create_type_of(Int64, INT64)
248 create_type_of(UInt8, UINT8)
249 create_type_of(UInt16, UINT16)
250 create_type_of(UInt32, UINT32)
251 create_type_of(UInt64, UINT64)
252 create_type_of(float, FLOAT)
253 create_type_of(double, DOUBLE)
254 #undef create_type_of
255 
256 VIGRA_EXPORT MyT::type type_of_hid_t(hid_t group_id, std::string name);
257 
258 VIGRA_EXPORT void options_import_HDF5(hid_t & group_id,
259  RandomForestOptions & opt,
260  std::string name);
261 
262 VIGRA_EXPORT void options_export_HDF5(hid_t & group_id,
263  RandomForestOptions const & opt,
264  std::string name);
265 
266 template<class T>
267 void problemspec_import_HDF5(hid_t & group_id,
268  ProblemSpec<T> & param,
269  std::string name)
270 {
271  hid_t param_id = H5Gopen (group_id,
272  name.c_str(),
273  H5P_DEFAULT);
274 
275  vigra_postcondition(param_id >= 0,
276  "problemspec_import_HDF5():"
277  " Unable to open external parameters");
278 
279  //get a map containing all the double fields
280  std::set<std::string> ext_set;
281  find_groups_hdf5(param_id, ext_set);
282  std::map<std::string, ArrayVector <double> > ext_map;
283  std::set<std::string>::iterator iter;
284  if(ext_set.find(std::string("labels")) == ext_set.end())
285  std::runtime_error("labels are missing");
286  for(iter = ext_set.begin(); iter != ext_set.end(); ++ iter)
287  {
288  if(*iter != std::string("labels"))
289  {
290  ext_map[*iter] = ArrayVector<double>();
291  write_hdf5_2_array(param_id, ext_map[*iter],
292  *iter, H5T_NATIVE_DOUBLE);
293  }
294  }
295  param.make_from_map(ext_map);
296  //load_class_labels
297  switch(type_of_hid_t(param_id,"labels" ))
298  {
299  #define SOME_CASE(type_, enum_) \
300  case MyT::enum_ :\
301  {\
302  ArrayVector<type_> tmp;\
303  write_hdf5_2_array(param_id, tmp, "labels", H5T_NATIVE_##enum_);\
304  param.classes_(tmp.begin(), tmp.end());\
305  }\
306  break;
307  SOME_CASE(UInt8, UINT8);
308  SOME_CASE(UInt16, UINT16);
309  SOME_CASE(UInt32, UINT32);
310  SOME_CASE(UInt64, UINT64);
311  SOME_CASE(Int8, INT8);
312  SOME_CASE(Int16, INT16);
313  SOME_CASE(Int32, INT32);
314  SOME_CASE(Int64, INT64);
315  SOME_CASE(double, DOUBLE);
316  SOME_CASE(float, FLOAT);
317  default:
318  std::runtime_error("exportRF_HDF5(): unknown class type");
319  #undef SOME_CASE
320  }
321  H5Gclose(param_id);
322 }
323 
324 template<class T>
325 void problemspec_export_HDF5(hid_t & group_id,
326  ProblemSpec<T> const & param,
327  std::string name)
328 {
329  hid_t param_id = H5Gcreate(group_id, name.c_str(),
330  H5P_DEFAULT,
331  H5P_DEFAULT,
332  H5P_DEFAULT);
333  vigra_postcondition(param_id >= 0,
334  "problemspec_export_HDF5():"
335  " Unable to create external parameters");
336 
337  //get a map containing all the double fields
338  std::map<std::string, ArrayVector<double> > serialized_param;
339  param.make_map(serialized_param);
340  std::map<std::string, ArrayVector<double> >::iterator iter;
341  for(iter = serialized_param.begin(); iter != serialized_param.end(); ++iter)
342  write_array_2_hdf5(param_id, iter->second, iter->first, H5T_NATIVE_DOUBLE);
343 
344  //save class_labels
345  switch(type_of(param.classes[0]))
346  {
347  #define SOME_CASE(type) \
348  case MyT::type:\
349  write_array_2_hdf5(param_id, param.classes, "labels", H5T_NATIVE_##type);\
350  break;
351  SOME_CASE(UINT8);
352  SOME_CASE(UINT16);
353  SOME_CASE(UINT32);
354  SOME_CASE(UINT64);
355  SOME_CASE(INT8);
356  SOME_CASE(INT16);
357  SOME_CASE(INT32);
358  SOME_CASE(INT64);
359  SOME_CASE(DOUBLE);
360  SOME_CASE(FLOAT);
361  default:
362  std::runtime_error("exportRF_HDF5(): unknown class type");
363  #undef SOME_CASE
364  }
365  H5Gclose(param_id);
366 }
367 
368 VIGRA_EXPORT void dt_import_HDF5(hid_t & group_id,
369  detail::DecisionTree & tree,
370  std::string name);
371 
372 
373 VIGRA_EXPORT void dt_export_HDF5(hid_t & group_id,
374  detail::DecisionTree const & tree,
375  std::string name);
376 
377 } //namespace detail
378 
379 template<class T>
380 bool rf_export_HDF5(RandomForest<T> const &rf,
381  std::string filename,
382  std::string pathname = "",
383  bool overwriteflag = false)
384 {
385  using detail::make_padded_number;
386  using detail::options_export_HDF5;
387  using detail::problemspec_export_HDF5;
388  using detail::dt_export_HDF5;
389 
390  hid_t file_id;
391  //if file exists load it.
392  FILE* pFile = std::fopen ( filename.c_str(), "r" );
393  if ( pFile != NULL)
394  {
395  std::fclose(pFile);
396  file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR,
397  H5P_DEFAULT);
398  }
399  else
400  {
401  //create a new file.
402  file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC,
403  H5P_DEFAULT,
404  H5P_DEFAULT);
405  }
406  vigra_postcondition(file_id >= 0,
407  "rf_export_HDF5(): Unable to open file.");
408  //std::cerr << pathname.c_str()
409 
410  //if the group already exists this will cause an error
411  //we will have to use the overwriteflag to check for
412  //this, but i dont know how to delete groups...
413 
414  hid_t group_id = pathname== "" ?
415  file_id
416  : H5Gcreate(file_id, pathname.c_str(),
417  H5P_DEFAULT,
418  H5P_DEFAULT,
419  H5P_DEFAULT);
420 
421  vigra_postcondition(group_id >= 0,
422  "rf_export_HDF5(): Unable to create group");
423 
424  //save serialized options
425  options_export_HDF5(group_id, rf.options(), "_options");
426  //save external parameters
427  problemspec_export_HDF5(group_id, rf.ext_param(), "_ext_param");
428  //save trees
429 
430  int tree_count = rf.options_.tree_count_;
431  for(int ii = 0; ii < tree_count; ++ii)
432  {
433  std::string treename = "Tree_" +
434  make_padded_number(ii, tree_count -1);
435  dt_export_HDF5(group_id, rf.tree(ii), treename);
436  }
437 
438  //clean up the mess
439  if(pathname != "")
440  H5Gclose(group_id);
441  H5Fclose(file_id);
442 
443  return 1;
444 }
445 
446 
447 template<class T>
448 bool rf_import_HDF5(RandomForest<T> &rf,
449  std::string filename,
450  std::string pathname = "")
451 {
453  using detail::options_import_HDF5;
454  using detail::problemspec_import_HDF5;
455  using detail::dt_export_HDF5;
456  // check if file exists
457  FILE* pFile = std::fopen ( filename.c_str(), "r" );
458  if ( pFile == NULL)
459  return 0;
460  std::fclose(pFile);
461  //open file
462  hid_t file_id = H5Fopen (filename.c_str(),
463  H5F_ACC_RDONLY,
464  H5P_DEFAULT);
465 
466  vigra_postcondition(file_id >= 0,
467  "rf_import_HDF5(): Unable to open file.");
468  hid_t group_id = pathname== "" ?
469  file_id
470  : H5Gopen (file_id,
471  pathname.c_str(),
472  H5P_DEFAULT);
473 
474  vigra_postcondition(group_id >= 0,
475  "rf_export_HDF5(): Unable to create group");
476 
477  //get serialized options
478  options_import_HDF5(group_id, rf.options_, "_options");
479  //save external parameters
480  problemspec_import_HDF5(group_id, rf.ext_param_, "_ext_param");
481  // TREE SAVING TIME
482  // get all groups in base path
483 
484  std::set<std::string> tree_set;
485  std::set<std::string>::iterator iter;
486  find_groups_hdf5(filename, pathname, tree_set);
487 
488  for(iter = tree_set.begin(); iter != tree_set.end(); ++iter)
489  {
490  if((*iter)[0] != '_')
491  {
492  rf.trees_.push_back(detail::DecisionTree(rf.ext_param_));
493  dt_import_HDF5(group_id, rf.trees_.back(), *iter);
494  }
495  }
496 
497  //clean up the mess
498  if(pathname != "")
499  H5Gclose(group_id);
500  H5Fclose(file_id);
501  /*rf.tree_indices_.resize(rf.tree_count());
502  for(int ii = 0; ii < rf.tree_count(); ++ii)
503  rf.tree_indices_[ii] = ii; */
504  return 1;
505 }
506 } // namespace vigra
507 
508 #endif // HasHDF5
509 
510 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX
511 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Thu Jun 14 2012)