SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
ScatterSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Written (W) 2009 Marius Kloft
00009  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
00010  */
00011 
00012 
00013 #include <shogun/kernel/Kernel.h>
00014 #include <shogun/classifier/svm/ScatterSVM.h>
00015 #include <shogun/kernel/ScatterKernelNormalizer.h>
00016 #include <shogun/io/SGIO.h>
00017 
00018 using namespace shogun;
00019 
00020 CScatterSVM::CScatterSVM()
00021 : CMultiClassSVM(ONE_VS_REST), scatter_type(NO_BIAS_LIBSVM),
00022   model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00023 {
00024     SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n");
00025 }
00026 
00027 CScatterSVM::CScatterSVM(SCATTER_TYPE type)
00028 : CMultiClassSVM(ONE_VS_REST), scatter_type(type), model(NULL),
00029     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00030 {
00031 }
00032 
00033 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00034 : CMultiClassSVM(ONE_VS_REST, C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
00035     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00036 {
00037 }
00038 
00039 CScatterSVM::~CScatterSVM()
00040 {
00041     SG_FREE(norm_wc);
00042     SG_FREE(norm_wcw);
00043 }
00044 
00045 bool CScatterSVM::train_machine(CFeatures* data)
00046 {
00047     ASSERT(labels && labels->get_num_labels());
00048     m_num_classes = labels->get_num_classes();
00049     int32_t num_vectors = labels->get_num_labels();
00050 
00051     if (data)
00052     {
00053         if (labels->get_num_labels() != data->get_num_vectors())
00054             SG_ERROR("Number of training vectors does not match number of labels\n");
00055         kernel->init(data, data);
00056     }
00057 
00058     int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
00059     CMath::fill_vector(numc, m_num_classes, 0);
00060 
00061     for (int32_t i=0; i<num_vectors; i++)
00062         numc[(int32_t) labels->get_int_label(i)]++;
00063 
00064     int32_t Nc=0;
00065     int32_t Nmin=num_vectors;
00066     for (int32_t i=0; i<m_num_classes; i++)
00067     {
00068         if (numc[i]>0)
00069         {
00070             Nc++;
00071             Nmin=CMath::min(Nmin, numc[i]);
00072         }
00073 
00074     }
00075     SG_FREE(numc);
00076     m_num_classes=m_num_classes;
00077 
00078     bool result=false;
00079 
00080     if (scatter_type==NO_BIAS_LIBSVM)
00081     {
00082         result=train_no_bias_libsvm();
00083     }
00084 
00085     else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 
00086     {
00087         float64_t nu_min=((float64_t) Nc)/num_vectors;
00088         float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
00089 
00090         SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00091 
00092         if (get_nu()<nu_min || get_nu()>nu_max)
00093             SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00094 
00095         result=train_testrule12();
00096     }
00097     else
00098         SG_ERROR("Unknown Scatter type\n"); 
00099 
00100     return result;
00101 }
00102 
00103 bool CScatterSVM::train_no_bias_libsvm()
00104 {
00105     struct svm_node* x_space;
00106 
00107     problem.l=labels->get_num_labels();
00108     SG_INFO( "%d trainlabels\n", problem.l);
00109 
00110     problem.y=SG_MALLOC(float64_t, problem.l);
00111     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00112     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00113 
00114     for (int32_t i=0; i<problem.l; i++)
00115     {
00116         problem.y[i]=+1;
00117         problem.x[i]=&x_space[2*i];
00118         x_space[2*i].index=i;
00119         x_space[2*i+1].index=-1;
00120     }
00121 
00122     int32_t weights_label[2]={-1,+1};
00123     float64_t weights[2]={1.0,get_C2()/get_C1()};
00124 
00125     ASSERT(kernel && kernel->has_features());
00126     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00127 
00128     param.svm_type=C_SVC; // Nu MC SVM
00129     param.kernel_type = LINEAR;
00130     param.degree = 3;
00131     param.gamma = 0;    // 1/k
00132     param.coef0 = 0;
00133     param.nu = get_nu(); // Nu
00134     CKernelNormalizer* prev_normalizer=kernel->get_normalizer();
00135     kernel->set_normalizer(new CScatterKernelNormalizer(
00136                 m_num_classes-1, -1, labels, prev_normalizer));
00137     param.kernel=kernel;
00138     param.cache_size = kernel->get_cache_size();
00139     param.C = 0;
00140     param.eps = epsilon;
00141     param.p = 0.1;
00142     param.shrinking = 0;
00143     param.nr_weight = 2;
00144     param.weight_label = weights_label;
00145     param.weight = weights;
00146     param.nr_class=m_num_classes;
00147     param.use_bias = get_bias_enabled();
00148 
00149     const char* error_msg = svm_check_parameter(&problem,&param);
00150 
00151     if(error_msg)
00152         SG_ERROR("Error: %s\n",error_msg);
00153 
00154     model = svm_train(&problem, &param);
00155     kernel->set_normalizer(prev_normalizer);
00156     SG_UNREF(prev_normalizer);
00157 
00158     if (model)
00159     {
00160         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00161 
00162         ASSERT(model->nr_class==m_num_classes);
00163         create_multiclass_svm(m_num_classes);
00164 
00165         rho=model->rho[0];
00166 
00167         SG_FREE(norm_wcw);
00168         norm_wcw = SG_MALLOC(float64_t, m_num_svms);
00169 
00170         for (int32_t i=0; i<m_num_classes; i++)
00171         {
00172             int32_t num_sv=model->nSV[i];
00173 
00174             CSVM* svm=new CSVM(num_sv);
00175             svm->set_bias(model->rho[i+1]);
00176             norm_wcw[i]=model->normwcw[i];
00177 
00178 
00179             for (int32_t j=0; j<num_sv; j++)
00180             {
00181                 svm->set_alpha(j, model->sv_coef[i][j]);
00182                 svm->set_support_vector(j, model->SV[i][j].index);
00183             }
00184 
00185             set_svm(i, svm);
00186         }
00187 
00188         SG_FREE(problem.x);
00189         SG_FREE(problem.y);
00190         SG_FREE(x_space);
00191         for (int32_t i=0; i<m_num_classes; i++)
00192         {
00193             SG_FREE(model->SV[i]);
00194             model->SV[i]=NULL;
00195         }
00196         svm_destroy_model(model);
00197 
00198         if (scatter_type==TEST_RULE2)
00199             compute_norm_wc();
00200 
00201         model=NULL;
00202         return true;
00203     }
00204     else
00205         return false;
00206 }
00207 
00208 
00209 
00210 bool CScatterSVM::train_testrule12()
00211 {
00212     struct svm_node* x_space;
00213     problem.l=labels->get_num_labels();
00214     SG_INFO( "%d trainlabels\n", problem.l);
00215 
00216     problem.y=SG_MALLOC(float64_t, problem.l);
00217     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00218     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00219 
00220     for (int32_t i=0; i<problem.l; i++)
00221     {
00222         problem.y[i]=labels->get_label(i);
00223         problem.x[i]=&x_space[2*i];
00224         x_space[2*i].index=i;
00225         x_space[2*i+1].index=-1;
00226     }
00227 
00228     int32_t weights_label[2]={-1,+1};
00229     float64_t weights[2]={1.0,get_C2()/get_C1()};
00230 
00231     ASSERT(kernel && kernel->has_features());
00232     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00233 
00234     param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
00235     param.kernel_type = LINEAR;
00236     param.degree = 3;
00237     param.gamma = 0;    // 1/k
00238     param.coef0 = 0;
00239     param.nu = get_nu(); // Nu
00240     param.kernel=kernel;
00241     param.cache_size = kernel->get_cache_size();
00242     param.C = 0;
00243     param.eps = epsilon;
00244     param.p = 0.1;
00245     param.shrinking = 0;
00246     param.nr_weight = 2;
00247     param.weight_label = weights_label;
00248     param.weight = weights;
00249     param.nr_class=m_num_classes;
00250     param.use_bias = get_bias_enabled();
00251 
00252     const char* error_msg = svm_check_parameter(&problem,&param);
00253 
00254     if(error_msg)
00255         SG_ERROR("Error: %s\n",error_msg);
00256 
00257     model = svm_train(&problem, &param);
00258 
00259     if (model)
00260     {
00261         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00262 
00263         ASSERT(model->nr_class==m_num_classes);
00264         create_multiclass_svm(m_num_classes);
00265 
00266         rho=model->rho[0];
00267 
00268         SG_FREE(norm_wcw);
00269         norm_wcw = SG_MALLOC(float64_t, m_num_svms);
00270 
00271         for (int32_t i=0; i<m_num_classes; i++)
00272         {
00273             int32_t num_sv=model->nSV[i];
00274 
00275             CSVM* svm=new CSVM(num_sv);
00276             svm->set_bias(model->rho[i+1]);
00277             norm_wcw[i]=model->normwcw[i];
00278 
00279 
00280             for (int32_t j=0; j<num_sv; j++)
00281             {
00282                 svm->set_alpha(j, model->sv_coef[i][j]);
00283                 svm->set_support_vector(j, model->SV[i][j].index);
00284             }
00285 
00286             set_svm(i, svm);
00287         }
00288 
00289         SG_FREE(problem.x);
00290         SG_FREE(problem.y);
00291         SG_FREE(x_space);
00292         for (int32_t i=0; i<m_num_classes; i++)
00293         {
00294             SG_FREE(model->SV[i]);
00295             model->SV[i]=NULL;
00296         }
00297         svm_destroy_model(model);
00298 
00299         if (scatter_type==TEST_RULE2)
00300             compute_norm_wc();
00301 
00302         model=NULL;
00303         return true;
00304     }
00305     else
00306         return false;
00307 }
00308 
00309 void CScatterSVM::compute_norm_wc()
00310 {
00311     SG_FREE(norm_wc);
00312     norm_wc = SG_MALLOC(float64_t, m_num_svms);
00313     for (int32_t i=0; i<m_num_svms; i++)
00314         norm_wc[i]=0;
00315 
00316 
00317     for (int c=0; c<m_num_svms; c++)
00318     {
00319         CSVM* svm=m_svms[c];
00320         int32_t num_sv = svm->get_num_support_vectors();
00321 
00322         for (int32_t i=0; i<num_sv; i++)
00323         {
00324             int32_t ii=svm->get_support_vector(i);
00325             for (int32_t j=0; j<num_sv; j++)
00326             {
00327                 int32_t jj=svm->get_support_vector(j);
00328                 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j);
00329             }
00330         }
00331     }
00332 
00333     for (int32_t i=0; i<m_num_svms; i++)
00334         norm_wc[i]=CMath::sqrt(norm_wc[i]);
00335 
00336     CMath::display_vector(norm_wc, m_num_svms, "norm_wc");
00337 }
00338 
00339 CLabels* CScatterSVM::classify_one_vs_rest()
00340 {
00341     CLabels* output=NULL;
00342     if (!kernel)
00343     {
00344         SG_ERROR( "SVM can not proceed without kernel!\n");
00345         return false ;
00346     }
00347 
00348     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00349     {
00350         int32_t num_vectors=kernel->get_num_vec_rhs();
00351 
00352         output=new CLabels(num_vectors);
00353         SG_REF(output);
00354 
00355         if (scatter_type == TEST_RULE1)
00356         {
00357             ASSERT(m_num_svms>0);
00358             for (int32_t i=0; i<num_vectors; i++)
00359                 output->set_label(i, apply(i));
00360         }
00361 
00362         else
00363         {
00364             ASSERT(m_num_svms>0);
00365             ASSERT(num_vectors==output->get_num_labels());
00366             CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
00367 
00368             for (int32_t i=0; i<m_num_svms; i++)
00369             {
00370                 //SG_PRINT("svm %d\n", i);
00371                 ASSERT(m_svms[i]);
00372                 m_svms[i]->set_kernel(kernel);
00373                 m_svms[i]->set_labels(labels);
00374                 outputs[i]=m_svms[i]->apply();
00375             }
00376 
00377             for (int32_t i=0; i<num_vectors; i++)
00378             {
00379                 int32_t winner=0;
00380                 float64_t max_out=outputs[0]->get_label(i)/norm_wc[0];
00381 
00382                 for (int32_t j=1; j<m_num_svms; j++)
00383                 {
00384                     float64_t out=outputs[j]->get_label(i)/norm_wc[j];
00385 
00386                     if (out>max_out)
00387                     {
00388                         winner=j;
00389                         max_out=out;
00390                     }
00391                 }
00392 
00393                 output->set_label(i, winner);
00394             }
00395 
00396             for (int32_t i=0; i<m_num_svms; i++)
00397                 SG_UNREF(outputs[i]);
00398 
00399             SG_FREE(outputs);
00400         }
00401     }
00402 
00403     return output;
00404 }
00405 
00406 float64_t CScatterSVM::apply(int32_t num)
00407 {
00408     ASSERT(m_num_svms>0);
00409     float64_t* outputs=SG_MALLOC(float64_t, m_num_svms);
00410     int32_t winner=0;
00411 
00412     if (scatter_type == TEST_RULE1)
00413     {
00414         for (int32_t c=0; c<m_num_svms; c++)
00415             outputs[c]=m_svms[c]->get_bias()-rho;
00416 
00417         for (int32_t c=0; c<m_num_svms; c++)
00418         {
00419             float64_t v=0;
00420 
00421             for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++)
00422             {
00423                 float64_t alpha=m_svms[c]->get_alpha(i);
00424                 int32_t svidx=m_svms[c]->get_support_vector(i);
00425                 v += alpha*kernel->kernel(svidx, num);
00426             }
00427 
00428             outputs[c] += v;
00429             for (int32_t j=0; j<m_num_svms; j++)
00430                 outputs[j] -= v/m_num_svms;
00431         }
00432 
00433         for (int32_t j=0; j<m_num_svms; j++)
00434             outputs[j]/=norm_wcw[j];
00435 
00436         float64_t max_out=outputs[0];
00437         for (int32_t j=0; j<m_num_svms; j++)
00438         {
00439             if (outputs[j]>max_out)
00440             {
00441                 max_out=outputs[j];
00442                 winner=j;
00443             }
00444         }
00445     }
00446 
00447     else
00448     {
00449         float64_t max_out=m_svms[0]->apply(num)/norm_wc[0];
00450 
00451         for (int32_t i=1; i<m_num_svms; i++)
00452         {
00453             outputs[i]=m_svms[i]->apply(num)/norm_wc[i];
00454             if (outputs[i]>max_out)
00455             {
00456                 winner=i;
00457                 max_out=outputs[i];
00458             }
00459         }
00460     }
00461 
00462     SG_FREE(outputs);
00463     return winner;
00464 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation