presage  0.8.7
ARPAPredictor.h
Go to the documentation of this file.
00001 
00002 /******************************************************
00003  *  Presage, an extensible predictive text entry system
00004  *  ---------------------------------------------------
00005  *
00006  *  Copyright (C) 2008  Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
00007 
00008     This program is free software; you can redistribute it and/or modify
00009     it under the terms of the GNU General Public License as published by
00010     the Free Software Foundation; either version 2 of the License, or
00011     (at your option) any later version.
00012 
00013     This program is distributed in the hope that it will be useful,
00014     but WITHOUT ANY WARRANTY; without even the implied warranty of
00015     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016     GNU General Public License for more details.
00017 
00018     You should have received a copy of the GNU General Public License along
00019     with this program; if not, write to the Free Software Foundation, Inc.,
00020     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
00021                                                                              *
00022                                                                 **********(*)*/
00023 
00024 
00025 #ifndef PRESAGE_ARPAPREDICTOR
00026 #define PRESAGE_ARPAPREDICTOR
00027 
00028 #include "predictor.h"
00029 #include "../core/logger.h"
00030 #include "../core/progress.h"
00031 #include "../core/dispatcher.h"
00032 
00033 #include <assert.h>
00034 #include <fstream>
00035 #include <iomanip>
00036 
00037 
00038 class cmp {
00039  public:
00040   bool operator() (const float& f1, const float& f2 ) const {
00041     return f2 < f1;
00042   }
00043 };
00044 
00045 class ARPAData
00046 {
00047   public:
00048     ARPAData() {};
00049     ARPAData(float lp,float la) : logProb(lp), logAlfa(la) {};
00050     float logProb;
00051     float logAlfa;
00052 };
00053 
00054 class TrigramKey
00055 {
00056   public:
00057     TrigramKey(int wd1, int wd2, int wd3) : key1(wd1), key2(wd2), key3(wd3) {};
00058 
00059     bool operator<(const TrigramKey &right) const
00060     {
00061       if(key1 < right.key1)
00062         return true;
00063 
00064       if(key1 == right.key1)
00065         if(key2 < right.key2 )
00066           return true;
00067 
00068       if(key1 == right.key1 && key2 == right.key2)
00069         if(key3 < right.key3)
00070           return true;
00071 
00072       return false;
00073     }
00074 
00075     bool operator==(const TrigramKey &right) const
00076     {
00077       return (key1 == right.key1 && key2 == right.key2 && key3 == right.key3 );
00078     }
00079     int key1;
00080     int key2;
00081     int key3;
00082 };
00083 
00084 class BigramKey
00085 {
00086   public:
00087     BigramKey(int wd1, int wd2) : key1(wd1), key2(wd2) {};
00088 
00089     bool operator<(const BigramKey &right) const
00090     {
00091       if(key1 < right.key1)
00092         return true;
00093 
00094       if(key1 == right.key1)
00095         if(key2 < right.key2 )
00096           return true;
00097 
00098       return false;
00099     }
00100 
00101     bool operator==(const TrigramKey &right) const
00102     {
00103       return (key1 == right.key1 && key2 == right.key2);
00104     }
00105     int key1;
00106     int key2;
00107 };
00108 
00112 class ARPAPredictor : public Predictor, public Observer {
00113 
00114 public:
00115     ARPAPredictor(Configuration*, ContextTracker*);
00116     ~ARPAPredictor();
00117 
00118     virtual Prediction predict(const size_t size, const char** filter) const;
00119 
00120     virtual void learn(const std::vector<std::string>& change);
00121 
00122     virtual void update (const Observable* variable);
00123 
00124     void set_vocab_filename (const std::string& value);
00125     void set_arpa_filename (const std::string& value);
00126     void set_timeout (const std::string& value);
00127 
00128 private:
00129     static const char* LOGGER;
00130     static const char* ARPAFILENAME;
00131     static const char* VOCABFILENAME;
00132     static const char* TIMEOUT;
00133 
00134     std::string arpaFilename;
00135     std::string vocabFilename;
00136     int timeout;
00137 
00138     std::map<std::string,int> vocabCode;
00139     std::map<int,std::string> vocabDecode;
00140 
00141     std::map<int,ARPAData> unigramMap;
00142     std::map<BigramKey,ARPAData>bigramMap;
00143     std::map<TrigramKey,float>trigramMap;
00144 
00145     void loadVocabulary();
00146     void createARPATable();
00147     bool matchesPrefixAndFilter(std::string , std::string , const char**  ) const;
00148 
00149     void addUnigram(std::string);
00150     void addBigram(std::string);
00151     void addTrigram(std::string);
00152 
00153     inline float computeTrigramBackoff(int,int,int) const;
00154     inline float computeBigramBackoff(int,int) const;
00155 
00156     int unigramCount;
00157     int bigramCount;
00158     int trigramCount;
00159 
00160     int unigramTot;
00161     int bigramTot;
00162     int trigramTot;
00163 
00164     ProgressBar<char>* unigramProg;
00165     ProgressBar<char>* bigramProg;
00166     ProgressBar<char>* trigramProg;
00167 
00168     Dispatcher<ARPAPredictor> dispatcher;
00169 };
00170 
00171 #endif // PRESAGE_ARPAPREDICTOR