SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 00013 #include <stdio.h> 00014 #include <string.h> 00015 00016 #include <shogun/io/SGIO.h> 00017 00018 #include <shogun/structure/PlifArray.h> 00019 #include <shogun/structure/Plif.h> 00020 00021 //#define PLIFARRAY_DEBUG 00022 00023 using namespace shogun; 00024 00025 CPlifArray::CPlifArray() 00026 : CPlifBase() 00027 { 00028 min_value=-1e6; 00029 max_value=1e6; 00030 } 00031 00032 CPlifArray::~CPlifArray() 00033 { 00034 } 00035 00036 void CPlifArray::add_plif(CPlifBase* new_plif) 00037 { 00038 ASSERT(new_plif); 00039 m_array.append_element(new_plif) ; 00040 00041 min_value = -1e6 ; 00042 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00043 { 00044 ASSERT(m_array[i]); 00045 if (!m_array[i]->uses_svm_values()) 00046 min_value = CMath::max(min_value, m_array[i]->get_min_value()) ; 00047 } 00048 00049 max_value = 1e6 ; 00050 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00051 if (!m_array[i]->uses_svm_values()) 00052 max_value = CMath::min(max_value, m_array[i]->get_max_value()) ; 00053 } 00054 00055 void CPlifArray::clear() 00056 { 00057 m_array.clear_array(); 00058 min_value = -1e6 ; 00059 max_value = 1e6 ; 00060 } 00061 00062 float64_t CPlifArray::lookup_penalty( 00063 float64_t p_value, float64_t* svm_values) const 00064 { 00065 //min_value = -1e6 ; 00066 //max_value = 1e6 ; 00067 if (p_value<min_value || p_value>max_value) 00068 { 00069 //SG_WARNING("lookup_penalty: p_value: %i min_value: %f, max_value: %f\n",p_value, min_value, max_value); 00070 return -CMath::INFTY ; 00071 } 00072 float64_t ret = 0.0 ; 00073 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00074 ret += m_array[i]->lookup_penalty(p_value, svm_values) ; 00075 return ret ; 00076 } 00077 00078 float64_t CPlifArray::lookup_penalty( 00079 int32_t p_value, float64_t* svm_values) const 00080 { 00081 //min_value = -1e6 ; 00082 //max_value = 1e6 ; 00083 if (p_value<min_value || p_value>max_value) 00084 { 00085 //SG_WARNING("lookup_penalty: p_value: %i min_value: %f, max_value: %f\n",p_value, min_value, max_value); 00086 return -CMath::INFTY ; 00087 } 00088 float64_t ret = 0.0 ; 00089 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00090 { 00091 float64_t val = m_array[i]->lookup_penalty(p_value, svm_values) ; 00092 ret += val ; 00093 #ifdef PLIFARRAY_DEBUG 00094 CPlif * plif = (CPlif*)m_array[i] ; 00095 if (plif->get_use_svm()) 00096 SG_PRINT("penalty[%i]=%1.5f (use_svm=%i -> %1.5f)\n", i, val, plif->get_use_svm(), svm_values[plif->get_use_svm()-1]) ; 00097 else 00098 SG_PRINT("penalty[%i]=%1.5f\n", i, val) ; 00099 #endif 00100 } 00101 return ret ; 00102 } 00103 00104 void CPlifArray::penalty_clear_derivative() 00105 { 00106 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00107 m_array[i]->penalty_clear_derivative() ; 00108 } 00109 00110 void CPlifArray::penalty_add_derivative( 00111 float64_t p_value, float64_t* svm_values, float64_t factor) 00112 { 00113 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00114 m_array[i]->penalty_add_derivative(p_value, svm_values, factor) ; 00115 } 00116 00117 bool CPlifArray::uses_svm_values() const 00118 { 00119 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00120 if (m_array[i]->uses_svm_values()) 00121 return true ; 00122 return false ; 00123 } 00124 00125 int32_t CPlifArray::get_max_id() const 00126 { 00127 int32_t max_id = 0 ; 00128 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00129 max_id = CMath::max(max_id, m_array[i]->get_max_id()) ; 00130 return max_id ; 00131 } 00132 00133 void CPlifArray::get_used_svms(int32_t* num_svms, int32_t* svm_ids) 00134 { 00135 SG_PRINT("get_used_svms: num: %i \n",m_array.get_num_elements()); 00136 for (int32_t i=0; i<m_array.get_num_elements(); i++) 00137 { 00138 m_array[i]->get_used_svms(num_svms, svm_ids); 00139 } 00140 SG_PRINT("\n"); 00141 }