SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MultiClassSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #include <shogun/lib/common.h>
12 #include <shogun/io/SGIO.h>
14 
15 using namespace shogun;
16 
18 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL)
19 {
20  init();
21 }
22 
23 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
24 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
25 {
26  init();
27 }
28 
30  EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
31 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
32 {
33  init();
34 }
35 
37 {
38  cleanup();
39 }
40 
41 void CMultiClassSVM::init()
42 {
44  "multiclass_type", "Type of MultiClassSVM.");
45  m_parameters->add(&m_num_classes, "m_num_classes",
46  "Number of classes.");
48  &m_num_svms, "m_svms");
49 }
50 
52 {
53  for (int32_t i=0; i<m_num_svms; i++)
54  SG_UNREF(m_svms[i]);
55 
56  SG_FREE(m_svms);
57  m_num_svms=0;
58  m_svms=NULL;
59 }
60 
61 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
62 {
63  if (num_classes>0)
64  {
65  cleanup();
66 
67  m_num_classes=num_classes;
68 
69  if (multiclass_type==ONE_VS_REST)
70  m_num_svms=num_classes;
71  else if (multiclass_type==ONE_VS_ONE)
72  m_num_svms=num_classes*(num_classes-1)/2;
73  else
74  SG_ERROR("unknown multiclass type\n");
75 
77  if (m_svms)
78  {
79  memset(m_svms,0, m_num_svms*sizeof(CSVM*));
80  return true;
81  }
82  }
83  return false;
84 }
85 
86 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
87 {
88  if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
89  {
90  SG_REF(svm);
91  m_svms[num]=svm;
92  return true;
93  }
94  return false;
95 }
96 
98 {
99  if (multiclass_type==ONE_VS_REST)
100  return classify_one_vs_rest();
101  else if (multiclass_type==ONE_VS_ONE)
102  return classify_one_vs_one();
103  else
104  SG_ERROR("unknown multiclass type\n");
105 
106  return NULL;
107 }
108 
110 {
111  ASSERT(m_num_svms>0);
113  CLabels* result=NULL;
114 
115  if (!kernel)
116  {
117  SG_ERROR( "SVM can not proceed without kernel!\n");
118  return false ;
119  }
120 
122  {
123  int32_t num_vectors=kernel->get_num_vec_rhs();
124 
125  result=new CLabels(num_vectors);
126  SG_REF(result);
127 
128  ASSERT(num_vectors==result->get_num_labels());
129  CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
130 
131  for (int32_t i=0; i<m_num_svms; i++)
132  {
133  SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
134  ASSERT(m_svms[i]);
135  m_svms[i]->set_kernel(kernel);
136  outputs[i]=m_svms[i]->apply();
137  }
138 
139  int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
140  for (int32_t v=0; v<num_vectors; v++)
141  {
142  int32_t s=0;
143  memset(votes, 0, sizeof(int32_t)*m_num_classes);
144 
145  for (int32_t i=0; i<m_num_classes; i++)
146  {
147  for (int32_t j=i+1; j<m_num_classes; j++)
148  {
149  if (outputs[s++]->get_label(v)>0)
150  votes[i]++;
151  else
152  votes[j]++;
153  }
154  }
155 
156  int32_t winner=0;
157  int32_t max_votes=votes[0];
158 
159  for (int32_t i=1; i<m_num_classes; i++)
160  {
161  if (votes[i]>max_votes)
162  {
163  max_votes=votes[i];
164  winner=i;
165  }
166  }
167 
168  result->set_label(v, winner);
169  }
170 
171  SG_FREE(votes);
172 
173  for (int32_t i=0; i<m_num_svms; i++)
174  SG_UNREF(outputs[i]);
175  SG_FREE(outputs);
176  }
177 
178  return result;
179 }
180 
182 {
183  ASSERT(m_num_svms>0);
184  CLabels* result=NULL;
185 
186  if (!kernel)
187  {
188  SG_ERROR( "SVM can not proceed without kernel!\n");
189  return false ;
190  }
191 
193  {
194  int32_t num_vectors=kernel->get_num_vec_rhs();
195 
196  result=new CLabels(num_vectors);
197  SG_REF(result);
198 
199  ASSERT(num_vectors==result->get_num_labels());
200  CLabels** outputs=SG_MALLOC(CLabels*, m_num_svms);
201 
202  for (int32_t i=0; i<m_num_svms; i++)
203  {
204  ASSERT(m_svms[i]);
205  m_svms[i]->set_kernel(kernel);
206  outputs[i]=m_svms[i]->apply();
207  }
208 
209  for (int32_t i=0; i<num_vectors; i++)
210  {
211  int32_t winner=0;
212  float64_t max_out=outputs[0]->get_label(i);
213 
214  for (int32_t j=1; j<m_num_svms; j++)
215  {
216  float64_t out=outputs[j]->get_label(i);
217 
218  if (out>max_out)
219  {
220  winner=j;
221  max_out=out;
222  }
223  }
224 
225  result->set_label(i, winner);
226  }
227 
228  for (int32_t i=0; i<m_num_svms; i++)
229  SG_UNREF(outputs[i]);
230 
231  SG_FREE(outputs);
232  }
233 
234  return result;
235 }
236 
238 {
239  if (multiclass_type==ONE_VS_REST)
240  return classify_example_one_vs_rest(num);
241  else if (multiclass_type==ONE_VS_ONE)
242  return classify_example_one_vs_one(num);
243  else
244  SG_ERROR("unknown multiclass type\n");
245 
246  return 0;
247 }
248 
250 {
251  ASSERT(m_num_svms>0);
253  int32_t winner=0;
254  float64_t max_out=m_svms[0]->apply(num);
255 
256  for (int32_t i=1; i<m_num_svms; i++)
257  {
258  outputs[i]=m_svms[i]->apply(num);
259  if (outputs[i]>max_out)
260  {
261  winner=i;
262  max_out=outputs[i];
263  }
264  }
265  SG_FREE(outputs);
266 
267  return winner;
268 }
269 
271 {
272  ASSERT(m_num_svms>0);
274 
275  int32_t* votes=SG_MALLOC(int32_t, m_num_classes);
276  int32_t s=0;
277 
278  for (int32_t i=0; i<m_num_classes; i++)
279  {
280  for (int32_t j=i+1; j<m_num_classes; j++)
281  {
282  if (m_svms[s++]->apply(num)>0)
283  votes[i]++;
284  else
285  votes[j]++;
286  }
287  }
288 
289  int32_t winner=0;
290  int32_t max_votes=votes[0];
291 
292  for (int32_t i=1; i<m_num_classes; i++)
293  {
294  if (votes[i]>max_votes)
295  {
296  max_votes=votes[i];
297  winner=i;
298  }
299  }
300 
301  SG_FREE(votes);
302 
303  return winner;
304 }
305 
306 bool CMultiClassSVM::load(FILE* modelfl)
307 {
308  bool result=true;
309  char char_buffer[1024];
310  int32_t int_buffer;
311  float64_t double_buffer;
312  int32_t line_number=1;
313  int32_t svm_idx=-1;
314 
316 
317  if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
318  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
319  else
320  {
321  char_buffer[15]='\0';
322  if (strcmp("%MultiClassSVM", char_buffer)!=0)
323  SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
324 
325  line_number++;
326  }
327 
328  int_buffer=0;
329  if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
330  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
331 
332  if (!feof(modelfl))
333  line_number++;
334 
335  if (int_buffer != multiclass_type)
336  SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
337 
338  int_buffer=0;
339  if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
340  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
341 
342  if (!feof(modelfl))
343  line_number++;
344 
345  if (int_buffer < 2)
346  SG_ERROR("less than 2 classes - how is this multiclass?\n");
347 
348  create_multiclass_svm(int_buffer);
349 
350  int_buffer=0;
351  if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
352  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
353 
354  if (!feof(modelfl))
355  line_number++;
356 
357  if (m_num_svms != int_buffer)
358  SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
359 
360  if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
361  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
362 
363  if (!feof(modelfl))
364  line_number++;
365 
366  for (int32_t n=0; n<m_num_svms; n++)
367  {
368  svm_idx=-1;
369  if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
370  {
371  result=false;
372  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
373  }
374  else
375  {
376  char_buffer[4]='\0';
377  if (strncmp("%SVM", char_buffer, 4)!=0)
378  {
379  result=false;
380  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
381  }
382 
383  if (svm_idx != n)
384  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
385 
386  line_number++;
387  }
388 
389  int_buffer=0;
390  if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
391  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
392 
393  if (svm_idx != n)
394  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
395 
396  if (!feof(modelfl))
397  line_number++;
398 
399  SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
400  CSVM* svm=new CSVM(int_buffer);
401 
402  double_buffer=0;
403 
404  if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
405  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
406 
407  if (svm_idx != n)
408  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
409 
410  if (!feof(modelfl))
411  line_number++;
412 
413  svm->set_bias(double_buffer);
414 
415  if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
416  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
417 
418  if (svm_idx != n)
419  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
420 
421  if (!feof(modelfl))
422  line_number++;
423 
424  for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
425  {
426  double_buffer=0;
427  int_buffer=0;
428 
429  if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
430  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
431 
432  if (!feof(modelfl))
433  line_number++;
434 
435  svm->set_support_vector(i, int_buffer);
436  svm->set_alpha(i, double_buffer);
437  }
438 
439  if (fscanf(modelfl,"%2s", char_buffer) == EOF)
440  {
441  result=false;
442  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
443  }
444  else
445  {
446  char_buffer[3]='\0';
447  if (strcmp("];", char_buffer)!=0)
448  {
449  result=false;
450  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
451  }
452  line_number++;
453  }
454 
455  set_svm(n, svm);
456  }
457 
458  svm_loaded=result;
459 
461  return result;
462 }
463 
464 bool CMultiClassSVM::save(FILE* modelfl)
465 {
467 
468  if (!kernel)
469  SG_ERROR("Kernel not defined!\n");
470 
471  if (!m_svms || m_num_svms<1 || m_num_classes <=2)
472  SG_ERROR("Multiclass SVM not trained!\n");
473 
474  SG_INFO( "Writing model file...");
475  fprintf(modelfl,"%%MultiClassSVM\n");
476  fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
477  fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
478  fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
479  fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
480 
481  for (int32_t i=0; i<m_num_svms; i++)
482  {
483  CSVM* svm=m_svms[i];
484  ASSERT(svm);
485  fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
486  fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
487  fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
488 
489  fprintf(modelfl, "alphas%d=[\n", i);
490 
491  for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
492  {
493  fprintf(modelfl,"\t[%+10.16e,%d];\n",
494  svm->get_alpha(j), svm->get_support_vector(j));
495  }
496 
497  fprintf(modelfl, "];\n");
498  }
499 
501  SG_DONE();
502  return true ;
503 }

SHOGUN Machine Learning Toolbox - Documentation