00001
00002
00003 #include "clsfy_rbf_svm_smo_1_builder.h"
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <clsfy/clsfy_smo_1.h>
00013 #include <vcl_string.h>
00014 #include <vcl_vector.h>
00015 #include <vcl_algorithm.h>
00016 #include <vcl_cassert.h>
00017
00018 #include <mbl/mbl_data_wrapper.h>
00019
00020
00021
00022 clsfy_rbf_svm_smo_1_builder::clsfy_rbf_svm_smo_1_builder()
00023 {
00024 boundC_ = 0;
00025 rbf_width_ = 1.0;
00026 }
00027
00028
00029
00030 inline int class_to_svm_target (unsigned v) {return v==1?1:-1;};
00031
00032
00033
00034
00035 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00036 mbl_data_wrapper<vnl_vector<double> >& inputs,
00037 const vcl_vector<unsigned> &outputs) const
00038 {
00039 inputs.reset();
00040
00041 const unsigned int nSamples = inputs.size();
00042 assert(outputs.size() == nSamples);
00043 assert(*vcl_max_element(outputs.begin(), outputs.end()) <= 1);
00044
00045 assert(classifier.is_class("clsfy_rbf_svm"));
00046 clsfy_rbf_svm &svm = static_cast<clsfy_rbf_svm &>(classifier);
00047
00048 clsfy_smo_1_rbf svAPI;
00049 vcl_vector<int> targets(nSamples);
00050 vcl_transform(outputs.begin(), outputs.end(),
00051 targets.begin(), class_to_svm_target);
00052
00053 svAPI.set_data(inputs, targets);
00054
00055
00056
00057 svAPI.set_C(boundC_);
00058 svAPI.set_gamma(1.0/(2.0*rbf_width_*rbf_width_));
00059
00060 svAPI.calc();
00061
00062
00063
00064 {
00065 vcl_vector<vnl_vector<double> > supportVectors;
00066 const vnl_vector<double> &allAlphas = svAPI.lagrange_mults();
00067 vcl_vector<double> alphas;
00068 vcl_vector<unsigned> labels;
00069 for (unsigned i=0; i<nSamples; ++i)
00070 if (allAlphas[i]!=0.0)
00071 {
00072 alphas.push_back(allAlphas[i]);
00073 labels.push_back(outputs[i]);
00074 inputs.set_index(i);
00075 supportVectors.push_back(inputs.current());
00076 }
00077 svm.set(supportVectors, alphas, labels, rbf_width_, svAPI.bias());
00078 }
00079
00080 return svAPI.error_rate();
00081 }
00082
00083
00084
00085
00086
00087 double clsfy_rbf_svm_smo_1_builder::build(clsfy_classifier_base& classifier,
00088 mbl_data_wrapper<vnl_vector<double> >& inputs,
00089 unsigned nClasses,
00090 const vcl_vector<unsigned> &outputs) const
00091 {
00092 assert(nClasses == 1);
00093 return build(classifier, inputs, outputs);
00094 }
00095
00096
00097
00098 double clsfy_rbf_svm_smo_1_builder::rbf_width() const
00099 {
00100 return rbf_width_;
00101 }
00102
00103
00104
00105 void clsfy_rbf_svm_smo_1_builder::set_rbf_width(double rbf_width)
00106 {
00107 rbf_width_ = rbf_width;
00108 }
00109
00110
00111 vcl_string clsfy_rbf_svm_smo_1_builder::is_a() const
00112 {
00113 return vcl_string("clsfy_rbf_svm_smo_1_builder");
00114 }
00115
00116
00117
00118 bool clsfy_rbf_svm_smo_1_builder::is_class(vcl_string const& s) const
00119 {
00120 return s == clsfy_rbf_svm_smo_1_builder::is_a() || clsfy_builder_base::is_class(s);
00121 }
00122
00123
00124
00125 short clsfy_rbf_svm_smo_1_builder::version_no() const
00126 {
00127 return 1;
00128 }
00129
00130
00131
00132 clsfy_builder_base* clsfy_rbf_svm_smo_1_builder::clone() const
00133 {
00134 return new clsfy_rbf_svm_smo_1_builder(*this);
00135 }
00136
00137
00138
00139 void clsfy_rbf_svm_smo_1_builder::print_summary(vcl_ostream& os) const
00140 {
00141
00142 os << "RBF width = " << rbf_width_ << ", bounds = " << boundC_;
00143 }
00144
00145
00146
00147 void clsfy_rbf_svm_smo_1_builder::b_write(vsl_b_ostream& bfs) const
00148 {
00149 vsl_b_write(bfs,version_no());
00150 vsl_b_write(bfs,boundC_);
00151 vsl_b_write(bfs,rbf_width_);
00152 }
00153
00154
00155
00156 void clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream& bfs)
00157 {
00158 if (!bfs) return;
00159
00160 short version;
00161 vsl_b_read(bfs,version);
00162 switch (version)
00163 {
00164 case (1):
00165 vsl_b_read(bfs,boundC_);
00166 vsl_b_read(bfs,rbf_width_);
00167 break;
00168 default:
00169 vcl_cerr << "I/O ERROR: clsfy_rbf_svm_smo_1_builder::b_read(vsl_b_istream&)\n"
00170 << " Unknown version number "<< version << "\n";
00171 bfs.is().clear(vcl_ios::badbit);
00172 return;
00173 }
00174 }