presage
0.8.7
|
00001 00002 /****************************************************** 00003 * Presage, an extensible predictive text entry system 00004 * --------------------------------------------------- 00005 * 00006 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk> 00007 00008 This program is free software; you can redistribute it and/or modify 00009 it under the terms of the GNU General Public License as published by 00010 the Free Software Foundation; either version 2 of the License, or 00011 (at your option) any later version. 00012 00013 This program is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 GNU General Public License for more details. 00017 00018 You should have received a copy of the GNU General Public License along 00019 with this program; if not, write to the Free Software Foundation, Inc., 00020 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 00021 * 00022 **********(*)*/ 00023 00024 00025 #include "smoothedNgramPredictor.h" 00026 00027 #include <sstream> 00028 #include <algorithm> 00029 00030 const char* SmoothedNgramPredictor::LOGGER = "Presage.Predictors.SmoothedNgramPredictor.LOGGER"; 00031 const char* SmoothedNgramPredictor::DBFILENAME = "Presage.Predictors.SmoothedNgramPredictor.DBFILENAME"; 00032 const char* SmoothedNgramPredictor::DELTAS = "Presage.Predictors.SmoothedNgramPredictor.DELTAS"; 00033 const char* SmoothedNgramPredictor::LEARN = "Presage.Predictors.SmoothedNgramPredictor.LEARN"; 00034 const char* SmoothedNgramPredictor::DATABASE_LOGGER = "Presage.Predictors.SmoothedNgramPredictor.DatabaseConnector.LOGGER"; 00035 00036 SmoothedNgramPredictor::SmoothedNgramPredictor(Configuration* config, ContextTracker* ct) 00037 : Predictor(config, 00038 ct, 00039 "SmoothedNgramPredictor", 00040 "SmoothedNgramPredictor, a linear interpolating n-gram predictor", 00041 "SmoothedNgramPredictor, long description." ), 00042 db (0), 00043 dispatcher (this) 00044 { 00045 // build notification dispatch map 00046 dispatcher.map (config->find (LOGGER), & SmoothedNgramPredictor::set_logger); 00047 dispatcher.map (config->find (DATABASE_LOGGER), & SmoothedNgramPredictor::set_database_logger_level); 00048 dispatcher.map (config->find (DBFILENAME), & SmoothedNgramPredictor::set_dbfilename); 00049 dispatcher.map (config->find (DELTAS), & SmoothedNgramPredictor::set_deltas); 00050 dispatcher.map (config->find (LEARN), & SmoothedNgramPredictor::set_learn); 00051 } 00052 00053 00054 00055 SmoothedNgramPredictor::~SmoothedNgramPredictor() 00056 { 00057 delete db; 00058 } 00059 00060 00061 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename) 00062 { 00063 dbfilename = filename; 00064 logger << INFO << "DBFILENAME: " << filename << endl; 00065 00066 delete db; 00067 00068 if (dbloglevel.empty ()) { 00069 // open database connector 00070 db = new SqliteDatabaseConnector(dbfilename); 00071 00072 } else { 00073 // open database connector with logger lever 00074 db = new SqliteDatabaseConnector(dbfilename, dbloglevel); 00075 } 00076 } 00077 00078 00079 void SmoothedNgramPredictor::set_database_logger_level (const std::string& value) 00080 { 00081 dbloglevel = value; 00082 } 00083 00084 00085 void SmoothedNgramPredictor::set_deltas (const std::string& value) 00086 { 00087 std::stringstream ss_deltas(value); 00088 std::string delta; 00089 while (ss_deltas >> delta) { 00090 logger << DEBUG << "Pushing delta: " << delta << endl; 00091 deltas.push_back (Utility::toDouble (delta)); 00092 } 00093 logger << INFO << "DELTAS: " << value << endl; 00094 } 00095 00096 00097 void SmoothedNgramPredictor::set_learn (const std::string& value) 00098 { 00099 wanna_learn = Utility::isTrue (value); 00100 logger << INFO << "LEARN: " << value << endl; 00101 } 00102 00103 00119 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const 00120 { 00121 assert(offset <= 0); // TODO: handle this better 00122 assert(ngram_size >= 0); 00123 00124 if (ngram_size > 0) { 00125 Ngram ngram(ngram_size); 00126 copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin()); 00127 00128 logger << DEBUG << "count ngram: "; 00129 for (size_t j = 0; j < ngram.size(); j++) { 00130 logger << DEBUG << ngram[j] << ' '; 00131 } 00132 logger << DEBUG << endl; 00133 00134 return db->getNgramCount(ngram); 00135 } else { 00136 return db->getUnigramCountsSum(); 00137 } 00138 } 00139 00140 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const 00141 { 00142 logger << DEBUG << "predict()" << endl; 00143 00144 // Result prediction 00145 Prediction prediction; 00146 00147 // n-gram cardinality (i.e. what is the n in n-gram?) 00148 int cardinality = deltas.size(); 00149 00150 // Cache all the needed tokens. 00151 // tokens[k] corresponds to w_{i-k} in the generalized smoothed 00152 // n-gram probability formula 00153 // 00154 std::vector<std::string> tokens(cardinality); 00155 for (int i = 0; i < cardinality; i++) { 00156 tokens[cardinality - 1 - i] = Utility::strtolower(contextTracker->getToken(i)); 00157 logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl; 00158 } 00159 00160 // Generate list of prefix completition candidates. 00161 // 00162 // The prefix completion candidates used to be obtained from the 00163 // _1_gram table because in a well-constructed ngram database the 00164 // _1_gram table (which contains all known tokens). However, this 00165 // introduced a skew, since the unigram counts will take 00166 // precedence over the higher-order counts. 00167 // 00168 // The current solution retrieves candidates from the highest 00169 // n-gram table, falling back on lower order n-gram tables if 00170 // initial completion set is smaller than required. 00171 // 00172 std::vector<std::string> prefixCompletionCandidates; 00173 for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) { 00174 logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl; 00175 // create n-gram used to retrieve initial prefix completion table 00176 Ngram prefix_ngram(k); 00177 copy(tokens.end() - k, tokens.end(), prefix_ngram.begin()); 00178 00179 if (logger.shouldLog()) { 00180 logger << DEBUG << "prefix_ngram: "; 00181 for (size_t r = 0; r < prefix_ngram.size(); r++) { 00182 logger << DEBUG << prefix_ngram[r] << ' '; 00183 } 00184 logger << DEBUG << endl; 00185 } 00186 00187 // obtain initial prefix completion candidates 00188 db->beginTransaction(); 00189 00190 NgramTable partial; 00191 00192 if (filter == 0) { 00193 partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size()); 00194 } else { 00195 partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size()); 00196 } 00197 00198 db->endTransaction(); 00199 00200 if (logger.shouldLog()) { 00201 logger << DEBUG << "partial prefixCompletionCandidates" << endl 00202 << DEBUG << "----------------------------------" << endl; 00203 for (size_t j = 0; j < partial.size(); j++) { 00204 for (size_t k = 0; k < partial[j].size(); k++) { 00205 logger << DEBUG << partial[j][k] << " "; 00206 } 00207 logger << endl; 00208 } 00209 } 00210 00211 logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl; 00212 00213 // append newly discovered potential completions to prefix 00214 // completion candidates array to fill it up to 00215 // max_partial_prediction_size 00216 // 00217 std::vector<Ngram>::const_iterator it = partial.begin(); 00218 while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) { 00219 // only add new candidates, iterator it points to Ngram, 00220 // it->end() - 2 points to the token candidate 00221 // 00222 std::string candidate = *(it->end() - 2); 00223 if (find(prefixCompletionCandidates.begin(), 00224 prefixCompletionCandidates.end(), 00225 candidate) == prefixCompletionCandidates.end()) { 00226 prefixCompletionCandidates.push_back(candidate); 00227 } 00228 it++; 00229 } 00230 } 00231 00232 if (logger.shouldLog()) { 00233 logger << DEBUG << "prefixCompletionCandidates" << endl 00234 << DEBUG << "--------------------------" << endl; 00235 for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) { 00236 logger << DEBUG << prefixCompletionCandidates[j] << endl; 00237 } 00238 } 00239 00240 // compute smoothed probabilities for all candidates 00241 // 00242 db->beginTransaction(); 00243 // getUnigramCountsSum is an expensive SQL query 00244 // caching it here saves much time later inside the loop 00245 int unigrams_counts_sum = db->getUnigramCountsSum(); 00246 for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) { 00247 // store w_i candidate at end of tokens 00248 tokens[cardinality - 1] = prefixCompletionCandidates[j]; 00249 00250 logger << DEBUG << "------------------" << endl; 00251 logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl; 00252 00253 double probability = 0; 00254 for (int k = 0; k < cardinality; k++) { 00255 double numerator = count(tokens, 0, k+1); 00256 // reuse cached unigrams_counts_sum to speed things up 00257 double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k)); 00258 double frequency = ((denominator > 0) ? (numerator / denominator) : 0); 00259 probability += deltas[k] * frequency; 00260 00261 logger << DEBUG << "numerator: " << numerator << endl; 00262 logger << DEBUG << "denominator: " << denominator << endl; 00263 logger << DEBUG << "frequency: " << frequency << endl; 00264 logger << DEBUG << "delta: " << deltas[k] << endl; 00265 00266 // for some sanity checks 00267 assert(numerator <= denominator); 00268 assert(frequency <= 1); 00269 } 00270 00271 logger << DEBUG << "____________" << endl; 00272 logger << DEBUG << "probability: " << probability << endl; 00273 00274 if (probability > 0) { 00275 prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability)); 00276 } 00277 } 00278 db->endTransaction(); 00279 00280 logger << DEBUG << "Prediction:" << endl; 00281 logger << DEBUG << "-----------" << endl; 00282 logger << DEBUG << prediction << endl; 00283 00284 return prediction; 00285 } 00286 00287 00288 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change) 00289 { 00290 logger << INFO << "learn()" << endl; 00291 00292 if (wanna_learn) { 00293 // learning is turned on 00294 00295 // n-gram cardinality (i.e. what is the n in n-gram?) 00296 size_t cardinality = deltas.size(); 00297 00298 std::string token; 00299 for (size_t curr_cardinality = 1; 00300 curr_cardinality < cardinality + 1; 00301 curr_cardinality++) { 00302 00303 // idx walks the change vector back to front 00304 for (std::vector<std::string>::const_reverse_iterator idx = change.rbegin(); 00305 idx != change.rend(); 00306 idx++) 00307 { 00308 Ngram ngram; 00309 00310 // try to fill in the ngram to be learnt with change 00311 // tokens first 00312 for (std::vector<std::string>::const_reverse_iterator inner_idx = idx; 00313 inner_idx != change.rend() && ngram.size() < curr_cardinality; 00314 inner_idx++) 00315 { 00316 ngram.insert(ngram.begin(), *inner_idx); 00317 } 00318 00319 // then use past stream if ngram not filled in yet 00320 for (int tk_idx = 1; 00321 ngram.size() < curr_cardinality; 00322 tk_idx++) 00323 { 00324 // ContextTracker already sees latest tokens that 00325 // we need to learn, hence we need to look at the 00326 // sliding window and obtain tokens from there. 00327 // 00328 // getSlidingWindowToken returns tokens from 00329 // stream tied to sliding window from context 00330 // change detector 00331 00332 ngram.insert(ngram.begin(), contextTracker->getSlidingWindowToken(tk_idx)); 00333 } 00334 00335 // now we have built the ngram we have to learn 00336 logger << INFO << "Considering to learn ngram: |"; 00337 for (size_t j = 0; j < ngram.size(); j++) { 00338 logger << INFO << ngram[j] << '|'; 00339 } 00340 logger << INFO << endl; 00341 00342 if (ngram.end() == find(ngram.begin(), ngram.end(), "")) { 00343 // only learn ngram if it doesn't contain empty strings 00344 try 00345 { 00346 db->beginTransaction(); 00347 00348 db->incrementNgramCount(ngram); 00349 check_learn_consistency(ngram); 00350 00351 db->endTransaction(); 00352 logger << INFO << "Committed ngram update to database" << endl; 00353 } 00354 catch (SqliteDatabaseConnector::SqliteDatabaseConnectorException& ex) 00355 { 00356 db->rollbackTransaction(); 00357 logger << ERROR << ex.what() << endl; 00358 throw; 00359 } 00360 } else { 00361 logger << INFO << "Discarded ngram" << endl; 00362 } 00363 } 00364 } 00365 } 00366 00367 logger << DEBUG << "end learn()" << endl; 00368 } 00369 00370 void SmoothedNgramPredictor::check_learn_consistency(const Ngram& ngram) const 00371 { 00372 // no need to begin a new transaction, as we'll be called from 00373 // within an existing transaction from learn() 00374 00375 // BEWARE: if the previous sentence is not true, then performance 00376 // will suffer! 00377 00378 size_t size = ngram.size(); 00379 for (size_t i = 0; i < size; i++) { 00380 if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) { 00381 logger << INFO << "consistency adjustment needed!" << endl; 00382 00383 int offset = -(i + 1); 00384 int sub_ngram_size = size - (i + 1); 00385 00386 logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl; 00387 00388 Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram 00389 copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin()); 00390 00391 if (logger.shouldLog()) { 00392 logger << "ngram to be count adjusted is: "; 00393 for (size_t i = 0; i < sub_ngram.size(); i++) { 00394 logger << sub_ngram[i] << ' '; 00395 } 00396 logger << endl; 00397 } 00398 00399 db->incrementNgramCount(sub_ngram); 00400 logger << DEBUG << "consistency adjusted" << endl; 00401 } 00402 } 00403 } 00404 00405 void SmoothedNgramPredictor::update (const Observable* var) 00406 { 00407 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl; 00408 dispatcher.dispatch (var); 00409 }