00001
00002 #ifndef bsta_adaptive_updater_h_
00003 #define bsta_adaptive_updater_h_
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include <bsta/bsta_distribution.h>
00020 #include <bsta/bsta_mixture.h>
00021 #include <bsta/bsta_mixture_fixed.h>
00022 #include <bsta/bsta_attributes.h>
00023 #include <bsta/bsta_gauss_ff3.h>
00024 #include "bsta_gaussian_updater.h"
00025 #include "bsta_gaussian_stats.h"
00026
00027
00028
00029
00030 template <class mix_dist_>
00031 class bsta_mg_adaptive_updater
00032 {
00033 private:
00034 typedef typename mix_dist_::dist_type obs_gaussian_;
00035 typedef typename obs_gaussian_::contained_type gaussian_;
00036 typedef typename gaussian_::math_type T;
00037 typedef typename gaussian_::vector_type vector_;
00038
00039 public:
00040
00041 typedef typename gaussian_::field_type field_type;
00042 typedef mix_dist_ distribution_type;
00043
00044 protected:
00045
00046 bsta_mg_adaptive_updater(const gaussian_& model,
00047 unsigned int max_cmp = 5)
00048 : init_gaussian_(model,T(1)),
00049 max_components_(max_cmp) {}
00050
00051
00052 void insert(mix_dist_& mixture, const vector_& sample, T init_weight) const
00053 {
00054 bool removed = false;
00055 if (mixture.num_components() >= max_components_){
00056 removed = true;
00057 do {
00058 mixture.remove_last();
00059 } while (mixture.num_components() >= max_components_);
00060 }
00061
00062
00063 if (removed){
00064 T adjust = T(0);
00065 for (unsigned int i=0; i<mixture.num_components(); ++i)
00066 adjust += mixture.weight(i);
00067 adjust = (T(1)-init_weight) / adjust;
00068 for (unsigned int i=0; i<mixture.num_components(); ++i)
00069 mixture.set_weight(i, mixture.weight(i)*adjust);
00070 }
00071 init_gaussian_.set_mean(sample);
00072 mixture.insert(init_gaussian_,init_weight);
00073 }
00074
00075
00076 mutable obs_gaussian_ init_gaussian_;
00077
00078 unsigned int max_components_;
00079 };
00080
00081
00082
00083
00084 template <class mix_dist_>
00085 class bsta_mg_statistical_updater : public bsta_mg_adaptive_updater<mix_dist_>
00086 {
00087 public:
00088 typedef typename mix_dist_::dist_type obs_gaussian_;
00089 typedef typename obs_gaussian_::contained_type gaussian_;
00090 typedef typename gaussian_::math_type T;
00091 typedef typename gaussian_::vector_type vector_;
00092 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00093
00094
00095 typedef obs_mix_dist_ distribution_type;
00096
00097 enum { data_dimension = gaussian_::dimension };
00098
00099
00100 bsta_mg_statistical_updater(const gaussian_& model,
00101 unsigned int max_cmp = 5,
00102 T g_thresh = T(3),
00103 T min_stdev = T(0))
00104 : bsta_mg_adaptive_updater<mix_dist_>(model, max_cmp),
00105 gt2_(g_thresh*g_thresh), min_var_(min_stdev*min_stdev) {}
00106
00107
00108 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00109 {
00110 mix.num_observations += T(1);
00111 this->update(mix, sample, T(1)/mix.num_observations);
00112 }
00113
00114 void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00115 #if 0
00116 void update( mix_dist_& mix, const T & sample, T alpha ) const;
00117 #endif
00118
00119 T gt2_;
00120
00121 T min_var_;
00122 };
00123
00124
00125
00126
00127 template <class mix_dist_>
00128 class bsta_mg_window_updater : public bsta_mg_statistical_updater<mix_dist_>
00129 {
00130 public:
00131 typedef typename mix_dist_::dist_type obs_gaussian_;
00132 typedef typename obs_gaussian_::contained_type gaussian_;
00133 typedef typename gaussian_::math_type T;
00134 typedef typename gaussian_::vector_type vector_;
00135 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00136
00137
00138 typedef obs_mix_dist_ distribution_type;
00139
00140 enum { data_dimension = gaussian_::dimension };
00141
00142
00143 bsta_mg_window_updater(const gaussian_& model,
00144 unsigned int max_cmp = 5,
00145 T g_thresh = T(3),
00146 T min_stdev = T(0),
00147 unsigned int window_size = 40)
00148 : bsta_mg_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev),
00149 window_size_(window_size) {}
00150
00151
00152 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00153 {
00154 if (mix.num_observations < window_size_)
00155 mix.num_observations += T(1);
00156 this->update(mix, sample, T(1)/mix.num_observations);
00157 }
00158
00159 protected:
00160 unsigned int window_size_;
00161 };
00162
00163
00164
00165 template <class mix_dist_>
00166 class bsta_mg_weighted_updater : bsta_mg_statistical_updater<mix_dist_>
00167 {
00168 public:
00169 typedef typename mix_dist_::dist_type obs_gaussian_;
00170 typedef typename obs_gaussian_::contained_type gaussian_;
00171 typedef typename gaussian_::math_type T;
00172 typedef typename gaussian_::vector_type vector_;
00173 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00174
00175
00176 typedef obs_mix_dist_ distribution_type;
00177
00178 enum { data_dimension = gaussian_::dimension };
00179
00180
00181 bsta_mg_weighted_updater(const gaussian_& model,
00182 unsigned int max_cmp = 5,
00183 T g_thresh = T(3),
00184 T min_stdev = T(0))
00185 : bsta_mg_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev){}
00186
00187
00188 void operator() ( obs_mix_dist_& mix, const vector_& sample, const T weight ) const
00189 {
00190 mix.num_observations += weight;
00191 this->update(mix, sample, weight/mix.num_observations);
00192 }
00193 };
00194
00195
00196
00197
00198 template <class mix_dist_>
00199 class bsta_mg_grimson_statistical_updater : public bsta_mg_adaptive_updater<mix_dist_>
00200 {
00201 public:
00202 typedef typename mix_dist_::dist_type obs_gaussian_;
00203 typedef typename obs_gaussian_::contained_type gaussian_;
00204 typedef typename gaussian_::math_type T;
00205 typedef typename gaussian_::vector_type vector_;
00206 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00207
00208
00209 typedef obs_mix_dist_ distribution_type;
00210
00211 enum { data_dimension = gaussian_::dimension };
00212
00213
00214 bsta_mg_grimson_statistical_updater(const gaussian_& model,
00215 unsigned int max_cmp = 5,
00216 T g_thresh = T(3),
00217 T min_stdev = T(0) )
00218 : bsta_mg_adaptive_updater<mix_dist_>(model, max_cmp),
00219 gt2_(g_thresh*g_thresh), min_var_(min_stdev*min_stdev) {}
00220
00221
00222 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00223 {
00224 mix.num_observations += T(1);
00225 this->update(mix, sample, T(1)/mix.num_observations);
00226 }
00227
00228 void update( mix_dist_& mix, const vector_& sample, T alpha ) const;
00229
00230
00231 T gt2_;
00232
00233 T min_var_;
00234 };
00235
00236
00237
00238 template <class mix_dist_>
00239 class bsta_mg_grimson_window_updater : public bsta_mg_grimson_statistical_updater<mix_dist_>
00240 {
00241 public:
00242 typedef typename mix_dist_::dist_type obs_gaussian_;
00243 typedef typename obs_gaussian_::contained_type gaussian_;
00244 typedef typename gaussian_::math_type T;
00245 typedef typename gaussian_::vector_type vector_;
00246 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00247
00248
00249 typedef obs_mix_dist_ distribution_type;
00250
00251 enum { data_dimension = gaussian_::dimension };
00252
00253
00254 bsta_mg_grimson_window_updater(const gaussian_& model,
00255 unsigned int max_cmp = 5,
00256 T g_thresh = T(3),
00257 T min_stdev = T(0),
00258 unsigned int window_size = 40)
00259 : bsta_mg_grimson_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev),
00260 window_size_(window_size) {}
00261
00262
00263 void operator() ( obs_mix_dist_& mix, const vector_& sample ) const
00264 {
00265 if (mix.num_observations < window_size_)
00266 mix.num_observations += T(1);
00267 this->update(mix, sample, T(1)/mix.num_observations);
00268 }
00269
00270 protected:
00271 unsigned int window_size_;
00272 };
00273
00274
00275
00276 template <class mix_dist_>
00277 class bsta_mg_grimson_weighted_updater : bsta_mg_grimson_statistical_updater<mix_dist_>
00278 {
00279 public:
00280 typedef typename mix_dist_::dist_type obs_gaussian_;
00281 typedef typename obs_gaussian_::contained_type gaussian_;
00282 typedef typename gaussian_::math_type T;
00283 typedef typename gaussian_::vector_type vector_;
00284 typedef bsta_num_obs<mix_dist_> obs_mix_dist_;
00285
00286
00287 typedef obs_mix_dist_ distribution_type;
00288
00289 enum { data_dimension = gaussian_::dimension };
00290
00291
00292 bsta_mg_grimson_weighted_updater(const gaussian_& model,
00293 unsigned int max_cmp = 5,
00294 T g_thresh = T(3),
00295 T min_stdev = T(0) )
00296 : bsta_mg_grimson_statistical_updater<mix_dist_>(model, max_cmp, g_thresh, min_stdev){}
00297
00298
00299 void operator() ( obs_mix_dist_& mix, const vector_& sample, const T weight ) const
00300 {
00301 mix.num_observations += weight;
00302 this->update(mix, sample, weight/mix.num_observations);
00303 }
00304 };
00305
00306
00307 #endif // bsta_adaptive_updater_h_