presage
0.8.7
|
00001 00002 /****************************************************** 00003 * Presage, an extensible predictive text entry system 00004 * --------------------------------------------------- 00005 * 00006 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk> 00007 00008 This program is free software; you can redistribute it and/or modify 00009 it under the terms of the GNU General Public License as published by 00010 the Free Software Foundation; either version 2 of the License, or 00011 (at your option) any later version. 00012 00013 This program is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 GNU General Public License for more details. 00017 00018 You should have received a copy of the GNU General Public License along 00019 with this program; if not, write to the Free Software Foundation, Inc., 00020 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 00021 * 00022 **********(*)*/ 00023 00024 00025 #include "ARPAPredictor.h" 00026 00027 00028 #include <sstream> 00029 #include <algorithm> 00030 #include <cmath> 00031 00032 const char* ARPAPredictor::LOGGER = "Presage.Predictors.ARPAPredictor.LOGGER"; 00033 const char* ARPAPredictor::ARPAFILENAME = "Presage.Predictors.ARPAPredictor.ARPAFILENAME"; 00034 const char* ARPAPredictor::VOCABFILENAME = "Presage.Predictors.ARPAPredictor.VOCABFILENAME"; 00035 const char* ARPAPredictor::TIMEOUT = "Presage.Predictors.ARPAPredictor.TIMEOUT"; 00036 00037 00038 #define OOV "<UNK>" 00039 00040 00041 00042 ARPAPredictor::ARPAPredictor(Configuration* config, ContextTracker* ct) 00043 : Predictor(config, 00044 ct, 00045 "ARPAPredictor", 00046 "ARPAPredictor, a predictor relying on an ARPA language model", 00047 "ARPAPredictor, long description." 00048 ), 00049 dispatcher (this) 00050 { 00051 // build notification dispatch map 00052 dispatcher.map (config->find (LOGGER), & ARPAPredictor::set_logger); 00053 dispatcher.map (config->find (VOCABFILENAME), & ARPAPredictor::set_vocab_filename); 00054 dispatcher.map (config->find (ARPAFILENAME), & ARPAPredictor::set_arpa_filename); 00055 dispatcher.map (config->find (TIMEOUT), & ARPAPredictor::set_timeout); 00056 00057 loadVocabulary(); 00058 createARPATable(); 00059 } 00060 00061 void ARPAPredictor::set_vocab_filename (const std::string& value) 00062 { 00063 logger << INFO << "VOCABFILENAME: " << value << endl; 00064 vocabFilename = value; 00065 } 00066 00067 void ARPAPredictor::set_arpa_filename (const std::string& value) 00068 { 00069 logger << INFO << "ARPAFILENAME: " << value << endl; 00070 arpaFilename = value; 00071 } 00072 00073 void ARPAPredictor::set_timeout (const std::string& value) 00074 { 00075 logger << INFO << "TIMEOUT: " << value << endl; 00076 timeout = atoi(value.c_str()); 00077 } 00078 00079 void ARPAPredictor::loadVocabulary() 00080 { 00081 std::ifstream vocabFile; 00082 vocabFile.open(vocabFilename.c_str()); 00083 if(!vocabFile) 00084 logger << ERROR << "Error opening vocabulary file: " << vocabFilename << endl; 00085 00086 assert(vocabFile); 00087 std::string row; 00088 int code = 0; 00089 while(std::getline(vocabFile,row)) 00090 { 00091 if(row[0]=='#') 00092 continue; 00093 00094 vocabCode[row]=code; 00095 vocabDecode[code]=row; 00096 00097 //logger << DEBUG << "["<<row<<"] -> "<< code<<endl; 00098 00099 code++; 00100 } 00101 00102 logger << DEBUG << "Loaded "<<code<<" words from vocabulary" <<endl; 00103 00104 } 00105 00106 void ARPAPredictor::createARPATable() 00107 { 00108 std::ifstream arpaFile; 00109 arpaFile.open(arpaFilename.c_str()); 00110 00111 if(!arpaFile) 00112 logger << ERROR << "Error opening ARPA model file: " << arpaFilename << endl; 00113 00114 assert(arpaFile); 00115 std::string row; 00116 00117 int currOrder = 0; 00118 00119 unigramCount = 0; 00120 bigramCount = 0; 00121 trigramCount = 0; 00122 00123 int lineNum =0; 00124 bool startData = false; 00125 00126 while(std::getline(arpaFile,row)) 00127 { 00128 lineNum++; 00129 if(row.empty()) 00130 continue; 00131 00132 if(row == "\\end\\") 00133 break; 00134 00135 if(row == "\\data\\") 00136 { 00137 startData = true; 00138 continue; 00139 } 00140 00141 00142 if( startData == true && currOrder == 0) 00143 { 00144 if( row.find("ngram 1")==0 ) 00145 { 00146 unigramTot = atoi(row.substr(8).c_str()); 00147 logger << DEBUG << "tot unigram = "<<unigramTot<<endl; 00148 continue; 00149 } 00150 00151 if( row.find("ngram 2")==0) 00152 { 00153 bigramTot = atoi(row.substr(8).c_str()); 00154 logger << DEBUG << "tot bigram = "<<bigramTot<<endl; 00155 continue; 00156 } 00157 00158 if( row.find("ngram 3")==0) 00159 { 00160 trigramTot = atoi(row.substr(8).c_str()); 00161 logger << DEBUG << "tot trigram = "<<trigramTot<<endl; 00162 continue; 00163 } 00164 } 00165 00166 if( row == "\\1-grams:" && startData) 00167 { 00168 currOrder = 1; 00169 std::cerr << std::endl << "ARPA loading unigrams:" << std::endl; 00170 unigramProg = new ProgressBar<char>(std::cerr); 00171 continue; 00172 } 00173 00174 if( row == "\\2-grams:" && startData) 00175 { 00176 currOrder = 2; 00177 std::cerr << std::endl << std::endl << "ARPA loading bigrams:" << std::endl; 00178 bigramProg = new ProgressBar<char>(std::cerr); 00179 continue; 00180 } 00181 00182 if( row == "\\3-grams:" && startData) 00183 { 00184 currOrder = 3; 00185 std::cerr << std::endl << std::endl << "ARPA loading trigrams:" << std::endl; 00186 trigramProg = new ProgressBar<char>(std::cerr); 00187 continue; 00188 } 00189 00190 if(currOrder == 0) 00191 continue; 00192 00193 switch(currOrder) 00194 { 00195 case 1: addUnigram(row); 00196 break; 00197 00198 case 2: addBigram(row); 00199 break; 00200 00201 case 3: addTrigram(row); 00202 break; 00203 } 00204 00205 } 00206 00207 std::cerr << std::endl << std::endl; 00208 00209 logger << DEBUG << "loaded unigrams: "<< unigramCount << endl; 00210 logger << DEBUG << "loaded bigrams: " << bigramCount << endl; 00211 logger << DEBUG << "loaded trigrams: "<< trigramCount << endl; 00212 } 00213 00214 void ARPAPredictor::addUnigram(std::string row) 00215 { 00216 std::stringstream str(row); 00217 float logProb = 0; 00218 float logAlfa = 0; 00219 std::string wd1Str; 00220 00221 str >> logProb; 00222 str >> wd1Str; 00223 str >> logAlfa; 00224 00225 00226 if(wd1Str != OOV ) 00227 { 00228 int wd1 = vocabCode[wd1Str]; 00229 00230 unigramMap[wd1]= ARPAData(logProb,logAlfa); 00231 00232 //logger << DEBUG << "adding unigram ["<<wd1Str<< "] -> "<<logProb<<" "<<logAlfa<<endl; 00233 } 00234 00235 00236 unigramCount++; 00237 00238 unigramProg->update((float)unigramCount/(float)unigramTot); 00239 } 00240 00241 void ARPAPredictor::addBigram(std::string row) 00242 { 00243 std::stringstream str(row); 00244 float logProb = 0; 00245 float logAlfa = 0; 00246 std::string wd1Str; 00247 std::string wd2Str; 00248 00249 str >> logProb; 00250 str >> wd1Str; 00251 str >> wd2Str; 00252 str >> logAlfa; 00253 00254 if(wd1Str != OOV && wd2Str != OOV) 00255 { 00256 int wd1 = vocabCode[wd1Str]; 00257 int wd2 = vocabCode[wd2Str]; 00258 00259 bigramMap[BigramKey(wd1,wd2)]=ARPAData(logProb,logAlfa); 00260 00261 //logger << DEBUG << "adding bigram ["<<wd1Str<< "] ["<<wd2Str<< "] -> "<<logProb<<" "<<logAlfa<<endl; 00262 } 00263 00264 bigramCount++; 00265 bigramProg->update((float)bigramCount/(float)bigramTot); 00266 } 00267 00268 void ARPAPredictor::addTrigram(std::string row) 00269 { 00270 std::stringstream str(row); 00271 float logProb = 0; 00272 00273 std::string wd1Str; 00274 std::string wd2Str; 00275 std::string wd3Str; 00276 00277 str >> logProb; 00278 str >> wd1Str; 00279 str >> wd2Str; 00280 str >> wd3Str; 00281 00282 if(wd1Str != OOV && wd2Str != OOV && wd3Str != OOV) 00283 { 00284 int wd1 = vocabCode[wd1Str]; 00285 int wd2 = vocabCode[wd2Str]; 00286 int wd3 = vocabCode[wd3Str]; 00287 00288 trigramMap[TrigramKey(wd1,wd2,wd3)]=logProb; 00289 //logger << DEBUG << "adding trigram ["<<wd1Str<< "] ["<<wd2Str<< "] ["<<wd3Str<< "] -> "<<logProb <<endl; 00290 00291 } 00292 00293 trigramCount++; 00294 trigramProg->update((float)trigramCount/(float)trigramTot); 00295 } 00296 00297 00298 ARPAPredictor::~ARPAPredictor() 00299 { 00300 delete unigramProg; 00301 delete bigramProg; 00302 delete trigramProg; 00303 } 00304 00305 bool ARPAPredictor::matchesPrefixAndFilter(std::string word, std::string prefix, const char** filter ) const 00306 { 00307 if(filter == 0) 00308 return word.find(prefix)==0; 00309 00310 for(int j = 0; filter[j] != 0; j++) 00311 { 00312 std::string pattern = prefix+std::string(filter[j]); 00313 if(word.find(pattern)==0) 00314 return true; 00315 } 00316 00317 return false; 00318 } 00319 00320 Prediction ARPAPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const 00321 { 00322 logger << DEBUG << "predict()" << endl; 00323 Prediction prediction; 00324 00325 int cardinality = 3; 00326 std::vector<std::string> tokens(cardinality); 00327 00328 std::string prefix = Utility::strtolower(contextTracker->getToken(0)); 00329 std::string wd2Str = Utility::strtolower(contextTracker->getToken(1)); 00330 std::string wd1Str = Utility::strtolower(contextTracker->getToken(2)); 00331 00332 std::multimap< float, std::string, cmp > result; 00333 00334 logger << DEBUG << "["<<wd1Str<<"]"<<" ["<<wd2Str<<"] "<<"["<<prefix<<"]"<<endl; 00335 00336 //search for the past tokens in the vocabulary 00337 std::map<std::string,int>::const_iterator wd1It,wd2It; 00338 wd1It = vocabCode.find(wd1Str); 00339 wd2It = vocabCode.find(wd2Str); 00340 00346 //we have two valid past tokens available 00347 if(wd1It!=vocabCode.end() && wd2It!=vocabCode.end()) 00348 { 00349 //iterate over all vocab words 00350 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++) 00351 { 00352 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set 00353 if(matchesPrefixAndFilter(it->second,prefix,filter)) 00354 { 00355 std::pair<const float,std::string> p (computeTrigramBackoff(wd1It->second,wd2It->second,it->first), 00356 it->second); 00357 result.insert(p); 00358 } 00359 } 00360 } 00361 00362 //we have one valid past token available 00363 else if(wd2It!=vocabCode.end()) 00364 { 00365 //iterate over all vocab words 00366 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++) 00367 { 00368 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set 00369 if(matchesPrefixAndFilter(it->second,prefix,filter)) 00370 { 00371 std::pair<const float,std::string> p(computeBigramBackoff(wd2It->second,it->first), 00372 it->second); 00373 result.insert(p); 00374 } 00375 } 00376 } 00377 00378 //we have no valid past token available 00379 else 00380 { 00381 //iterate over all vocab words 00382 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++) 00383 { 00384 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set 00385 if(matchesPrefixAndFilter(it->second,prefix,filter)) 00386 { 00387 std::pair<const float,std::string> p (unigramMap.find(it->first)->second.logProb, 00388 it->second); 00389 result.insert(p); 00390 } 00391 } 00392 } 00393 00394 00395 size_t numSuggestions = 0; 00396 for(std::multimap< float, std::string, cmp >::const_iterator it = result.begin(); 00397 it != result.end() && numSuggestions < max_partial_prediction_size; 00398 it++) 00399 { 00400 prediction.addSuggestion(Suggestion(it->second,exp(it->first))); 00401 numSuggestions++; 00402 } 00403 00404 return prediction; 00405 } 00409 float ARPAPredictor::computeTrigramBackoff(int wd1,int wd2,int wd3) const 00410 { 00411 logger << DEBUG << "computing P( ["<<vocabDecode.find(wd3)->second<< "] | ["<<vocabDecode.find(wd1)->second<<"] ["<<vocabDecode.find(wd2)->second<<"] )"<<endl; 00412 00413 //trigram exist 00414 std::map<TrigramKey,float>::const_iterator trigramIt =trigramMap.find(TrigramKey(wd1,wd2,wd3)); 00415 if(trigramIt!=trigramMap.end()) 00416 { 00417 logger << DEBUG << "trigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] ["<<vocabDecode.find(wd3)->second<< "] exists" <<endl; 00418 logger << DEBUG << "returning "<<trigramIt->second <<endl; 00419 return trigramIt->second; 00420 } 00421 00422 //bigram exist 00423 std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2)); 00424 if(bigramIt!=bigramMap.end()) 00425 { 00426 logger << DEBUG << "bigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] exists" <<endl; 00427 float prob = bigramIt->second.logAlfa + computeBigramBackoff(wd2,wd3); 00428 logger << DEBUG << "returning "<<prob<<endl; 00429 return prob; 00430 } 00431 00432 //else 00433 logger << DEBUG << "no bigram w1,w2 exist" <<endl; 00434 float prob = computeBigramBackoff(wd2,wd3); 00435 logger << DEBUG << "returning "<<prob<<endl; 00436 return prob; 00437 00438 } 00439 00443 float ARPAPredictor::computeBigramBackoff(int wd1, int wd2) const 00444 { 00445 //bigram exist 00446 std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2)); 00447 if(bigramIt!=bigramMap.end()) 00448 return bigramIt->second.logProb; 00449 00450 //else 00451 return unigramMap.find(wd1)->second.logAlfa +unigramMap.find(wd2)->second.logProb; 00452 00453 } 00454 00455 void ARPAPredictor::learn(const std::vector<std::string>& change) 00456 { 00457 logger << DEBUG << "learn() method called" << endl; 00458 logger << DEBUG << "learn() method exited" << endl; 00459 } 00460 00461 void ARPAPredictor::update (const Observable* var) 00462 { 00463 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl; 00464 dispatcher.dispatch (var); 00465 }