00001
00002 #include "clsfy_binary_hyperplane_gmrho_builder.h"
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vnl/vnl_vector_ref.h>
00019 #include <vnl/algo/vnl_lbfgs.h>
00020
00021
00022
00023 namespace clsfy_binary_hyperplane_gmrho_builder_helpers
00024 {
00025
00026 class gmrho_sum : public vnl_cost_function
00027 {
00028
00029 const vnl_matrix<double>& x_;
00030
00031 const vnl_vector<double>& y_;
00032
00033 double sigma_;
00034
00035 double var_;
00036
00037 unsigned num_examples_;
00038
00039 unsigned num_vars_;
00040
00041 double alpha_;
00042
00043 double beta_;
00044 public:
00045
00046 gmrho_sum(const vnl_matrix<double>& x,
00047 const vnl_vector<double>& y,double sigma=1);
00048
00049
00050 void set_sigma(double sigma);
00051
00052
00053 virtual double f(vnl_vector<double> const& w);
00054
00055
00056 virtual void gradf(vnl_vector<double> const& x, vnl_vector<double>& gradient);
00057 };
00058
00059
00060 class gm_grad_accum
00061 {
00062 const double* px_;
00063 const double wt_;
00064 public:
00065 gm_grad_accum(const double* px,double wt) : px_(px),wt_(wt) {}
00066 void operator()(double& grad)
00067 {
00068 grad += (*px_++) * wt_;
00069 }
00070 };
00071
00072
00073 class category_value
00074 {
00075 const double y0;
00076 const double y1;
00077 public:
00078 category_value(unsigned num_category1,unsigned num_total):
00079 y0(-1.0*double(num_total-num_category1)/double(num_total)),
00080 y1(double(num_category1)/double(num_total)) {}
00081
00082 double operator()(const unsigned& classNum)
00083 {
00084
00085 return classNum ? 1.0 : -1.0;
00086 }
00087 };
00088 };
00089
00090
00091
00092
00093
00094
00095
00096 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00097 mbl_data_wrapper<vnl_vector<double> >& inputs,
00098 unsigned n_classes,
00099 const vcl_vector<unsigned>& outputs) const
00100 {
00101 assert (n_classes == 1);
00102 return clsfy_binary_hyperplane_gmrho_builder::build(classifier, inputs, outputs);
00103 }
00104
00105
00106
00107
00108 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00109 mbl_data_wrapper<vnl_vector<double> >& inputs,
00110 const vcl_vector<unsigned>& outputs) const
00111 {
00112 using clsfy_binary_hyperplane_gmrho_builder_helpers::category_value;
00113
00114
00115 clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00116
00117 num_examples_ = inputs.size();
00118 if (num_examples_ == 0)
00119 {
00120 vcl_cerr<<"WARNING - clsfy_binary_hyperplane_gmrho_builder::build called with no data\n";
00121 return 0.0;
00122 }
00123
00124
00125 inputs.reset();
00126 num_vars_ = inputs.current().size();
00127 vnl_matrix<double> data(num_examples_,num_vars_,0.0);
00128 unsigned i=0;
00129 do
00130 {
00131 double* row=data[i++];
00132 vcl_copy(inputs.current().begin(),inputs.current().end(),row);
00133 } while (inputs.next());
00134
00135
00136 vnl_vector<double> y(num_examples_,0.0);
00137 vcl_transform(outputs.begin(),outputs.end(),
00138 y.begin(),
00139 category_value(vcl_count(outputs.begin(),outputs.end(),1u),outputs.size()));
00140 weights_.set_size(num_vars_+1);
00141
00142
00143 clsfy_binary_hyperplane& hyperplane = dynamic_cast<clsfy_binary_hyperplane &>(classifier);
00144
00145 weights_.update(hyperplane.weights(),0);
00146 weights_[num_vars_] = hyperplane.bias();
00147
00148
00149 double sigma_scale_target = sigma_preset_;
00150 if (auto_estimate_sigma_)
00151 sigma_scale_target=estimate_sigma(data,y);
00152
00153
00154
00155 double kappa = 5.0;
00156 const double alpha_anneal=0.75;
00157
00158 int N = 1+int(vcl_log(1.1/kappa)/vcl_log(alpha_anneal));
00159 if (N<1) N=1;
00160 double sigma_scale = kappa * sigma_scale_target;
00161
00162 epsilon_ = 1.0E-4;
00163 for (int ianneal=0;ianneal<N;++ianneal)
00164 {
00165
00166 determine_weights(data,y,sigma_scale);
00167
00168 sigma_scale *= alpha_anneal;
00169 }
00170
00171 epsilon_ = 1.0E-8;
00172
00173
00174
00175
00176 for (unsigned iter=0; iter<(auto_estimate_sigma_ ? 2u : 1u); ++iter)
00177 {
00178 if (auto_estimate_sigma_)
00179 sigma_scale_target=estimate_sigma(data,y);
00180 else
00181 sigma_scale_target = sigma_preset_;
00182
00183 determine_weights(data,y,sigma_scale_target);
00184 }
00185
00186 vnl_vector_ref<double > weights(num_vars_,weights_.data_block());
00187 hyperplane.set(weights, weights_[num_vars_]);
00188
00189 return clsfy_test_error(classifier, inputs, outputs);
00190 }
00191
00192 void clsfy_binary_hyperplane_gmrho_builder::determine_weights(const vnl_matrix<double>& data,
00193 const vnl_vector<double >& y,
00194 double sigma) const
00195 {
00196
00197
00198 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum costFn(data,y,sigma);
00199
00200
00201 vnl_lbfgs cgMinimiser(costFn);
00202
00203 cgMinimiser.set_f_tolerance(epsilon_);
00204 cgMinimiser.set_x_tolerance(epsilon_);
00205
00206 cgMinimiser.minimize(weights_);
00207 }
00208
00209 double clsfy_binary_hyperplane_gmrho_builder::estimate_sigma(const vnl_matrix<double>& data,
00210 const vnl_vector<double >& y) const
00211 {
00212
00213
00214
00215
00216 vcl_vector<double > falsePosScores;
00217 vcl_vector<double > falseNegScores;
00218
00219 double b=weights_[num_vars_];
00220 for (unsigned i=0; i<num_examples_;++i)
00221 {
00222 const double* px=data[i];
00223 double yval = y[i];
00224 double ypred = vcl_inner_product(px,px+num_vars_,weights_.begin(),0.0) - b ;
00225 if (yval>0.0)
00226 {
00227 if (ypred<0.0)
00228 {
00229 falseNegScores.push_back(vcl_fabs(ypred));
00230 }
00231 }
00232 else
00233 {
00234 if (ypred>0.0)
00235 {
00236 falsePosScores.push_back(vcl_fabs(ypred));
00237 }
00238 }
00239 }
00240 double sigma=1.0;
00241 double delta0=0.0;
00242 if (!falsePosScores.empty())
00243 {
00244 vcl_vector<double >::iterator medianIter=falsePosScores.begin() + falsePosScores.size()/2;
00245 vcl_nth_element(falsePosScores.begin(),medianIter,falsePosScores.end());
00246 delta0 = (*medianIter);
00247 }
00248 double delta1=0.0;
00249 if (!falseNegScores.empty())
00250 {
00251 vcl_vector<double >::iterator medianIter=falseNegScores.begin() + falseNegScores.size()/2;
00252 vcl_nth_element(falseNegScores.begin(),medianIter,falseNegScores.end());
00253 delta1 = (*medianIter);
00254 }
00255 sigma += vcl_max(delta0,delta1);
00256
00257 sigma *= vcl_sqrt(3.0);
00258 return sigma;
00259 }
00260
00261
00262
00263 void clsfy_binary_hyperplane_gmrho_builder::b_write(vsl_b_ostream &bfs) const
00264 {
00265 const int version_no=1;
00266 vsl_b_write(bfs, version_no);
00267 clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00268 }
00269
00270
00271
00272 void clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream &bfs)
00273 {
00274 if (!bfs) return;
00275
00276 short version;
00277 vsl_b_read(bfs,version);
00278 switch (version)
00279 {
00280 case (1):
00281 clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00282 break;
00283 default:
00284 vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream&)\n"
00285 << " Unknown version number "<< version << '\n';
00286 bfs.is().clear(vcl_ios::badbit);
00287 }
00288 }
00289
00290
00291
00292 vcl_string clsfy_binary_hyperplane_gmrho_builder::is_a() const
00293 {
00294 return vcl_string("clsfy_binary_hyperplane_gmrho_builder");
00295 }
00296
00297
00298
00299 bool clsfy_binary_hyperplane_gmrho_builder::is_class(vcl_string const& s) const
00300 {
00301 return s == clsfy_binary_hyperplane_gmrho_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00302 }
00303
00304
00305
00306 short clsfy_binary_hyperplane_gmrho_builder::version_no() const
00307 {
00308 return 1;
00309 }
00310
00311
00312
00313 void clsfy_binary_hyperplane_gmrho_builder::print_summary(vcl_ostream& os) const
00314 {
00315 os << is_a();
00316 }
00317
00318
00319 clsfy_builder_base* clsfy_binary_hyperplane_gmrho_builder::clone() const
00320 {
00321 return new clsfy_binary_hyperplane_gmrho_builder(*this);
00322 }
00323
00324
00325
00326
00327
00328
00329
00330 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gmrho_sum(const vnl_matrix<double>& x,
00331 const vnl_vector<double>& y,
00332 double sigma):
00333 vnl_cost_function(x.cols()+1),
00334 x_(x),y_(y),sigma_(1.0),var_(1.0),num_examples_(x.rows()),num_vars_(x.cols())
00335 {
00336 set_sigma(sigma);
00337 }
00338
00339 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::set_sigma(double sigma)
00340 {
00341 sigma_ = sigma;
00342 var_ = sigma*sigma;
00343 double s=1.0+var_;
00344 s = s*s;
00345 alpha_ = var_/s;
00346 beta_ = 1.0/s;
00347 }
00348
00349
00350
00351 double clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::f(vnl_vector<double> const& w)
00352 {
00353
00354 double sum=0.0;
00355 double b=w[num_vars_];
00356 for (unsigned i=0; i<num_examples_;++i)
00357 {
00358 const double* px=x_[i];
00359 double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00360 double e = y_[i] - pred;
00361 double e2 = e*e;
00362 if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00363 ((y_[i] < 0.0) && (e >= -1.0)) )
00364 {
00365
00366
00367 sum += e2/(e2+var_);
00368 }
00369 else
00370 {
00371
00372
00373 sum += alpha_*e2 + beta_;
00374 }
00375 }
00376 return sum;
00377 }
00378
00379
00380 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gradf(vnl_vector<double> const& w,
00381 vnl_vector<double>& gradient)
00382 {
00383 using clsfy_binary_hyperplane_gmrho_builder_helpers::gm_grad_accum;
00384 double b=w[num_vars_];
00385 gradient.fill(0.0);
00386
00387 for (unsigned i=0; i<num_examples_;++i)
00388 {
00389 const double* px=x_[i];
00390 double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00391
00392 double e = y_[i] - pred;
00393 double e2 = e*e;
00394 double wt=1.0;
00395 if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00396 ((y_[i] < 0.0) && (e >= -1.0)) )
00397 {
00398 wt = e2 + var_;
00399 }
00400 else
00401 {
00402
00403 wt = 1.0 + var_;
00404 }
00405
00406 double wtInv = -e/(wt*wt);
00407 vcl_for_each(gradient.begin(),gradient.begin()+num_vars_,
00408 gm_grad_accum(px,wtInv));
00409
00410 gradient[num_vars_] += (-wtInv);
00411 }
00412
00413 vcl_transform(gradient.begin(),gradient.end(),gradient.begin(),
00414 vcl_bind2nd(vcl_multiplies<double>(),2.0*var_));
00415 }