SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _ONLINELINEARCLASSIFIER_H__ 00012 #define _ONLINELINEARCLASSIFIER_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/Labels.h> 00016 #include <shogun/features/StreamingDotFeatures.h> 00017 #include <shogun/machine/Machine.h> 00018 00019 #include <stdio.h> 00020 00021 namespace shogun 00022 { 00049 class COnlineLinearMachine : public CMachine 00050 { 00051 public: 00053 COnlineLinearMachine(); 00054 virtual ~COnlineLinearMachine(); 00055 00061 virtual inline void get_w(float32_t*& dst_w, int32_t& dst_dims) 00062 { 00063 ASSERT(w && w_dim>0); 00064 dst_w=w; 00065 dst_dims=w_dim; 00066 } 00067 00074 virtual void get_w(float64_t*& dst_w, int32_t& dst_dims) 00075 { 00076 ASSERT(w && w_dim>0); 00077 dst_w=SG_MALLOC(float64_t, w_dim); 00078 for (int32_t i=0; i<w_dim; i++) 00079 dst_w[i]=w[i]; 00080 dst_dims=w_dim; 00081 } 00082 00087 virtual inline SGVector<float32_t> get_w() 00088 { 00089 return SGVector<float32_t>(w, w_dim); 00090 } 00091 00097 virtual inline void set_w(float32_t* src_w, int32_t src_w_dim) 00098 { 00099 SG_FREE(w); 00100 w=SG_MALLOC(float32_t, src_w_dim); 00101 memcpy(w, src_w, size_t(src_w_dim)*sizeof(float32_t)); 00102 w_dim=src_w_dim; 00103 } 00104 00111 virtual void set_w(float64_t* src_w, int32_t src_w_dim) 00112 { 00113 SG_FREE(w); 00114 w=SG_MALLOC(float32_t, src_w_dim); 00115 for (int32_t i=0; i<src_w_dim; i++) 00116 w[i] = src_w[i]; 00117 w_dim=src_w_dim; 00118 } 00119 00124 virtual inline void set_bias(float32_t b) 00125 { 00126 bias=b; 00127 } 00128 00133 virtual inline float32_t get_bias() 00134 { 00135 return bias; 00136 } 00137 00143 virtual bool load(FILE* srcfile); 00144 00150 virtual bool save(FILE* dstfile); 00151 00156 virtual inline void set_features(CStreamingDotFeatures* feat) 00157 { 00158 if (features) 00159 SG_UNREF(features); 00160 SG_REF(feat); 00161 features=feat; 00162 } 00163 00168 virtual CLabels* apply(); 00169 00175 virtual CLabels* apply(CFeatures* data); 00176 00178 virtual float64_t apply(int32_t vec_idx) 00179 { 00180 SG_NOTIMPLEMENTED; 00181 return CMath::INFTY; 00182 } 00183 00192 virtual float32_t apply(float32_t* vec, int32_t len); 00193 00199 virtual float32_t apply_to_current_example(); 00200 00205 virtual CStreamingDotFeatures* get_features() { SG_REF(features); return features; } 00206 00212 virtual const char* get_name() const { return "OnlineLinearMachine"; } 00213 00214 protected: 00216 int32_t w_dim; 00218 float32_t* w; 00220 float32_t bias; 00222 CStreamingDotFeatures* features; 00223 }; 00224 } 00225 #endif