contrib/mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx

Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_rbf_svm_smo_1_builder.cxx
00002 // Copyright: (C) 2001 British Telecommunications plc.
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004 //:
00005 // \file
00006 // \brief Implement an interface to SMO algorithm SVM builder and additional logic
00007 // \author Ian Scott
00008 // \date Dec 2001
00009 
00010 //=======================================================================
00011 
00012 #include <clsfy/clsfy_smo_1.h>
00013 #include <vcl_string.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_cassert.h>
00017 
00018 #include <mbl/mbl_data_wrapper.h>
00019 
00020 //=======================================================================
00021 
00022 clsfy_rbf_svm_smo_1_builder::clsfy_rbf_svm_smo_1_builder()
00023 {
00024   boundC_ = 0;
00025   rbf_width_ = 1.0;
00026 }
00027 
00028 //=======================================================================
00029 
00030 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;};
00031 
00032 //=======================================================================
00033 //: Build classifier from data
00034 // returns the training error, or +INF if there is an error.
00035 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00036                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00037                                           const vcl_vector<unsigned> &outputs) const
00038 {
00039   inputs.reset();
00040 //const unsigned int nDims = inputs.current().size(); // unused variable
00041   const unsigned int nSamples = inputs.size();
00042   assert(outputs.size() == nSamples);
00043   assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00044 
00045   assert(classifier.is_class("clsfy_rbf_svm"));
00046   clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00047 
00048   clsfy_smo_1_rbf svAPI;
00049   vcl_vector<int> targets(nSamples);
00050   vcl_transform(outputs.begin(), outputs.end(),
00051     targets.begin(), class_to_svm_target);
00052 
00053   svAPI.set_data(inputs, targets);
00054 
00055 
00056   // Set the SVM solver parameters
00057   svAPI.set_C(boundC_);
00058   svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00059   // Solve the SVM
00060   svAPI.calc();
00061 
00062 
00063   // Get the SVM description, and build an SVM machine
00064   {
00065     vcl_vector<vnl_vector<double> > supportVectors;
00066     const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00067     vcl_vector<double> alphas;
00068     vcl_vector<unsigned> labels;
00069     for (unsigned i=0; i<nSamples; ++i)
00070       if (allAlphas[i]!=0.0)
00071       {
00072         alphas.push_back(allAlphas[i]);
00073         labels.push_back(outputs[i]);
00074         inputs.set_index(i);
00075         supportVectors.push_back(inputs.current());
00076       }
00077     svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00078   }
00079 
00080   return svAPI.error_rate();
00081 }
00082 
00083 //=======================================================================
00084 //: Build classifier from data.
00085 // returns the training error, or +INF if there is an error.
00086 // nClasses must be 1.
00087 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00088                                           mbl_data_wrapper<vnl_vector<double> >& inputs,
00089                                           unsigned nClasses,
00090                                           const vcl_vector<unsigned> &outputs) const
00091 {
00092   assert(nClasses == 1);
00093   return build(classifier, inputs, outputs);
00094 }
00095 
00096 //=======================================================================
00097 
00098 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00099 {
00100   return rbf_width_;
00101 }
00102 
00103 //=======================================================================
00104 
00105 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00106 {
00107   rbf_width_ = rbf_width;
00108 }
00109 //=======================================================================
00110 
00111 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00112 {
00113   return vcl_string("clsfy_rbf_svm_smo_1_builder");
00114 }
00115 
00116 //=======================================================================
00117 
00118 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00119 {
00120   return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00121 }
00122 
00123 //=======================================================================
00124 
00125 short clsfy_rbf_svm_smo_1_builder::version_no() const
00126 {
00127   return 1;
00128 }
00129 
00130 //=======================================================================
00131 
00132 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00133 {
00134   return new clsfy_rbf_svm_smo_1_builder(*this);
00135 }
00136 
00137 //=======================================================================
00138 
00139 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00140 {
00141   // os << data_; // example of data output
00142   os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00143 }
00144 
00145 //=======================================================================
00146 
00147 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00148 {
00149   vsl_b_write(bfs,version_no());
00150   vsl_b_write(bfs,boundC_);
00151   vsl_b_write(bfs,rbf_width_);
00152 }
00153 
00154 //=======================================================================
00155 
00156 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00157 {
00158   if (!bfs) return;
00159 
00160   short version;
00161   vsl_b_read(bfs,version);
00162   switch (version)
00163   {
00164   case (1):
00165     vsl_b_read(bfs,boundC_);
00166     vsl_b_read(bfs,rbf_width_);
00167     break;
00168   default:
00169     vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00170              << "           Unknown version number "<< version << "\n";
00171     bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00172     return;
00173   }
00174 }

Generated on Thu Jan 8 05:11:34 2009 for contrib/mul/clsfy by  doxygen 1.5.1