LDAのコードをC++で書いてみた
前回の記事で思ったより勉強になったので、調子に乗って再び id:tsubosaka さんのJavaで書かれたLDAの実装をC++で書いてみました。ベースとなる手法は同じく collapsed gibbs sampling(Griffiths and Steyvers, PNAS, 2004) です。動作確認は g++ 4.3.3 (Ubuntu 9.04) で行っています。ソースコードは前回のARROWよりかなり長くなってしまいました。今度から長い場合は github あたりにアップするかもしれません。
2010.01.08追記: pcomp.h のソースが抜けていたので追加しました。
2010.01.09追記: メモリのバグを修正しました。
インストール
g++ -Wall -O2 -c lda.cc g++ -Wall -O2 -c main.cc g++ -Wall -O2 -o lda lda.o main.o
使用法
./lda Bag-Of-Wordsファイル Vocabファイル
結果は標準出力に出力されるようにしました。
実験結果
- 実験データ: UCI Machine Learning Repository: Bag of Words Data Set のNIPSのデータセット
- 文書数: 1500
- vocaburary の単語数: 12,419
- collection の単語数: 746,316
- 繰り返しの回数: 200
トピック数 K = 50 場合の各クラスにおけるトピックの単語生成確率の上位10件を以下に示します。
繰り返し数 = 1 の時
topic: 0 network 0.014843 model 0.00880151 input 0.00787398 function 0.00699658 algorithm 0.00692138 set 0.00679604 learning 0.00614426 system 0.00566796 data 0.00541728 neural 0.005317 topic: 5 network 0.014523 neural 0.00990054 learning 0.00972469 model 0.00952371 function 0.00688591 set 0.00628299 data 0.0061825 input 0.00598153 system 0.00573031 training 0.00565494 topic: 49 network 0.0139108 model 0.00867621 input 0.00844862 function 0.00839805 learning 0.00791758 set 0.00695665 data 0.00607158 neural 0.00594514 unit 0.00551525 system 0.00513593
繰り返し数 = 200 の時
topic: 0 algorithm 0.0815442 problem 0.0440062 number 0.0218871 result 0.014514 step 0.014514 search 0.0144454 solution 0.0141296 method 0.0130587 run 0.00742933 topic: 5 robot 0.0332836 environment 0.0234933 goal 0.0178843 control 0.0126152 system 0.0125812 path 0.0109495 position 0.0106775 place 0.00921579 world 0.00897783 behavior 0.00884186 topic: 49 neuron 0.0350895 synaptic 0.024088 model 0.0231375 input 0.0228341 cell 0.0185872 synapses 0.0131876 membrane 0.0115091 potential 0.0111046 current 0.010417 response 0.00819246 values 0.00731949
このように、繰り返しの回数が増加するにつれ、単語の出現確率がトピックによって異なっていることが見て取れます。
ソース
token.h
#ifndef TOKEN_H #define TOKEN_H namespace lda { struct Token { int doc_id_; int word_id_; public: Token(int d, int w) : doc_id_(d), word_id_(w) {} virtual ~Token() {} }; } // namespace lda #endif // TOKEN_H
random.h
#ifndef RANDOM_H #define RANDOM_H namespace lda { class Random { public: Random() { srand(static_cast<std::size_t>(time(NULL))); } virtual ~Random() {} int gen(int max) { double tmp = static_cast<double>(rand()) / static_cast<double>(RAND_MAX + 1.0); return static_cast<int>(tmp * max); } double gen(double max) { double tmp = static_cast<double>(rand()) / static_cast<double>(RAND_MAX + 1.0); return tmp * max; } }; } // namespace lda #endif // RANDOM_H
lda.h
#ifndef LDA_H #define LDA_H #include <iostream> #include <vector> #include <string> #include <sstream> #include <fstream> #include <algorithm> #include <ctime> #include "token.h" #include "random.h" namespace lda { class LDA { public: LDA(int doc_num, int topic_num, int word_num, std::vector<Token>& tok_lis) { D_ = doc_num; K_ = topic_num; W_ = word_num; tokens_ = tok_lis; alpha_ = 50.0 / topic_num; beta_ = 0.1; ptr_random_ = new Random(); // 2010.01.09 update Init(); } ~LDA() { try { // Delete word_count_ for (int i = 0; i < W_; ++i) delete [] word_count_[i]; delete [] word_count_; // Delete doc_count_ for (int i = 0; i < D_; ++i) delete [] doc_count_[i]; delete [] doc_count_; delete [] topic_count_; delete [] z_; delete [] p_; delete ptr_random_; // 2010.01.09 update } catch (...) { std::cerr << "~LDA(): Out of memory" << std::endl; exit(EXIT_FAILURE); } } void Update(); double** GetTheta(); double** GetPhi(); private: int D_; // # of document int K_; // # of topic int W_; // # of unique word int** word_count_; int** doc_count_; int* topic_count_; // hyper parameter double alpha_; double beta_; std::vector<Token> tokens_; double* p_; int* z_; // topic assignment Random* ptr_random_; void Init(); void _Init(); int SelectNextTopic(const Token t); void Resample(std::size_t token_id); }; int DeleteTheta(double** theta, int doc_num); int DeletePhi(double** phi, int topic_num); } // namespace lda #endif // LDA_H
lda.cc
#include "lda.h" namespace lda { void LDA::Init() { try { word_count_ = new int*[W_]; for (int i = 0; i < W_; ++i) { word_count_[i] = new int[K_]; for (int j = 0; j < K_; ++j) word_count_[i][j] = 0; } topic_count_ = new int[K_]; for (int i = 0; i < K_; ++i) topic_count_[i] = 0; doc_count_ = new int*[D_]; for (int i = 0; i < D_; ++i) { doc_count_[i] = new int[K_]; for (int j = 0; j < K_; ++j) doc_count_[i][j] = 0; } z_ = new int[tokens_.size()]; for (std::size_t i = 0; i < tokens_.size(); ++i) z_[i] = 0; p_ = new double[K_]; for (int i = 0; i < K_; ++i) p_[i] = 0.0; } catch (...) { std::cerr << "Init(): Out of memory" << std::endl; exit(EXIT_FAILURE); } _Init(); } void LDA::_Init() { for (std::size_t i = 0; i < tokens_.size(); ++i) { Token t = tokens_[i]; std::size_t assign = static_cast<std::size_t>((*ptr_random_).gen(K_)); word_count_[t.word_id_][assign]++; doc_count_[t.doc_id_][assign]++; topic_count_[assign]++; z_[i] = assign; } } int LDA::SelectNextTopic(Token t) { for (int k = 0; k < K_; ++k) { p_[k] = (word_count_[t.word_id_][k] + beta_) * (doc_count_[t.doc_id_][k] + alpha_) / (topic_count_[k] + W_ * beta_); if (k != 0) p_[k] += p_[k - 1]; } double u = (*ptr_random_).gen(1.0) * p_[K_-1]; for (int k = 0; k < K_; ++k) { if (u < p_[k]) return k; } return K_ - 1; } inline void LDA::Resample(std::size_t token_id) { Token t = tokens_[token_id]; int assign = z_[token_id]; // remove from current topic word_count_[t.word_id_][assign]--; doc_count_[t.doc_id_][assign]--; topic_count_[assign]--; assign = SelectNextTopic(t); word_count_[t.word_id_][assign]++; doc_count_[t.doc_id_][assign]++; topic_count_[assign]++; z_[token_id] = assign; } void LDA::Update() { for (std::size_t i = 0; i < tokens_.size(); ++i) LDA::Resample(i); } double** LDA::GetTheta() { double** theta; try { theta = new double*[D_]; for (int i = 0; i < D_; ++i) { theta[i] = new double[K_]; } for (int i = 0; i < D_; ++i) { double sum = 0.0; for (int j = 0; j < K_; ++j) { theta[i][j] = alpha_ + doc_count_[i][j]; sum += theta[i][j]; } // normalize double sinv = 1.0 / sum; for (int j = 0; j < K_; ++j) { theta[i][j] *= sinv; } } } catch (...) { std::cerr << "GetTheta(): Out of memory" << std::endl; exit(EXIT_FAILURE); } return theta; } double** LDA::GetPhi() { double** phi; try { phi = new double*[K_]; for (int i = 0; i < K_; ++i) { phi[i] = new double[W_]; } for (int i = 0; i < K_; ++i) { double sum = 0.0; for (int j = 0; j < W_; ++j) { phi[i][j] = beta_ + word_count_[j][i]; sum += phi[i][j]; } // normalize double sinv = 1.0 / sum; for (int j = 0; j < W_; ++j) phi[i][j] *= sinv; } } catch (...) { std::cerr << "GetPhi(): Out of memory" << std::endl; exit(EXIT_FAILURE); } return phi; } int DeleteTheta(double** theta, int doc_num) { try { for (int i = 0; i < doc_num; ++i) delete [] theta[i]; delete [] theta; } catch (...) { std::cerr << "DeleteTheta(): Out of memory" << std::endl; exit(EXIT_FAILURE); } return 0; } int DeletePhi(double** phi, int topic_num) { try { for (int i = 0; i < topic_num; ++i) delete [] phi[i]; delete [] phi; } catch (...) { std::cerr << "DeletePhi(): Out of memory" << std::endl; exit(EXIT_FAILURE); } return 0; } } // namespace lda
pcomp.h
#ifndef PCOMP_H #define PCOMP_H namespace lda { struct Pcomp { int id; double prob; }; class LessProb { public: bool operator()(const Pcomp& a, const Pcomp& b) const { return b.prob < a.prob; } }; } // namespace lda #endif // PCOMP_H
実験に使ったコード: main.cc
#include <cstring> #include "lda.h" #include "pcomp.h" namespace lda { inline int ParseLine(const std::string& line, std::vector<Token>& tokens) { std::istringstream is(line); int doc_id = 0; int word_id = 0; int count = 0; is >> doc_id >> word_id >> count; if (!doc_id || !word_id || !count) { std::cerr << "parse error"; return -1; } for (int i = 0; i < count; ++i) { tokens.push_back(Token(doc_id - 1, word_id - 1)); } return 0; } int ReadBOWData(const std::string& file, std::vector<Token>& tokens, int& D, int& W, int& N) { std::istream *ifs; if (file == "-") { ifs = &std::cin; } else { ifs = new std::ifstream(file.c_str()); } if (!*ifs) { std::cerr << "Cannot open: " << file << std::endl; return -1; } std::size_t line_num = 0; std::string line; // Get feature size std::getline(*ifs, line); D = atoi(line.c_str()); std::getline(*ifs, line); W = atoi(line.c_str()); std::getline(*ifs, line); N = atoi(line.c_str()); line_num += 3; if (N <= 0) { std::cerr << "Invalid # of N" << std::endl; return -1; } for (int i = 0; i < N; ++i) { std::getline(*ifs, line); if (line[0] == '#') continue; // comment if (ParseLine(line, tokens) == - 1) { std::cerr << " line: " << line_num; return -1; } line_num++; } if (file != "-") delete ifs; return 0; } int ReadVocabData(const std::string& file, std::vector<const char*>& words, int& W) { std::istream *ifs; if (file == "-") { ifs = &std::cin; } else { ifs = new std::ifstream(file.c_str()); } if (!*ifs) { std::cerr << "Cannot open: " << file << std::endl; return -1; } std::string line; for (int i = 0; i < W; ++i) { std::getline(*ifs, line); if (line[0] == '#') continue; char *tmp = new char[line.size()+1]; std::strcpy(tmp, line.c_str()); words[i] = tmp; } if (file != "-") delete ifs; return 0; } void PrintWordTopic(double** phi, int K, int W, std::vector<const char *>& words) { for (int k = 0; k < K; ++k) { std::cout << "topic: " << k << std::endl; std::vector<Pcomp> ps(W); for (int w = 0; w < W; ++w) { Pcomp pc; pc.id = w; pc.prob = phi[k][w]; ps[w] = pc; } std::sort(ps.begin(), ps.end(), LessProb()); // print top 10 words for (int i = 0; i < 10; ++i) { Pcomp p = ps[i]; std::cout << words[p.id] << ' ' << p.prob << std::endl; } std::cout << std::endl; } } } // namespace lda int main(int argc, char** argv) { if (argc < 3) { std::cerr << "Usage: " << argv[0] << " bow_file vocab_file" << std::endl; return -1; } std::string bow_file = argv[1]; std::string vocab_file = argv[2]; const int K = 50; const int num_iter = 200; // # of iteration std::vector<lda::Token> tokens; int D = 0; int W = 0; int N = 0; // Read bow file if (lda::ReadBOWData(bow_file, tokens, D, W, N) != 0) { std::cerr << "Cannot read" << std::endl; return -1; } // Read vocabrary file std::vector<const char*> words(W); if (lda::ReadVocabData(vocab_file, words, W) != 0) { std::cerr << "Cannot read" << std::endl; return -1; } lda::LDA lda(D, K, W, tokens); for (int i = 0; i <= num_iter; ++i) { std::cout << "# of iteration: " << i << std::endl; lda.Update(); if (i % 10 == 0) { double** phi = lda.GetPhi(); lda::PrintWordTopic(phi, K, W, words); lda::DeletePhi(phi, K); } std::cout << std::endl; } for (std::size_t i = 0; i < words.size(); ++i) delete [] words[i]; return 0; }