SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/CPLEXSVM.h> 00012 #include <shogun/lib/common.h> 00013 00014 #ifdef USE_CPLEX 00015 #include <shogun/io/SGIO.h> 00016 #include <shogun/mathematics/Math.h> 00017 #include <shogun/mathematics/Cplex.h> 00018 #include <shogun/features/Labels.h> 00019 00020 using namespace shogun; 00021 00022 CCPLEXSVM::CCPLEXSVM() 00023 : CSVM() 00024 { 00025 } 00026 00027 CCPLEXSVM::~CCPLEXSVM() 00028 { 00029 } 00030 00031 bool CCPLEXSVM::train_machine(CFeatures* data) 00032 { 00033 bool result = false; 00034 CCplex cplex; 00035 00036 if (data) 00037 { 00038 if (labels->get_num_labels() != data->get_num_vectors()) 00039 SG_ERROR("Number of training vectors does not match number of labels\n"); 00040 kernel->init(data, data); 00041 } 00042 00043 if (cplex.init(E_QP)) 00044 { 00045 int32_t n,m; 00046 int32_t num_label=0; 00047 float64_t* y = labels->get_labels(num_label); 00048 float64_t* H = kernel->get_kernel_matrix<float64_t>(m, n, NULL); 00049 ASSERT(n>0 && n==m && n==num_label); 00050 float64_t* alphas=SG_MALLOC(float64_t, n); 00051 float64_t* lb=SG_MALLOC(float64_t, n); 00052 float64_t* ub=SG_MALLOC(float64_t, n); 00053 00054 //hessian y'y.*K 00055 for (int32_t i=0; i<n; i++) 00056 { 00057 lb[i]=0; 00058 ub[i]=get_C1(); 00059 00060 for (int32_t j=0; j<n; j++) 00061 H[i*n+j]*=y[j]*y[i]; 00062 } 00063 00064 //feed qp to cplex 00065 00066 00067 int32_t j=0; 00068 for (int32_t i=0; i<n; i++) 00069 { 00070 if (alphas[i]>0) 00071 { 00072 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]); 00073 set_alpha(j, alphas[i]*labels->get_label(i)); 00074 set_support_vector(j, i); 00075 j++; 00076 } 00077 } 00078 //compute_objective(); 00079 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias()); 00080 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors()); 00081 00082 SG_FREE(alphas); 00083 SG_FREE(lb); 00084 SG_FREE(ub); 00085 SG_FREE(H); 00086 00087 result = true; 00088 } 00089 00090 if (!result) 00091 SG_ERROR( "cplex svm failed"); 00092 00093 return result; 00094 } 00095 #endif