SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/CrossValidation.h> 00012 #include <shogun/machine/Machine.h> 00013 #include <shogun/evaluation/Evaluation.h> 00014 #include <shogun/evaluation/SplittingStrategy.h> 00015 #include <shogun/base/Parameter.h> 00016 #include <shogun/mathematics/Statistics.h> 00017 00018 using namespace shogun; 00019 00020 CCrossValidation::CCrossValidation() 00021 { 00022 init(); 00023 } 00024 00025 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features, 00026 CLabels* labels, CSplittingStrategy* splitting_strategy, 00027 CEvaluation* evaluation_criterium) 00028 { 00029 init(); 00030 00031 m_machine=machine; 00032 m_features=features; 00033 m_labels=labels; 00034 m_splitting_strategy=splitting_strategy; 00035 m_evaluation_criterium=evaluation_criterium; 00036 00037 SG_REF(m_machine); 00038 SG_REF(m_features); 00039 SG_REF(m_labels); 00040 SG_REF(m_splitting_strategy); 00041 SG_REF(m_evaluation_criterium); 00042 } 00043 00044 CCrossValidation::~CCrossValidation() 00045 { 00046 SG_UNREF(m_machine); 00047 SG_UNREF(m_features); 00048 SG_UNREF(m_labels); 00049 SG_UNREF(m_splitting_strategy); 00050 SG_UNREF(m_evaluation_criterium); 00051 } 00052 00053 EEvaluationDirection CCrossValidation::get_evaluation_direction() 00054 { 00055 return m_evaluation_criterium->get_evaluation_direction(); 00056 } 00057 00058 void CCrossValidation::init() 00059 { 00060 m_machine=NULL; 00061 m_features=NULL; 00062 m_labels=NULL; 00063 m_splitting_strategy=NULL; 00064 m_evaluation_criterium=NULL; 00065 m_num_runs=1; 00066 m_conf_int_alpha=0; 00067 00068 m_parameters->add((CSGObject**) &m_machine, "machine", 00069 "Used learning machine"); 00070 m_parameters->add((CSGObject**) &m_features, "features", "Used features"); 00071 m_parameters->add((CSGObject**) &m_labels, "labels", "Used labels"); 00072 m_parameters->add((CSGObject**) &m_splitting_strategy, 00073 "splitting_strategy", "Used splitting strategy"); 00074 m_parameters->add((CSGObject**) &m_evaluation_criterium, 00075 "evaluation_criterium", "Used evaluation criterium"); 00076 m_parameters->add(&m_num_runs, "num_runs", "Number of repetitions"); 00077 m_parameters->add(&m_conf_int_alpha, "conf_int_alpha", "alpha-value of confidence " 00078 "interval"); 00079 } 00080 00081 CMachine* CCrossValidation::get_machine() const 00082 { 00083 SG_REF(m_machine); 00084 return m_machine; 00085 } 00086 00087 CrossValidationResult CCrossValidation::evaluate() 00088 { 00089 SGVector<float64_t> results(m_num_runs); 00090 00091 for (index_t i=0; i<m_num_runs; ++i) 00092 results.vector[i]=evaluate_one_run(); 00093 00094 /* construct evaluation result */ 00095 CrossValidationResult result; 00096 result.has_conf_int=m_conf_int_alpha!=0; 00097 result.conf_int_alpha=m_conf_int_alpha; 00098 00099 if (result.has_conf_int) 00100 { 00101 result.conf_int_alpha=m_conf_int_alpha; 00102 result.mean=CStatistics::confidence_intervals_mean(results, 00103 result.conf_int_alpha, result.conf_int_low, result.conf_int_up); 00104 } 00105 else 00106 { 00107 result.mean=CStatistics::mean(results); 00108 result.conf_int_low=0; 00109 result.conf_int_up=0; 00110 } 00111 00112 SG_FREE(results.vector); 00113 00114 return result; 00115 } 00116 00117 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha) 00118 { 00119 if (conf_int_alpha<0 || conf_int_alpha>=1) 00120 { 00121 SG_ERROR("%f is an illegal alpha-value for confidence interval of " 00122 "cross-validation\n", conf_int_alpha); 00123 } 00124 00125 m_conf_int_alpha=conf_int_alpha; 00126 } 00127 00128 void CCrossValidation::set_num_runs(int32_t num_runs) 00129 { 00130 if (num_runs<1) 00131 SG_ERROR("%d is an illegal number of repetitions\n", num_runs); 00132 00133 m_num_runs=num_runs; 00134 } 00135 00136 float64_t CCrossValidation::evaluate_one_run() 00137 { 00138 index_t num_subsets=m_splitting_strategy->get_num_subsets(); 00139 float64_t* results=SG_MALLOC(float64_t, num_subsets); 00140 00141 /* set labels to machine */ 00142 m_machine->set_labels(m_labels); 00143 00144 /* tell machine to store model internally 00145 * (otherwise changing subset of features will kaboom the classifier) */ 00146 m_machine->set_store_model_features(true); 00147 00148 /* do actual cross-validation */ 00149 for (index_t i=0; i<num_subsets; ++i) 00150 { 00151 /* set feature subset for training */ 00152 SGVector<index_t> inverse_subset_indices= 00153 m_splitting_strategy->generate_subset_inverse(i); 00154 m_features->set_subset(new CSubset(inverse_subset_indices)); 00155 00156 /* set label subset for training (copy data before) */ 00157 SGVector<index_t> inverse_subset_indices_copy( 00158 inverse_subset_indices.vlen); 00159 memcpy(inverse_subset_indices_copy.vector, 00160 inverse_subset_indices.vector, 00161 inverse_subset_indices.vlen*sizeof(index_t)); 00162 m_labels->set_subset(new CSubset(inverse_subset_indices_copy)); 00163 00164 /* train machine on training features */ 00165 m_machine->train(m_features); 00166 00167 /* set feature subset for testing (subset method that stores pointer) */ 00168 SGVector<index_t> subset_indices= 00169 m_splitting_strategy->generate_subset_indices(i); 00170 m_features->set_subset(new CSubset(subset_indices)); 00171 00172 /* apply machine to test features */ 00173 CLabels* result_labels=m_machine->apply(m_features); 00174 SG_REF(result_labels); 00175 00176 /* set label subset for testing (copy data before) */ 00177 SGVector<index_t> subset_indices_copy(subset_indices.vlen); 00178 memcpy(subset_indices_copy.vector, subset_indices.vector, 00179 subset_indices.vlen*sizeof(index_t)); 00180 m_labels->set_subset(new CSubset(subset_indices_copy)); 00181 00182 /* evaluate */ 00183 results[i]=m_evaluation_criterium->evaluate(result_labels, m_labels); 00184 00185 /* clean up, reset subsets */ 00186 SG_UNREF(result_labels); 00187 m_features->remove_subset(); 00188 m_labels->remove_subset(); 00189 } 00190 00191 /* build arithmetic mean of results */ 00192 float64_t mean=CStatistics::mean(SGVector<float64_t>(results, num_subsets)); 00193 00194 /* clean up */ 00195 SG_FREE(results); 00196 00197 return mean; 00198 }