SHOGUN
v1.1.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/mathematics/Math.h> 00013 00014 using namespace shogun; 00015 00016 CROCEvaluation::~CROCEvaluation() 00017 { 00018 SG_FREE(m_ROC_graph); 00019 } 00020 00021 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00022 { 00023 ASSERT(predicted && ground_truth); 00024 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels()); 00025 ASSERT(ground_truth->is_two_class_labeling()); 00026 00027 // assume threshold as negative infinity 00028 float64_t threshold = CMath::ALMOST_NEG_INFTY; 00029 // false positive rate 00030 float64_t fp = 0.0; 00031 // true positive rate 00032 float64_t tp=0.0; 00033 00034 int32_t i; 00035 // total number of positive labels in predicted 00036 int32_t pos_count=0; 00037 int32_t neg_count=0; 00038 00039 // initialize number of labels and labels 00040 SGVector<float64_t> orig_labels = predicted->get_labels(); 00041 int32_t length = orig_labels.vlen; 00042 float64_t* labels = CMath::clone_vector(orig_labels.vector, length); 00043 orig_labels.free_vector(); 00044 00045 // get sorted indexes 00046 int32_t* idxs = SG_MALLOC(int32_t, length); 00047 for(i=0; i<length; i++) 00048 idxs[i] = i; 00049 00050 CMath::qsort_backward_index(labels,idxs,length); 00051 00052 // number of different predicted labels 00053 int32_t diff_count=1; 00054 00055 // get number of different labels 00056 for (i=0; i<length-1; i++) 00057 { 00058 if (labels[i] != labels[i+1]) 00059 diff_count++; 00060 } 00061 00062 delete [] labels; 00063 00064 // initialize graph and auROC 00065 SG_FREE(m_ROC_graph); 00066 m_ROC_graph = SG_MALLOC(float64_t, diff_count*2+2); 00067 m_thresholds = SG_MALLOC(float64_t, length); 00068 m_auROC = 0.0; 00069 00070 // get total numbers of positive and negative labels 00071 for(i=0; i<length; i++) 00072 { 00073 if (ground_truth->get_label(i) > 0) 00074 pos_count++; 00075 else 00076 neg_count++; 00077 } 00078 00079 // assure both number of positive and negative examples is >0 00080 ASSERT(pos_count>0 && neg_count>0); 00081 00082 int32_t j = 0; 00083 float64_t label; 00084 00085 // create ROC curve and calculate auROC 00086 for(i=0; i<length; i++) 00087 { 00088 label = predicted->get_label(idxs[i]); 00089 00090 if (label != threshold) 00091 { 00092 threshold = label; 00093 m_ROC_graph[2*j] = fp/neg_count; 00094 m_ROC_graph[2*j+1] = tp/pos_count; 00095 j++; 00096 } 00097 00098 m_thresholds[i]=threshold; 00099 00100 if (ground_truth->get_label(idxs[i]) > 0) 00101 tp+=1.0; 00102 else 00103 fp+=1.0; 00104 } 00105 00106 // add (1,1) to ROC curve 00107 m_ROC_graph[2*diff_count] = 1.0; 00108 m_ROC_graph[2*diff_count+1] = 1.0; 00109 00110 // set ROC length 00111 m_ROC_length = diff_count+1; 00112 00113 // calc auROC using area under curve 00114 m_auROC = CMath::area_under_curve(m_ROC_graph,m_ROC_length,false); 00115 00116 m_computed = true; 00117 00118 return m_auROC; 00119 } 00120 00121 SGMatrix<float64_t> CROCEvaluation::get_ROC() 00122 { 00123 if (!m_computed) 00124 SG_ERROR("Uninitialized, please call evaluate first"); 00125 00126 ASSERT(m_ROC_graph); 00127 00128 return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length); 00129 } 00130 00131 SGVector<float64_t> CROCEvaluation::get_thresholds() 00132 { 00133 if (!m_computed) 00134 SG_ERROR("Uninitialized, please call evaluate first"); 00135 00136 ASSERT(m_thresholds); 00137 00138 return SGVector<float64_t>(m_thresholds,m_ROC_length); 00139 } 00140 00141 float64_t CROCEvaluation::get_auROC() 00142 { 00143 if (!m_computed) 00144 SG_ERROR("Uninitialized, please call evaluate first"); 00145 00146 return m_auROC; 00147 }