presage  0.8.7
ARPAPredictor.cpp
Go to the documentation of this file.
00001 
00002 /******************************************************
00003  *  Presage, an extensible predictive text entry system
00004  *  ---------------------------------------------------
00005  *
00006  *  Copyright (C) 2008  Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
00007 
00008     This program is free software; you can redistribute it and/or modify
00009     it under the terms of the GNU General Public License as published by
00010     the Free Software Foundation; either version 2 of the License, or
00011     (at your option) any later version.
00012 
00013     This program is distributed in the hope that it will be useful,
00014     but WITHOUT ANY WARRANTY; without even the implied warranty of
00015     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016     GNU General Public License for more details.
00017 
00018     You should have received a copy of the GNU General Public License along
00019     with this program; if not, write to the Free Software Foundation, Inc.,
00020     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
00021                                                                              *
00022                                                                 **********(*)*/
00023 
00024 
00025 #include "ARPAPredictor.h"
00026 
00027 
00028 #include <sstream>
00029 #include <algorithm>
00030 #include <cmath>
00031 
00032 const char* ARPAPredictor::LOGGER     = "Presage.Predictors.ARPAPredictor.LOGGER";
00033 const char* ARPAPredictor::ARPAFILENAME = "Presage.Predictors.ARPAPredictor.ARPAFILENAME";
00034 const char* ARPAPredictor::VOCABFILENAME = "Presage.Predictors.ARPAPredictor.VOCABFILENAME";
00035 const char* ARPAPredictor::TIMEOUT = "Presage.Predictors.ARPAPredictor.TIMEOUT";
00036 
00037 
00038 #define OOV "<UNK>"
00039 
00040 
00041 
00042 ARPAPredictor::ARPAPredictor(Configuration* config, ContextTracker* ct)
00043     : Predictor(config,
00044                 ct,
00045                 "ARPAPredictor",
00046                 "ARPAPredictor, a predictor relying on an ARPA language model",
00047                 "ARPAPredictor, long description."
00048         ),
00049       dispatcher (this)
00050 {
00051     // build notification dispatch map
00052     dispatcher.map (config->find (LOGGER), & ARPAPredictor::set_logger);
00053     dispatcher.map (config->find (VOCABFILENAME), & ARPAPredictor::set_vocab_filename);
00054     dispatcher.map (config->find (ARPAFILENAME), & ARPAPredictor::set_arpa_filename);
00055     dispatcher.map (config->find (TIMEOUT), & ARPAPredictor::set_timeout);
00056 
00057     loadVocabulary();
00058     createARPATable();
00059 }
00060 
00061 void ARPAPredictor::set_vocab_filename (const std::string& value)
00062 {
00063     logger << INFO << "VOCABFILENAME: " << value << endl;
00064     vocabFilename = value;
00065 }
00066 
00067 void ARPAPredictor::set_arpa_filename (const std::string& value)
00068 {
00069     logger << INFO << "ARPAFILENAME: " << value << endl;
00070     arpaFilename = value;
00071 }
00072 
00073 void ARPAPredictor::set_timeout (const std::string& value)
00074 {
00075     logger << INFO << "TIMEOUT: " << value << endl;
00076     timeout = atoi(value.c_str());
00077 }
00078 
00079 void ARPAPredictor::loadVocabulary()
00080 {
00081     std::ifstream vocabFile;
00082     vocabFile.open(vocabFilename.c_str());
00083     if(!vocabFile)
00084         logger << ERROR << "Error opening vocabulary file: " << vocabFilename << endl;
00085 
00086     assert(vocabFile);
00087     std::string row;
00088     int code = 0;
00089     while(std::getline(vocabFile,row))
00090     {
00091         if(row[0]=='#')
00092             continue;
00093 
00094         vocabCode[row]=code;
00095         vocabDecode[code]=row;
00096 
00097         //logger << DEBUG << "["<<row<<"] -> "<< code<<endl;
00098 
00099         code++;
00100     }
00101 
00102     logger << DEBUG << "Loaded "<<code<<" words from vocabulary" <<endl;
00103 
00104 }
00105 
00106 void ARPAPredictor::createARPATable()
00107 {
00108     std::ifstream arpaFile;
00109     arpaFile.open(arpaFilename.c_str());
00110 
00111     if(!arpaFile)
00112         logger << ERROR << "Error opening ARPA model file: " << arpaFilename << endl;
00113 
00114     assert(arpaFile);
00115     std::string row;
00116 
00117     int currOrder = 0;
00118 
00119     unigramCount = 0;
00120     bigramCount  = 0;
00121     trigramCount = 0;
00122 
00123     int lineNum =0;
00124     bool startData = false;
00125 
00126     while(std::getline(arpaFile,row))
00127     {
00128         lineNum++;
00129         if(row.empty())
00130             continue;
00131 
00132         if(row == "\\end\\")
00133             break;
00134 
00135         if(row == "\\data\\")
00136         {
00137             startData = true;
00138             continue;
00139         }
00140 
00141 
00142         if( startData == true && currOrder == 0)
00143         {
00144             if( row.find("ngram 1")==0 )
00145             {
00146                 unigramTot = atoi(row.substr(8).c_str());
00147                 logger << DEBUG << "tot unigram = "<<unigramTot<<endl;
00148                 continue;
00149             }
00150 
00151             if( row.find("ngram 2")==0)
00152             {
00153                 bigramTot = atoi(row.substr(8).c_str());
00154                 logger << DEBUG << "tot bigram = "<<bigramTot<<endl;
00155                 continue;
00156             }
00157 
00158             if( row.find("ngram 3")==0)
00159             {
00160                 trigramTot = atoi(row.substr(8).c_str());
00161                 logger << DEBUG << "tot trigram = "<<trigramTot<<endl;
00162                 continue;
00163             }
00164         }
00165 
00166         if( row == "\\1-grams:" && startData)
00167         {
00168             currOrder = 1;
00169             std::cerr << std::endl << "ARPA loading unigrams:" << std::endl;
00170             unigramProg = new ProgressBar<char>(std::cerr);
00171             continue;
00172         }
00173 
00174         if( row == "\\2-grams:" && startData)
00175         {
00176             currOrder = 2;
00177             std::cerr << std::endl << std::endl << "ARPA loading bigrams:" << std::endl;
00178             bigramProg = new ProgressBar<char>(std::cerr);
00179             continue;
00180         }
00181 
00182         if( row == "\\3-grams:" && startData)
00183         {
00184             currOrder = 3;
00185             std::cerr << std::endl << std::endl << "ARPA loading trigrams:" << std::endl;
00186             trigramProg = new ProgressBar<char>(std::cerr);
00187             continue;
00188         }
00189 
00190         if(currOrder == 0)
00191             continue;
00192 
00193         switch(currOrder)
00194         {
00195         case 1:   addUnigram(row);
00196             break;
00197 
00198         case 2:   addBigram(row);
00199             break;
00200 
00201         case 3:   addTrigram(row);
00202             break;
00203         }
00204 
00205     }
00206 
00207     std::cerr << std::endl << std::endl;
00208 
00209     logger << DEBUG << "loaded unigrams: "<< unigramCount << endl;
00210     logger << DEBUG << "loaded bigrams: " << bigramCount << endl;
00211     logger << DEBUG << "loaded trigrams: "<< trigramCount << endl;
00212 }
00213 
00214 void ARPAPredictor::addUnigram(std::string row)
00215 {
00216     std::stringstream str(row);
00217     float logProb = 0;
00218     float logAlfa = 0;
00219     std::string wd1Str;
00220 
00221     str >> logProb;
00222     str >> wd1Str;
00223     str >> logAlfa;
00224 
00225 
00226     if(wd1Str != OOV )
00227     {
00228         int wd1 = vocabCode[wd1Str];
00229 
00230         unigramMap[wd1]= ARPAData(logProb,logAlfa);
00231 
00232         //logger << DEBUG << "adding unigram ["<<wd1Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
00233     }
00234 
00235 
00236     unigramCount++;
00237 
00238     unigramProg->update((float)unigramCount/(float)unigramTot);
00239 }
00240 
00241 void ARPAPredictor::addBigram(std::string row)
00242 {
00243     std::stringstream str(row);
00244     float logProb = 0;
00245     float logAlfa = 0;
00246     std::string wd1Str;
00247     std::string wd2Str;
00248 
00249     str >> logProb;
00250     str >> wd1Str;
00251     str >> wd2Str;
00252     str >> logAlfa;
00253 
00254     if(wd1Str != OOV && wd2Str != OOV)
00255     {
00256         int wd1 = vocabCode[wd1Str];
00257         int wd2 = vocabCode[wd2Str];
00258 
00259         bigramMap[BigramKey(wd1,wd2)]=ARPAData(logProb,logAlfa);
00260 
00261         //logger << DEBUG << "adding bigram ["<<wd1Str<< "] ["<<wd2Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
00262     }
00263 
00264     bigramCount++;
00265     bigramProg->update((float)bigramCount/(float)bigramTot);
00266 }
00267 
00268 void ARPAPredictor::addTrigram(std::string row)
00269 {
00270     std::stringstream str(row);
00271     float logProb = 0;
00272 
00273     std::string wd1Str;
00274     std::string wd2Str;
00275     std::string wd3Str;
00276 
00277     str >> logProb;
00278     str >> wd1Str;
00279     str >> wd2Str;
00280     str >> wd3Str;
00281 
00282     if(wd1Str != OOV && wd2Str != OOV && wd3Str != OOV)
00283     {
00284         int wd1 = vocabCode[wd1Str];
00285         int wd2 = vocabCode[wd2Str];
00286         int wd3 = vocabCode[wd3Str];
00287 
00288         trigramMap[TrigramKey(wd1,wd2,wd3)]=logProb;
00289         //logger << DEBUG << "adding trigram ["<<wd1Str<< "] ["<<wd2Str<< "] ["<<wd3Str<< "] -> "<<logProb <<endl;
00290 
00291     }
00292 
00293     trigramCount++;
00294     trigramProg->update((float)trigramCount/(float)trigramTot);
00295 }
00296 
00297 
00298 ARPAPredictor::~ARPAPredictor()
00299 {
00300     delete unigramProg;
00301     delete bigramProg;
00302     delete trigramProg;
00303 }
00304 
00305 bool ARPAPredictor::matchesPrefixAndFilter(std::string word, std::string prefix, const char** filter ) const
00306 {
00307     if(filter == 0)
00308         return word.find(prefix)==0;
00309 
00310     for(int j = 0; filter[j] != 0; j++)
00311     {
00312         std::string pattern = prefix+std::string(filter[j]);
00313         if(word.find(pattern)==0)
00314             return true;
00315     }
00316 
00317     return false;
00318 }
00319 
00320 Prediction ARPAPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
00321 {
00322     logger << DEBUG << "predict()" << endl;
00323     Prediction prediction;
00324 
00325     int cardinality = 3;
00326     std::vector<std::string> tokens(cardinality);
00327 
00328     std::string prefix = Utility::strtolower(contextTracker->getToken(0));
00329     std::string wd2Str = Utility::strtolower(contextTracker->getToken(1));
00330     std::string wd1Str = Utility::strtolower(contextTracker->getToken(2));
00331 
00332     std::multimap< float, std::string, cmp > result;
00333 
00334     logger << DEBUG << "["<<wd1Str<<"]"<<" ["<<wd2Str<<"] "<<"["<<prefix<<"]"<<endl;
00335 
00336     //search for the past tokens in the vocabulary
00337     std::map<std::string,int>::const_iterator wd1It,wd2It;
00338     wd1It = vocabCode.find(wd1Str);
00339     wd2It = vocabCode.find(wd2Str);
00340 
00346     //we have two valid past tokens available
00347     if(wd1It!=vocabCode.end() && wd2It!=vocabCode.end())
00348     {
00349         //iterate over all vocab words
00350         for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
00351         {
00352             //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
00353             if(matchesPrefixAndFilter(it->second,prefix,filter))
00354             {
00355                 std::pair<const float,std::string> p (computeTrigramBackoff(wd1It->second,wd2It->second,it->first),
00356                                                       it->second);
00357                 result.insert(p);
00358             }
00359         }
00360     }
00361 
00362     //we have one valid past token available
00363     else if(wd2It!=vocabCode.end())
00364     {
00365         //iterate over all vocab words
00366         for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
00367         {
00368             //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
00369             if(matchesPrefixAndFilter(it->second,prefix,filter))
00370             {
00371                 std::pair<const float,std::string> p(computeBigramBackoff(wd2It->second,it->first),
00372                                                      it->second);
00373                 result.insert(p);
00374             }
00375         }
00376     }
00377 
00378     //we have no valid past token available
00379     else
00380     {
00381         //iterate over all vocab words
00382         for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
00383         {
00384             //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
00385             if(matchesPrefixAndFilter(it->second,prefix,filter))
00386             {
00387                 std::pair<const float,std::string> p (unigramMap.find(it->first)->second.logProb,
00388                                                       it->second);
00389                 result.insert(p);
00390             }
00391         }
00392     }
00393 
00394 
00395     size_t numSuggestions = 0;
00396     for(std::multimap< float, std::string, cmp >::const_iterator it = result.begin();
00397         it != result.end() && numSuggestions < max_partial_prediction_size;
00398         it++)
00399     {
00400         prediction.addSuggestion(Suggestion(it->second,exp(it->first)));
00401         numSuggestions++;
00402     }
00403 
00404     return prediction;
00405 }
00409 float ARPAPredictor::computeTrigramBackoff(int wd1,int wd2,int wd3) const
00410 {
00411     logger << DEBUG << "computing P( ["<<vocabDecode.find(wd3)->second<< "] | ["<<vocabDecode.find(wd1)->second<<"] ["<<vocabDecode.find(wd2)->second<<"] )"<<endl;
00412 
00413     //trigram exist
00414     std::map<TrigramKey,float>::const_iterator trigramIt =trigramMap.find(TrigramKey(wd1,wd2,wd3));
00415     if(trigramIt!=trigramMap.end())
00416     {
00417         logger << DEBUG << "trigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] ["<<vocabDecode.find(wd3)->second<< "] exists" <<endl;
00418         logger << DEBUG << "returning "<<trigramIt->second <<endl;
00419         return trigramIt->second;
00420     }
00421 
00422     //bigram exist
00423     std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
00424     if(bigramIt!=bigramMap.end())
00425     {
00426         logger << DEBUG << "bigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] exists" <<endl;
00427         float prob = bigramIt->second.logAlfa + computeBigramBackoff(wd2,wd3);
00428         logger << DEBUG << "returning "<<prob<<endl;
00429         return prob;
00430     }
00431 
00432     //else
00433     logger << DEBUG << "no bigram w1,w2 exist" <<endl;
00434     float prob = computeBigramBackoff(wd2,wd3);
00435     logger << DEBUG << "returning "<<prob<<endl;
00436     return prob;
00437 
00438 }
00439 
00443 float ARPAPredictor::computeBigramBackoff(int wd1, int wd2) const
00444 {
00445     //bigram exist
00446     std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
00447     if(bigramIt!=bigramMap.end())
00448         return bigramIt->second.logProb;
00449 
00450     //else
00451     return unigramMap.find(wd1)->second.logAlfa +unigramMap.find(wd2)->second.logProb;
00452 
00453 }
00454 
00455 void ARPAPredictor::learn(const std::vector<std::string>& change)
00456 {
00457     logger << DEBUG << "learn() method called" << endl;
00458     logger << DEBUG << "learn() method exited" << endl;
00459 }
00460 
00461 void ARPAPredictor::update (const Observable* var)
00462 {
00463     logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
00464     dispatcher.dispatch (var);
00465 }