00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include "clsfy_classifier_base.h"
00014 #include <vcl_iostream.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_vector.h>
00017 #include <vsl/vsl_indent.h>
00018 #include <vsl/vsl_binary_loader.h>
00019
00020
00021
00022 unsigned clsfy_classifier_base::classify(const vnl_vector<double> &input) const
00023 {
00024 unsigned N = n_classes();
00025
00026 vcl_vector<double> probs;
00027 class_probabilities(probs, input);
00028
00029 if (N == 1)
00030 {
00031 if (probs[0] > 0.5)
00032 return 1u;
00033 else return 0u;
00034 }
00035 else
00036 {
00037 unsigned bestIndex = 0;
00038 unsigned i = 1;
00039 double bestProb = probs[bestIndex];
00040
00041 while (i < N)
00042 {
00043 if (probs[i] > bestProb)
00044 {
00045 bestIndex = i;
00046 bestProb = probs[i];
00047 }
00048 i++;
00049 }
00050 return bestIndex;
00051 }
00052 }
00053
00054
00055
00056 void clsfy_classifier_base::classify_many(vcl_vector<unsigned> &outputs, mbl_data_wrapper<vnl_vector<double> > &inputs) const
00057 {
00058 outputs.resize(inputs.size());
00059
00060 inputs.reset();
00061 unsigned i=0;
00062
00063 do
00064 {
00065 outputs[i++] = classify(inputs.current());
00066 } while (inputs.next());
00067 }
00068
00069
00070
00071 vcl_string clsfy_classifier_base::is_a() const
00072 {
00073 return vcl_string("clsfy_classifier_base");
00074 }
00075
00076
00077
00078 bool clsfy_classifier_base::is_class(vcl_string const& s) const
00079 {
00080 return s == clsfy_classifier_base::is_a();
00081 }
00082
00083
00084
00085 vcl_ostream& operator<<(vcl_ostream& os, clsfy_classifier_base const& b)
00086 {
00087 os << b.is_a() << ": ";
00088 vsl_indent_inc(os);
00089 b.print_summary(os);
00090 vsl_indent_dec(os);
00091 return os;
00092 }
00093
00094
00095
00096 vcl_ostream& operator<<(vcl_ostream& os,const clsfy_classifier_base* b)
00097 {
00098 if (b)
00099 return os << *b;
00100 else
00101 return os << vsl_indent() << "No clsfy_classifier_base defined.";
00102 }
00103
00104
00105
00106 void vsl_add_to_binary_loader(const clsfy_classifier_base& b)
00107 {
00108 vsl_binary_loader<clsfy_classifier_base>::instance().add(b);
00109 }
00110
00111
00112
00113 void vsl_b_write(vsl_b_ostream& os, const clsfy_classifier_base& b)
00114 {
00115 b.b_write(os);
00116 }
00117
00118
00119
00120 void vsl_b_read(vsl_b_istream& bfs, clsfy_classifier_base& b)
00121 {
00122 b.b_read(bfs);
00123 }
00124
00125
00126
00127 double clsfy_test_error(const clsfy_classifier_base &classifier,
00128 mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00129 const vcl_vector<unsigned> & test_outputs)
00130 {
00131 assert(test_inputs.size() == test_outputs.size());
00132 if (test_inputs.size()==0) return -1;
00133
00134 vcl_vector<unsigned> results;
00135 classifier.classify_many(results, test_inputs);
00136 unsigned sum_diff = 0;
00137 const unsigned n = results.size();
00138 for (unsigned i=0; i < n; ++i)
00139 if (results[i] != test_outputs[i]) sum_diff++;
00140 return ((double) sum_diff) / ((double) n);
00141 }
00142
00143
00144
00145
00146 double clsfy_test_error(const clsfy_classifier_base &classifier,
00147 mbl_data_wrapper<vnl_vector<double> > & test_inputs,
00148 const vcl_vector<unsigned> & test_outputs,
00149 unsigned test_class)
00150 {
00151 assert(test_inputs.size() == test_outputs.size());
00152 if (test_inputs.size()==0) return -1;
00153 test_inputs.reset();
00154 unsigned n_class=0, n_bad=0, i=0;
00155 do
00156 {
00157 if (test_outputs[i] == test_class)
00158 {
00159 if (test_outputs[i] != classifier.classify(test_inputs.current()))
00160 n_bad ++;
00161 n_class ++;
00162 }
00163 i++;
00164 } while (test_inputs.next());
00165
00166 if (n_class==0) return -1.0;
00167 return ((double) n_bad) / ((double) n_class);
00168 }
00169