SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _LINEARHMM_H__ 00013 #define _LINEARHMM_H__ 00014 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/features/Labels.h> 00017 #include <shogun/distributions/Distribution.h> 00018 00019 namespace shogun 00020 { 00039 class CLinearHMM : public CDistribution 00040 { 00041 public: 00043 CLinearHMM(); 00044 00049 CLinearHMM(CStringFeatures<uint16_t>* f); 00050 00056 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols); 00057 00058 virtual ~CLinearHMM(); 00059 00068 virtual bool train(CFeatures* data=NULL); 00069 00077 bool train( 00078 const int32_t* indizes, int32_t num_indizes, 00079 float64_t pseudo_count); 00080 00087 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len); 00088 00095 float64_t get_likelihood_example(uint16_t* vector, int32_t len); 00096 00102 virtual float64_t get_log_likelihood_example(int32_t num_example); 00103 00110 virtual float64_t get_log_derivative( 00111 int32_t num_param, int32_t num_example); 00112 00119 virtual inline float64_t get_log_derivative_obsolete( 00120 uint16_t obs, int32_t pos) 00121 { 00122 return 1.0/transition_probs[pos*num_symbols+obs]; 00123 } 00124 00131 virtual inline float64_t get_derivative_obsolete( 00132 uint16_t* vector, int32_t len, int32_t pos) 00133 { 00134 ASSERT(pos<len); 00135 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]]; 00136 } 00137 00142 virtual inline int32_t get_sequence_length() { return sequence_length; } 00143 00148 virtual inline int32_t get_num_symbols() { return num_symbols; } 00149 00154 virtual inline int32_t get_num_model_parameters() { return num_params; } 00155 00162 virtual inline float64_t get_positional_log_parameter( 00163 uint16_t obs, int32_t position) 00164 { 00165 return log_transition_probs[position*num_symbols+obs]; 00166 } 00167 00173 virtual inline float64_t get_log_model_parameter(int32_t num_param) 00174 { 00175 ASSERT(log_transition_probs); 00176 ASSERT(num_param<num_params); 00177 00178 return log_transition_probs[num_param]; 00179 } 00180 00185 virtual SGVector<float64_t> get_log_transition_probs(); 00186 00192 virtual bool set_log_transition_probs(SGVector<float64_t> probs); 00193 00198 virtual SGVector<float64_t> get_transition_probs(); 00199 00205 virtual bool set_transition_probs(SGVector<float64_t> probs); 00206 00208 inline virtual const char* get_name() const { return "LinearHMM"; } 00209 00210 protected: 00211 virtual void load_serializable_post() throw (ShogunException); 00212 00213 private: 00214 void init(); 00215 00216 protected: 00218 int32_t sequence_length; 00220 int32_t num_symbols; 00222 int32_t num_params; 00224 float64_t* transition_probs; 00226 float64_t* log_transition_probs; 00227 }; 00228 } 00229 #endif