SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _PLUGINESTIMATE_H___ 00012 #define _PLUGINESTIMATE_H___ 00013 00014 #include <shogun/machine/Machine.h> 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/features/Labels.h> 00017 #include <shogun/distributions/LinearHMM.h> 00018 00019 namespace shogun 00020 { 00034 class CPluginEstimate: public CMachine 00035 { 00036 public: 00041 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10); 00042 virtual ~CPluginEstimate(); 00043 00048 CLabels* apply(); 00049 00055 virtual CLabels* apply(CFeatures* data); 00056 00061 virtual inline void set_features(CStringFeatures<uint16_t>* feat) 00062 { 00063 SG_UNREF(features); 00064 SG_REF(feat); 00065 features=feat; 00066 } 00067 00072 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; } 00073 00075 float64_t apply(int32_t vec_idx); 00076 00083 inline float64_t posterior_log_odds_obsolete( 00084 uint16_t* vector, int32_t len) 00085 { 00086 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len); 00087 } 00088 00095 inline float64_t get_parameterwise_log_odds( 00096 uint16_t obs, int32_t position) 00097 { 00098 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position); 00099 } 00100 00107 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos) 00108 { 00109 return pos_model->get_log_derivative_obsolete(obs, pos); 00110 } 00111 00118 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos) 00119 { 00120 return neg_model->get_log_derivative_obsolete(obs, pos); 00121 } 00122 00131 inline bool get_model_params( 00132 float64_t*& pos_params, float64_t*& neg_params, 00133 int32_t &seq_length, int32_t &num_symbols) 00134 { 00135 if ((!pos_model) || (!neg_model)) 00136 { 00137 SG_ERROR( "no model available\n"); 00138 return false; 00139 } 00140 00141 SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs(); 00142 pos_params = log_pos_trans.vector; 00143 SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs(); 00144 neg_params = log_neg_trans.vector; 00145 00146 seq_length = pos_model->get_sequence_length(); 00147 num_symbols = pos_model->get_num_symbols(); 00148 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters()); 00149 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols()); 00150 return true; 00151 } 00152 00159 inline void set_model_params( 00160 float64_t* pos_params, float64_t* neg_params, 00161 int32_t seq_length, int32_t num_symbols) 00162 { 00163 int32_t num_params; 00164 00165 SG_UNREF(pos_model); 00166 pos_model=new CLinearHMM(seq_length, num_symbols); 00167 SG_REF(pos_model); 00168 00169 00170 SG_UNREF(neg_model); 00171 neg_model=new CLinearHMM(seq_length, num_symbols); 00172 SG_REF(neg_model); 00173 00174 num_params=pos_model->get_num_model_parameters(); 00175 ASSERT(seq_length*num_symbols==num_params); 00176 ASSERT(num_params==neg_model->get_num_model_parameters()); 00177 00178 pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params)); 00179 neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params)); 00180 } 00181 00186 inline int32_t get_num_params() 00187 { 00188 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters(); 00189 } 00190 00195 inline bool check_models() 00196 { 00197 return ( (pos_model!=NULL) && (neg_model!=NULL) ); 00198 } 00199 00201 inline virtual const char* get_name() const { return "PluginEstimate"; } 00202 00203 protected: 00212 virtual bool train_machine(CFeatures* data=NULL); 00213 00214 protected: 00216 float64_t m_pos_pseudo; 00218 float64_t m_neg_pseudo; 00219 00221 CLinearHMM* pos_model; 00223 CLinearHMM* neg_model; 00224 00226 CStringFeatures<uint16_t>* features; 00227 }; 00228 } 00229 #endif