SHOGUN
v1.1.0
|
Class CVowpalWabbit is the implementation of the online learning algorithm used in Vowpal Wabbit.
VW is a fast online learning algorithm which operates on sparse features. It uses an online gradient descent technique.
For more details, refer to the tutorial at https://github.com/JohnLangford/vowpal_wabbit/wiki/v5.1_tutorial.pdf
Definition at line 38 of file VowpalWabbit.h.
Public Member Functions | |
CVowpalWabbit () | |
CVowpalWabbit (CStreamingVwFeatures *feat) | |
~CVowpalWabbit () | |
void | reinitialize_weights () |
void | set_no_training (bool dont_train) |
void | set_adaptive (bool adaptive_learning) |
void | set_exact_adaptive_norm (bool exact_adaptive) |
void | set_num_passes (int32_t passes) |
void | load_regressor (char *file_name) |
void | set_regressor_out (char *file_name, bool is_text=true) |
void | set_prediction_out (char *file_name) |
void | add_quadratic_pair (char *pair) |
virtual bool | train_machine (CFeatures *feat=NULL) |
virtual float32_t | predict_and_finalize (VwExample *ex) |
float32_t | compute_exact_norm (VwExample *&ex, float32_t &sum_abs_x) |
float32_t | compute_exact_norm_quad (float32_t *weights, VwFeature &page_feature, v_array< VwFeature > &offer_features, vw_size_t mask, float32_t g, float32_t &sum_abs_x) |
virtual CVwEnvironment * | get_env () |
virtual const char * | get_name () const |
![]() | |
COnlineLinearMachine () | |
virtual | ~COnlineLinearMachine () |
virtual void | get_w (float32_t *&dst_w, int32_t &dst_dims) |
virtual void | get_w (float64_t *&dst_w, int32_t &dst_dims) |
virtual SGVector< float32_t > | get_w () |
virtual void | set_w (float32_t *src_w, int32_t src_w_dim) |
virtual void | set_w (float64_t *src_w, int32_t src_w_dim) |
virtual void | set_bias (float32_t b) |
virtual float32_t | get_bias () |
virtual bool | load (FILE *srcfile) |
virtual bool | save (FILE *dstfile) |
virtual void | set_features (CStreamingDotFeatures *feat) |
virtual CLabels * | apply () |
virtual CLabels * | apply (CFeatures *data) |
virtual float64_t | apply (int32_t vec_idx) |
get output for example "vec_idx" | |
virtual float32_t | apply (float32_t *vec, int32_t len) |
virtual float32_t | apply_to_current_example () |
virtual CStreamingDotFeatures * | get_features () |
![]() | |
CMachine () | |
virtual | ~CMachine () |
virtual bool | train (CFeatures *data=NULL) |
virtual void | set_labels (CLabels *lab) |
virtual CLabels * | get_labels () |
virtual float64_t | get_label (int32_t i) |
void | set_max_train_time (float64_t t) |
float64_t | get_max_train_time () |
virtual EClassifierType | get_classifier_type () |
void | set_solver_type (ESolverType st) |
ESolverType | get_solver_type () |
virtual void | set_store_model_features (bool store_model) |
![]() | |
CSGObject () | |
CSGObject (const CSGObject &orig) | |
virtual | ~CSGObject () |
virtual bool | is_generic (EPrimitiveType *generic) const |
template<class T > | |
void | set_generic () |
void | unset_generic () |
virtual void | print_serializable (const char *prefix="") |
virtual bool | save_serializable (CSerializableFile *file, const char *prefix="") |
virtual bool | load_serializable (CSerializableFile *file, const char *prefix="") |
void | set_global_io (SGIO *io) |
SGIO * | get_global_io () |
void | set_global_parallel (Parallel *parallel) |
Parallel * | get_global_parallel () |
void | set_global_version (Version *version) |
Version * | get_global_version () |
SGVector< char * > | get_modelsel_names () |
char * | get_modsel_param_descr (const char *param_name) |
index_t | get_modsel_param_index (const char *param_name) |
Protected Attributes | |
CStreamingVwFeatures * | features |
Features. | |
CVwEnvironment * | env |
Environment for VW, i.e., globals. | |
CVwLearner * | learner |
Learner to use. | |
CVwRegressor * | reg |
Regressor. | |
![]() | |
int32_t | w_dim |
float32_t * | w |
float32_t | bias |
CStreamingDotFeatures * | features |
![]() | |
float64_t | max_train_time |
CLabels * | labels |
ESolverType | solver_type |
bool | m_store_model_features |
Additional Inherited Members | |
![]() | |
SGIO * | io |
Parallel * | parallel |
Version * | version |
Parameter * | m_parameters |
Parameter * | m_model_selection_parameters |
![]() | |
virtual void | store_model_features () |
CVowpalWabbit | ( | ) |
Default constructor
Definition at line 20 of file VowpalWabbit.cpp.
CVowpalWabbit | ( | CStreamingVwFeatures * | feat | ) |
Constructor, taking a features object as argument
feat | StreamingVwFeatures object |
Definition at line 28 of file VowpalWabbit.cpp.
~CVowpalWabbit | ( | ) |
Destructor
Definition at line 36 of file VowpalWabbit.cpp.
void add_quadratic_pair | ( | char * | pair | ) |
Add a pair of namespaces whose features should be crossed for quadratic updates
pair | a string with the two namespace names concatenated |
Definition at line 101 of file VowpalWabbit.cpp.
Computes the exact norm during adaptive learning
ex | example |
sum_abs_x | set by reference, sum of abs of features |
Definition at line 382 of file VowpalWabbit.cpp.
float32_t compute_exact_norm_quad | ( | float32_t * | weights, |
VwFeature & | page_feature, | ||
v_array< VwFeature > & | offer_features, | ||
vw_size_t | mask, | ||
float32_t | g, | ||
float32_t & | sum_abs_x | ||
) |
Computes the exact norm for quadratic features during adaptive learning
weights | weights |
page_feature | current feature |
offer_features | paired features |
mask | mask |
g | square of gradient |
sum_abs_x | sum of absolute value of features |
Definition at line 419 of file VowpalWabbit.cpp.
|
virtual |
Get the environment
Definition at line 176 of file VowpalWabbit.h.
|
virtual |
Return the name of the object
Reimplemented from COnlineLinearMachine.
Definition at line 187 of file VowpalWabbit.h.
void load_regressor | ( | char * | file_name | ) |
Load regressor from a dump file
file_name | name of regressor file |
Definition at line 80 of file VowpalWabbit.cpp.
Predict for an example
ex | VwExample to predict for |
Definition at line 178 of file VowpalWabbit.cpp.
void reinitialize_weights | ( | ) |
Reinitialize the weight vectors. Call after updating env variables eg. stride.
Definition at line 43 of file VowpalWabbit.cpp.
void set_adaptive | ( | bool | adaptive_learning | ) |
Set whether learning is adaptive or not
adaptive_learning | true if adaptive |
Definition at line 56 of file VowpalWabbit.cpp.
void set_exact_adaptive_norm | ( | bool | exact_adaptive | ) |
Set whether to use the more expensive exact norm for adaptive learning
exact_adaptive | true if exact norm is required |
Definition at line 69 of file VowpalWabbit.cpp.
void set_no_training | ( | bool | dont_train | ) |
Set whether one desires to not train and only make passes over all examples instead.
This is useful if one wants to create a cache file from data.
dont_train | true if one doesn't want to train |
Definition at line 73 of file VowpalWabbit.h.
void set_num_passes | ( | int32_t | passes | ) |
Set number of passes (only works for cached input)
passes | number of passes |
Definition at line 95 of file VowpalWabbit.h.
void set_prediction_out | ( | char * | file_name | ) |
Set file name of prediction output
file_name | name of file to save predictions to |
Definition at line 93 of file VowpalWabbit.cpp.
void set_regressor_out | ( | char * | file_name, |
bool | is_text = true |
||
) |
Set regressor output parameters
file_name | name of file to save regressor to |
is_text | human readable or not, bool |
Definition at line 87 of file VowpalWabbit.cpp.
|
virtual |
Train on a StreamingVwFeatures object
feat | StreamingVwFeatures to train using |
Reimplemented from CMachine.
Definition at line 106 of file VowpalWabbit.cpp.
|
protected |
Environment for VW, i.e., globals.
Definition at line 266 of file VowpalWabbit.h.
|
protected |
Features.
Definition at line 263 of file VowpalWabbit.h.
|
protected |
Learner to use.
Definition at line 269 of file VowpalWabbit.h.
|
protected |
Regressor.
Definition at line 272 of file VowpalWabbit.h.