00001
00002 #include "clsfy_binary_threshold_1d_builder.h"
00003
00004
00005
00006
00007
00008 #include <vcl_iostream.h>
00009 #include <vcl_string.h>
00010 #include <vcl_cassert.h>
00011 #include <vsl/vsl_binary_loader.h>
00012 #include <vnl/vnl_double_2.h>
00013 #include <clsfy/clsfy_builder_1d.h>
00014 #include <clsfy/clsfy_binary_threshold_1d.h>
00015 #include <vcl_algorithm.h>
00016
00017
00018
00019 clsfy_binary_threshold_1d_builder::clsfy_binary_threshold_1d_builder()
00020 {
00021 }
00022
00023
00024
00025 clsfy_binary_threshold_1d_builder::~clsfy_binary_threshold_1d_builder()
00026 {
00027 }
00028
00029
00030
00031 short clsfy_binary_threshold_1d_builder::version_no() const
00032 {
00033 return 1;
00034 }
00035
00036
00037
00038
00039 clsfy_classifier_1d* clsfy_binary_threshold_1d_builder::new_classifier() const
00040 {
00041 return new clsfy_binary_threshold_1d();
00042 }
00043
00044
00045
00046
00047
00048
00049
00050
00051 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00052 const vnl_vector<double>& egs,
00053 const vnl_vector<double>& wts,
00054 const vcl_vector<unsigned> &outputs) const
00055 {
00056
00057 assert(classifier.is_class("clsfy_mean_square_1d"));
00058
00059 unsigned int n= egs.size();
00060 assert ( wts.size() == n );
00061 assert ( outputs.size() == n );
00062
00063
00064 vcl_vector<vbl_triple<double,int,int> > data;
00065
00066 vbl_triple<double,int,int> t;
00067
00068 for (unsigned int i=0;i<n;++i)
00069 {
00070 t.first=egs(i);
00071 t.second= outputs[i];
00072 t.third = i;
00073 data.push_back(t);
00074 }
00075
00076 vbl_triple<double,int,int> *data_ptr=&data[0];
00077 vcl_sort(data_ptr,data_ptr+n);
00078 return build_from_sorted_data(classifier, &data[0], wts);
00079 }
00080
00081
00082
00083
00084
00085 double clsfy_binary_threshold_1d_builder::build(clsfy_classifier_1d& classifier,
00086 vnl_vector<double>& egs0,
00087 vnl_vector<double>& wts0,
00088 vnl_vector<double>& egs1,
00089 vnl_vector<double>& wts1) const
00090 {
00091
00092 assert(classifier.is_class("clsfy_binary_threshold_1d"));
00093
00094 vcl_vector<vbl_triple<double,int,int> > data;
00095 unsigned int n0 = egs0.size();
00096 unsigned int n1 = egs1.size();
00097 vnl_vector<double> wts(n0+n1);
00098 vbl_triple<double,int,int> t;
00099
00100 for (unsigned int i=0;i<n0;++i)
00101 {
00102 t.first=egs0[i];
00103 t.second=0;
00104 t.third = i;
00105 wts(i)= wts0[i];
00106 data.push_back(t);
00107 }
00108
00109
00110 for (unsigned int i=0;i<n1;++i)
00111 {
00112 t.first=egs1[i];
00113 t.second=1;
00114 t.third = i+n0;
00115 wts(i+n0)= wts1[i];
00116 data.push_back(t);
00117 }
00118
00119 unsigned int n=n0+n1;
00120
00121 vbl_triple<double,int,int> *data_ptr=&data[0];
00122 vcl_sort(data_ptr,data_ptr+n);
00123
00124 return build_from_sorted_data(classifier,&data[0], wts);
00125 }
00126
00127
00128
00129
00130 double clsfy_binary_threshold_1d_builder::build_from_sorted_data(
00131 clsfy_classifier_1d& classifier,
00132 const vbl_triple<double,int,int> *data,
00133 const vnl_vector<double>& wts
00134 ) const
00135 {
00136
00137
00138
00139
00140
00141
00142
00143 unsigned int n=wts.size();
00144 double tot_wts0=0.0, tot_wts1=0.0;
00145 for (unsigned int i=0;i<n;++i)
00146 if (data[i].second==0)
00147 tot_wts0+=wts(data[i].third);
00148 else
00149 tot_wts1+=wts(data[i].third);
00150
00151 double e0=0.0, e1=0.0, min_err=2.0;
00152 double etot0,etot1;
00153 unsigned int index=n; int polarity=0;
00154 for (unsigned int i=0;i<n;++i)
00155 {
00156 if (data[i].second==0)
00157 e0+=wts(data[i].third);
00158 else
00159 e1+=wts(data[i].third);
00160
00161 etot0=(tot_wts0-e0) +e1;
00162 etot1=(tot_wts1-e1) +e0;
00163
00164 if ( etot0< min_err)
00165 {
00166
00167
00168 polarity=+1;
00169 index=i;
00170
00171 min_err= etot0;
00172 }
00173
00174 if ( etot1< min_err)
00175 {
00176
00177
00178 polarity=-1;
00179 index=i;
00180
00181 min_err= etot1;
00182 }
00183 }
00184
00185 assert ( index!=n );
00186
00187
00188 double threshold;
00189 if ( index+1==n )
00190 threshold=data[index].first+0.01;
00191 else
00192 threshold=(data[index].first+data[index+1].first)/2;
00193
00194
00195 vnl_double_2 params(polarity, threshold*polarity);
00196 classifier.set_params(params.as_vector());
00197 return min_err;
00198 }
00199
00200
00201
00202 vcl_string clsfy_binary_threshold_1d_builder::is_a() const
00203 {
00204 return vcl_string("clsfy_binary_threshold_1d_builder");
00205 }
00206
00207 bool clsfy_binary_threshold_1d_builder::is_class(vcl_string const& s) const
00208 {
00209 return s == clsfy_binary_threshold_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
00210 }
00211
00212
00213
00214 #if 0
00215
00216
00217 clsfy_binary_threshold_1d_builder::clsfy_binary_threshold_1d_builder(
00218 const clsfy_binary_threshold_1d_builder& new_b) :
00219 data_ptr_(0)
00220 {
00221 *this = new_b;
00222 }
00223
00224
00225
00226
00227 clsfy_binary_threshold_1d_builder&
00228 clsfy_binary_threshold_1d_builder::operator=(const clsfy_binary_threshold_1d_builder& new_b)
00229 {
00230 if (&new_b==this) return *this;
00231
00232
00233 delete data_ptr_; data_ptr_=0;
00234
00235 if (new_b.data_ptr_)
00236 data_ptr_ = new_b.data_ptr_->clone();
00237
00238
00239 data_ = new_b.data_;
00240
00241 return *this;
00242 }
00243
00244 #endif // 0
00245
00246
00247
00248 clsfy_builder_1d* clsfy_binary_threshold_1d_builder::clone() const
00249 {
00250 return new clsfy_binary_threshold_1d_builder(*this);
00251 }
00252
00253
00254
00255
00256 void clsfy_binary_threshold_1d_builder::print_summary(vcl_ostream& ) const
00257 {
00258
00259
00260
00261 vcl_cerr << "clsfy_binary_threshold_1d_builder::print_summary() NYI\n";
00262 }
00263
00264
00265
00266
00267 void clsfy_binary_threshold_1d_builder::b_write(vsl_b_ostream& ) const
00268 {
00269
00270
00271
00272 vcl_cerr << "clsfy_binary_threshold_1d_builder::b_write() NYI\n";
00273 }
00274
00275
00276
00277
00278 void clsfy_binary_threshold_1d_builder::b_read(vsl_b_istream& )
00279 {
00280 vcl_cerr << "clsfy_binary_threshold_1d_builder::b_read() NYI\n";
00281 #if 0
00282 if (!bfs) return;
00283
00284 short version;
00285 vsl_b_read(bfs,version);
00286 switch (version)
00287 {
00288 case (1):
00289
00290 vsl_b_read(bfs,data_);
00291 break;
00292 default:
00293 vcl_cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_binary_threshold_1d_builder&)\n"
00294 << " Unknown version number "<< version << '\n';
00295 bfs.is().clear(vcl_ios::badbit);
00296 return;
00297 }
00298 #endif // 0
00299 }