SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 00012 #include <stdio.h> 00013 #include <string.h> 00014 00015 #include <shogun/lib/config.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/structure/Plif.h> 00018 00019 //#define PLIF_DEBUG 00020 00021 using namespace shogun; 00022 00023 CPlif::CPlif(int32_t l) 00024 : CPlifBase() 00025 { 00026 limits=NULL; 00027 penalties=NULL; 00028 cum_derivatives=NULL; 00029 id=-1; 00030 transform=T_LINEAR; 00031 name=NULL; 00032 max_value=0; 00033 min_value=0; 00034 cache=NULL; 00035 use_svm=0; 00036 use_cache=false; 00037 len=0; 00038 do_calc = true; 00039 if (l>0) 00040 set_plif_length(l); 00041 } 00042 00043 CPlif::~CPlif() 00044 { 00045 SG_FREE(limits); 00046 SG_FREE(penalties); 00047 SG_FREE(name); 00048 SG_FREE(cache); 00049 SG_FREE(cum_derivatives); 00050 } 00051 00052 bool CPlif::set_transform_type(const char *type_str) 00053 { 00054 invalidate_cache(); 00055 00056 if (strcmp(type_str, "linear")==0) 00057 transform = T_LINEAR ; 00058 else if (strcmp(type_str, "")==0) 00059 transform = T_LINEAR ; 00060 else if (strcmp(type_str, "log")==0) 00061 transform = T_LOG ; 00062 else if (strcmp(type_str, "log(+1)")==0) 00063 transform = T_LOG_PLUS1 ; 00064 else if (strcmp(type_str, "log(+3)")==0) 00065 transform = T_LOG_PLUS3 ; 00066 else if (strcmp(type_str, "(+3)")==0) 00067 transform = T_LINEAR_PLUS3 ; 00068 else 00069 { 00070 SG_ERROR( "unknown transform type (%s)\n", type_str) ; 00071 return false ; 00072 } 00073 return true ; 00074 } 00075 00076 void CPlif::init_penalty_struct_cache() 00077 { 00078 if (!use_cache) 00079 return ; 00080 if (cache || use_svm) 00081 return ; 00082 if (max_value<=0) 00083 return ; 00084 00085 float64_t* local_cache=SG_MALLOC(float64_t, ((int32_t) max_value) + 2); 00086 00087 if (local_cache) 00088 { 00089 for (int32_t i=0; i<=max_value; i++) 00090 { 00091 if (i<min_value) 00092 local_cache[i] = -CMath::INFTY ; 00093 else 00094 local_cache[i] = lookup_penalty(i, NULL) ; 00095 } 00096 } 00097 this->cache=local_cache ; 00098 } 00099 00100 void CPlif::set_plif_name(char *p_name) 00101 { 00102 SG_FREE(name); 00103 name=SG_MALLOC(char, strlen(p_name)+3); 00104 strcpy(name,p_name) ; 00105 } 00106 00107 void CPlif::delete_penalty_struct(CPlif** PEN, int32_t P) 00108 { 00109 for (int32_t i=0; i<P; i++) 00110 delete PEN[i] ; 00111 SG_FREE(PEN); 00112 } 00113 00114 float64_t CPlif::lookup_penalty_svm( 00115 float64_t p_value, float64_t *d_values) const 00116 { 00117 ASSERT(use_svm>0); 00118 float64_t d_value=d_values[use_svm-1] ; 00119 #ifdef PLIF_DEBUG 00120 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ; 00121 #endif 00122 00123 if (!do_calc) 00124 return d_value; 00125 switch (transform) 00126 { 00127 case T_LINEAR: 00128 break ; 00129 case T_LOG: 00130 d_value = log(d_value) ; 00131 break ; 00132 case T_LOG_PLUS1: 00133 d_value = log(d_value+1) ; 00134 break ; 00135 case T_LOG_PLUS3: 00136 d_value = log(d_value+3) ; 00137 break ; 00138 case T_LINEAR_PLUS3: 00139 d_value = d_value+3 ; 00140 break ; 00141 default: 00142 SG_ERROR("unknown transform\n"); 00143 break ; 00144 } 00145 00146 int32_t idx = 0 ; 00147 float64_t ret ; 00148 for (int32_t i=0; i<len; i++) 00149 if (limits[i]<=d_value) 00150 idx++ ; 00151 else 00152 break ; // assume it is monotonically increasing 00153 00154 #ifdef PLIF_DEBUG 00155 SG_PRINT(" -> idx = %i ", idx) ; 00156 #endif 00157 00158 if (idx==0) 00159 ret=penalties[0] ; 00160 else if (idx==len) 00161 ret=penalties[len-1] ; 00162 else 00163 { 00164 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00165 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00166 #ifdef PLIF_DEBUG 00167 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ; 00168 #endif 00169 } 00170 #ifdef PLIF_DEBUG 00171 SG_PRINT(" -> ret=%1.3f\n", ret) ; 00172 #endif 00173 00174 return ret ; 00175 } 00176 00177 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const 00178 { 00179 if (use_svm) 00180 return lookup_penalty_svm(p_value, svm_values) ; 00181 00182 if ((p_value<min_value) || (p_value>max_value)) 00183 { 00184 //SG_PRINT("Feature:%s, %s.lookup_penalty(%i): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) ; 00185 return -CMath::INFTY ; 00186 } 00187 if (!do_calc) 00188 return p_value; 00189 if (cache!=NULL && (p_value>=0) && (p_value<=max_value)) 00190 { 00191 float64_t ret=cache[p_value] ; 00192 return ret ; 00193 } 00194 return lookup_penalty((float64_t) p_value, svm_values) ; 00195 } 00196 00197 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const 00198 { 00199 if (use_svm) 00200 return lookup_penalty_svm(p_value, svm_values) ; 00201 00202 #ifdef PLIF_DEBUG 00203 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ; 00204 #endif 00205 00206 00207 if ((p_value<min_value) || (p_value>max_value)) 00208 { 00209 //SG_PRINT("Feature:%s, %s.lookup_penalty(%f): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) ; 00210 return -CMath::INFTY ; 00211 } 00212 00213 if (!do_calc) 00214 return p_value; 00215 00216 float64_t d_value = (float64_t) p_value ; 00217 switch (transform) 00218 { 00219 case T_LINEAR: 00220 break ; 00221 case T_LOG: 00222 d_value = log(d_value) ; 00223 break ; 00224 case T_LOG_PLUS1: 00225 d_value = log(d_value+1) ; 00226 break ; 00227 case T_LOG_PLUS3: 00228 d_value = log(d_value+3) ; 00229 break ; 00230 case T_LINEAR_PLUS3: 00231 d_value = d_value+3 ; 00232 break ; 00233 default: 00234 SG_ERROR( "unknown transform\n") ; 00235 break ; 00236 } 00237 00238 #ifdef PLIF_DEBUG 00239 SG_PRINT(" -> value = %1.4f ", d_value) ; 00240 #endif 00241 00242 int32_t idx = 0 ; 00243 float64_t ret ; 00244 for (int32_t i=0; i<len; i++) 00245 if (limits[i]<=d_value) 00246 idx++ ; 00247 else 00248 break ; // assume it is monotonically increasing 00249 00250 #ifdef PLIF_DEBUG 00251 SG_PRINT(" -> idx = %i ", idx) ; 00252 #endif 00253 00254 if (idx==0) 00255 ret=penalties[0] ; 00256 else if (idx==len) 00257 ret=penalties[len-1] ; 00258 else 00259 { 00260 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00261 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00262 #ifdef PLIF_DEBUG 00263 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ; 00264 #endif 00265 } 00266 //if (p_value>=30 && p_value<150) 00267 //SG_PRINT("%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ; 00268 #ifdef PLIF_DEBUG 00269 SG_PRINT(" -> ret=%1.3f\n", ret) ; 00270 #endif 00271 00272 return ret ; 00273 } 00274 00275 void CPlif::penalty_clear_derivative() 00276 { 00277 for (int32_t i=0; i<len; i++) 00278 cum_derivatives[i]=0.0 ; 00279 } 00280 00281 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor) 00282 { 00283 if (use_svm) 00284 { 00285 penalty_add_derivative_svm(p_value, svm_values, factor) ; 00286 return ; 00287 } 00288 00289 if ((p_value<min_value) || (p_value>max_value)) 00290 { 00291 return ; 00292 } 00293 float64_t d_value = (float64_t) p_value ; 00294 switch (transform) 00295 { 00296 case T_LINEAR: 00297 break ; 00298 case T_LOG: 00299 d_value = log(d_value) ; 00300 break ; 00301 case T_LOG_PLUS1: 00302 d_value = log(d_value+1) ; 00303 break ; 00304 case T_LOG_PLUS3: 00305 d_value = log(d_value+3) ; 00306 break ; 00307 case T_LINEAR_PLUS3: 00308 d_value = d_value+3 ; 00309 break ; 00310 default: 00311 SG_ERROR( "unknown transform\n") ; 00312 break ; 00313 } 00314 00315 int32_t idx = 0 ; 00316 for (int32_t i=0; i<len; i++) 00317 if (limits[i]<=d_value) 00318 idx++ ; 00319 else 00320 break ; // assume it is monotonically increasing 00321 00322 if (idx==0) 00323 cum_derivatives[0]+= factor ; 00324 else if (idx==len) 00325 cum_derivatives[len-1]+= factor ; 00326 else 00327 { 00328 cum_derivatives[idx] += factor * (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00329 cum_derivatives[idx-1]+= factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00330 } 00331 } 00332 00333 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values, float64_t factor) 00334 { 00335 ASSERT(use_svm>0); 00336 float64_t d_value=d_values[use_svm-1] ; 00337 00338 if (d_value<-1e+20) 00339 return; 00340 00341 switch (transform) 00342 { 00343 case T_LINEAR: 00344 break ; 00345 case T_LOG: 00346 d_value = log(d_value) ; 00347 break ; 00348 case T_LOG_PLUS1: 00349 d_value = log(d_value+1) ; 00350 break ; 00351 case T_LOG_PLUS3: 00352 d_value = log(d_value+3) ; 00353 break ; 00354 case T_LINEAR_PLUS3: 00355 d_value = d_value+3 ; 00356 break ; 00357 default: 00358 SG_ERROR( "unknown transform\n") ; 00359 break ; 00360 } 00361 00362 int32_t idx = 0 ; 00363 for (int32_t i=0; i<len; i++) 00364 if (limits[i]<=d_value) 00365 idx++ ; 00366 else 00367 break ; // assume it is monotonically increasing 00368 00369 if (idx==0) 00370 cum_derivatives[0]+=factor ; 00371 else if (idx==len) 00372 cum_derivatives[len-1]+=factor ; 00373 else 00374 { 00375 cum_derivatives[idx] += factor*(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00376 cum_derivatives[idx-1] += factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00377 } 00378 } 00379 00380 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids) 00381 { 00382 if (use_svm) 00383 { 00384 svm_ids[(*num_svms)] = use_svm; 00385 (*num_svms)++; 00386 } 00387 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type()); 00388 } 00389 00390 bool CPlif::get_do_calc() 00391 { 00392 return do_calc; 00393 } 00394 00395 void CPlif::set_do_calc(bool b) 00396 { 00397 do_calc = b;; 00398 }