00001
00002
00003 #include "clsfy_binary_hyperplane_ls_builder.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vcl_string.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_algorithm.h>
00017 #include <vnl/algo/vnl_svd.h>
00018 #include <vnl/vnl_math.h>
00019
00020
00021
00022 vcl_string clsfy_binary_hyperplane_ls_builder::is_a() const
00023 {
00024 return vcl_string("clsfy_binary_hyperplane_ls_builder");
00025 }
00026
00027
00028
00029 bool clsfy_binary_hyperplane_ls_builder::is_class(vcl_string const& s) const
00030 {
00031 return s == clsfy_binary_hyperplane_ls_builder::is_a() || clsfy_builder_base::is_class(s);
00032 }
00033
00034
00035
00036 short clsfy_binary_hyperplane_ls_builder::version_no() const
00037 {
00038 return 1;
00039 }
00040
00041
00042
00043 void clsfy_binary_hyperplane_ls_builder::print_summary(vcl_ostream& os) const
00044 {
00045 os << is_a();
00046 }
00047
00048
00049
00050
00051 double clsfy_binary_hyperplane_ls_builder::build(
00052 clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00053 const vcl_vector<unsigned> &outputs) const
00054 {
00055 assert(outputs.size() == inputs.size());
00056 assert(* vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00057 assert(classifier.is_class("clsfy_binary_hyperplane"));
00058
00059 clsfy_binary_hyperplane &hyperplane = (clsfy_binary_hyperplane &) classifier;
00060
00061 inputs.reset();
00062 const unsigned k = inputs.current().size();
00063 vnl_matrix<double> XtX(k+1, k+1, 0.0);
00064 vnl_vector<double> XtY(k+1, 0.0);
00065
00066 #if 0 // The calculation is as follows
00067 do
00068 {
00069
00070 const vnl_vector<double> &x=inputs.current();
00071 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00072 vnl_vector<double> xp(k+1);
00073 xp.update(x, 0);
00074 xp(k) = -1.0;
00075 XtX += outer_product(xp, xp);
00076 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00077 XtY += y * xp;
00078 } while (inputs.next());
00079 #else// However the following version is faster
00080 do
00081 {
00082
00083 const vnl_vector<double> &x=inputs.current();
00084 double y = outputs[inputs.index()] ? 1.0 : -1.0;
00085 for (unsigned i=0; i<k; ++i)
00086 {
00087 for (unsigned j=0; j<i; ++j)
00088 XtX(i,j) += x(i) * x(j);
00089 XtX(i,i) += vnl_math_sqr(x(i));
00090 XtX(i,k) -= x(i);
00091 XtY(i) += y * x(i);
00092 }
00093 XtY(k) += y * -1.0;
00094
00095 } while (inputs.next());
00096 for (unsigned i=0; i<k; ++i)
00097 {
00098 for (unsigned j=0; j<i; ++j)
00099 XtX(j,i) += XtX(i,j);
00100 XtX(k,i) = XtX(i,k);
00101 }
00102 XtX(k, k) = (double) inputs.size();
00103 #endif
00104
00105
00106
00107
00108
00109
00110 vnl_svd<double> svd(XtX, 1.0e-12);
00111 vnl_vector<double> w = svd.solve(XtY);
00112 #if 0
00113 vcl_cerr << "XtX: " << XtX << vcl_endl
00114 << "XtY: " << XtY << vcl_endl
00115 << "w: " << w << vcl_endl;
00116 #endif
00117 vnl_vector<double> weights(&w(0), k);
00118 hyperplane.set(weights, w(k));
00119
00120 return clsfy_test_error(classifier, inputs, outputs);
00121 }
00122
00123
00124
00125
00126
00127
00128
00129
00130 double clsfy_binary_hyperplane_ls_builder::build(
00131 clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00132 unsigned n_classes, const vcl_vector<unsigned> &outputs) const
00133 {
00134 assert (n_classes == 1);
00135 return build(classifier, inputs, outputs);
00136 }
00137
00138
00139
00140 void clsfy_binary_hyperplane_ls_builder::b_write(vsl_b_ostream &bfs) const
00141 {
00142 const int version_no=1;
00143 vsl_b_write(bfs, version_no);
00144 }
00145
00146
00147
00148 void clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream &bfs)
00149 {
00150 if (!bfs) return;
00151
00152 short version;
00153 vsl_b_read(bfs,version);
00154 switch (version)
00155 {
00156 case (1):
00157 break;
00158 default:
00159 vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream&)\n"
00160 << " Unknown version number "<< version << '\n';
00161 bfs.is().clear(vcl_ios::badbit);
00162 }
00163 }