presage  0.8.7
smoothedNgramPredictor.cpp
Go to the documentation of this file.
00001 
00002 /******************************************************
00003  *  Presage, an extensible predictive text entry system
00004  *  ---------------------------------------------------
00005  *
00006  *  Copyright (C) 2008  Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
00007 
00008     This program is free software; you can redistribute it and/or modify
00009     it under the terms of the GNU General Public License as published by
00010     the Free Software Foundation; either version 2 of the License, or
00011     (at your option) any later version.
00012 
00013     This program is distributed in the hope that it will be useful,
00014     but WITHOUT ANY WARRANTY; without even the implied warranty of
00015     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016     GNU General Public License for more details.
00017 
00018     You should have received a copy of the GNU General Public License along
00019     with this program; if not, write to the Free Software Foundation, Inc.,
00020     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
00021                                                                              *
00022                                                                 **********(*)*/
00023 
00024 
00025 #include "smoothedNgramPredictor.h"
00026 
00027 #include <sstream>
00028 #include <algorithm>
00029 
00030 const char* SmoothedNgramPredictor::LOGGER     = "Presage.Predictors.SmoothedNgramPredictor.LOGGER";
00031 const char* SmoothedNgramPredictor::DBFILENAME = "Presage.Predictors.SmoothedNgramPredictor.DBFILENAME";
00032 const char* SmoothedNgramPredictor::DELTAS     = "Presage.Predictors.SmoothedNgramPredictor.DELTAS";
00033 const char* SmoothedNgramPredictor::LEARN      = "Presage.Predictors.SmoothedNgramPredictor.LEARN";
00034 const char* SmoothedNgramPredictor::DATABASE_LOGGER = "Presage.Predictors.SmoothedNgramPredictor.DatabaseConnector.LOGGER";
00035 
00036 SmoothedNgramPredictor::SmoothedNgramPredictor(Configuration* config, ContextTracker* ct)
00037     : Predictor(config,
00038                 ct,
00039                 "SmoothedNgramPredictor",
00040                 "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
00041                 "SmoothedNgramPredictor, long description." ),
00042       db (0),
00043       dispatcher (this)
00044 {
00045     // build notification dispatch map
00046     dispatcher.map (config->find (LOGGER), & SmoothedNgramPredictor::set_logger);
00047     dispatcher.map (config->find (DATABASE_LOGGER), & SmoothedNgramPredictor::set_database_logger_level);
00048     dispatcher.map (config->find (DBFILENAME), & SmoothedNgramPredictor::set_dbfilename);
00049     dispatcher.map (config->find (DELTAS), & SmoothedNgramPredictor::set_deltas);
00050     dispatcher.map (config->find (LEARN), & SmoothedNgramPredictor::set_learn);
00051 }
00052 
00053 
00054 
00055 SmoothedNgramPredictor::~SmoothedNgramPredictor()
00056 {
00057     delete db;
00058 }
00059 
00060 
00061 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
00062 {
00063     dbfilename = filename;
00064     logger << INFO << "DBFILENAME: " << filename << endl;
00065 
00066     delete db;
00067 
00068     if (dbloglevel.empty ()) {
00069         // open database connector
00070         db = new SqliteDatabaseConnector(dbfilename);
00071 
00072     } else {
00073         // open database connector with logger lever
00074         db = new SqliteDatabaseConnector(dbfilename, dbloglevel);
00075     }
00076 }
00077 
00078 
00079 void SmoothedNgramPredictor::set_database_logger_level (const std::string& value)
00080 {
00081     dbloglevel = value;
00082 }
00083 
00084 
00085 void SmoothedNgramPredictor::set_deltas (const std::string& value)
00086 {
00087     std::stringstream ss_deltas(value);
00088     std::string delta;
00089     while (ss_deltas >> delta) {
00090         logger << DEBUG << "Pushing delta: " << delta << endl;
00091         deltas.push_back (Utility::toDouble (delta));
00092     }
00093     logger << INFO << "DELTAS: " << value << endl;
00094 }
00095 
00096 
00097 void SmoothedNgramPredictor::set_learn (const std::string& value)
00098 {
00099     wanna_learn = Utility::isTrue (value);
00100     logger << INFO << "LEARN: " << value << endl;
00101 }
00102 
00103 
00119 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
00120 {
00121     assert(offset <= 0); // TODO: handle this better
00122     assert(ngram_size >= 0);
00123 
00124     if (ngram_size > 0) {
00125         Ngram ngram(ngram_size);
00126         copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
00127 
00128         logger << DEBUG << "count ngram: ";
00129         for (size_t j = 0; j < ngram.size(); j++) {
00130             logger << DEBUG << ngram[j] << ' ';
00131         }
00132         logger << DEBUG << endl;
00133 
00134         return db->getNgramCount(ngram);
00135     } else {
00136         return db->getUnigramCountsSum();
00137     }
00138 }
00139 
00140 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
00141 {
00142     logger << DEBUG << "predict()" << endl;
00143 
00144     // Result prediction
00145     Prediction prediction;
00146 
00147     // n-gram cardinality (i.e. what is the n in n-gram?)
00148     int cardinality = deltas.size();
00149 
00150     // Cache all the needed tokens.
00151     // tokens[k] corresponds to w_{i-k} in the generalized smoothed
00152     // n-gram probability formula
00153     //
00154     std::vector<std::string> tokens(cardinality);
00155     for (int i = 0; i < cardinality; i++) {
00156         tokens[cardinality - 1 - i] = Utility::strtolower(contextTracker->getToken(i));
00157         logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
00158     }
00159 
00160     // Generate list of prefix completition candidates.
00161     //
00162     // The prefix completion candidates used to be obtained from the
00163     // _1_gram table because in a well-constructed ngram database the
00164     // _1_gram table (which contains all known tokens). However, this
00165     // introduced a skew, since the unigram counts will take
00166     // precedence over the higher-order counts.
00167     //
00168     // The current solution retrieves candidates from the highest
00169     // n-gram table, falling back on lower order n-gram tables if
00170     // initial completion set is smaller than required.
00171     //
00172     std::vector<std::string> prefixCompletionCandidates;
00173     for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
00174         logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
00175         // create n-gram used to retrieve initial prefix completion table
00176         Ngram prefix_ngram(k);
00177         copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
00178 
00179         if (logger.shouldLog()) {
00180             logger << DEBUG << "prefix_ngram: ";
00181             for (size_t r = 0; r < prefix_ngram.size(); r++) {
00182                 logger << DEBUG << prefix_ngram[r] << ' ';
00183             }
00184             logger << DEBUG << endl;
00185         }
00186 
00187         // obtain initial prefix completion candidates
00188         db->beginTransaction();
00189 
00190         NgramTable partial;
00191 
00192         if (filter == 0) {
00193             partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
00194         } else {
00195             partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
00196         }
00197 
00198         db->endTransaction();
00199 
00200         if (logger.shouldLog()) {
00201             logger << DEBUG << "partial prefixCompletionCandidates" << endl
00202                    << DEBUG << "----------------------------------" << endl;
00203             for (size_t j = 0; j < partial.size(); j++) {
00204                 for (size_t k = 0; k < partial[j].size(); k++) {
00205                     logger << DEBUG << partial[j][k] << " ";
00206                 }
00207                 logger << endl;
00208             }
00209         }
00210 
00211         logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
00212 
00213         // append newly discovered potential completions to prefix
00214         // completion candidates array to fill it up to
00215         // max_partial_prediction_size
00216         //
00217         std::vector<Ngram>::const_iterator it = partial.begin();
00218         while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
00219             // only add new candidates, iterator it points to Ngram,
00220             // it->end() - 2 points to the token candidate
00221             //
00222             std::string candidate = *(it->end() - 2);
00223             if (find(prefixCompletionCandidates.begin(),
00224                      prefixCompletionCandidates.end(),
00225                      candidate) == prefixCompletionCandidates.end()) {
00226                 prefixCompletionCandidates.push_back(candidate);
00227             }
00228             it++;
00229         }
00230     }
00231 
00232     if (logger.shouldLog()) {
00233         logger << DEBUG << "prefixCompletionCandidates" << endl
00234                << DEBUG << "--------------------------" << endl;
00235         for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
00236             logger << DEBUG << prefixCompletionCandidates[j] << endl;
00237         }
00238     }
00239 
00240     // compute smoothed probabilities for all candidates
00241     //
00242     db->beginTransaction();
00243     // getUnigramCountsSum is an expensive SQL query
00244     // caching it here saves much time later inside the loop
00245     int unigrams_counts_sum = db->getUnigramCountsSum(); 
00246     for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
00247         // store w_i candidate at end of tokens
00248         tokens[cardinality - 1] = prefixCompletionCandidates[j];
00249 
00250         logger << DEBUG << "------------------" << endl;
00251         logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
00252 
00253         double probability = 0;
00254         for (int k = 0; k < cardinality; k++) {
00255             double numerator = count(tokens, 0, k+1);
00256             // reuse cached unigrams_counts_sum to speed things up
00257             double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
00258             double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
00259             probability += deltas[k] * frequency;
00260 
00261             logger << DEBUG << "numerator:   " << numerator << endl;
00262             logger << DEBUG << "denominator: " << denominator << endl;
00263             logger << DEBUG << "frequency:   " << frequency << endl;
00264             logger << DEBUG << "delta:       " << deltas[k] << endl;
00265 
00266             // for some sanity checks
00267             assert(numerator <= denominator);
00268             assert(frequency <= 1);
00269         }
00270 
00271         logger << DEBUG << "____________" << endl;
00272         logger << DEBUG << "probability: " << probability << endl;
00273 
00274         if (probability > 0) {
00275             prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
00276         }
00277     }
00278     db->endTransaction();
00279 
00280     logger << DEBUG << "Prediction:" << endl;
00281     logger << DEBUG << "-----------" << endl;
00282     logger << DEBUG << prediction << endl;
00283 
00284     return prediction;
00285 }
00286 
00287 
00288 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
00289 {
00290     logger << INFO << "learn()" << endl;
00291 
00292     if (wanna_learn) {
00293         // learning is turned on
00294 
00295         // n-gram cardinality (i.e. what is the n in n-gram?)
00296         size_t cardinality = deltas.size();
00297 
00298         std::string token;
00299         for (size_t curr_cardinality = 1;
00300              curr_cardinality < cardinality + 1;
00301              curr_cardinality++) {
00302 
00303             // idx walks the change vector back to front
00304             for (std::vector<std::string>::const_reverse_iterator idx = change.rbegin();
00305                  idx != change.rend();
00306                  idx++)
00307             {
00308                 Ngram ngram;
00309 
00310                 // try to fill in the ngram to be learnt with change
00311                 // tokens first
00312                 for (std::vector<std::string>::const_reverse_iterator inner_idx = idx;
00313                      inner_idx != change.rend() && ngram.size() < curr_cardinality;
00314                      inner_idx++)
00315                 {
00316                     ngram.insert(ngram.begin(), *inner_idx);
00317                 }
00318 
00319                 // then use past stream if ngram not filled in yet
00320                 for (int tk_idx = 1;
00321                      ngram.size() < curr_cardinality;
00322                      tk_idx++)
00323                 {                    
00324                     // ContextTracker already sees latest tokens that
00325                     // we need to learn, hence we need to look at the
00326                     // sliding window and obtain tokens from there.
00327                     //
00328                     // getSlidingWindowToken returns tokens from
00329                     // stream tied to sliding window from context
00330                     // change detector
00331 
00332                     ngram.insert(ngram.begin(), contextTracker->getSlidingWindowToken(tk_idx));
00333                 }
00334 
00335                 // now we have built the ngram we have to learn
00336                 logger << INFO << "Considering to learn ngram: |";
00337                 for (size_t j = 0; j < ngram.size(); j++) {
00338                     logger << INFO << ngram[j] << '|';
00339                 }
00340                 logger << INFO << endl;
00341                 
00342                 if (ngram.end() == find(ngram.begin(), ngram.end(), "")) {
00343                     // only learn ngram if it doesn't contain empty strings
00344                     try
00345                     {
00346                         db->beginTransaction();
00347                     
00348                         db->incrementNgramCount(ngram);
00349                         check_learn_consistency(ngram);
00350                         
00351                         db->endTransaction();
00352                         logger << INFO << "Committed ngram update to database" << endl;
00353                     }
00354                     catch (SqliteDatabaseConnector::SqliteDatabaseConnectorException& ex)
00355                     {
00356                         db->rollbackTransaction();
00357                         logger << ERROR << ex.what() << endl;
00358                         throw;
00359                     }
00360                 } else {
00361                     logger << INFO << "Discarded ngram" << endl;
00362                 }
00363             }
00364         }
00365     }
00366 
00367     logger << DEBUG << "end learn()" << endl;
00368 }
00369 
00370 void SmoothedNgramPredictor::check_learn_consistency(const Ngram& ngram) const
00371 {
00372     // no need to begin a new transaction, as we'll be called from
00373     // within an existing transaction from learn()
00374 
00375     // BEWARE: if the previous sentence is not true, then performance
00376     // will suffer!
00377 
00378     size_t size = ngram.size();
00379     for (size_t i = 0; i < size; i++) {
00380         if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
00381             logger << INFO << "consistency adjustment needed!" << endl;
00382 
00383             int offset = -(i + 1);
00384             int sub_ngram_size = size - (i + 1);
00385 
00386             logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
00387 
00388             Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
00389             copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
00390 
00391             if (logger.shouldLog()) {
00392                 logger << "ngram to be count adjusted is: ";
00393                 for (size_t i = 0; i < sub_ngram.size(); i++) {
00394                     logger << sub_ngram[i] << ' ';
00395                 }
00396                 logger << endl;
00397             }
00398 
00399             db->incrementNgramCount(sub_ngram);
00400             logger << DEBUG << "consistency adjusted" << endl;
00401         }
00402     }
00403 }
00404 
00405 void SmoothedNgramPredictor::update (const Observable* var)
00406 {
00407     logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
00408     dispatcher.dispatch (var);
00409 }