contrib/mul/clsfy/clsfy_classifier_base.cxx

Go to the documentation of this file.
00001 // Copyright: (C) 2000 British Telecommunications plc
00002 
00003 //:
00004 // \file
00005 // \brief Implement bits of an abstract classifier
00006 // \author Ian Scott
00007 // \date 2000/05/10
00008 // \verbatim
00009 //  Modifications
00010 //  2 May 2001 IMS Converted to VXL
00011 // \endverbatim
00012 
00013 #include "clsfy_classifier_base.h"
00014 #include <vcl_iostream.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_vector.h>
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_binary_loader.h>
00019 
00020 //=======================================================================
00021 
00022 unsigned clsfy_classifier_base::classify(const vnl_vector<double> &input) const
00023 {
00024   unsigned N = n_classes();
00025 
00026   vcl_vector<double> probs;
00027   class_probabilities(probs, input);
00028 
00029   if (N == 1) // This is a binary classifier
00030   {
00031     if (probs[0] > 0.5)
00032       return 1u;
00033     else return 0u;
00034   }
00035   else
00036   {
00037     unsigned bestIndex = 0;
00038     unsigned i = 1;
00039     double bestProb = probs[bestIndex];
00040 
00041     while (i < N)
00042     {
00043       if (probs[i] > bestProb)
00044       {
00045         bestIndex = i;
00046         bestProb = probs[i];
00047       }
00048       i++;
00049     }
00050     return bestIndex;
00051   }
00052 }
00053 
00054 //=======================================================================
00055 
00056 void clsfy_classifier_base::classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<vnl_vector<double> > &inputs) const
00057 {
00058   outputs.resize(inputs.size());
00059 
00060   inputs.reset();
00061   unsigned i=0;
00062 
00063   do
00064   {
00065     outputs[i++] = classify(inputs.current());
00066   } while (inputs.next());
00067 }
00068 
00069 //=======================================================================
00070 
00071 vcl_string clsfy_classifier_base::is_a() const
00072 {
00073   return vcl_string("clsfy_classifier_base");
00074 }
00075 
00076 //=======================================================================
00077 
00078 bool clsfy_classifier_base::is_class(vcl_string const& s) const
00079 {
00080   return s == clsfy_classifier_base::is_a();
00081 }
00082 
00083 //=======================================================================
00084 
00085 vcl_ostream& operator<<(vcl_ostream& os, clsfy_classifier_base const& b)
00086 {
00087   os << b.is_a() << ": ";
00088   vsl_indent_inc(os);
00089   b.print_summary(os);
00090   vsl_indent_dec(os);
00091   return os;
00092 }
00093 
00094 //=======================================================================
00095 
00096 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_classifier_base* b)
00097 {
00098   if (b)
00099     return os << *b;
00100   else
00101     return os << vsl_indent() << "No clsfy_classifier_base defined.";
00102 }
00103 
00104 //=======================================================================
00105 
00106 void vsl_add_to_binary_loader(const clsfy_classifier_base& b)
00107 {
00108   vsl_binary_loader<clsfy_classifier_base>::instance().add(b);
00109 }
00110 
00111 //=======================================================================
00112 
00113 void vsl_b_write(vsl_b_ostream& os, const clsfy_classifier_base& b)
00114 {
00115   b.b_write(os);
00116 }
00117 
00118 //=======================================================================
00119 
00120 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_base& b)
00121 {
00122   b.b_read(bfs);
00123 }
00124 
00125 //=======================================================================
00126 //: Calculate the fraction of test samples which are classified incorrectly
00127 double clsfy_test_error(const clsfy_classifier_base &classifier,
00128                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00129                         const vcl_vector<unsigned> & test_outputs)
00130 {
00131   assert(test_inputs.size() == test_outputs.size());
00132   if (test_inputs.size()==0) return -1;
00133 
00134   vcl_vector<unsigned> results;
00135   classifier.classify_many(results, test_inputs);
00136   unsigned sum_diff = 0;
00137   const unsigned n = results.size();
00138   for (unsigned i=0; i < n; ++i)
00139     if (results[i] != test_outputs[i]) sum_diff++;
00140   return ((double) sum_diff) / ((double) n);
00141 }
00142 
00143 //=======================================================================
00144 //: Calculate the fraction of test samples of a particular class which are classified incorrectly
00145 // \return -1 if there are no samples of test_class. 
00146 double clsfy_test_error(const clsfy_classifier_base &classifier,
00147                         mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00148                         const vcl_vector<unsigned> & test_outputs,
00149                         unsigned test_class)
00150 {
00151   assert(test_inputs.size() == test_outputs.size());
00152   if (test_inputs.size()==0) return -1;
00153   test_inputs.reset();
00154   unsigned n_class=0, n_bad=0, i=0;
00155   do
00156   {
00157     if (test_outputs[i] == test_class)
00158     {
00159       if (test_outputs[i] != classifier.classify(test_inputs.current()))
00160         n_bad ++;
00161       n_class ++;
00162     }
00163     i++;
00164   } while (test_inputs.next());
00165 
00166   if (n_class==0) return -1.0;
00167   return ((double) n_bad) / ((double) n_class);
00168 }
00169 

Generated on Thu Jan 8 05:11:33 2009 for contrib/mul/clsfy by  doxygen 1.5.1