SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/PRCEvaluation.h> 00012 #include <shogun/mathematics/Math.h> 00013 00014 using namespace shogun; 00015 00016 CPRCEvaluation::~CPRCEvaluation() 00017 { 00018 SG_FREE(m_PRC_graph); 00019 } 00020 00021 float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00022 { 00023 ASSERT(predicted && ground_truth); 00024 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels()); 00025 ASSERT(ground_truth->is_two_class_labeling()); 00026 00027 // number of true positive examples 00028 float64_t tp = 0.0; 00029 int32_t i; 00030 00031 // total number of positive labels in predicted 00032 int32_t pos_count=0; 00033 00034 // initialize number of labels and labels 00035 SGVector<float64_t> orig_labels = predicted->get_labels(); 00036 int32_t length = orig_labels.vlen; 00037 float64_t* labels = CMath::clone_vector(orig_labels.vector, length); 00038 orig_labels.free_vector(); 00039 00040 // get indexes for sort 00041 int32_t* idxs = SG_MALLOC(int32_t, length); 00042 for(i=0; i<length; i++) 00043 idxs[i] = i; 00044 00045 // sort indexes by labels ascending 00046 CMath::qsort_backward_index(labels,idxs,length); 00047 00048 // clean and initialize graph and auPRC 00049 SG_FREE(labels); 00050 SG_FREE(m_PRC_graph); 00051 m_PRC_graph = SG_MALLOC(float64_t, length*2); 00052 m_thresholds = SG_MALLOC(float64_t, length); 00053 m_auPRC = 0.0; 00054 00055 // get total numbers of positive and negative labels 00056 for (i=0; i<length; i++) 00057 { 00058 if (ground_truth->get_label(i) > 0) 00059 pos_count++; 00060 } 00061 00062 // assure number of positive examples is >0 00063 ASSERT(pos_count>0); 00064 00065 // create PRC curve 00066 for (i=0; i<length; i++) 00067 { 00068 // update number of true positive examples 00069 if (ground_truth->get_label(idxs[i]) > 0) 00070 tp += 1.0; 00071 00072 // precision (x) 00073 m_PRC_graph[2*i] = tp/float64_t(i+1); 00074 // recall (y) 00075 m_PRC_graph[2*i+1] = tp/float64_t(pos_count); 00076 00077 m_thresholds[i]= predicted->get_label(idxs[i]); 00078 } 00079 00080 // calc auRPC using area under curve 00081 m_auPRC = CMath::area_under_curve(m_PRC_graph,length,true); 00082 00083 // set PRC length and computed indicator 00084 m_PRC_length = length; 00085 m_computed = true; 00086 00087 return m_auPRC; 00088 } 00089 00090 SGMatrix<float64_t> CPRCEvaluation::get_PRC() 00091 { 00092 if (!m_computed) 00093 SG_ERROR("Uninitialized, please call evaluate first"); 00094 00095 ASSERT(m_PRC_graph); 00096 00097 return SGMatrix<float64_t>(m_PRC_graph,2,m_PRC_length); 00098 } 00099 00100 SGVector<float64_t> CPRCEvaluation::get_thresholds() 00101 { 00102 if (!m_computed) 00103 SG_ERROR("Uninitialized, please call evaluate first"); 00104 00105 ASSERT(m_thresholds); 00106 00107 return SGVector<float64_t>(m_thresholds,m_PRC_length); 00108 } 00109 00110 float64_t CPRCEvaluation::get_auPRC() 00111 { 00112 if (!m_computed) 00113 SG_ERROR("Uninitialized, please call evaluate first"); 00114 00115 return m_auPRC; 00116 } 00117 00118