AROWのコードを移植してみた
id:tsubosaka さんの日記でAROW (Adaptive Regularization Of Weight Vector) がJavaで実装されていたのでC++の勉強がてらに,C++で実装してみました. 動作確認は g++ 4.2.1 (Mac OS X 10.6.2), g++ 4.3.3 (Ubuntu 9.04) で行いました.蛇足ですが,最初AROWをARROWと空目してたので, ファイル名やクラス名,名前空間はARROW or arrowになっています.
変更点:
- 入力データの1行目に素性 (特徴) の数が書かれていなくてはならないようにしました.
- 標準入力から, 訓練データの数を指定できるようにしました.
インストール:
警告オプション, 最適化オプションは省略しています. -O2 や -Wall -Wextra あたり付けておくとよいかもしれません.
g++ -c arrow.cc g++ -c main.cc g++ -o arrow arrow.o main.o
実行結果
- 実験データ news20.binary
- クラス数: 2
- データ数: 19,996
- 素性数: 1,355,191
- 繰り返し回数: 10
をシャッフルし,15000例の訓練データと4996例のテストデータに分けました.
1 th iteration: # of mistake: 136 error rate: 0.0272218 2 th iteration: # of mistake: 134 error rate: 0.0268215 3 th iteration: # of mistake: 135 error rate: 0.0270216 4 th iteration: # of mistake: 134 error rate: 0.0268215 5 th iteration: # of mistake: 134 error rate: 0.0268215 6 th iteration: # of mistake: 134 error rate: 0.0268215 7 th iteration: # of mistake: 134 error rate: 0.0268215 8 th iteration: # of mistake: 134 error rate: 0.0268215 9 th iteration: # of mistake: 134 error rate: 0.0268215 10 th iteration: # of mistake: 134 error rate: 0.0268215
このように,収束が非常に速いことが見て取れます.
コード
arrow.h
// -*- mode: c++ -*- #ifndef ARROW_H #define ARROW_H #include <iostream> #include <vector> #include <sstream> #include <fstream> #include <string> #include <algorithm> #include <ctime> namespace arrow { struct f_node { int index; double weight; }; typedef std::vector<f_node> feature_vec; struct example { int label; feature_vec fv; }; class Arrow { public: Arrow(std::size_t feature_size) { feature_size_ = feature_size; mean_.reserve(feature_size_); mean_.assign(feature_size_, 0.0); cov_.reserve(feature_size_); cov_.assign(feature_size_, 1.0); r_ = 0.1; } virtual ~Arrow() {} double GetMargin(feature_vec& fv) const { double res = 0.0; for (feature_vec::iterator it = fv.begin(); it != fv.end(); ++it) { res += mean_[(*it).index] * (*it).weight; } return res; } double GetConfidence(feature_vec& fv) const { double res = 0.0; for (feature_vec::iterator it = fv.begin(); it != fv.end(); ++it) { res += cov_[(*it).index] * (*it).weight * (*it).weight; } return res; } int Update(feature_vec& fv, int& label); int Predict(feature_vec& fv) const { const double m = GetMargin(fv); return m > 0 ? 1 : -1; } private: std::size_t feature_size_; std::vector<double> mean_; std::vector<double> cov_; double r_; }; class Random { public: Random() { srand(static_cast<unsigned int>(time(NULL))); } unsigned int operator() (unsigned int max) { double tmp = static_cast<double>(rand()) / static_cast<double>(RAND_MAX); return static_cast<unsigned int>(tmp * max); } }; int ReadData(const std::string& file, std::vector<example>& data, std::size_t& num_feature); int ParseLine(const std::string& line, feature_vec& fv, int& label); void ShuffleData(std::vector<example>& data); } // namespace arrow #endif // ARROW_H
arrow.cc
#include "arrow.h" namespace arrow { int Arrow::Update(feature_vec& fv, int& label) { const double m = GetMargin(fv); const int loss = m * label < 0 ? 1 : 0; if (m * label >= 1) return 0; const double conf = GetConfidence(fv); const double beta = 1.0 / (conf + r_); const double alpha = (1.0 - label * m) * beta; // Update mean for (feature_vec::iterator it = fv.begin(); it != fv.end(); ++it) mean_[(*it).index] += alpha * label * cov_[(*it).index] * (*it).weight; // Update covariance for (feature_vec::iterator it = fv.begin(); it != fv.end(); ++it) cov_[(*it).index] = 1.0 / ((1.0 / cov_[(*it).index]) + (*it).weight * (*it).weight / r_); return loss; } int ReadData(const std::string& file, std::vector<example>& data, std::size_t& num_feature) { std::istream *ifs; if (file == "-") { ifs = &std::cin; } else { ifs = new std::ifstream(file.c_str()); } if (!*ifs) { std::cerr << "Cannot open: " << file << std::endl; return -1; } std::size_t line_num = 0; std::string line; // get feature size std::getline(*ifs, line); num_feature = atoi(line.c_str()); if (num_feature <= 0) { std::cerr << "Invalid # of feature" << std::endl; return -1; } while (std::getline(*ifs, line)) { line_num++; if (line[0] == '#') continue; // comment int label = 0; // label label feature_vec vec; if (ParseLine(line, vec, label) == - 1) { std::cerr << " line: " << line_num; return -1; } example ex; ex.label = label; ex.fv = vec; data.push_back(ex); } // end while if (file != "-") delete ifs; return 0; } int ParseLine(const std::string& line, feature_vec& fv, int& label) { std::istringstream is(line); if (!(is >> label)) { std::cerr << "parse error: no label"; return -1; } if (label != 1 && label != -1) { std::cerr << "parse error: label is not +1 nor -1 "; return -1; } int id = 0; char sep = 0; double val = 0.0; while (is >> id >> sep >> val) { f_node feature; feature.index = id; feature.weight = val; fv.push_back(feature); } return 0; } void ShuffleData(std::vector<example>& data) { Random r; std::random_shuffle(data.begin(), data.end(), r); } } // namespace arrow
main.cc
#include "arrow.h" int main(int argc, char** argv) { if (argc < 2) { std::cerr << "Usage: " << argv[0] << " file [train_size]" << std::endl; return -1; } typedef std::vector<arrow::example> examples; std::string file = argv[1]; examples data; std::size_t num_feature; // # of features if (arrow::ReadData(file, data, num_feature) != 0) { std::cerr << "Cannot read" << std::endl; return -1; } std::size_t train_size = 0; if (argc == 3) { train_size = atoi(argv[argc-1]); } if (train_size <= 0 || train_size >= data.size()) { train_size = static_cast<std::size_t>(data.size() * 0.75); } const std::size_t test_size = data.size() - train_size; if (test_size <= 0) { std::cerr << "Size of train data is large"; return -1; } arrow::ShuffleData(data); examples train(train_size); examples test(test_size); for (std::size_t i = 0; i < train_size; ++i) train[i] = data[i]; for (std::size_t i = 0; i < test_size; ++i) test[i] = data[i+train_size]; arrow::Arrow arow(num_feature); // # of iteration const std::size_t iter_num = 10; // Iteration for (std::size_t i = 0; i < iter_num; ++i) { // Update for (examples::iterator it = train.begin(); it != train.end(); ++it) { arow.Update((*it).fv, (*it).label); } // Predict std::size_t mistake = 0; for (examples::iterator it = test.begin(); it != test.end(); ++it) { int l = arow.Predict((*it).fv); if (l != (*it).label) mistake++; } // Print result std::cout << i+1 << " th iteration: " << std::endl; std::cout << "# of mistake: " << mistake << std::endl; std::cout << "error rate: " << mistake * 1.0 / test_size << std::endl; std::cout << std::endl; } return 0; }