// Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_H__ #define DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_H__ #include <vector> #include "../matrix.h" #include "../statistics.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename reg_funct_type, typename sample_type, typename label_type > label_type test_regression_function ( const reg_funct_type& reg_funct, const std::vector<sample_type>& x_test, const std::vector<label_type>& y_test ) { typedef typename reg_funct_type::scalar_type scalar_type; typedef typename reg_funct_type::mem_manager_type mem_manager_type; // make sure requires clause is not broken DLIB_ASSERT( is_learning_problem(x_test,y_test) == true, "\tmatrix test_regression_function()" << "\n\t invalid inputs were given to this function" << "\n\t is_learning_problem(x_test,y_test): " << is_learning_problem(x_test,y_test)); running_stats<label_type> rs; for (unsigned long i = 0; i < x_test.size(); ++i) { // compute error label_type temp = reg_funct(x_test[i]) - y_test[i]; rs.add(temp*temp); } return rs.mean(); } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename sample_type, typename label_type > label_type cross_validate_regression_trainer ( const trainer_type& trainer, const std::vector<sample_type>& x, const std::vector<label_type>& y, const long folds ) { typedef typename trainer_type::scalar_type scalar_type; typedef typename trainer_type::mem_manager_type mem_manager_type; // make sure requires clause is not broken DLIB_ASSERT(is_learning_problem(x,y) == true && 1 < folds && folds <= static_cast<long>(x.size()), "\tmatrix cross_validate_regression_trainer()" << "\n\t invalid inputs were given to this function" << "\n\t x.size(): " << x.size() << "\n\t folds: " << folds << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) ); const long num_in_test = x.size()/folds; const long num_in_train = x.size() - num_in_test; std::vector<sample_type> x_test, x_train; std::vector<label_type> y_test, y_train; running_stats<label_type> rs; long next_test_idx = 0; for (long i = 0; i < folds; ++i) { x_test.clear(); y_test.clear(); x_train.clear(); y_train.clear(); // load up the test samples for (long cnt = 0; cnt < num_in_test; ++cnt) { x_test.push_back(x[next_test_idx]); y_test.push_back(y[next_test_idx]); next_test_idx = (next_test_idx + 1)%x.size(); } // load up the training samples long next = next_test_idx; for (long cnt = 0; cnt < num_in_train; ++cnt) { x_train.push_back(x[next]); y_train.push_back(y[next]); next = (next + 1)%x.size(); } try { // do the training and testing rs.add(test_regression_function(trainer.train(x_train,y_train),x_test,y_test)); } catch (invalid_nu_error&) { // just ignore cases which result in an invalid nu } } // for (long i = 0; i < folds; ++i) return rs.mean(); } } // ---------------------------------------------------------------------------------------- #endif // DLIB_CROSS_VALIDATE_REGRESSION_TRaINER_H__