SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MPDSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
12 #include <shogun/io/SGIO.h>
13 #include <shogun/lib/common.h>
15 
16 using namespace shogun;
17 
19 : CSVM()
20 {
21 }
22 
24 : CSVM(C, k, lab)
25 {
26 }
27 
29 {
30 }
31 
33 {
34  ASSERT(labels);
35  ASSERT(kernel);
36 
37  if (data)
38  {
39  if (labels->get_num_labels() != data->get_num_vectors())
40  SG_ERROR("Number of training vectors does not match number of labels\n");
41  kernel->init(data, data);
42  }
44 
45  //const float64_t nu=0.32;
46  const float64_t alpha_eps=1e-12;
47  const float64_t eps=get_epsilon();
48  const int64_t maxiter = 1L<<30;
49  //const bool nustop=false;
50  //const int32_t k=2;
51  const int32_t n=labels->get_num_labels();
52  ASSERT(n>0);
53  //const float64_t d = 1.0/n/nu; //NUSVC
54  const float64_t d = get_C1(); //CSVC
55  const float64_t primaleps=eps;
56  const float64_t dualeps=eps*n; //heuristic
57  int64_t niter=0;
58 
60  float64_t* alphas=SG_MALLOC(float64_t, n);
61  float64_t* dalphas=SG_MALLOC(float64_t, n);
62  //float64_t* hessres=SG_MALLOC(float64_t, 2*n);
63  float64_t* hessres=SG_MALLOC(float64_t, n);
64  //float64_t* F=SG_MALLOC(float64_t, 2*n);
66 
67  //float64_t hessest[2]={0,0};
68  //float64_t hstep[2];
69  //float64_t etas[2]={0,0};
70  //float64_t detas[2]={0,1}; //NUSVC
71  float64_t etas=0;
72  float64_t detas=0; //CSVC
73  float64_t hessest=0;
74  float64_t hstep;
75 
76  const float64_t stopfac = 1;
77 
78  bool primalcool;
79  bool dualcool;
80 
81  //if (nustop)
82  //etas[1] = 1;
83 
84  for (int32_t i=0; i<n; i++)
85  {
86  alphas[i]=0;
87  F[i]=labels->get_label(i);
88  //F[i+n]=-1;
89  hessres[i]=labels->get_label(i);
90  //hessres[i+n]=-1;
91  //dalphas[i]=F[i+n]*etas[1]; //NUSVC
92  dalphas[i]=-1; //CSVC
93  }
94 
95  // go ...
96  while (niter++ < maxiter)
97  {
98  int32_t maxpidx=-1;
99  float64_t maxpviol = -1;
100  //float64_t maxdviol = CMath::abs(detas[0]);
101  float64_t maxdviol = CMath::abs(detas);
102  bool free_alpha=false;
103 
104  //if (CMath::abs(detas[1])> maxdviol)
105  //maxdviol=CMath::abs(detas[1]);
106 
107  // compute kkt violations with correct sign ...
108  for (int32_t i=0; i<n; i++)
109  {
110  float64_t v=CMath::abs(dalphas[i]);
111 
112  if (alphas[i] > 0 && alphas[i] < d)
113  free_alpha=true;
114 
115  if ( (dalphas[i]==0) ||
116  (alphas[i]==0 && dalphas[i] >0) ||
117  (alphas[i]==d && dalphas[i] <0)
118  )
119  v=0;
120 
121  if (v > maxpviol)
122  {
123  maxpviol=v;
124  maxpidx=i;
125  } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
126  else if (v == maxpviol)
127  {
128  if (kernel_cache->is_cached(i))
129  maxpidx=i;
130  }
131  }
132 
133  if (maxpidx<0 || maxdviol<0)
134  SG_ERROR( "no violation no convergence, should not happen!\n");
135 
136  // ... and evaluate stopping conditions
137  //if (nustop)
138  //stopfac = CMath::max(etas[1], 1e-10);
139  //else
140  //stopfac = 1;
141 
142  if (niter%10000 == 0)
143  {
144  float64_t obj=0;
145 
146  for (int32_t i=0; i<n; i++)
147  {
148  obj-=alphas[i];
149  for (int32_t j=0; j<n; j++)
150  obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
151  }
152 
153  SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
154  }
155 
156  //for (int32_t i=0; i<n; i++)
157  // SG_DEBUG( "alphas:%f dalphas:%f\n", alphas[i], dalphas[i]);
158 
159  primalcool = (maxpviol < primaleps*stopfac);
160  dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
161 
162  // done?
163  if (primalcool && dualcool)
164  {
165  if (!free_alpha)
166  SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
167  else
168  SG_INFO( " done! #iter=%d\n", niter);
169  break;
170  }
171 
172 
173  ASSERT(maxpidx>=0 && maxpidx<n);
174  // hessian updates
175  hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
176  //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
177  //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
178 
179  hessest-=F[maxpidx]*hstep;
180  //hessest[0]-=F[maxpidx]*hstep[0];
181  //hessest[1]-=F[maxpidx+n]*hstep[1];
182 
183  // do primal updates ..
184  float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
185 
186  if (tmpalpha > d-alpha_eps)
187  tmpalpha = d;
188 
189  if (tmpalpha < 0+alpha_eps)
190  tmpalpha = 0;
191 
192  // update alphas & dalphas & detas ...
193  float64_t alphachange = tmpalpha - alphas[maxpidx];
194  alphas[maxpidx] = tmpalpha;
195 
196  KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
197  for (int32_t i=0; i<n; i++)
198  {
199  hessres[i]+=h[i]*hstep;
200  //hessres[i]+=h[i]*hstep[0];
201  //hessres[i+n]+=h[i]*hstep[1];
202  dalphas[i] +=h[i]*alphachange;
203  }
204  unlock_kernel_row(maxpidx);
205 
206  detas+=F[maxpidx]*alphachange;
207  //detas[0]+=F[maxpidx]*alphachange;
208  //detas[1]+=F[maxpidx+n]*alphachange;
209 
210  // if at primal minimum, do eta update ...
211  if (primalcool)
212  {
213  //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
214  float64_t etachange = detas/hessest;
215 
216  etas+=etachange;
217  //etas[0]+=etachange[0];
218  //etas[1]+=etachange[1];
219 
220  // update dalphas
221  for (int32_t i=0; i<n; i++)
222  dalphas[i]+= F[i] * etachange;
223  //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
224  }
225  }
226 
227  if (niter >= maxiter)
228  SG_WARNING( "increase maxiter ... \n");
229 
230 
231  int32_t nsv=0;
232  for (int32_t i=0; i<n; i++)
233  {
234  if (alphas[i]>0)
235  nsv++;
236  }
237 
238 
239  create_new_model(nsv);
240  //set_bias(etas[0]/etas[1]);
241  set_bias(etas);
242 
243  int32_t j=0;
244  for (int32_t i=0; i<n; i++)
245  {
246  if (alphas[i]>0)
247  {
248  //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
249  set_alpha(j, alphas[i]*labels->get_label(i));
250  set_support_vector(j, i);
251  j++;
252  }
253  }
255  SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
256  SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
257 
258  SG_FREE(alphas);
259  SG_FREE(dalphas);
260  SG_FREE(hessres);
261  SG_FREE(F);
262  delete kernel_cache;
263 
264  return true;
265 }

SHOGUN Machine Learning Toolbox - Documentation