SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/classifier/svm/MultiClassSVM.h> 00014 00015 using namespace shogun; 00016 00017 CMultiClassSVM::CMultiClassSVM() 00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL) 00019 { 00020 init(); 00021 } 00022 00023 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type) 00024 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00025 { 00026 init(); 00027 } 00028 00029 CMultiClassSVM::CMultiClassSVM( 00030 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab) 00031 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00032 { 00033 init(); 00034 } 00035 00036 CMultiClassSVM::~CMultiClassSVM() 00037 { 00038 cleanup(); 00039 } 00040 00041 void CMultiClassSVM::init() 00042 { 00043 m_parameters->add((machine_int_t*) &multiclass_type, 00044 "multiclass_type", "Type of MultiClassSVM."); 00045 m_parameters->add(&m_num_classes, "m_num_classes", 00046 "Number of classes."); 00047 m_parameters->add_vector((CSGObject***) &m_svms, 00048 &m_num_svms, "m_svms"); 00049 } 00050 00051 void CMultiClassSVM::cleanup() 00052 { 00053 for (int32_t i=0; i<m_num_svms; i++) 00054 SG_UNREF(m_svms[i]); 00055 00056 SG_FREE(m_svms); 00057 m_num_svms=0; 00058 m_svms=NULL; 00059 } 00060 00061 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes) 00062 { 00063 if (num_classes>0) 00064 { 00065 cleanup(); 00066 00067 m_num_classes=num_classes; 00068 00069 if (multiclass_type==ONE_VS_REST) 00070 m_num_svms=num_classes; 00071 else if (multiclass_type==ONE_VS_ONE) 00072 m_num_svms=num_classes*(num_classes-1)/2; 00073 else 00074 SG_ERROR("unknown multiclass type\n"); 00075 00076 m_svms=SG_MALLOC(CSVM*, m_num_svms); 00077 if (m_svms) 00078 { 00079 memset(m_svms,0, m_num_svms*sizeof(CSVM*)); 00080 return true; 00081 } 00082 } 00083 return false; 00084 } 00085 00086 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm) 00087 { 00088 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm) 00089 { 00090 SG_REF(svm); 00091 m_svms[num]=svm; 00092 return true; 00093 } 00094 return false; 00095 } 00096 00097 CLabels* CMultiClassSVM::apply() 00098 { 00099 if (multiclass_type==ONE_VS_REST) 00100 return classify_one_vs_rest(); 00101 else if (multiclass_type==ONE_VS_ONE) 00102 return classify_one_vs_one(); 00103 else 00104 SG_ERROR("unknown multiclass type\n"); 00105 00106 return NULL; 00107 } 00108 00109 CLabels* CMultiClassSVM::classify_one_vs_one() 00110 { 00111 ASSERT(m_num_svms>0); 00112 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00113 CLabels* result=NULL; 00114 00115 if (!kernel) 00116 { 00117 SG_ERROR( "SVM can not proceed without kernel!\n"); 00118 return false ; 00119 } 00120 00121 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00122 { 00123 int32_t num_vectors=kernel->get_num_vec_rhs(); 00124 00125 result=new CLabels(num_vectors); 00126 SG_REF(result); 00127 00128 ASSERT(num_vectors==result->get_num_labels()); 00129 CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms); 00130 00131 for (int32_t i=0; i<m_num_svms; i++) 00132 { 00133 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]); 00134 ASSERT(m_svms[i]); 00135 m_svms[i]->set_kernel(kernel); 00136 outputs[i]=m_svms[i]->apply(); 00137 } 00138 00139 int32_t* votes=SG_MALLOC(int32_t, m_num_classes); 00140 for (int32_t v=0; v<num_vectors; v++) 00141 { 00142 int32_t s=0; 00143 memset(votes, 0, sizeof(int32_t)*m_num_classes); 00144 00145 for (int32_t i=0; i<m_num_classes; i++) 00146 { 00147 for (int32_t j=i+1; j<m_num_classes; j++) 00148 { 00149 if (outputs[s++]->get_label(v)>0) 00150 votes[i]++; 00151 else 00152 votes[j]++; 00153 } 00154 } 00155 00156 int32_t winner=0; 00157 int32_t max_votes=votes[0]; 00158 00159 for (int32_t i=1; i<m_num_classes; i++) 00160 { 00161 if (votes[i]>max_votes) 00162 { 00163 max_votes=votes[i]; 00164 winner=i; 00165 } 00166 } 00167 00168 result->set_label(v, winner); 00169 } 00170 00171 SG_FREE(votes); 00172 00173 for (int32_t i=0; i<m_num_svms; i++) 00174 SG_UNREF(outputs[i]); 00175 SG_FREE(outputs); 00176 } 00177 00178 return result; 00179 } 00180 00181 CLabels* CMultiClassSVM::classify_one_vs_rest() 00182 { 00183 ASSERT(m_num_svms>0); 00184 CLabels* result=NULL; 00185 00186 if (!kernel) 00187 { 00188 SG_ERROR( "SVM can not proceed without kernel!\n"); 00189 return false ; 00190 } 00191 00192 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00193 { 00194 int32_t num_vectors=kernel->get_num_vec_rhs(); 00195 00196 result=new CLabels(num_vectors); 00197 SG_REF(result); 00198 00199 ASSERT(num_vectors==result->get_num_labels()); 00200 CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms); 00201 00202 for (int32_t i=0; i<m_num_svms; i++) 00203 { 00204 ASSERT(m_svms[i]); 00205 m_svms[i]->set_kernel(kernel); 00206 outputs[i]=m_svms[i]->apply(); 00207 } 00208 00209 for (int32_t i=0; i<num_vectors; i++) 00210 { 00211 int32_t winner=0; 00212 float64_t max_out=outputs[0]->get_label(i); 00213 00214 for (int32_t j=1; j<m_num_svms; j++) 00215 { 00216 float64_t out=outputs[j]->get_label(i); 00217 00218 if (out>max_out) 00219 { 00220 winner=j; 00221 max_out=out; 00222 } 00223 } 00224 00225 result->set_label(i, winner); 00226 } 00227 00228 for (int32_t i=0; i<m_num_svms; i++) 00229 SG_UNREF(outputs[i]); 00230 00231 SG_FREE(outputs); 00232 } 00233 00234 return result; 00235 } 00236 00237 float64_t CMultiClassSVM::apply(int32_t num) 00238 { 00239 if (multiclass_type==ONE_VS_REST) 00240 return classify_example_one_vs_rest(num); 00241 else if (multiclass_type==ONE_VS_ONE) 00242 return classify_example_one_vs_one(num); 00243 else 00244 SG_ERROR("unknown multiclass type\n"); 00245 00246 return 0; 00247 } 00248 00249 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num) 00250 { 00251 ASSERT(m_num_svms>0); 00252 float64_t* outputs=SG_MALLOC(float64_t, m_num_svms); 00253 int32_t winner=0; 00254 float64_t max_out=m_svms[0]->apply(num); 00255 00256 for (int32_t i=1; i<m_num_svms; i++) 00257 { 00258 outputs[i]=m_svms[i]->apply(num); 00259 if (outputs[i]>max_out) 00260 { 00261 winner=i; 00262 max_out=outputs[i]; 00263 } 00264 } 00265 SG_FREE(outputs); 00266 00267 return winner; 00268 } 00269 00270 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num) 00271 { 00272 ASSERT(m_num_svms>0); 00273 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00274 00275 int32_t* votes=SG_MALLOC(int32_t, m_num_classes); 00276 int32_t s=0; 00277 00278 for (int32_t i=0; i<m_num_classes; i++) 00279 { 00280 for (int32_t j=i+1; j<m_num_classes; j++) 00281 { 00282 if (m_svms[s++]->apply(num)>0) 00283 votes[i]++; 00284 else 00285 votes[j]++; 00286 } 00287 } 00288 00289 int32_t winner=0; 00290 int32_t max_votes=votes[0]; 00291 00292 for (int32_t i=1; i<m_num_classes; i++) 00293 { 00294 if (votes[i]>max_votes) 00295 { 00296 max_votes=votes[i]; 00297 winner=i; 00298 } 00299 } 00300 00301 SG_FREE(votes); 00302 00303 return winner; 00304 } 00305 00306 bool CMultiClassSVM::load(FILE* modelfl) 00307 { 00308 bool result=true; 00309 char char_buffer[1024]; 00310 int32_t int_buffer; 00311 float64_t double_buffer; 00312 int32_t line_number=1; 00313 int32_t svm_idx=-1; 00314 00315 SG_SET_LOCALE_C; 00316 00317 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF) 00318 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00319 else 00320 { 00321 char_buffer[15]='\0'; 00322 if (strcmp("%MultiClassSVM", char_buffer)!=0) 00323 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number); 00324 00325 line_number++; 00326 } 00327 00328 int_buffer=0; 00329 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1) 00330 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00331 00332 if (!feof(modelfl)) 00333 line_number++; 00334 00335 if (int_buffer != multiclass_type) 00336 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type); 00337 00338 int_buffer=0; 00339 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1) 00340 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00341 00342 if (!feof(modelfl)) 00343 line_number++; 00344 00345 if (int_buffer < 2) 00346 SG_ERROR("less than 2 classes - how is this multiclass?\n"); 00347 00348 create_multiclass_svm(int_buffer); 00349 00350 int_buffer=0; 00351 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1) 00352 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00353 00354 if (!feof(modelfl)) 00355 line_number++; 00356 00357 if (m_num_svms != int_buffer) 00358 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer); 00359 00360 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00361 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00362 00363 if (!feof(modelfl)) 00364 line_number++; 00365 00366 for (int32_t n=0; n<m_num_svms; n++) 00367 { 00368 svm_idx=-1; 00369 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF) 00370 { 00371 result=false; 00372 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00373 } 00374 else 00375 { 00376 char_buffer[4]='\0'; 00377 if (strncmp("%SVM", char_buffer, 4)!=0) 00378 { 00379 result=false; 00380 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00381 } 00382 00383 if (svm_idx != n) 00384 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00385 00386 line_number++; 00387 } 00388 00389 int_buffer=0; 00390 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2) 00391 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00392 00393 if (svm_idx != n) 00394 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00395 00396 if (!feof(modelfl)) 00397 line_number++; 00398 00399 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx); 00400 CSVM* svm=new CSVM(int_buffer); 00401 00402 double_buffer=0; 00403 00404 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2) 00405 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00406 00407 if (svm_idx != n) 00408 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00409 00410 if (!feof(modelfl)) 00411 line_number++; 00412 00413 svm->set_bias(double_buffer); 00414 00415 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1) 00416 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00417 00418 if (svm_idx != n) 00419 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00420 00421 if (!feof(modelfl)) 00422 line_number++; 00423 00424 for (int32_t i=0; i<svm->get_num_support_vectors(); i++) 00425 { 00426 double_buffer=0; 00427 int_buffer=0; 00428 00429 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00430 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00431 00432 if (!feof(modelfl)) 00433 line_number++; 00434 00435 svm->set_support_vector(i, int_buffer); 00436 svm->set_alpha(i, double_buffer); 00437 } 00438 00439 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00440 { 00441 result=false; 00442 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00443 } 00444 else 00445 { 00446 char_buffer[3]='\0'; 00447 if (strcmp("];", char_buffer)!=0) 00448 { 00449 result=false; 00450 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00451 } 00452 line_number++; 00453 } 00454 00455 set_svm(n, svm); 00456 } 00457 00458 svm_loaded=result; 00459 00460 SG_RESET_LOCALE; 00461 return result; 00462 } 00463 00464 bool CMultiClassSVM::save(FILE* modelfl) 00465 { 00466 SG_SET_LOCALE_C; 00467 00468 if (!kernel) 00469 SG_ERROR("Kernel not defined!\n"); 00470 00471 if (!m_svms || m_num_svms<1 || m_num_classes <=2) 00472 SG_ERROR("Multiclass SVM not trained!\n"); 00473 00474 SG_INFO( "Writing model file..."); 00475 fprintf(modelfl,"%%MultiClassSVM\n"); 00476 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type); 00477 fprintf(modelfl,"num_classes=%d;\n", m_num_classes); 00478 fprintf(modelfl,"num_svms=%d;\n", m_num_svms); 00479 fprintf(modelfl,"kernel='%s';\n", kernel->get_name()); 00480 00481 for (int32_t i=0; i<m_num_svms; i++) 00482 { 00483 CSVM* svm=m_svms[i]; 00484 ASSERT(svm); 00485 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1); 00486 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors()); 00487 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias()); 00488 00489 fprintf(modelfl, "alphas%d=[\n", i); 00490 00491 for(int32_t j=0; j<svm->get_num_support_vectors(); j++) 00492 { 00493 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00494 svm->get_alpha(j), svm->get_support_vector(j)); 00495 } 00496 00497 fprintf(modelfl, "];\n"); 00498 } 00499 00500 SG_RESET_LOCALE; 00501 SG_DONE(); 00502 return true ; 00503 }