SHOGUN  v3.0.1
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
MulticlassModel.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Fernando José Iglesias García
8  * Copyright (C) 2012 Fernando José Iglesias García
9  */
10 
15 
16 using namespace shogun;
17 
20 {
21  init();
22 }
23 
25 : CStructuredModel(features, labels)
26 {
27  init();
28 }
29 
31 {
32 }
33 
35 {
36  // TODO make the casts safe!
37  int32_t num_classes = ((CMulticlassSOLabels*) m_labels)->get_num_classes();
38  int32_t feats_dim = ((CDotFeatures*) m_features)->get_dim_feature_space();
39 
40  return feats_dim*num_classes;
41 }
42 
44 {
46  psi.zero();
47 
49  get_computed_dot_feature_vector(feat_idx);
51  ASSERT(r != NULL)
52  float64_t label_value = r->value;
53 
54  for ( index_t i = 0, j = label_value*x.vlen ; i < x.vlen ; ++i, ++j )
55  psi[j] = x[i];
56 
57  return psi;
58 }
59 
62  int32_t feat_idx,
63  bool const training)
64 {
66  int32_t feats_dim = df->get_dim_feature_space();
67 
68  if ( training )
69  {
71  m_num_classes = ml->get_num_classes();
72  }
73  else
74  {
75  REQUIRE(m_num_classes > 0, "The model needs to be trained before "
76  "using it for prediction\n");
77  }
78 
79  int32_t dim = get_dim();
80  ASSERT(dim == w.vlen)
81 
82  // Find the class that gives the maximum score
83 
84  float64_t score = 0, ypred = 0;
85  float64_t max_score = -CMath::INFTY;
86 
87  for ( int32_t c = 0 ; c < m_num_classes ; ++c )
88  {
89  score = df->dense_dot(feat_idx, w.vector+c*feats_dim, feats_dim);
90  if ( training )
91  score += delta_loss(feat_idx, c);
92 
93  if ( score > max_score )
94  {
95  max_score = score;
96  ypred = c;
97  }
98  }
99 
100  // Build the CResultSet object to return
101  CResultSet* ret = new CResultSet();
102  SG_REF(ret);
103  CRealNumber* y = new CRealNumber(ypred);
104  SG_REF(y);
105 
106  ret->psi_pred = get_joint_feature_vector(feat_idx, y);
107  ret->score = max_score;
108  ret->argmax = y;
109  if ( training )
110  {
111  ret->delta = CStructuredModel::delta_loss(feat_idx, y);
113  feat_idx, feat_idx);
115  ret->psi_truth.vector, dim);
116  }
117 
118  return ret;
119 }
120 
122 {
125  ASSERT(rn1 != NULL)
126  ASSERT(rn2 != NULL)
127 
128  return delta_loss(rn1->value, rn2->value);
129 }
130 
132 {
133  REQUIRE(y1_idx >= 0 || y1_idx < m_labels->get_num_labels(),
134  "The label index must be inside [0, num_labels-1]\n");
135 
137  float64_t ret = delta_loss(rn1->value, y2);
138  SG_UNREF(rn1);
139 
140  return ret;
141 }
142 
144 {
145  return (y1 == y2) ? 0 : 1;
146 }
147 
149  float64_t regularization,
157 {
159 }
160 
161 void CMulticlassModel::init()
162 {
163  SG_ADD(&m_num_classes, "m_num_classes", "The number of classes",
165 
166  m_num_classes = 0;
167 }
168 
SGVector< float64_t > psi_truth
static float64_t dot(const bool *v1, const bool *v2, int32_t n)
compute dot product between v1 and v2 (blas optimized)
Definition: SGVector.h:344
Base class of the labels used in Structured Output (SO) problems.
virtual void init_primal_opt(float64_t regularization, SGMatrix< float64_t > &A, SGVector< float64_t > a, SGMatrix< float64_t > B, SGVector< float64_t > &b, SGVector< float64_t > lb, SGVector< float64_t > ub, SGMatrix< float64_t > &C)
int32_t index_t
Definition: common.h:60
static const float64_t INFTY
infinity
Definition: Math.h:1355
virtual float64_t dense_dot(int32_t vec_idx1, const float64_t *vec2, int32_t vec2_len)=0
#define REQUIRE(x,...)
Definition: SGIO.h:208
virtual float64_t delta_loss(CStructuredData *y1, CStructuredData *y2)
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)
SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx)
Features that support dot products among other operations.
Definition: DotFeatures.h:41
#define SG_REF(x)
Definition: SGObject.h:53
static CRealNumber * obtain_from_generic(CStructuredData *base_data)
virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData *y)
virtual int32_t get_dim_feature_space() const =0
#define ASSERT(x)
Definition: SGIO.h:203
double float64_t
Definition: common.h:48
static SGMatrix< T > create_identity_matrix(index_t size, T scale)
float64_t delta_loss(int32_t ytrue_idx, CStructuredData *ypred)
virtual int32_t get_dim() const
Class CStructuredModel that represents the application specific model and contains most of the applic...
#define SG_UNREF(x)
Definition: SGObject.h:54
CStructuredLabels * m_labels
The class Features is the base class of all feature objects.
Definition: Features.h:62
CStructuredData * argmax
CStructuredData * get_label(int32_t idx)
SGVector< float64_t > psi_pred
Class CRealNumber to be used in the application of Structured Output (SO) learning to multiclass clas...
#define SG_ADD(...)
Definition: SGObject.h:83
Class CMulticlassSOLabels to be used in the application of Structured Output (SO) learning to multicl...
Base class of the components of StructuredLabels.
index_t vlen
Definition: SGVector.h:706

SHOGUN Machine Learning Toolbox - Documentation