contrib/mul/clsfy/clsfy_binary_hyperplane_gmrho_builder.cxx

Go to the documentation of this file.
00001 // This is mul/clsfy/clsfy_binary_hyperplane_gmrho_builder.cxx
00002 #include "clsfy_binary_hyperplane_gmrho_builder.h"
00003 //:
00004 // \file
00005 // \brief Implement a two-class output linear classifier builder using a Geman-McClure robust error function
00006 // \author Martin Roberts
00007 // \date 4 Nov 2006
00008 
00009 //=======================================================================
00010 
00011 #include <vcl_string.h>
00012 #include <vcl_iostream.h>
00013 #include <vcl_vector.h>
00014 #include <vcl_cassert.h>
00015 #include <vcl_cmath.h>
00016 #include <vcl_algorithm.h>
00017 #include <vcl_numeric.h>
00018 #include <vnl/vnl_vector_ref.h>
00019 #include <vnl/algo/vnl_lbfgs.h>
00020 
00021 
00022 //: Some helper stuff, like the error function to be minimised
00023 namespace clsfy_binary_hyperplane_gmrho_builder_helpers
00024 {
00025     //: The cost function, sum Geman-McClure error functions over all training examples
00026     class gmrho_sum : public vnl_cost_function
00027     {
00028         //: Reference to data matrix, one row per training example
00029         const vnl_matrix<double>& x_;
00030         //: Reference to required outputs
00031         const vnl_vector<double>& y_;
00032         //: Scale factor used in Geman-McClure error function
00033         double sigma_;
00034         //: sigma squared
00035         double var_;
00036         //: Number of training examples (x_.rows())
00037         unsigned num_examples_;
00038         //: Number of dimensions (x_.cols())
00039         unsigned num_vars_;
00040         //: var_/(1+var_)^2 - ensures continuity of derivative at hyperplane boundary
00041         double alpha_;
00042         //: 1/(1+var_)^2 - with alpha, ensures continuity of function at hyperplane boundary
00043         double beta_;
00044       public:
00045         //: construct passing in reference to data matrix
00046         gmrho_sum(const vnl_matrix<double>& x,
00047                   const vnl_vector<double>& y,double sigma=1);
00048 
00049         //: reset the scaling factor
00050         void set_sigma(double sigma);
00051 
00052         //:  The main function.  Given the vector of weights parameters vector , compute the value of f(x).
00053         virtual double f(vnl_vector<double> const& w);
00054 
00055         //:  Calculate the gradient of f at parameter vector x.
00056         virtual void gradf(vnl_vector<double> const& x, vnl_vector<double>& gradient);
00057     };
00058 
00059     //: functor to accumulate gradient contributions for given training example
00060     class gm_grad_accum
00061     {
00062         const double* px_;
00063         const double wt_;
00064       public:
00065         gm_grad_accum(const double* px,double wt) : px_(px),wt_(wt) {}
00066         void operator()(double& grad)
00067         {
00068             grad += (*px_++) * wt_;
00069         }
00070     };
00071 
00072     //: Given the class category variable, return the associated regression value (e.g. 1 for class 1, -1 for class 0)
00073     class category_value
00074     {
00075         const double y0;
00076         const double y1;
00077       public:
00078         category_value(unsigned num_category1,unsigned num_total):
00079             y0(-1.0*double(num_total-num_category1)/double(num_total)),
00080             y1(double(num_category1)/double(num_total)) {}
00081 
00082         double operator()(const unsigned& classNum)
00083         {
00084             //return classNum ? y1 : y0;
00085             return classNum ? 1.0 : -1.0;
00086         }
00087     };
00088 };
00089 
00090 //-----------------------------------------------------------------------------------------------
00091 //------------------------ The builder member functions ------------------------------------------
00092 //------------------------------------------------------------------------------------------------
00093 //: Build a linear classifier, with the given data.
00094 // Return the mean error over the training set.
00095 // n_classes must be 1.
00096 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00097                                                     mbl_data_wrapper<vnl_vector<double> >& inputs,
00098                                                     unsigned n_classes,
00099                                                     const vcl_vector<unsigned>& outputs) const
00100 {
00101     assert (n_classes == 1);
00102     return clsfy_binary_hyperplane_gmrho_builder::build(classifier, inputs, outputs);
00103 }
00104 
00105 //: Build a linear hyperplane classifier with the given data.
00106 // Reduce the influence of well classified points far into their correct region by
00107 // applying a Geman-McClure robust error function, rather than a least squares fit
00108 double clsfy_binary_hyperplane_gmrho_builder::build(clsfy_classifier_base& classifier,
00109                                                     mbl_data_wrapper<vnl_vector<double> >& inputs,
00110                                                     const vcl_vector<unsigned>& outputs) const
00111 {
00112     using clsfy_binary_hyperplane_gmrho_builder_helpers::category_value;
00113 
00114     //First let the base class get us a starting solution
00115     clsfy_binary_hyperplane_ls_builder::build( classifier,inputs,outputs);
00116     //Extract the data into a matrix
00117     num_examples_ = inputs.size();
00118     if (num_examples_ == 0)
00119     {
00120         vcl_cerr<<"WARNING - clsfy_binary_hyperplane_gmrho_builder::build called with no data\n";
00121         return 0.0;
00122     }
00123 
00124     //Now copy from the urggghh data wrapper into a sensible data structure (matrix!)
00125     inputs.reset();
00126     num_vars_ = inputs.current().size();
00127     vnl_matrix<double> data(num_examples_,num_vars_,0.0);
00128     unsigned i=0;
00129     do
00130     {
00131         double* row=data[i++];
00132         vcl_copy(inputs.current().begin(),inputs.current().end(),row);
00133     } while (inputs.next());
00134 
00135     //Set up category regression values determined by output class
00136     vnl_vector<double> y(num_examples_,0.0);
00137     vcl_transform(outputs.begin(),outputs.end(),
00138                   y.begin(),
00139                   category_value(vcl_count(outputs.begin(),outputs.end(),1u),outputs.size()));
00140     weights_.set_size(num_vars_+1);
00141 
00142     //Initialise the weights using the standard least squares fit of my base class
00143     clsfy_binary_hyperplane& hyperplane = dynamic_cast<clsfy_binary_hyperplane &>(classifier);
00144 
00145     weights_.update(hyperplane.weights(),0);
00146     weights_[num_vars_] = hyperplane.bias();
00147 
00148     //Estimate the scaling factor used in the Geman-McClure function
00149     double sigma_scale_target = sigma_preset_;
00150     if (auto_estimate_sigma_)
00151         sigma_scale_target=estimate_sigma(data,y);
00152 
00153     //To avoid local minima perform deterministic annealing starting from a large initial sigma
00154     //Set initial kappa so that everything is an inlier
00155     double kappa = 5.0;
00156     const double alpha_anneal=0.75;
00157     //Num of iterations to reduce back to 10% on top of required sigma
00158     int N = 1+int(vcl_log(1.1/kappa)/vcl_log(alpha_anneal));
00159     if (N<1) N=1;
00160     double sigma_scale = kappa * sigma_scale_target;
00161 
00162     epsilon_ = 1.0E-4; //slacken off convergence tolerance during annealing
00163     for (int ianneal=0;ianneal<N;++ianneal)
00164     {
00165         //Then do it at this sigma
00166         determine_weights(data,y,sigma_scale);
00167         //and then reduce sigma
00168         sigma_scale *= alpha_anneal;
00169     }
00170 
00171     epsilon_ = 1.0E-8; //re-impose a more precise convergence criterion
00172     //Then re-estimate sigma scale and do a final pair of iterations
00173     //as sigma depends on the mis-classification overlap depth
00174 
00175 
00176     for (unsigned iter=0; iter<(auto_estimate_sigma_ ? 2u : 1u); ++iter)
00177     {
00178         if (auto_estimate_sigma_)
00179             sigma_scale_target=estimate_sigma(data,y);
00180         else
00181             sigma_scale_target = sigma_preset_;
00182         //Finally do it at exactly the target sigma
00183         determine_weights(data,y,sigma_scale_target);
00184     }
00185     //And finally copy the parameters into the hyperplane
00186     vnl_vector_ref<double > weights(num_vars_,weights_.data_block());
00187     hyperplane.set(weights, weights_[num_vars_]);
00188 
00189     return clsfy_test_error(classifier, inputs, outputs);
00190 }
00191 
00192 void clsfy_binary_hyperplane_gmrho_builder::determine_weights(const vnl_matrix<double>& data,
00193                                                               const vnl_vector<double >& y,
00194                                                               double sigma) const
00195 {
00196     //Optimise the weights to fit the data to y
00197 
00198     clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum costFn(data,y,sigma);
00199 
00200     //minimise using the quasi-Newton lbfgs method
00201     vnl_lbfgs cgMinimiser(costFn);
00202 
00203     cgMinimiser.set_f_tolerance(epsilon_);
00204     cgMinimiser.set_x_tolerance(epsilon_);
00205 
00206     cgMinimiser.minimize(weights_);
00207 }
00208 
00209 double clsfy_binary_hyperplane_gmrho_builder::estimate_sigma(const vnl_matrix<double>& data,
00210                                                              const vnl_vector<double >& y) const
00211 {
00212     //Sigma is set to root(3) * (1+d), where d is the median distance past zero
00213     //of the misclassified values
00214     //The root(3) is because GM function reduces influence after sigma/sqrt(3)
00215 
00216     vcl_vector<double > falsePosScores;
00217     vcl_vector<double > falseNegScores;
00218 
00219     double b=weights_[num_vars_]; //constant stored as final variable
00220     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00221     {
00222         const double* px=data[i];
00223         double yval = y[i];
00224         double ypred = vcl_inner_product(px,px+num_vars_,weights_.begin(),0.0) - b ;
00225         if (yval>0.0)
00226         {
00227             if (ypred<0.0) // mis-classified false negative
00228             {
00229                 falseNegScores.push_back(vcl_fabs(ypred));
00230             }
00231         }
00232         else
00233         {
00234             if (ypred>0.0)//mis-classified false negative
00235             {
00236                 falsePosScores.push_back(vcl_fabs(ypred));
00237             }
00238         }
00239     }
00240     double sigma=1.0;
00241     double delta0=0.0;
00242     if (!falsePosScores.empty())
00243     {
00244         vcl_vector<double >::iterator medianIter=falsePosScores.begin() + falsePosScores.size()/2;
00245         vcl_nth_element(falsePosScores.begin(),medianIter,falsePosScores.end());
00246         delta0 = (*medianIter);
00247     }
00248     double delta1=0.0;
00249     if (!falseNegScores.empty())
00250     {
00251         vcl_vector<double >::iterator medianIter=falseNegScores.begin() + falseNegScores.size()/2;
00252         vcl_nth_element(falseNegScores.begin(),medianIter,falseNegScores.end());
00253         delta1 = (*medianIter);
00254     }
00255     sigma += vcl_max(delta0,delta1);
00256 
00257     sigma *= vcl_sqrt(3.0);
00258     return sigma;
00259 }
00260 
00261 //=======================================================================
00262 
00263 void clsfy_binary_hyperplane_gmrho_builder::b_write(vsl_b_ostream &bfs) const
00264 {
00265     const int version_no=1;
00266     vsl_b_write(bfs, version_no);
00267     clsfy_binary_hyperplane_ls_builder::b_write(bfs);
00268 }
00269 
00270 //=======================================================================
00271 
00272 void clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream &bfs)
00273 {
00274     if (!bfs) return;
00275 
00276     short version;
00277     vsl_b_read(bfs,version);
00278     switch (version)
00279     {
00280         case (1):
00281             clsfy_binary_hyperplane_ls_builder::b_read(bfs);
00282             break;
00283         default:
00284             vcl_cerr << "I/O ERROR: clsfy_binary_hyperplane_gmrho_builder::b_read(vsl_b_istream&)\n"
00285                      << "           Unknown version number "<< version << '\n';
00286             bfs.is().clear(vcl_ios::badbit); // Set an unrecoverable IO error on stream
00287     }
00288 }
00289 
00290 //=======================================================================
00291 
00292 vcl_string clsfy_binary_hyperplane_gmrho_builder::is_a() const
00293 {
00294     return vcl_string("clsfy_binary_hyperplane_gmrho_builder");
00295 }
00296 
00297 //=======================================================================
00298 
00299 bool clsfy_binary_hyperplane_gmrho_builder::is_class(vcl_string const& s) const
00300 {
00301     return s == clsfy_binary_hyperplane_gmrho_builder::is_a() || clsfy_binary_hyperplane_ls_builder::is_class(s);
00302 }
00303 
00304 //=======================================================================
00305 
00306 short clsfy_binary_hyperplane_gmrho_builder::version_no() const
00307 {
00308     return 1;
00309 }
00310 
00311 //=======================================================================
00312 
00313 void clsfy_binary_hyperplane_gmrho_builder::print_summary(vcl_ostream& os) const
00314 {
00315     os << is_a();
00316 }
00317 
00318 //=======================================================================
00319 clsfy_builder_base* clsfy_binary_hyperplane_gmrho_builder::clone() const
00320 {
00321     return new clsfy_binary_hyperplane_gmrho_builder(*this);
00322 }
00323 
00324 //---------------------------------------------------------------------------------------------
00325 //: The error function class
00326 //  This returns a geman-mcclure robust function if the point is correctly classified
00327 // Otherwise the squared error is returned, with coefficient and offset to ensure continuity
00328 // and smoothness at the join
00329 //---------------------------------------------------------------------------------------------
00330 clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gmrho_sum(const vnl_matrix<double>& x,
00331                                                                     const vnl_vector<double>& y,
00332                                                                     double sigma):
00333         vnl_cost_function(x.cols()+1),
00334         x_(x),y_(y),sigma_(1.0),var_(1.0),num_examples_(x.rows()),num_vars_(x.cols())
00335 {
00336     set_sigma(sigma);
00337 }
00338 
00339 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::set_sigma(double sigma)
00340 {
00341     sigma_ = sigma;
00342     var_ = sigma*sigma;
00343     double s=1.0+var_;
00344     s = s*s;
00345     alpha_ = var_/s;
00346     beta_ = 1.0/s;
00347 }
00348 
00349 
00350 //: Return the error sum function
00351 double clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::f(vnl_vector<double> const& w)
00352 {
00353     //Sum the error contributions from each example
00354     double sum=0.0;
00355     double b=w[num_vars_]; //constant stored as final variable
00356     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00357     {
00358         const double* px=x_[i];
00359         double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00360         double e =  y_[i] - pred;
00361         double e2 = e*e;
00362         if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00363              ((y_[i] < 0.0) && (e >= -1.0)) )
00364         {
00365             //In the correctly classified region
00366             //So use Geman-McClure function
00367             sum += e2/(e2+var_);
00368         }
00369         else
00370         {
00371             //Misclassified, so keep as quadratic (influence increases with error)
00372             //NB alpha and beta are chosen for continuity of function and gradient at boundary
00373             sum += alpha_*e2 + beta_;
00374         }
00375     }
00376     return sum;
00377 }
00378 
00379 //: Calculate gradient of the error sum function
00380 void clsfy_binary_hyperplane_gmrho_builder_helpers::gmrho_sum::gradf(vnl_vector<double> const& w,
00381                                                                      vnl_vector<double>& gradient)
00382 {
00383     using clsfy_binary_hyperplane_gmrho_builder_helpers::gm_grad_accum;
00384     double b=w[num_vars_]; //constant stored as final variable
00385     gradient.fill(0.0);
00386 
00387     for (unsigned i=0; i<num_examples_;++i) //Loop over examples (matrix rows)
00388     {
00389         const double* px=x_[i];
00390         double pred = vcl_inner_product(px,px+num_vars_,w.begin(),0.0) - b;
00391 
00392         double e =  y_[i] - pred;
00393         double e2 = e*e;
00394         double wt=1.0;
00395         if ( ((y_[i] > 0.0) && (e <= 1.0)) ||
00396              ((y_[i] < 0.0) && (e >= -1.0)) )
00397         {
00398             wt = e2 + var_;
00399         }
00400         else
00401         {
00402             //Freeze weight decay once in misclassification region
00403             wt = 1.0 + var_;
00404         }
00405 
00406         double wtInv = -e/(wt*wt);
00407         vcl_for_each(gradient.begin(),gradient.begin()+num_vars_,
00408                      gm_grad_accum(px,wtInv));
00409 
00410         gradient[num_vars_] += (-wtInv); //dg/db, last term is for constant
00411     }
00412     //And multiply everything by 2sigma^2
00413     vcl_transform(gradient.begin(),gradient.end(),gradient.begin(),
00414                   vcl_bind2nd(vcl_multiplies<double>(),2.0*var_));
00415 }

Generated on Thu Nov 20 05:11:41 2008 for contrib/mul/clsfy by  doxygen 1.5.1