SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2010 Soeren Sonnenburg 00008 * Copyright (C) 2010 Berlin Institute of Technology 00009 */ 00010 00011 #ifndef _SCATTERKERNELNORMALIZER_H___ 00012 #define _SCATTERKERNELNORMALIZER_H___ 00013 00014 #include <shogun/kernel/KernelNormalizer.h> 00015 #include <shogun/kernel/IdentityKernelNormalizer.h> 00016 #include <shogun/kernel/Kernel.h> 00017 #include <shogun/features/Labels.h> 00018 #include <shogun/io/SGIO.h> 00019 00020 namespace shogun 00021 { 00023 class CScatterKernelNormalizer: public CKernelNormalizer 00024 { 00025 00026 public: 00028 CScatterKernelNormalizer() : CKernelNormalizer() 00029 { 00030 init(); 00031 } 00032 00035 CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag, 00036 CLabels* labels,CKernelNormalizer* normalizer=NULL) 00037 : CKernelNormalizer() 00038 { 00039 init(); 00040 00041 m_testing_class=-1; 00042 m_const_diag=const_diag; 00043 m_const_offdiag=const_offdiag; 00044 00045 ASSERT(labels) 00046 SG_REF(labels); 00047 m_labels=labels; 00048 00049 if (normalizer==NULL) 00050 normalizer=new CIdentityKernelNormalizer(); 00051 SG_REF(normalizer); 00052 m_normalizer=normalizer; 00053 00054 SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g" 00055 " const_offdiag=%g num_labels=%d and normalizer='%s'\n", 00056 const_diag, const_offdiag, labels->get_num_labels(), 00057 normalizer->get_name()); 00058 } 00059 00061 virtual ~CScatterKernelNormalizer() 00062 { 00063 SG_UNREF(m_labels); 00064 SG_UNREF(m_normalizer); 00065 } 00066 00069 virtual bool init(CKernel* k) 00070 { 00071 m_normalizer->init(k); 00072 return true; 00073 } 00074 00079 int32_t get_testing_class() 00080 { 00081 return m_testing_class; 00082 } 00083 00088 void set_testing_class(int32_t c) 00089 { 00090 m_testing_class=c; 00091 } 00092 00098 inline virtual float64_t normalize(float64_t value, int32_t idx_lhs, 00099 int32_t idx_rhs) 00100 { 00101 value=m_normalizer->normalize(value, idx_lhs, idx_rhs); 00102 float64_t c=m_const_offdiag; 00103 00104 if (m_testing_class>=0) 00105 { 00106 if (m_labels->get_label(idx_lhs) == m_testing_class) 00107 c=m_const_diag; 00108 } 00109 else 00110 { 00111 if (m_labels->get_label(idx_lhs) == m_labels->get_label(idx_rhs)) 00112 c=m_const_diag; 00113 00114 } 00115 return value*c; 00116 } 00117 00122 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00123 { 00124 SG_ERROR("normalize_lhs not implemented"); 00125 return 0; 00126 } 00127 00132 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00133 { 00134 SG_ERROR("normalize_rhs not implemented"); 00135 return 0; 00136 } 00137 00139 inline virtual const char* get_name() const 00140 { 00141 return "ScatterKernelNormalizer"; 00142 } 00143 00144 private: 00145 void init() 00146 { 00147 m_const_diag = 1.0; 00148 m_const_offdiag = 1.0; 00149 00150 m_labels = NULL; 00151 m_normalizer = NULL; 00152 00153 m_testing_class = -1; 00154 00155 00156 m_parameters->add(&m_testing_class, "m_testing_class" 00157 "Testing Class."); 00158 m_parameters->add(&m_const_diag, "m_const_diag" 00159 "Factor to multiply to diagonal elements."); 00160 m_parameters->add(&m_const_offdiag, "m_const_offdiag" 00161 "Factor to multiply to off-diagonal elements."); 00162 00163 m_parameters->add((CSGObject**) &m_labels, "m_labels", "Labels"); 00164 m_parameters->add((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer."); 00165 } 00166 00167 protected: 00168 00170 float64_t m_const_diag; 00172 float64_t m_const_offdiag; 00173 00175 CLabels* m_labels; 00176 00178 CKernelNormalizer* m_normalizer; 00179 00181 int32_t m_testing_class; 00182 }; 00183 } 00184 #endif 00185