SHOGUN  v3.0.1
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
ScatterSVM.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2009 Soeren Sonnenburg
8  * Written (W) 2009 Marius Kloft
9  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
10  */
11 
12 
13 #include <shogun/kernel/Kernel.h>
17 #include <shogun/io/SGIO.h>
18 
19 using namespace shogun;
20 
23  model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
24 {
25  SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n")
26 }
27 
29 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL),
30  norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
31 {
32 }
33 
35 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
36  norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
37 {
38 }
39 
41 {
42  SG_FREE(norm_wc);
43  SG_FREE(norm_wcw);
44 }
45 
47 {
50 
52  int32_t num_vectors = m_labels->get_num_labels();
53 
54  if (data)
55  {
56  if (m_labels->get_num_labels() != data->get_num_vectors())
57  SG_ERROR("Number of training vectors does not match number of labels\n")
58  m_kernel->init(data, data);
59  }
60 
61  int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
63 
64  for (int32_t i=0; i<num_vectors; i++)
65  numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++;
66 
67  int32_t Nc=0;
68  int32_t Nmin=num_vectors;
69  for (int32_t i=0; i<m_num_classes; i++)
70  {
71  if (numc[i]>0)
72  {
73  Nc++;
74  Nmin=CMath::min(Nmin, numc[i]);
75  }
76 
77  }
78  SG_FREE(numc);
79  m_num_classes=Nc;
80 
81  bool result=false;
82 
84  {
85  result=train_no_bias_libsvm();
86  }
87 
89  {
90  float64_t nu_min=((float64_t) Nc)/num_vectors;
91  float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
92 
93  SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max)
94 
95  if (get_nu()<nu_min || get_nu()>nu_max)
96  SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max)
97 
98  result=train_testrule12();
99  }
100  else
101  SG_ERROR("Unknown Scatter type\n")
102 
103  return result;
104 }
105 
106 bool CScatterSVM::train_no_bias_libsvm()
107 {
108  struct svm_node* x_space;
109 
111  SG_INFO("%d trainlabels\n", problem.l)
112 
113  problem.y=SG_MALLOC(float64_t, problem.l);
114  problem.x=SG_MALLOC(struct svm_node*, problem.l);
115  x_space=SG_MALLOC(struct svm_node, 2*problem.l);
116 
117  for (int32_t i=0; i<problem.l; i++)
118  {
119  problem.y[i]=+1;
120  problem.x[i]=&x_space[2*i];
121  x_space[2*i].index=i;
122  x_space[2*i+1].index=-1;
123  }
124 
125  int32_t weights_label[2]={-1,+1};
126  float64_t weights[2]={1.0,get_C()/get_C()};
127 
130 
131  param.svm_type=C_SVC; // Nu MC SVM
132  param.kernel_type = LINEAR;
133  param.degree = 3;
134  param.gamma = 0; // 1/k
135  param.coef0 = 0;
136  param.nu = get_nu(); // Nu
137  CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
139  m_num_classes-1, -1, m_labels, prev_normalizer));
140  param.kernel=m_kernel;
141  param.cache_size = m_kernel->get_cache_size();
142  param.C = 0;
143  param.eps = get_epsilon();
144  param.p = 0.1;
145  param.shrinking = 0;
146  param.nr_weight = 2;
147  param.weight_label = weights_label;
148  param.weight = weights;
149  param.nr_class=m_num_classes;
150  param.use_bias = svm_proto()->get_bias_enabled();
151 
152  const char* error_msg = svm_check_parameter(&problem,&param);
153 
154  if(error_msg)
155  SG_ERROR("Error: %s\n",error_msg)
156 
157  model = svm_train(&problem, &param);
158  m_kernel->set_normalizer(prev_normalizer);
159  SG_UNREF(prev_normalizer);
160 
161  if (model)
162  {
163  ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
164 
165  ASSERT(model->nr_class==m_num_classes)
167 
168  rho=model->rho[0];
169 
170  SG_FREE(norm_wcw);
171  norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
172 
173  for (int32_t i=0; i<m_num_classes; i++)
174  {
175  int32_t num_sv=model->nSV[i];
176 
177  CSVM* svm=new CSVM(num_sv);
178  svm->set_bias(model->rho[i+1]);
179  norm_wcw[i]=model->normwcw[i];
180 
181 
182  for (int32_t j=0; j<num_sv; j++)
183  {
184  svm->set_alpha(j, model->sv_coef[i][j]);
185  svm->set_support_vector(j, model->SV[i][j].index);
186  }
187 
188  set_svm(i, svm);
189  }
190 
191  SG_FREE(problem.x);
192  SG_FREE(problem.y);
193  SG_FREE(x_space);
194  for (int32_t i=0; i<m_num_classes; i++)
195  {
196  SG_FREE(model->SV[i]);
197  model->SV[i]=NULL;
198  }
199  svm_destroy_model(model);
200 
202  compute_norm_wc();
203 
204  model=NULL;
205  return true;
206  }
207  else
208  return false;
209 }
210 
211 
212 
213 bool CScatterSVM::train_testrule12()
214 {
215  struct svm_node* x_space;
216  problem.l=m_labels->get_num_labels();
217  SG_INFO("%d trainlabels\n", problem.l)
218 
219  problem.y=SG_MALLOC(float64_t, problem.l);
220  problem.x=SG_MALLOC(struct svm_node*, problem.l);
221  x_space=SG_MALLOC(struct svm_node, 2*problem.l);
222 
223  for (int32_t i=0; i<problem.l; i++)
224  {
225  problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
226  problem.x[i]=&x_space[2*i];
227  x_space[2*i].index=i;
228  x_space[2*i+1].index=-1;
229  }
230 
231  int32_t weights_label[2]={-1,+1};
232  float64_t weights[2]={1.0,get_C()/get_C()};
233 
235  ASSERT(m_kernel->get_num_vec_lhs()==problem.l)
236 
237  param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
238  param.kernel_type = LINEAR;
239  param.degree = 3;
240  param.gamma = 0; // 1/k
241  param.coef0 = 0;
242  param.nu = get_nu(); // Nu
243  param.kernel=m_kernel;
245  param.C = 0;
246  param.eps = get_epsilon();
247  param.p = 0.1;
248  param.shrinking = 0;
249  param.nr_weight = 2;
250  param.weight_label = weights_label;
251  param.weight = weights;
252  param.nr_class=m_num_classes;
253  param.use_bias = svm_proto()->get_bias_enabled();
254 
255  const char* error_msg = svm_check_parameter(&problem,&param);
256 
257  if(error_msg)
258  SG_ERROR("Error: %s\n",error_msg)
259 
260  model = svm_train(&problem, &param);
261 
262  if (model)
263  {
264  ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
265 
266  ASSERT(model->nr_class==m_num_classes)
267  create_multiclass_svm(m_num_classes);
268 
269  rho=model->rho[0];
270 
271  SG_FREE(norm_wcw);
272  norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
273 
274  for (int32_t i=0; i<m_num_classes; i++)
275  {
276  int32_t num_sv=model->nSV[i];
277 
278  CSVM* svm=new CSVM(num_sv);
279  svm->set_bias(model->rho[i+1]);
280  norm_wcw[i]=model->normwcw[i];
281 
282 
283  for (int32_t j=0; j<num_sv; j++)
284  {
285  svm->set_alpha(j, model->sv_coef[i][j]);
286  svm->set_support_vector(j, model->SV[i][j].index);
287  }
288 
289  set_svm(i, svm);
290  }
291 
292  SG_FREE(problem.x);
293  SG_FREE(problem.y);
294  SG_FREE(x_space);
295  for (int32_t i=0; i<m_num_classes; i++)
296  {
297  SG_FREE(model->SV[i]);
298  model->SV[i]=NULL;
299  }
300  svm_destroy_model(model);
301 
303  compute_norm_wc();
304 
305  model=NULL;
306  return true;
307  }
308  else
309  return false;
310 }
311 
312 void CScatterSVM::compute_norm_wc()
313 {
314  SG_FREE(norm_wc);
315  norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements());
316  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
317  norm_wc[i]=0;
318 
319 
320  for (int c=0; c<m_machines->get_num_elements(); c++)
321  {
322  CSVM* svm=get_svm(c);
323  int32_t num_sv = svm->get_num_support_vectors();
324 
325  for (int32_t i=0; i<num_sv; i++)
326  {
327  int32_t ii=svm->get_support_vector(i);
328  for (int32_t j=0; j<num_sv; j++)
329  {
330  int32_t jj=svm->get_support_vector(j);
331  norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j);
332  }
333  }
334  }
335 
336  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
337  norm_wc[i]=CMath::sqrt(norm_wc[i]);
338 
340 }
341 
343 {
344  CMulticlassLabels* output=NULL;
345  if (!m_kernel)
346  {
347  SG_ERROR("SVM can not proceed without kernel!\n")
348  return NULL;
349  }
350 
352  return NULL;
353 
354  int32_t num_vectors=m_kernel->get_num_vec_rhs();
355 
356  output=new CMulticlassLabels(num_vectors);
357  SG_REF(output);
358 
359  if (scatter_type == TEST_RULE1)
360  {
362  for (int32_t i=0; i<num_vectors; i++)
363  output->set_label(i, apply_one(i));
364  }
365 
366  else
367  {
369  ASSERT(num_vectors==output->get_num_labels())
370  CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements());
371 
372  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
373  {
374  //SG_PRINT("svm %d\n", i)
375  CSVM *svm = get_svm(i);
376  ASSERT(svm)
377  svm->set_kernel(m_kernel);
378  svm->set_labels(m_labels);
379  outputs[i]=svm->apply();
380  SG_UNREF(svm);
381  }
382 
383  for (int32_t i=0; i<num_vectors; i++)
384  {
385  int32_t winner=0;
386  float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0];
387 
388  for (int32_t j=1; j<m_machines->get_num_elements(); j++)
389  {
390  float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j];
391 
392  if (out>max_out)
393  {
394  winner=j;
395  max_out=out;
396  }
397  }
398 
399  output->set_label(i, winner);
400  }
401 
402  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
403  SG_UNREF(outputs[i]);
404 
405  SG_FREE(outputs);
406  }
407 
408  return output;
409 }
410 
411 float64_t CScatterSVM::apply_one(int32_t num)
412 {
414  float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements());
415  int32_t winner=0;
416 
417  if (scatter_type == TEST_RULE1)
418  {
419  for (int32_t c=0; c<m_machines->get_num_elements(); c++)
420  outputs[c]=get_svm(c)->get_bias()-rho;
421 
422  for (int32_t c=0; c<m_machines->get_num_elements(); c++)
423  {
424  float64_t v=0;
425 
426  for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++)
427  {
428  float64_t alpha=get_svm(c)->get_alpha(i);
429  int32_t svidx=get_svm(c)->get_support_vector(i);
430  v += alpha*m_kernel->kernel(svidx, num);
431  }
432 
433  outputs[c] += v;
434  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
435  outputs[j] -= v/m_machines->get_num_elements();
436  }
437 
438  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
439  outputs[j]/=norm_wcw[j];
440 
441  float64_t max_out=outputs[0];
442  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
443  {
444  if (outputs[j]>max_out)
445  {
446  max_out=outputs[j];
447  winner=j;
448  }
449  }
450  }
451 
452  else
453  {
454  float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0];
455 
456  for (int32_t i=1; i<m_machines->get_num_elements(); i++)
457  {
458  outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i];
459  if (outputs[i]>max_out)
460  {
461  winner=i;
462  max_out=outputs[i];
463  }
464  }
465  }
466 
467  SG_FREE(outputs);
468  return winner;
469 }
virtual float64_t apply_one(int32_t num)
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:83
virtual bool train_machine(CFeatures *data=NULL)
Definition: ScatterSVM.cpp:46
#define SG_INFO(...)
Definition: SGIO.h:120
virtual ELabelType get_label_type() const =0
Real Labels are real-valued labels.
float64_t * norm_wcw
Definition: ScatterSVM.h:125
virtual float64_t apply_one(int32_t num)
Definition: ScatterSVM.cpp:411
virtual int32_t get_num_labels() const
no bias w/ libsvm
Definition: ScatterSVM.h:28
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
virtual int32_t get_num_labels() const =0
multi-class labels 0,1,...
Definition: LabelTypes.h:16
virtual bool set_normalizer(CKernelNormalizer *normalizer)
Definition: Kernel.cpp:124
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:356
#define SG_ERROR(...)
Definition: SGIO.h:131
float64_t kernel(int32_t idx_a, int32_t idx_b)
Definition: Kernel.h:198
float64_t * norm_wc
Definition: ScatterSVM.h:122
virtual int32_t get_num_vec_lhs()
Definition: Kernel.h:355
#define SG_REF(x)
Definition: SGObject.h:53
int32_t cache_size
cache_size in MB
Definition: Kernel.h:695
bool set_label(int32_t idx, float64_t label)
Multiclass Labels for multi-class classification.
virtual CKernelNormalizer * get_normalizer()
Definition: Kernel.cpp:136
#define ASSERT(x)
Definition: SGIO.h:203
class MultiClassSVM
Definition: MulticlassSVM.h:26
void set_bias(float64_t bias)
CMulticlassStrategy * m_multiclass_strategy
virtual ~CScatterSVM()
Definition: ScatterSVM.cpp:40
double float64_t
Definition: common.h:48
bool set_alpha(int32_t idx, float64_t val)
SCATTER_TYPE scatter_type
Definition: ScatterSVM.h:111
float64_t get_alpha(int32_t idx)
the scatter kernel normalizer
bool set_support_vector(int32_t idx, int32_t val)
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:271
The class Kernel Normalizer defines a function to post-process kernel values.
int32_t get_support_vector(int32_t idx)
virtual int32_t get_num_vec_rhs()
Definition: Kernel.h:364
#define SG_UNREF(x)
Definition: SGObject.h:54
SCATTER_TYPE
Definition: ScatterSVM.h:25
training with bias using test rule 2
Definition: ScatterSVM.h:33
The class Features is the base class of all feature objects.
Definition: Features.h:62
training with bias using test rule 1
Definition: ScatterSVM.h:31
bool create_multiclass_svm(int32_t num_classes)
static T min(T a, T b)
return the minimum of two integers
Definition: Math.h:160
A generic Support Vector Machine Interface.
Definition: SVM.h:47
The Kernel base class.
Definition: Kernel.h:150
int32_t get_cache_size()
Definition: Kernel.h:435
void set_kernel(CKernel *k)
svm_parameter param
Definition: ScatterSVM.h:116
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...
bool set_svm(int32_t num, CSVM *svm)
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:405
static float32_t sqrt(float32_t x)
x^0.5
Definition: Math.h:252
virtual CLabels * classify_one_vs_rest()
Definition: ScatterSVM.cpp:342
virtual bool has_features()
Definition: Kernel.h:373
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:75
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:134
CSVM * get_svm(int32_t num)
Definition: MulticlassSVM.h:74
svm_problem problem
Definition: ScatterSVM.h:114
struct svm_model * model
Definition: ScatterSVM.h:119
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:162

SHOGUN Machine Learning Toolbox - Documentation