LDAのコードをC++で書いてみた

前回の記事で思ったより勉強になったので、調子に乗って再び id:tsubosaka さんのJavaで書かれたLDAの実装C++で書いてみました。ベースとなる手法は同じく collapsed gibbs sampling(Griffiths and Steyvers, PNAS, 2004) です。動作確認は g++ 4.3.3 (Ubuntu 9.04) で行っています。ソースコードは前回のARROWよりかなり長くなってしまいました。今度から長い場合は github あたりにアップするかもしれません。

2010.01.08追記: pcomp.h のソースが抜けていたので追加しました。
2010.01.09追記: メモリのバグを修正しました。

インストール

g++ -Wall -O2 -c lda.cc
g++ -Wall -O2 -c main.cc
g++ -Wall -O2 -o lda lda.o main.o

使用法

./lda Bag-Of-Wordsファイル Vocabファイル

結果は標準出力に出力されるようにしました。

実験結果

トピック数 K = 50 場合の各クラスにおけるトピックの単語生成確率の上位10件を以下に示します。

繰り返し数 = 1 の時

topic: 0
network 0.014843
model 0.00880151
input 0.00787398
function 0.00699658
algorithm 0.00692138
set 0.00679604
learning 0.00614426
system 0.00566796
data 0.00541728
neural 0.005317

topic: 5
network 0.014523
neural 0.00990054
learning 0.00972469
model 0.00952371
function 0.00688591
set 0.00628299
data 0.0061825
input 0.00598153
system 0.00573031
training 0.00565494

topic: 49
network 0.0139108
model 0.00867621
input 0.00844862
function 0.00839805
learning 0.00791758
set 0.00695665
data 0.00607158
neural 0.00594514
unit 0.00551525
system 0.00513593

繰り返し数 = 200 の時

topic: 0
algorithm 0.0815442
problem 0.0440062
number 0.0218871
result 0.014514
step 0.014514
search 0.0144454
solution 0.0141296
method 0.0130587
run 0.00742933

topic: 5
robot 0.0332836
environment 0.0234933
goal 0.0178843
control 0.0126152
system 0.0125812
path 0.0109495
position 0.0106775
place 0.00921579
world 0.00897783
behavior 0.00884186

topic: 49
neuron 0.0350895
synaptic 0.024088
model 0.0231375
input 0.0228341
cell 0.0185872
synapses 0.0131876
membrane 0.0115091
potential 0.0111046
current 0.010417
response 0.00819246
values 0.00731949

このように、繰り返しの回数が増加するにつれ、単語の出現確率がトピックによって異なっていることが見て取れます。

ソース

token.h

#ifndef TOKEN_H
#define TOKEN_H

namespace lda {

struct Token {
  int doc_id_;
  int word_id_;
public:
  Token(int d, int w) : doc_id_(d), word_id_(w) {}
  virtual ~Token() {}
};

} // namespace lda

#endif  // TOKEN_H

random.h

#ifndef RANDOM_H
#define RANDOM_H

namespace lda {

class Random {
public:
  Random() {
    srand(static_cast<std::size_t>(time(NULL)));
  }
  
  virtual ~Random() {}
  
  int gen(int max) {
    double tmp = static_cast<double>(rand()) / static_cast<double>(RAND_MAX + 1.0);
    return static_cast<int>(tmp * max);
  }

  double gen(double max) {
    double tmp = static_cast<double>(rand()) / static_cast<double>(RAND_MAX + 1.0);
    return tmp * max;
  }
};

} // namespace lda

#endif  // RANDOM_H

lda.h

#ifndef LDA_H
#define LDA_H

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <fstream>
#include <algorithm>
#include <ctime>

#include "token.h"
#include "random.h"

namespace lda {

class LDA {
 public:
  LDA(int doc_num, int topic_num, int word_num,
      std::vector<Token>& tok_lis) {
    D_ = doc_num;
    K_ = topic_num;
    W_ = word_num;
    tokens_ = tok_lis;
    alpha_ = 50.0 / topic_num;
    beta_ = 0.1;
    ptr_random_ = new Random();   // 2010.01.09 update

    Init();
  }

  ~LDA() {
    try {
      // Delete word_count_
      for (int i = 0; i < W_; ++i) delete [] word_count_[i];
      delete [] word_count_;

      // Delete doc_count_
      for (int i = 0; i < D_; ++i) delete [] doc_count_[i];
      delete [] doc_count_;

      delete [] topic_count_;
      delete [] z_;
      delete [] p_;

      delete ptr_random_;   // 2010.01.09 update
    }
    catch (...) {
      std::cerr << "~LDA(): Out of memory" << std::endl;
      exit(EXIT_FAILURE);
    }
  }

  void Update();
  double** GetTheta();
  double** GetPhi();

 private:
  int D_;                         // # of document
  int K_;                       // # of topic
  int W_;                        // # of unique word

  int** word_count_;
  int** doc_count_;
  int* topic_count_;

  // hyper parameter
  double alpha_;
  double beta_;
  
  std::vector<Token> tokens_;
  
  double* p_;
  int* z_;                              // topic assignment

  Random* ptr_random_;

  void Init();
  void _Init();
  
  int SelectNextTopic(const Token t);
  
  void Resample(std::size_t token_id);
};


int DeleteTheta(double** theta, int doc_num);

int DeletePhi(double** phi, int topic_num);

} // namespace lda

#endif  // LDA_H

lda.cc

#include "lda.h"

namespace lda {

void LDA::Init() {
  try {
    word_count_ = new int*[W_];
    for (int i = 0; i < W_; ++i) {
      word_count_[i] = new int[K_];

      for (int j = 0; j < K_; ++j)
        word_count_[i][j] = 0;
    }

    topic_count_ = new int[K_];
    for (int i = 0; i < K_; ++i)
      topic_count_[i] = 0;

    doc_count_ = new int*[D_];
    for (int i = 0; i < D_; ++i) {
      doc_count_[i] = new int[K_];

      for (int j = 0; j < K_; ++j)
        doc_count_[i][j] = 0;
    }

    z_ = new int[tokens_.size()];
    for (std::size_t i = 0; i < tokens_.size(); ++i)
      z_[i] = 0;

    p_ = new double[K_];
    for (int i = 0; i < K_; ++i)
      p_[i] = 0.0;
  }
  catch (...) {
    std::cerr << "Init(): Out of memory" << std::endl;
    exit(EXIT_FAILURE);
  }

  _Init();
}

void LDA::_Init() {
  for (std::size_t i = 0; i < tokens_.size(); ++i) {
    Token t = tokens_[i];
    std::size_t assign = static_cast<std::size_t>((*ptr_random_).gen(K_));
    
    word_count_[t.word_id_][assign]++;
    doc_count_[t.doc_id_][assign]++;
    topic_count_[assign]++;
    z_[i] = assign;
  }
}

int LDA::SelectNextTopic(Token t) {
  for (int k = 0; k < K_; ++k) {
    p_[k] = (word_count_[t.word_id_][k] + beta_)
      * (doc_count_[t.doc_id_][k] + alpha_) / (topic_count_[k] + W_ * beta_);
    if (k != 0) p_[k] += p_[k - 1];
  }
  double u = (*ptr_random_).gen(1.0) * p_[K_-1];
  
  for (int k = 0; k < K_; ++k) {
    if (u < p_[k]) return k;
  }
  
  return K_ - 1;
}

inline void LDA::Resample(std::size_t token_id) {
  Token t = tokens_[token_id];
  int assign = z_[token_id];

  // remove from current topic
  word_count_[t.word_id_][assign]--;
  doc_count_[t.doc_id_][assign]--;
  topic_count_[assign]--;
    
  assign = SelectNextTopic(t);
  word_count_[t.word_id_][assign]++;
  doc_count_[t.doc_id_][assign]++;
  topic_count_[assign]++;
  
  z_[token_id] = assign;
}

void LDA::Update() {
  for (std::size_t i = 0; i < tokens_.size(); ++i)
    LDA::Resample(i);
}

double** LDA::GetTheta() {
  double** theta;
  
  try {
    theta = new double*[D_];
    for (int i = 0; i < D_; ++i) {
      theta[i] = new double[K_];
    }
    for (int i = 0; i < D_; ++i) {
      double sum = 0.0;
      for (int j = 0; j < K_; ++j) {
        theta[i][j] = alpha_ + doc_count_[i][j];
        sum += theta[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < K_; ++j) {
        theta[i][j] *= sinv;
      }
    }
  }
  catch (...) {
    std::cerr << "GetTheta(): Out of memory" << std::endl;
    exit(EXIT_FAILURE);
  }
  
  return theta;
}

double** LDA::GetPhi() {
  double** phi;
  
  try {
    phi = new double*[K_];
    for (int i = 0; i < K_; ++i) {
      phi[i] = new double[W_];
    }
    for (int i = 0; i < K_; ++i) {
      double sum = 0.0;
      for (int j = 0; j < W_; ++j) {
        phi[i][j] = beta_ + word_count_[j][i];
        sum += phi[i][j];
      }
      // normalize
      double sinv = 1.0 / sum;
      for (int j = 0; j < W_; ++j)
        phi[i][j] *= sinv;
    }
  }
  catch (...) {
    std::cerr << "GetPhi(): Out of memory" << std::endl;
    exit(EXIT_FAILURE);
  }
  
  return phi;
}

int DeleteTheta(double** theta, int doc_num) {
  try {
    for (int i = 0; i < doc_num; ++i)
      delete [] theta[i];
    delete [] theta;
  }
  catch (...) {
    std::cerr << "DeleteTheta(): Out of memory" << std::endl;
    exit(EXIT_FAILURE);
  }
  
  return 0;
}

int DeletePhi(double** phi, int topic_num) {
  try {
    for (int i = 0; i < topic_num; ++i)
      delete [] phi[i];
    delete [] phi;
  }
  catch (...) {
    std::cerr << "DeletePhi(): Out of memory" << std::endl;
    exit(EXIT_FAILURE);
  }
  
  return 0;
}

} // namespace lda

pcomp.h

#ifndef PCOMP_H
#define PCOMP_H

namespace lda {

struct Pcomp {
  int id;
  double prob;
};

class LessProb {
public:
  bool operator()(const Pcomp& a, const Pcomp& b) const {
    return b.prob < a.prob;
  }
};

} // namespace lda

#endif  // PCOMP_H

実験に使ったコード: main.cc

#include <cstring>
#include "lda.h"
#include "pcomp.h"

namespace lda {

inline int ParseLine(const std::string& line, std::vector<Token>& tokens) {
  std::istringstream is(line);
  int doc_id = 0;
  int word_id = 0;
  int count = 0;
  is >> doc_id >> word_id >> count;
  if (!doc_id || !word_id || !count) {
    std::cerr << "parse error";
    return -1;
  }
  
  for (int i = 0; i < count; ++i) {
    tokens.push_back(Token(doc_id - 1, word_id - 1));
  }
  return 0;
}

int ReadBOWData(const std::string& file, std::vector<Token>& tokens,
             int& D, int& W, int& N) {
  std::istream *ifs;

  if (file == "-") {
    ifs = &std::cin;
  } else {
    ifs = new std::ifstream(file.c_str());
  }

  if (!*ifs) {
    std::cerr << "Cannot open: " << file << std::endl;
    return -1;
  }
  
  std::size_t line_num = 0;
  std::string line;

  // Get feature size
  std::getline(*ifs, line);
  D = atoi(line.c_str());
  std::getline(*ifs, line);
  W = atoi(line.c_str());
  std::getline(*ifs, line);
  N = atoi(line.c_str());

  line_num += 3;

  if (N <= 0) {
    std::cerr << "Invalid # of N" << std::endl;
    return -1;
  }

  for (int i = 0; i < N; ++i) {
    std::getline(*ifs, line);
    if (line[0] == '#') continue;       // comment
    if (ParseLine(line, tokens) == - 1) {
      std::cerr << " line: " << line_num;
      return -1;
    }
    line_num++;
  }  
  if (file != "-") delete ifs;

  return 0;
}

int ReadVocabData(const std::string& file,
                  std::vector<const char*>& words, int& W) {
  std::istream *ifs;
  
  if (file == "-") {
    ifs = &std::cin;
  } else {
    ifs = new std::ifstream(file.c_str());
  }

  if (!*ifs) {
    std::cerr << "Cannot open: " << file << std::endl;
    return -1;
  }

  std::string line;
  for (int i = 0; i < W; ++i) {
    std::getline(*ifs, line);
    if (line[0] == '#') continue;    
    char *tmp = new char[line.size()+1];
    std::strcpy(tmp, line.c_str());
    words[i] = tmp;
  }
  if (file != "-") delete ifs;

  return 0;  
}

void PrintWordTopic(double** phi, int K, int W,
                    std::vector<const char *>& words) {
  for (int k = 0; k < K; ++k) {
    std::cout << "topic: " << k << std::endl;
    std::vector<Pcomp> ps(W);

    for (int w = 0; w < W; ++w) {
      Pcomp pc;
      pc.id = w;
      pc.prob = phi[k][w];
      ps[w] = pc;
    }
    
    std::sort(ps.begin(), ps.end(), LessProb());

    // print top 10 words
    for (int i = 0; i < 10; ++i) {
      Pcomp p = ps[i];
      std::cout << words[p.id] << ' ' << p.prob << std::endl;
    }
    std::cout << std::endl;
  }
}

} // namespace lda

int main(int argc, char** argv) {
  if (argc < 3) {
    std::cerr << "Usage: " << argv[0] << " bow_file vocab_file" << std::endl;
    return -1;
  }
  std::string bow_file = argv[1];
  std::string vocab_file = argv[2];
  
  const int K = 50;
  const int num_iter = 200; // # of iteration
  
  std::vector<lda::Token> tokens;
  int D = 0;
  int W = 0;
  int N = 0;

  // Read bow file
  if (lda::ReadBOWData(bow_file, tokens, D, W, N) != 0) {
    std::cerr << "Cannot read" << std::endl;
    return -1;
  }
  
  // Read vocabrary file
  std::vector<const char*> words(W);
  if (lda::ReadVocabData(vocab_file, words, W) != 0) {
    std::cerr << "Cannot read" << std::endl;
    return -1;
  }

  lda::LDA lda(D, K, W, tokens);

  for (int i = 0; i <= num_iter; ++i) {
    std::cout << "# of iteration: " << i << std::endl;
    lda.Update();
    if (i % 10 == 0) {
      double** phi = lda.GetPhi();
      lda::PrintWordTopic(phi, K, W, words);
      lda::DeletePhi(phi, K);
    }
    std::cout << std::endl;
  }
  for (std::size_t i = 0; i < words.size(); ++i)
    delete [] words[i];
  
  return 0;
}