SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <shogun/ui/GUIHMM.h> 00013 #include <shogun/ui/SGInterface.h> 00014 00015 #include <shogun/lib/config.h> 00016 #include <shogun/lib/common.h> 00017 #include <shogun/features/StringFeatures.h> 00018 #include <shogun/features/Labels.h> 00019 00020 #include <unistd.h> 00021 00022 using namespace shogun; 00023 00024 CGUIHMM::CGUIHMM(CSGInterface* ui_) 00025 : CSGObject(), ui(ui_) 00026 { 00027 working=NULL; 00028 00029 pos=NULL; 00030 neg=NULL; 00031 test=NULL; 00032 00033 PSEUDO=1e-10; 00034 M=4; 00035 } 00036 00037 CGUIHMM::~CGUIHMM() 00038 { 00039 SG_UNREF(working); 00040 } 00041 00042 bool CGUIHMM::new_hmm(int32_t n, int32_t m) 00043 { 00044 SG_UNREF(working); 00045 working=new CHMM(n, m, NULL, PSEUDO); 00046 M=m; 00047 return true; 00048 } 00049 00050 bool CGUIHMM::baum_welch_train() 00051 { 00052 if (!working) 00053 SG_ERROR("Create HMM first.\n"); 00054 00055 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00056 if (!trainfeatures) 00057 SG_ERROR("Assign train features first.\n"); 00058 if (trainfeatures->get_feature_type()!=F_WORD || 00059 trainfeatures->get_feature_class()!=C_STRING) 00060 SG_ERROR("Features must be STRING of type WORD.\n"); 00061 00062 CStringFeatures<uint16_t>* sf=(CStringFeatures<uint16_t>*) trainfeatures; 00063 SG_DEBUG("Stringfeatures have %ld orig_symbols %ld symbols %d order %ld max_symbols\n", (int64_t) sf->get_original_num_symbols(), (int64_t) sf->get_num_symbols(), sf->get_order(), (int64_t) sf->get_max_num_symbols()); 00064 00065 working->set_observations(sf); 00066 00067 return working->baum_welch_viterbi_train(BW_NORMAL); 00068 } 00069 00070 00071 bool CGUIHMM::baum_welch_trans_train() 00072 { 00073 if (!working) 00074 SG_ERROR("Create HMM first.\n"); 00075 00076 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00077 if (!trainfeatures) 00078 SG_ERROR("Assign train features first.\n"); 00079 if (trainfeatures->get_feature_type()!=F_WORD || 00080 trainfeatures->get_feature_class()!=C_STRING) 00081 SG_ERROR("Features must be STRING of type WORD.\n"); 00082 00083 working->set_observations((CStringFeatures<uint16_t>*) trainfeatures); 00084 00085 return working->baum_welch_viterbi_train(BW_TRANS); 00086 } 00087 00088 00089 bool CGUIHMM::baum_welch_train_defined() 00090 { 00091 if (!working) 00092 SG_ERROR("Create HMM first.\n"); 00093 if (!working->get_observations()) 00094 SG_ERROR("Assign observation first.\n"); 00095 00096 return working->baum_welch_viterbi_train(BW_DEFINED); 00097 } 00098 00099 bool CGUIHMM::viterbi_train() 00100 { 00101 if (!working) 00102 SG_ERROR("Create HMM first.\n"); 00103 if (!working->get_observations()) 00104 SG_ERROR("Assign observation first.\n"); 00105 00106 return working->baum_welch_viterbi_train(VIT_NORMAL); 00107 } 00108 00109 bool CGUIHMM::viterbi_train_defined() 00110 { 00111 if (!working) 00112 SG_ERROR("Create HMM first.\n"); 00113 if (!working->get_observations()) 00114 SG_ERROR("Assign observation first.\n"); 00115 00116 return working->baum_welch_viterbi_train(VIT_DEFINED); 00117 } 00118 00119 bool CGUIHMM::linear_train(char align) 00120 { 00121 if (!working) 00122 SG_ERROR("Create HMM first.\n"); 00123 00124 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00125 if (!trainfeatures) 00126 SG_ERROR("Assign train features first.\n"); 00127 if (trainfeatures->get_feature_type()!=F_WORD || 00128 trainfeatures->get_feature_class()!=C_STRING) 00129 SG_ERROR("Features must be STRING of type WORD.\n"); 00130 00131 working->set_observations((CStringFeatures<uint16_t>*) ui-> 00132 ui_features->get_train_features()); 00133 00134 bool right_align=false; 00135 if (align=='r') 00136 { 00137 SG_INFO("Using alignment to right.\n"); 00138 right_align=true; 00139 } 00140 else 00141 SG_INFO("Using alignment to left.\n"); 00142 working->linear_train(right_align); 00143 00144 return true; 00145 } 00146 00147 CLabels* CGUIHMM::classify(CLabels* result) 00148 { 00149 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00150 ui_features->get_test_features(); 00151 ASSERT(obs); 00152 int32_t num_vec=obs->get_num_vectors(); 00153 00154 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00155 //CStringFeatures<uint16_t>* old_neg=neg->get_observations(); 00156 00157 pos->set_observations(obs); 00158 neg->set_observations(obs); 00159 00160 if (!result) 00161 result=new CLabels(num_vec); 00162 00163 for (int32_t i=0; i<num_vec; i++) 00164 result->set_label(i, pos->model_probability(i) - neg->model_probability(i)); 00165 00166 //pos->set_observations(old_pos); 00167 //neg->set_observations(old_neg); 00168 return result; 00169 } 00170 00171 float64_t CGUIHMM::classify_example(int32_t idx) 00172 { 00173 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00174 ui_features->get_test_features(); 00175 ASSERT(obs); 00176 00177 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00178 //CStringFeatures<uint16_t>* old_neg=neg->get_observations(); 00179 00180 pos->set_observations(obs); 00181 neg->set_observations(obs); 00182 00183 float64_t result=pos->model_probability(idx) - neg->model_probability(idx); 00184 //pos->set_observations(old_pos); 00185 //neg->set_observations(old_neg); 00186 return result; 00187 } 00188 00189 CLabels* CGUIHMM::one_class_classify(CLabels* result) 00190 { 00191 ASSERT(working); 00192 00193 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00194 ui_features->get_test_features(); 00195 ASSERT(obs); 00196 int32_t num_vec=obs->get_num_vectors(); 00197 00198 //CStringFeatures<uint16_t>* old_pos=working->get_observations(); 00199 working->set_observations(obs); 00200 00201 if (!result) 00202 result=new CLabels(num_vec); 00203 00204 for (int32_t i=0; i<num_vec; i++) 00205 result->set_label(i, working->model_probability(i)); 00206 00207 //working->set_observations(old_pos); 00208 return result; 00209 } 00210 00211 CLabels* CGUIHMM::linear_one_class_classify(CLabels* result) 00212 { 00213 ASSERT(working); 00214 00215 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00216 ui_features->get_test_features(); 00217 ASSERT(obs); 00218 int32_t num_vec=obs->get_num_vectors(); 00219 00220 //CStringFeatures<uint16_t>* old_pos=working->get_observations(); 00221 working->set_observations(obs); 00222 00223 if (!result) 00224 result=new CLabels(num_vec); 00225 00226 for (int32_t i=0; i<num_vec; i++) 00227 result->set_label(i, working->linear_model_probability(i)); 00228 00229 //working->set_observations(old_pos); 00230 return result; 00231 } 00232 00233 00234 float64_t CGUIHMM::one_class_classify_example(int32_t idx) 00235 { 00236 ASSERT(working); 00237 00238 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00239 ui_features->get_test_features(); 00240 ASSERT(obs); 00241 00242 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00243 00244 pos->set_observations(obs); 00245 neg->set_observations(obs); 00246 00247 float64_t result=working->model_probability(idx); 00248 //working->set_observations(old_pos); 00249 return result; 00250 } 00251 00252 bool CGUIHMM::append_model(char* filename, int32_t base1, int32_t base2) 00253 { 00254 if (!working) 00255 SG_ERROR("Create HMM first.\n"); 00256 if (!filename) 00257 SG_ERROR("Invalid filename.\n"); 00258 00259 FILE* model_file=fopen(filename, "r"); 00260 if (!model_file) 00261 SG_ERROR("Opening file %s failed.\n", filename); 00262 00263 CHMM* h=new CHMM(model_file,PSEUDO); 00264 if (!h || !h->get_status()) 00265 { 00266 SG_UNREF(h); 00267 fclose(model_file); 00268 SG_ERROR("Reading file %s failed.\n", filename); 00269 } 00270 00271 fclose(model_file); 00272 SG_INFO("File %s successfully read.\n", filename); 00273 00274 SG_DEBUG("h %d , M: %d\n", h, h->get_M()); 00275 if (base1!=-1 && base2!=-1) 00276 { 00277 float64_t* cur_o=SG_MALLOC(float64_t, h->get_M()); 00278 float64_t* app_o=SG_MALLOC(float64_t, h->get_M()); 00279 00280 for (int32_t i=0; i<h->get_M(); i++) 00281 { 00282 if (i==base1) 00283 cur_o[i]=0; 00284 else 00285 cur_o[i]=-1000; 00286 00287 if (i==base2) 00288 app_o[i]=0; 00289 else 00290 app_o[i]=-1000; 00291 } 00292 00293 working->append_model(h, cur_o, app_o); 00294 00295 SG_FREE(cur_o); 00296 SG_FREE(app_o); 00297 } 00298 else 00299 working->append_model(h); 00300 00301 SG_UNREF(h); 00302 SG_INFO("New model has %i states.\n", working->get_N()); 00303 return true; 00304 } 00305 00306 bool CGUIHMM::add_states(int32_t num_states, float64_t value) 00307 { 00308 if (!working) 00309 SG_ERROR("Create HMM first.\n"); 00310 00311 working->add_states(num_states, value); 00312 SG_INFO("New model has %i states, value %f.\n", working->get_N(), value); 00313 return true; 00314 } 00315 00316 bool CGUIHMM::set_pseudo(float64_t pseudo) 00317 { 00318 PSEUDO=pseudo; 00319 SG_INFO("Current setting: pseudo=%e.\n", PSEUDO); 00320 return true; 00321 } 00322 00323 bool CGUIHMM::convergence_criteria(int32_t num_iterations, float64_t epsilon) 00324 { 00325 if (!working) 00326 SG_ERROR("Create HMM first.\n"); 00327 00328 working->set_iterations(num_iterations); 00329 working->set_epsilon(epsilon); 00330 00331 SG_INFO("Current HMM convergence criteria: iterations=%i, epsilon=%e\n", working->get_iterations(), working->get_epsilon()); 00332 return true; 00333 } 00334 00335 bool CGUIHMM::set_hmm_as(char* target) 00336 { 00337 if (!working) 00338 SG_ERROR("Create HMM first!\n"); 00339 00340 if (strncmp(target, "POS", 3)==0) 00341 { 00342 SG_UNREF(pos); 00343 pos=working; 00344 working=NULL; 00345 } 00346 else if (strncmp(target, "NEG", 3)==0) 00347 { 00348 SG_UNREF(neg); 00349 neg=working; 00350 working=NULL; 00351 } 00352 else if (strncmp(target, "TEST", 4)==0) 00353 { 00354 SG_UNREF(test); 00355 test=working; 00356 working=NULL; 00357 } 00358 else 00359 SG_ERROR("Target POS|NEG|TEST is missing.\n"); 00360 00361 return true; 00362 } 00363 00364 bool CGUIHMM::load(char* filename) 00365 { 00366 bool result=false; 00367 00368 FILE* model_file=fopen(filename, "r"); 00369 if (!model_file) 00370 SG_ERROR("Opening file %s failed.\n", filename); 00371 00372 SG_UNREF(working); 00373 working=new CHMM(model_file, PSEUDO); 00374 fclose(model_file); 00375 00376 if (working && working->get_status()) 00377 { 00378 SG_INFO("Loaded HMM successfully from file %s.\n", filename); 00379 result=true; 00380 } 00381 00382 M=working->get_M(); 00383 00384 return result; 00385 } 00386 00387 bool CGUIHMM::save(char* filename, bool is_binary) 00388 { 00389 bool result=false; 00390 00391 if (!working) 00392 SG_ERROR("Create HMM first.\n"); 00393 00394 FILE* file=fopen(filename, "w"); 00395 if (file) 00396 { 00397 if (is_binary) 00398 result=working->save_model_bin(file); 00399 else 00400 result=working->save_model(file); 00401 } 00402 00403 if (!file || !result) 00404 SG_ERROR("Writing to file %s failed!\n", filename); 00405 else 00406 SG_INFO("Successfully written model into %s!\n", filename); 00407 00408 if (file) 00409 fclose(file); 00410 00411 return result; 00412 } 00413 00414 bool CGUIHMM::load_definitions(char* filename, bool do_init) 00415 { 00416 if (!working) 00417 SG_ERROR("Create HMM first.\n"); 00418 00419 bool result=false; 00420 FILE* def_file=fopen(filename, "r"); 00421 if (!def_file) 00422 SG_ERROR("Opening file %s failed\n", filename); 00423 00424 if (working->load_definitions(def_file, true, do_init)) 00425 { 00426 SG_INFO("Definitions successfully read from %s.\n", filename); 00427 result=true; 00428 } 00429 else 00430 SG_ERROR("Couldn't load definitions form file %s.\n", filename); 00431 00432 fclose(def_file); 00433 return result; 00434 } 00435 00436 bool CGUIHMM::save_likelihood(char* filename, bool is_binary) 00437 { 00438 bool result=false; 00439 00440 if (!working) 00441 SG_ERROR("Create HMM first\n"); 00442 00443 FILE* file=fopen(filename, "w"); 00444 if (file) 00445 { 00447 //if (binary) 00448 // result=working->save_model_bin(file); 00449 //else 00450 00451 result=working->save_likelihood(file); 00452 } 00453 00454 if (!file || !result) 00455 SG_ERROR("Writing to file %s failed!\n", filename); 00456 else 00457 SG_INFO("Successfully written likelihoods into %s!\n", filename); 00458 00459 if (file) 00460 fclose(file); 00461 00462 return result; 00463 } 00464 00465 bool CGUIHMM::save_path(char* filename, bool is_binary) 00466 { 00467 bool result=false; 00468 if (!working) 00469 SG_ERROR("Create HMM first.\n"); 00470 00471 FILE* file=fopen(filename, "w"); 00472 if (file) 00473 { 00475 //if (binary) 00476 //_train()/ result=working->save_model_bin(file); 00477 //else 00478 CStringFeatures<uint16_t>* obs=(CStringFeatures<uint16_t>*) ui-> 00479 ui_features->get_test_features(); 00480 ASSERT(obs); 00481 working->set_observations(obs); 00482 00483 result=working->save_path(file); 00484 } 00485 00486 if (!file || !result) 00487 SG_ERROR("Writing to file %s failed!\n", filename); 00488 else 00489 SG_INFO("Successfully written path into %s!\n", filename); 00490 00491 if (file) 00492 fclose(file); 00493 00494 return result; 00495 } 00496 00497 bool CGUIHMM::chop(float64_t value) 00498 { 00499 if (!working) 00500 SG_ERROR("Create HMM first.\n"); 00501 00502 working->chop(value); 00503 return true; 00504 } 00505 00506 bool CGUIHMM::likelihood() 00507 { 00508 if (!working) 00509 SG_ERROR("Create HMM first!\n"); 00510 00511 working->output_model(false); 00512 return true; 00513 } 00514 00515 bool CGUIHMM::output_hmm() 00516 { 00517 if (!working) 00518 SG_ERROR("Create HMM first!\n"); 00519 00520 working->output_model(true); 00521 return true; 00522 } 00523 00524 bool CGUIHMM::output_hmm_defined() 00525 { 00526 if (!working) 00527 SG_ERROR("Create HMM first!\n"); 00528 00529 working->output_model_defined(true); 00530 return true; 00531 } 00532 00533 bool CGUIHMM::best_path(int32_t from, int32_t to) 00534 { 00535 // FIXME: from unused??? 00536 00537 if (!working) 00538 SG_ERROR("Create HMM first.\n"); 00539 00540 //get path 00541 working->best_path(0); 00542 00543 for (int32_t t=0; t<working->get_observations()->get_vector_length(0)-1 && t<to; t++) 00544 SG_PRINT("%d ", working->get_best_path_state(0, t)); 00545 SG_PRINT("\n"); 00546 00547 //for (t=0; t<p_observations->get_vector_length(0)-1 && t<to; t++) 00548 // SG_PRINT( "%d ", PATH(0)[t]); 00549 // 00550 return true; 00551 } 00552 00553 bool CGUIHMM::normalize(bool keep_dead_states) 00554 { 00555 if (!working) 00556 SG_ERROR("Create HMM first.\n"); 00557 00558 working->normalize(keep_dead_states); 00559 return true; 00560 } 00561 00562 bool CGUIHMM::relative_entropy(float64_t*& values, int32_t& len) 00563 { 00564 if (!pos || !neg) 00565 SG_ERROR("Set pos and neg HMM first!\n"); 00566 00567 int32_t pos_N=pos->get_N(); 00568 int32_t neg_N=neg->get_N(); 00569 int32_t pos_M=pos->get_M(); 00570 int32_t neg_M=neg->get_M(); 00571 if (pos_M!=neg_M || pos_N!=neg_N) 00572 SG_ERROR("Pos and neg HMM's differ in number of emissions or states.\n"); 00573 00574 float64_t* p=SG_MALLOC(float64_t, pos_M); 00575 float64_t* q=SG_MALLOC(float64_t, neg_M); 00576 00577 SG_FREE(values); 00578 values=SG_MALLOC(float64_t, pos_N); 00579 00580 for (int32_t i=0; i<pos_N; i++) 00581 { 00582 for (int32_t j=0; j<pos_M; j++) 00583 { 00584 p[j]=pos->get_b(i, j); 00585 q[j]=neg->get_b(i, j); 00586 } 00587 00588 values[i]=CMath::relative_entropy(p, q, pos_M); 00589 } 00590 SG_FREE(p); 00591 SG_FREE(q); 00592 00593 len=pos_N; 00594 return true; 00595 } 00596 00597 bool CGUIHMM::entropy(float64_t*& values, int32_t& len) 00598 { 00599 if (!working) 00600 SG_ERROR("Create HMM first!\n"); 00601 00602 int32_t n=working->get_N(); 00603 int32_t m=working->get_M(); 00604 float64_t* p=SG_MALLOC(float64_t, m); 00605 00606 SG_FREE(values); 00607 values=SG_MALLOC(float64_t, n); 00608 00609 for (int32_t i=0; i<n; i++) 00610 { 00611 for (int32_t j=0; j<m; j++) 00612 p[j]=working->get_b(i, j); 00613 00614 values[i]=CMath::entropy(p, m); 00615 } 00616 SG_FREE(p); 00617 00618 len=m; 00619 return true; 00620 } 00621 00622 bool CGUIHMM::permutation_entropy(int32_t width, int32_t seq_num) 00623 { 00624 if (!working) 00625 SG_ERROR("Create hmm first.\n"); 00626 00627 if (!working->get_observations()) 00628 SG_ERROR("Set observations first.\n"); 00629 00630 return working->permutation_entropy(width, seq_num); 00631 }