SHOGUN
v1.1.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Adaptation of Vowpal Wabbit v5.1. 00013 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00014 */ 00015 00016 #ifndef _VOWPALWABBIT_H__ 00017 #define _VOWPALWABBIT_H__ 00018 00019 #include <shogun/classifier/vw/vw_common.h> 00020 #include <shogun/classifier/vw/learners/VwAdaptiveLearner.h> 00021 #include <shogun/classifier/vw/learners/VwNonAdaptiveLearner.h> 00022 #include <shogun/classifier/vw/VwRegressor.h> 00023 00024 #include <shogun/features/StreamingVwFeatures.h> 00025 #include <shogun/machine/OnlineLinearMachine.h> 00026 00027 namespace shogun 00028 { 00038 class CVowpalWabbit: public COnlineLinearMachine 00039 { 00040 public: 00044 CVowpalWabbit(); 00045 00052 CVowpalWabbit(CStreamingVwFeatures* feat); 00053 00057 ~CVowpalWabbit(); 00058 00063 void reinitialize_weights(); 00064 00073 void set_no_training(bool dont_train) { no_training = dont_train; } 00074 00080 void set_adaptive(bool adaptive_learning); 00081 00088 void set_exact_adaptive_norm(bool exact_adaptive); 00089 00095 void set_num_passes(int32_t passes) 00096 { 00097 env->num_passes = passes; 00098 } 00099 00105 void load_regressor(char* file_name); 00106 00113 void set_regressor_out(char* file_name, bool is_text = true); 00114 00120 void set_prediction_out(char* file_name); 00121 00128 void add_quadratic_pair(char* pair); 00129 00135 virtual bool train_machine(CFeatures* feat = NULL); 00136 00144 virtual float32_t predict_and_finalize(VwExample* ex); 00145 00154 float32_t compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x); 00155 00168 float32_t compute_exact_norm_quad(float32_t* weights, VwFeature& page_feature, v_array<VwFeature> &offer_features, 00169 vw_size_t mask, float32_t g, float32_t& sum_abs_x); 00170 00176 virtual CVwEnvironment* get_env() 00177 { 00178 SG_REF(env); 00179 return env; 00180 } 00181 00187 virtual const char* get_name() const { return "VowpalWabbit"; } 00188 00189 private: 00195 virtual void init(CStreamingVwFeatures* feat = NULL); 00196 00201 virtual void set_learner(); 00202 00210 virtual float32_t inline_l1_predict(VwExample* &ex); 00211 00219 virtual float32_t inline_predict(VwExample* &ex); 00220 00228 virtual float32_t finalize_prediction(float32_t ret); 00229 00235 virtual void output_example(VwExample* &ex); 00236 00242 virtual void print_update(VwExample* &ex); 00243 00252 virtual void output_prediction(int32_t f, float32_t res, float32_t weight, v_array<char> tag); 00253 00259 void set_verbose(bool verbose); 00260 00261 protected: 00263 CStreamingVwFeatures* features; 00264 00266 CVwEnvironment* env; 00267 00269 CVwLearner* learner; 00270 00272 CVwRegressor* reg; 00273 00274 private: 00276 bool quiet; 00277 00279 bool no_training; 00280 00282 float32_t dump_interval; 00284 float32_t sum_loss_since_last_dump; 00286 float64_t old_weighted_examples; 00287 00289 char* reg_name; 00291 bool reg_dump_text; 00292 00294 bool save_predictions; 00296 int32_t prediction_fd; 00297 }; 00298 00299 } 00300 #endif // _VOWPALWABBIT_H__