contrib/mul/clsfy/clsfy_binary_hyperplane_ls_builder.cxx

Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_hyperplane_ls_builder.cxx
00002 // Copyright: (C) 2001 British Telecommunications PLC
00003 #include "clsfy_binary_hyperplane_ls_builder.h"
00004 //:
00005 // \file
00006 // \brief Implement a two-class output linear classifier builder
00007 // \author Ian Scott
00008 // \date 4 June 2001
00009 
00010 //=======================================================================
00011 
00012 #include <vcl_string.h>
00013 #include <vcl_iostream.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_cassert.h>
00016 #include <vcl_algorithm.h>
00017 #include <vnl/algo/vnl_svd.h>
00018 #include <vnl/vnl_math.h>
00019 
00020 //=======================================================================
00021 
00022 vcl_string clsfy_binary_hyperplane_ls_builder::is_a() const
00023 {
00024   return vcl_string("clsfy_binary_hyperplane_ls_builder");
00025 }
00026 
00027 //=======================================================================
00028 
00029 bool clsfy_binary_hyperplane_ls_builder::is_class(vcl_string const& s) const
00030 {
00031   return s == clsfy_binary_hyperplane_ls_builder::is_a() || clsfy_builder_base::is_class(s);
00032 }
00033 
00034 //=======================================================================
00035 
00036 short clsfy_binary_hyperplane_ls_builder::version_no() const
00037 {
00038   return 1;
00039 }
00040 
00041 //=======================================================================
00042 
00043 void clsfy_binary_hyperplane_ls_builder::print_summary(vcl_ostream& os) const
00044 {
00045   os << is_a();
00046 }
00047 
00048 //=======================================================================
00049 
00050 //: Build a multi layer perceptron classifier, with the given data.
00051 double clsfy_binary_hyperplane_ls_builder::build(
00052   clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00053   const vcl_vector<unsigned> &outputs) const
00054 {
00055   assert(outputs.size() == inputs.size());
00056   assert(* vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00057   assert(classifier.is_class("clsfy_binary_hyperplane"));
00058 
00059   clsfy_binary_hyperplane &hyperplane = (clsfy_binary_hyperplane &) classifier;
00060 
00061   inputs.reset();
00062   const unsigned k = inputs.current().size();
00063   vnl_matrix<double> XtX(k+1, k+1, 0.0);
00064   vnl_vector<double> XtY(k+1, 0.0);
00065 
00066 #if 0 // The calculation is as follows
00067   do
00068   {
00069     // XtX += [x, -1]' * [x, -1]
00070     const vnl_vector<double> &x=inputs.current();
00071     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00072     vnl_vector<double> xp(k+1);
00073     xp.update(x, 0);
00074     xp(k) = -1.0;
00075     XtX += outer_product(xp, xp);
00076     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00077     XtY += y * xp;
00078   } while (inputs.next());
00079 #else// However the following version is faster
00080   do
00081   {
00082     // XtX += [x, -1]' * [x, -1]
00083     const vnl_vector<double> &x=inputs.current();
00084     double y = outputs[inputs.index()] ? 1.0 : -1.0;
00085     for (unsigned i=0; i<k; ++i)
00086     {
00087       for (unsigned j=0; j<i; ++j)
00088         XtX(i,j) += x(i) * x(j);
00089       XtX(i,i) += vnl_math_sqr(x(i));
00090       XtX(i,k) -= x(i);
00091       XtY(i) += y * x(i);
00092     }
00093     XtY(k) += y * -1.0;
00094 
00095   } while (inputs.next());
00096   for (unsigned i=0; i<k; ++i)
00097   {
00098     for (unsigned j=0; j<i; ++j)
00099       XtX(j,i) += XtX(i,j);
00100     XtX(k,i) = XtX(i,k);
00101   }
00102   XtX(k, k) = (double) inputs.size();
00103 #endif
00104 
00105 
00106   // Find the solution to X w = Y;
00107   // However it is easier to find X' X w = X' Y;
00108   // because X is n_train x n_elems whereas X'X is n_elems x n_elems
00109 
00110   vnl_svd<double> svd(XtX, 1.0e-12); // 1e-12 = zero-tolerance for singular values
00111   vnl_vector<double> w = svd.solve(XtY);
00112 #if 0
00113   vcl_cerr << "XtX: " << XtX << vcl_endl
00114            << "XtY: " << XtY << vcl_endl
00115            << "w: "   << w   << vcl_endl;
00116 #endif
00117   vnl_vector<double> weights(&w(0), k);
00118   hyperplane.set(weights, w(k));
00119 
00120   return clsfy_test_error(classifier, inputs, outputs);
00121 }
00122 
00123 
00124 //=======================================================================
00125 
00126 
00127 //: Build a linear classifier, with the given data.
00128 // Return the mean error over the training set.
00129 // n_classes must be 1.
00130 double clsfy_binary_hyperplane_ls_builder::build(
00131   clsfy_classifier_base &classifier, mbl_data_wrapper<vnl_vector<double> > &inputs,
00132   unsigned n_classes, const vcl_vector<unsigned> &outputs) const
00133 {
00134   assert (n_classes == 1);
00135   return build(classifier, inputs, outputs);
00136 }
00137 
00138 //=======================================================================
00139 
00140 void clsfy_binary_hyperplane_ls_builder::b_write(vsl_b_ostream &bfs) const
00141 {
00142   const int version_no=1;
00143   vsl_b_write(bfs, version_no);
00144 }
00145 
00146 //=======================================================================
00147 
00148 void clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream &bfs)
00149 {
00150   if (!bfs) return;
00151 
00152   short version;
00153   vsl_b_read(bfs,version);
00154   switch (version)
00155   {
00156     case (1):
00157       break;
00158     default:
00159       vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_ls_builder::b_read(vsl_b_istream&)\n"
00160                << "           Unknown version number "<< version << '\n';
00161       bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00162   }
00163 }

Generated on Thu Jan 8 05:11:33 2009 for contrib/mul/clsfy by  doxygen 1.5.1