SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/kernel/StringKernel.h> 00017 #include <shogun/kernel/WeightedDegreeStringKernel.h> 00018 #include <shogun/lib/Trie.h> 00019 00020 namespace shogun 00021 { 00022 00023 class CSVM; 00024 00048 class CWeightedDegreePositionStringKernel: public CStringKernel<char> 00049 { 00050 public: 00052 CWeightedDegreePositionStringKernel(); 00053 00061 CWeightedDegreePositionStringKernel( 00062 int32_t size, int32_t degree, 00063 int32_t max_mismatch=0, int32_t mkl_stepsize=1); 00064 00075 CWeightedDegreePositionStringKernel( 00076 int32_t size, float64_t* weights, int32_t degree, 00077 int32_t max_mismatch, int32_t* shift, int32_t shift_len, 00078 int32_t mkl_stepsize=1); 00079 00086 CWeightedDegreePositionStringKernel( 00087 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree); 00088 00089 virtual ~CWeightedDegreePositionStringKernel(); 00090 00097 virtual bool init(CFeatures* l, CFeatures* r); 00098 00100 virtual void cleanup(); 00101 00106 virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; } 00107 00112 virtual const char* get_name() const { return "WeightedDegreePositionStringKernel"; } 00113 00121 inline virtual bool init_optimization( 00122 int32_t p_count, int32_t *IDX, float64_t * alphas) 00123 { 00124 return init_optimization(p_count, IDX, alphas, -1); 00125 } 00126 00138 virtual bool init_optimization( 00139 int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num, 00140 int32_t upto_tree=-1); 00141 00146 virtual bool delete_optimization(); 00147 00153 inline virtual float64_t compute_optimized(int32_t idx) 00154 { 00155 ASSERT(get_is_initialized()); 00156 ASSERT(alphabet); 00157 ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA); 00158 return compute_by_tree(idx); 00159 } 00160 00165 static void* compute_batch_helper(void* p); 00166 00177 virtual void compute_batch( 00178 int32_t num_vec, int32_t* vec_idx, float64_t* target, 00179 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, 00180 float64_t factor=1.0); 00181 00185 inline virtual void clear_normal() 00186 { 00187 if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes())) 00188 { 00189 tries.set_use_compact_terminal_nodes(false) ; 00190 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ; 00191 } 00192 00193 if (get_is_initialized()) 00194 { 00195 if (opt_type==SLOWBUTMEMEFFICIENT) 00196 tries.delete_trees(true); 00197 else if (opt_type==FASTBUTMEMHUNGRY) 00198 tries.delete_trees(false); // still buggy 00199 else 00200 SG_ERROR( "unknown optimization type\n"); 00201 00202 set_is_initialized(false); 00203 } 00204 } 00205 00211 inline virtual void add_to_normal(int32_t idx, float64_t weight) 00212 { 00213 add_example_to_tree(idx, weight); 00214 set_is_initialized(true); 00215 } 00216 00221 inline virtual int32_t get_num_subkernels() 00222 { 00223 if (position_weights!=NULL) 00224 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ; 00225 if (length==0) 00226 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize); 00227 return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ; 00228 } 00229 00235 inline void compute_by_subkernel( 00236 int32_t idx, float64_t * subkernel_contrib) 00237 { 00238 if (get_is_initialized()) 00239 { 00240 compute_by_tree(idx, subkernel_contrib); 00241 return ; 00242 } 00243 00244 SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ; 00245 } 00246 00252 inline const float64_t* get_subkernel_weights(int32_t& num_weights) 00253 { 00254 num_weights = get_num_subkernels() ; 00255 00256 SG_FREE(weights_buffer); 00257 weights_buffer = SG_MALLOC(float64_t, num_weights); 00258 00259 if (position_weights!=NULL) 00260 for (int32_t i=0; i<num_weights; i++) 00261 weights_buffer[i] = position_weights[i*mkl_stepsize] ; 00262 else 00263 for (int32_t i=0; i<num_weights; i++) 00264 weights_buffer[i] = weights[i*mkl_stepsize] ; 00265 00266 return weights_buffer ; 00267 } 00268 00274 virtual void set_subkernel_weights(SGVector<float64_t> w) 00275 { 00276 float64_t* weights2=w.vector; 00277 int32_t num_weights2=w.vlen; 00278 00279 int32_t num_weights = get_num_subkernels() ; 00280 if (num_weights!=num_weights2) 00281 SG_ERROR( "number of weights do not match\n") ; 00282 00283 if (position_weights!=NULL) 00284 for (int32_t i=0; i<num_weights; i++) 00285 for (int32_t j=0; j<mkl_stepsize; j++) 00286 { 00287 if (i*mkl_stepsize+j<seq_length) 00288 position_weights[i*mkl_stepsize+j] = weights2[i] ; 00289 } 00290 else if (length==0) 00291 { 00292 for (int32_t i=0; i<num_weights; i++) 00293 for (int32_t j=0; j<mkl_stepsize; j++) 00294 if (i*mkl_stepsize+j<get_degree()) 00295 weights[i*mkl_stepsize+j] = weights2[i] ; 00296 } 00297 else 00298 { 00299 for (int32_t i=0; i<num_weights; i++) 00300 for (int32_t j=0; j<mkl_stepsize; j++) 00301 if (i*mkl_stepsize+j<get_degree()*length) 00302 weights[i*mkl_stepsize+j] = weights2[i] ; 00303 } 00304 } 00305 00306 // other kernel tree operations 00312 float64_t* compute_abs_weights(int32_t & len); 00313 00318 bool is_tree_initialized() { return tree_initialized; } 00319 00324 inline int32_t get_max_mismatch() { return max_mismatch; } 00325 00330 inline int32_t get_degree() { return degree; } 00331 00337 inline float64_t *get_degree_weights(int32_t& d, int32_t& len) 00338 { 00339 d=degree; 00340 len=length; 00341 return weights; 00342 } 00343 00349 inline float64_t *get_weights(int32_t& num_weights) 00350 { 00351 if (position_weights!=NULL) 00352 { 00353 num_weights = seq_length ; 00354 return position_weights ; 00355 } 00356 if (length==0) 00357 num_weights = degree ; 00358 else 00359 num_weights = degree*length ; 00360 return weights; 00361 } 00362 00368 inline float64_t *get_position_weights(int32_t& len) 00369 { 00370 len=seq_length; 00371 return position_weights; 00372 } 00373 00378 void set_shifts(SGVector<int32_t> shifts); 00379 00384 bool set_weights(SGMatrix<float64_t> new_weights); 00385 00390 virtual bool set_wd_weights(); 00391 00397 virtual void set_position_weights(SGVector<float64_t> pws); 00398 00406 bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num); 00407 00415 bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num); 00416 00421 bool init_block_weights(); 00422 00427 bool init_block_weights_from_wd(); 00428 00433 bool init_block_weights_from_wd_external(); 00434 00439 bool init_block_weights_const(); 00440 00445 bool init_block_weights_linear(); 00446 00451 bool init_block_weights_sqpoly(); 00452 00457 bool init_block_weights_cubicpoly(); 00458 00463 bool init_block_weights_exp(); 00464 00469 bool init_block_weights_log(); 00470 00475 bool delete_position_weights() 00476 { 00477 SG_FREE(position_weights); 00478 position_weights=NULL; 00479 return true; 00480 } 00481 00486 bool delete_position_weights_lhs() 00487 { 00488 SG_FREE(position_weights_lhs); 00489 position_weights_lhs=NULL; 00490 return true; 00491 } 00492 00497 bool delete_position_weights_rhs() 00498 { 00499 SG_FREE(position_weights_rhs); 00500 position_weights_rhs=NULL; 00501 return true; 00502 } 00503 00509 virtual float64_t compute_by_tree(int32_t idx); 00510 00516 virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib); 00517 00530 float64_t* compute_scoring( 00531 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00532 float64_t* target, int32_t num_suppvec, int32_t* IDX, 00533 float64_t* weights); 00534 00543 char* compute_consensus( 00544 int32_t &num_feat, int32_t num_suppvec, int32_t* IDX, 00545 float64_t* alphas); 00546 00558 float64_t* extract_w( 00559 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00560 float64_t* w_result, int32_t num_suppvec, int32_t* IDX, 00561 float64_t* alphas); 00562 00575 float64_t* compute_POIM( 00576 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00577 float64_t* poim_result, int32_t num_suppvec, int32_t* IDX, 00578 float64_t* alphas, float64_t* distrib); 00579 00586 void prepare_POIM2( 00587 float64_t* distrib, int32_t num_sym, int32_t num_feat); 00588 00595 void compute_POIM2(int32_t max_degree, CSVM* svm); 00596 00602 void get_POIM2(float64_t** poim, int32_t* result_len); 00603 00605 void cleanup_POIM2(); 00606 00607 protected: 00609 void create_empty_tries(); 00610 00616 virtual void add_example_to_tree( 00617 int32_t idx, float64_t weight); 00618 00625 void add_example_to_single_tree( 00626 int32_t idx, float64_t weight, int32_t tree_num); 00627 00636 virtual float64_t compute(int32_t idx_a, int32_t idx_b); 00637 00646 float64_t compute_with_mismatch( 00647 char* avec, int32_t alen, char* bvec, int32_t blen); 00648 00657 float64_t compute_without_mismatch( 00658 char* avec, int32_t alen, char* bvec, int32_t blen); 00659 00668 float64_t compute_without_mismatch_matrix( 00669 char* avec, int32_t alen, char* bvec, int32_t blen); 00670 00681 float64_t compute_without_mismatch_position_weights( 00682 char* avec, float64_t *posweights_lhs, int32_t alen, 00683 char* bvec, float64_t *posweights_rhs, int32_t blen); 00684 00686 virtual void remove_lhs(); 00687 00696 virtual void load_serializable_post() throw (ShogunException); 00697 00698 private: 00701 void init(); 00702 00703 protected: 00705 float64_t* weights; 00707 int32_t weights_degree; 00709 int32_t weights_length; 00710 00712 float64_t* position_weights; 00714 int32_t position_weights_len; 00715 00717 float64_t* position_weights_lhs; 00719 int32_t position_weights_lhs_len; 00721 float64_t* position_weights_rhs; 00723 int32_t position_weights_rhs_len; 00725 bool* position_mask; 00726 00728 float64_t* weights_buffer; 00730 int32_t mkl_stepsize; 00731 00733 int32_t degree; 00735 int32_t length; 00736 00738 int32_t max_mismatch; 00740 int32_t seq_length; 00741 00743 int32_t *shift; 00745 int32_t shift_len; 00747 int32_t max_shift; 00748 00750 bool block_computation; 00751 00753 float64_t* block_weights; 00755 EWDKernType type; 00757 int32_t which_degree; 00758 00760 CTrie<DNATrie> tries; 00762 CTrie<POIMTrie> poim_tries; 00763 00765 bool tree_initialized; 00767 bool use_poim_tries; 00768 00770 float64_t* m_poim_distrib; 00772 float64_t* m_poim; 00773 00775 int32_t m_poim_num_sym; 00777 int32_t m_poim_num_feat; 00779 int32_t m_poim_result_len; 00780 00782 CAlphabet* alphabet; 00783 }; 00784 } 00785 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */