// This program converts a set of images to a lmdb/leveldb by storing them
// as Datum proto buffers.
// Usage:
//   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
//   subfolder1/file1.JPEG 7
//   ....

#include <algorithm>
#include <fstream>  // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include <stdio.h>
#include <iterator>
#include <sstream>

#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"

using namespace caffe;  // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;

DEFINE_string(data, "*.tab", "");
//DEFINE_bool(shuffle, false,
//    "Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
        "The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(train_cnt, 1, "# of training data");
DEFINE_int32(dim, 0, "");

int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);

#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("...");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  std::ifstream infile(FLAGS_data.c_str());
  std::vector<std::pair<std::string, int> > lines;
  string line;
  int line_num = 1; // start from 1 to keep compatible with matlab
  while (std::getline(infile, line)) {
    lines.push_back(make_pair(line, line_num++));
  }
  //if (FLAGS_shuffle) {
  //  // randomly shuffle data
  //  LOG(INFO) << "Shuffling data";
  //  shuffle(lines.begin(), lines.end());
  //}
  LOG(INFO) << "A total of " << lines.size() << " datum. ";
  CHECK_EQ(lines.size(), line_num - 1);

  int num_train = FLAGS_train_cnt; 
  int num_test = lines.size() - num_train;
  LOG(INFO) << "#train = " << num_train; 
  LOG(INFO) << "#test = " << num_test;

  // output index mapping
  //std::ofstream outfile(FLAGS_data + "_mapping");
  //for (int i = 0; i < lines.size(); ++i) {
  //  outfile << lines[i].second << std::endl;
  //} 
  //outfile.flush();
  //outfile.close();

  // Create new DB
  scoped_ptr<db::DB> train_db(db::GetDB(FLAGS_backend));
  string train_db_name(argv[1]);
  train_db->Open(train_db_name.c_str(), db::NEW);
  scoped_ptr<db::Transaction> train_txn(train_db->NewTransaction());

  // Storing to db
  Datum datum;
  datum.set_channels(1);
  datum.set_height(1);
  datum.set_width(FLAGS_dim);
  int count = 0;
  const int kMaxKeyLength = 256;
  char key_cstr[kMaxKeyLength];
  int data_size = 0;
  bool data_size_initialized = false;
  //float fea[FLAGS_dim];
  float label;
  for (int line_id = 0; line_id < num_train; ++line_id) {
    datum.clear_data();
    datum.clear_float_data();
    std::istringstream iss(lines[line_id].first);
    vector<float> fea((std::istream_iterator<float>(iss)), std::istream_iterator<float>());
    CHECK_EQ(fea.size(), FLAGS_dim + 1);
    for (int i = 0; i < FLAGS_dim; ++i) {
      datum.add_float_data(fea[i]);
    }
    datum.set_label((int)fea[FLAGS_dim]);
    //char* str = lines[line_id].c_str();
    //for (int i = 0; i < FLAGS_dim; ++i) {
    //  sscanf(lines[line_id].c_str(), "%f", &(fea[i]));
    //  datum.add_float_data(fea[i]);
    //}   
    //sscanf(lines[line_id].c_str(), "%f", &label);
    //datum.set_float_label(label);

    // sequential
    lines[line_id].first.resize(200);
    int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    train_txn->Put(string(key_cstr, length), out);

    if (++count % 1000 == 0) {
      // Commit db
      train_txn->Commit();
      train_txn.reset(train_db->NewTransaction());
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    train_txn->Commit();
    LOG(ERROR) << "Processed " << count << " files.";
  }

  // ----------------------------
  LOG(ERROR) << "train db done.";
  
  // Create new DB
  scoped_ptr<db::DB> test_db(db::GetDB(FLAGS_backend));
  string test_db_name(argv[2]);
  test_db->Open(test_db_name.c_str(), db::NEW);
  scoped_ptr<db::Transaction> test_txn(test_db->NewTransaction());

  // Storing to db
  count = 0;
  data_size = 0;
  data_size_initialized = false;
  for (int line_id = num_train; line_id < num_test + num_train; ++line_id) {
    datum.clear_data();
    datum.clear_float_data();
    datum.clear_data();
    datum.clear_float_data();
    std::istringstream iss(lines[line_id].first);
    vector<float> fea((std::istream_iterator<float>(iss)), std::istream_iterator<float>());
    CHECK_EQ(fea.size(), FLAGS_dim + 1);
    for (int i = 0; i < FLAGS_dim; ++i) {
      datum.add_float_data(fea[i]);
    }
    datum.set_label((int)fea[FLAGS_dim]);
    //for (int i = 0; i < FLAGS_dim; ++i) {
    //  sscanf(lines[line_id].c_str(), "%f", &(fea[i]));
    //  datum.add_float_data(fea[i]);
    //}   
    //sscanf(lines[line_id].c_str(), "%f", &label);
    //datum.set_float_label(label);

    // sequential
    lines[line_id].first.resize(200);
    int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
        lines[line_id].first.c_str());
    // Put in db
    string out;
    CHECK(datum.SerializeToString(&out));
    test_txn->Put(string(key_cstr, length), out);

    if (++count % 1000 == 0) {
      // Commit db
      test_txn->Commit();
      test_txn.reset(test_db->NewTransaction());
      LOG(ERROR) << "Processed " << count << " files.";
    }
  }
  // write the last batch
  if (count % 1000 != 0) {
    test_txn->Commit();
    LOG(ERROR) << "Processed " << count << " files.";
  }

  return 0;
}
