SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Written (W) 2009 Marius Kloft 00009 * Copyright (C) 2009 TU Berlin and Max-Planck-Society 00010 */ 00011 00012 00013 #include <shogun/kernel/Kernel.h> 00014 #include <shogun/classifier/svm/ScatterSVM.h> 00015 #include <shogun/kernel/ScatterKernelNormalizer.h> 00016 #include <shogun/io/SGIO.h> 00017 00018 using namespace shogun; 00019 00020 CScatterSVM::CScatterSVM() 00021 : CMultiClassSVM(ONE_VS_REST), scatter_type(NO_BIAS_LIBSVM), 00022 model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00023 { 00024 SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n"); 00025 } 00026 00027 CScatterSVM::CScatterSVM(SCATTER_TYPE type) 00028 : CMultiClassSVM(ONE_VS_REST), scatter_type(type), model(NULL), 00029 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00030 { 00031 } 00032 00033 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab) 00034 : CMultiClassSVM(ONE_VS_REST, C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL), 00035 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00036 { 00037 } 00038 00039 CScatterSVM::~CScatterSVM() 00040 { 00041 SG_FREE(norm_wc); 00042 SG_FREE(norm_wcw); 00043 } 00044 00045 bool CScatterSVM::train_machine(CFeatures* data) 00046 { 00047 ASSERT(labels && labels->get_num_labels()); 00048 m_num_classes = labels->get_num_classes(); 00049 int32_t num_vectors = labels->get_num_labels(); 00050 00051 if (data) 00052 { 00053 if (labels->get_num_labels() != data->get_num_vectors()) 00054 SG_ERROR("Number of training vectors does not match number of labels\n"); 00055 kernel->init(data, data); 00056 } 00057 00058 int32_t* numc=SG_MALLOC(int32_t, m_num_classes); 00059 CMath::fill_vector(numc, m_num_classes, 0); 00060 00061 for (int32_t i=0; i<num_vectors; i++) 00062 numc[(int32_t) labels->get_int_label(i)]++; 00063 00064 int32_t Nc=0; 00065 int32_t Nmin=num_vectors; 00066 for (int32_t i=0; i<m_num_classes; i++) 00067 { 00068 if (numc[i]>0) 00069 { 00070 Nc++; 00071 Nmin=CMath::min(Nmin, numc[i]); 00072 } 00073 00074 } 00075 SG_FREE(numc); 00076 m_num_classes=m_num_classes; 00077 00078 bool result=false; 00079 00080 if (scatter_type==NO_BIAS_LIBSVM) 00081 { 00082 result=train_no_bias_libsvm(); 00083 } 00084 00085 else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 00086 { 00087 float64_t nu_min=((float64_t) Nc)/num_vectors; 00088 float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors; 00089 00090 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max); 00091 00092 if (get_nu()<nu_min || get_nu()>nu_max) 00093 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max); 00094 00095 result=train_testrule12(); 00096 } 00097 else 00098 SG_ERROR("Unknown Scatter type\n"); 00099 00100 return result; 00101 } 00102 00103 bool CScatterSVM::train_no_bias_libsvm() 00104 { 00105 struct svm_node* x_space; 00106 00107 problem.l=labels->get_num_labels(); 00108 SG_INFO( "%d trainlabels\n", problem.l); 00109 00110 problem.y=SG_MALLOC(float64_t, problem.l); 00111 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00112 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00113 00114 for (int32_t i=0; i<problem.l; i++) 00115 { 00116 problem.y[i]=+1; 00117 problem.x[i]=&x_space[2*i]; 00118 x_space[2*i].index=i; 00119 x_space[2*i+1].index=-1; 00120 } 00121 00122 int32_t weights_label[2]={-1,+1}; 00123 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00124 00125 ASSERT(kernel && kernel->has_features()); 00126 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00127 00128 param.svm_type=C_SVC; // Nu MC SVM 00129 param.kernel_type = LINEAR; 00130 param.degree = 3; 00131 param.gamma = 0; // 1/k 00132 param.coef0 = 0; 00133 param.nu = get_nu(); // Nu 00134 CKernelNormalizer* prev_normalizer=kernel->get_normalizer(); 00135 kernel->set_normalizer(new CScatterKernelNormalizer( 00136 m_num_classes-1, -1, labels, prev_normalizer)); 00137 param.kernel=kernel; 00138 param.cache_size = kernel->get_cache_size(); 00139 param.C = 0; 00140 param.eps = epsilon; 00141 param.p = 0.1; 00142 param.shrinking = 0; 00143 param.nr_weight = 2; 00144 param.weight_label = weights_label; 00145 param.weight = weights; 00146 param.nr_class=m_num_classes; 00147 param.use_bias = get_bias_enabled(); 00148 00149 const char* error_msg = svm_check_parameter(&problem,¶m); 00150 00151 if(error_msg) 00152 SG_ERROR("Error: %s\n",error_msg); 00153 00154 model = svm_train(&problem, ¶m); 00155 kernel->set_normalizer(prev_normalizer); 00156 SG_UNREF(prev_normalizer); 00157 00158 if (model) 00159 { 00160 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00161 00162 ASSERT(model->nr_class==m_num_classes); 00163 create_multiclass_svm(m_num_classes); 00164 00165 rho=model->rho[0]; 00166 00167 SG_FREE(norm_wcw); 00168 norm_wcw = SG_MALLOC(float64_t, m_num_svms); 00169 00170 for (int32_t i=0; i<m_num_classes; i++) 00171 { 00172 int32_t num_sv=model->nSV[i]; 00173 00174 CSVM* svm=new CSVM(num_sv); 00175 svm->set_bias(model->rho[i+1]); 00176 norm_wcw[i]=model->normwcw[i]; 00177 00178 00179 for (int32_t j=0; j<num_sv; j++) 00180 { 00181 svm->set_alpha(j, model->sv_coef[i][j]); 00182 svm->set_support_vector(j, model->SV[i][j].index); 00183 } 00184 00185 set_svm(i, svm); 00186 } 00187 00188 SG_FREE(problem.x); 00189 SG_FREE(problem.y); 00190 SG_FREE(x_space); 00191 for (int32_t i=0; i<m_num_classes; i++) 00192 { 00193 SG_FREE(model->SV[i]); 00194 model->SV[i]=NULL; 00195 } 00196 svm_destroy_model(model); 00197 00198 if (scatter_type==TEST_RULE2) 00199 compute_norm_wc(); 00200 00201 model=NULL; 00202 return true; 00203 } 00204 else 00205 return false; 00206 } 00207 00208 00209 00210 bool CScatterSVM::train_testrule12() 00211 { 00212 struct svm_node* x_space; 00213 problem.l=labels->get_num_labels(); 00214 SG_INFO( "%d trainlabels\n", problem.l); 00215 00216 problem.y=SG_MALLOC(float64_t, problem.l); 00217 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00218 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00219 00220 for (int32_t i=0; i<problem.l; i++) 00221 { 00222 problem.y[i]=labels->get_label(i); 00223 problem.x[i]=&x_space[2*i]; 00224 x_space[2*i].index=i; 00225 x_space[2*i+1].index=-1; 00226 } 00227 00228 int32_t weights_label[2]={-1,+1}; 00229 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00230 00231 ASSERT(kernel && kernel->has_features()); 00232 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00233 00234 param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM 00235 param.kernel_type = LINEAR; 00236 param.degree = 3; 00237 param.gamma = 0; // 1/k 00238 param.coef0 = 0; 00239 param.nu = get_nu(); // Nu 00240 param.kernel=kernel; 00241 param.cache_size = kernel->get_cache_size(); 00242 param.C = 0; 00243 param.eps = epsilon; 00244 param.p = 0.1; 00245 param.shrinking = 0; 00246 param.nr_weight = 2; 00247 param.weight_label = weights_label; 00248 param.weight = weights; 00249 param.nr_class=m_num_classes; 00250 param.use_bias = get_bias_enabled(); 00251 00252 const char* error_msg = svm_check_parameter(&problem,¶m); 00253 00254 if(error_msg) 00255 SG_ERROR("Error: %s\n",error_msg); 00256 00257 model = svm_train(&problem, ¶m); 00258 00259 if (model) 00260 { 00261 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00262 00263 ASSERT(model->nr_class==m_num_classes); 00264 create_multiclass_svm(m_num_classes); 00265 00266 rho=model->rho[0]; 00267 00268 SG_FREE(norm_wcw); 00269 norm_wcw = SG_MALLOC(float64_t, m_num_svms); 00270 00271 for (int32_t i=0; i<m_num_classes; i++) 00272 { 00273 int32_t num_sv=model->nSV[i]; 00274 00275 CSVM* svm=new CSVM(num_sv); 00276 svm->set_bias(model->rho[i+1]); 00277 norm_wcw[i]=model->normwcw[i]; 00278 00279 00280 for (int32_t j=0; j<num_sv; j++) 00281 { 00282 svm->set_alpha(j, model->sv_coef[i][j]); 00283 svm->set_support_vector(j, model->SV[i][j].index); 00284 } 00285 00286 set_svm(i, svm); 00287 } 00288 00289 SG_FREE(problem.x); 00290 SG_FREE(problem.y); 00291 SG_FREE(x_space); 00292 for (int32_t i=0; i<m_num_classes; i++) 00293 { 00294 SG_FREE(model->SV[i]); 00295 model->SV[i]=NULL; 00296 } 00297 svm_destroy_model(model); 00298 00299 if (scatter_type==TEST_RULE2) 00300 compute_norm_wc(); 00301 00302 model=NULL; 00303 return true; 00304 } 00305 else 00306 return false; 00307 } 00308 00309 void CScatterSVM::compute_norm_wc() 00310 { 00311 SG_FREE(norm_wc); 00312 norm_wc = SG_MALLOC(float64_t, m_num_svms); 00313 for (int32_t i=0; i<m_num_svms; i++) 00314 norm_wc[i]=0; 00315 00316 00317 for (int c=0; c<m_num_svms; c++) 00318 { 00319 CSVM* svm=m_svms[c]; 00320 int32_t num_sv = svm->get_num_support_vectors(); 00321 00322 for (int32_t i=0; i<num_sv; i++) 00323 { 00324 int32_t ii=svm->get_support_vector(i); 00325 for (int32_t j=0; j<num_sv; j++) 00326 { 00327 int32_t jj=svm->get_support_vector(j); 00328 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j); 00329 } 00330 } 00331 } 00332 00333 for (int32_t i=0; i<m_num_svms; i++) 00334 norm_wc[i]=CMath::sqrt(norm_wc[i]); 00335 00336 CMath::display_vector(norm_wc, m_num_svms, "norm_wc"); 00337 } 00338 00339 CLabels* CScatterSVM::classify_one_vs_rest() 00340 { 00341 CLabels* output=NULL; 00342 if (!kernel) 00343 { 00344 SG_ERROR( "SVM can not proceed without kernel!\n"); 00345 return false ; 00346 } 00347 00348 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00349 { 00350 int32_t num_vectors=kernel->get_num_vec_rhs(); 00351 00352 output=new CLabels(num_vectors); 00353 SG_REF(output); 00354 00355 if (scatter_type == TEST_RULE1) 00356 { 00357 ASSERT(m_num_svms>0); 00358 for (int32_t i=0; i<num_vectors; i++) 00359 output->set_label(i, apply(i)); 00360 } 00361 00362 else 00363 { 00364 ASSERT(m_num_svms>0); 00365 ASSERT(num_vectors==output->get_num_labels()); 00366 CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms); 00367 00368 for (int32_t i=0; i<m_num_svms; i++) 00369 { 00370 //SG_PRINT("svm %d\n", i); 00371 ASSERT(m_svms[i]); 00372 m_svms[i]->set_kernel(kernel); 00373 m_svms[i]->set_labels(labels); 00374 outputs[i]=m_svms[i]->apply(); 00375 } 00376 00377 for (int32_t i=0; i<num_vectors; i++) 00378 { 00379 int32_t winner=0; 00380 float64_t max_out=outputs[0]->get_label(i)/norm_wc[0]; 00381 00382 for (int32_t j=1; j<m_num_svms; j++) 00383 { 00384 float64_t out=outputs[j]->get_label(i)/norm_wc[j]; 00385 00386 if (out>max_out) 00387 { 00388 winner=j; 00389 max_out=out; 00390 } 00391 } 00392 00393 output->set_label(i, winner); 00394 } 00395 00396 for (int32_t i=0; i<m_num_svms; i++) 00397 SG_UNREF(outputs[i]); 00398 00399 SG_FREE(outputs); 00400 } 00401 } 00402 00403 return output; 00404 } 00405 00406 float64_t CScatterSVM::apply(int32_t num) 00407 { 00408 ASSERT(m_num_svms>0); 00409 float64_t* outputs=SG_MALLOC(float64_t, m_num_svms); 00410 int32_t winner=0; 00411 00412 if (scatter_type == TEST_RULE1) 00413 { 00414 for (int32_t c=0; c<m_num_svms; c++) 00415 outputs[c]=m_svms[c]->get_bias()-rho; 00416 00417 for (int32_t c=0; c<m_num_svms; c++) 00418 { 00419 float64_t v=0; 00420 00421 for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++) 00422 { 00423 float64_t alpha=m_svms[c]->get_alpha(i); 00424 int32_t svidx=m_svms[c]->get_support_vector(i); 00425 v += alpha*kernel->kernel(svidx, num); 00426 } 00427 00428 outputs[c] += v; 00429 for (int32_t j=0; j<m_num_svms; j++) 00430 outputs[j] -= v/m_num_svms; 00431 } 00432 00433 for (int32_t j=0; j<m_num_svms; j++) 00434 outputs[j]/=norm_wcw[j]; 00435 00436 float64_t max_out=outputs[0]; 00437 for (int32_t j=0; j<m_num_svms; j++) 00438 { 00439 if (outputs[j]>max_out) 00440 { 00441 max_out=outputs[j]; 00442 winner=j; 00443 } 00444 } 00445 } 00446 00447 else 00448 { 00449 float64_t max_out=m_svms[0]->apply(num)/norm_wc[0]; 00450 00451 for (int32_t i=1; i<m_num_svms; i++) 00452 { 00453 outputs[i]=m_svms[i]->apply(num)/norm_wc[i]; 00454 if (outputs[i]>max_out) 00455 { 00456 winner=i; 00457 max_out=outputs[i]; 00458 } 00459 } 00460 } 00461 00462 SG_FREE(outputs); 00463 return winner; 00464 }