// This script converts the MNIST dataset to a lmdb (default) or
// leveldb (--backend=leveldb) format used by caffe to load data.
// Usage:
//    convert_mnist_data [FLAGS] input_image_file input_label_file
//                        output_db_file
// The MNIST dataset could be downloaded at
//    http://yann.lecun.com/exdb/mnist/

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/highgui/highgui_c.h>
#include <opencv2/imgproc/imgproc.hpp>
#include "caffe/util/io.hpp"

#include <gflags/gflags.h>
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <lmdb.h>
#include <stdint.h>
#include <sys/stat.h>

#include <fstream>  // NOLINT(readability/streams)
#include <string>
#include <cstdio>
#include "caffe/proto/caffe.pb.h"

using namespace caffe;  // NOLINT(build/namespaces)
using std::string;

DEFINE_string(path, "./", "The path where images are save");

uint32_t swap_endian(uint32_t val) {
    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
    return (val << 16) | (val >> 16);
}

void convert_dataset(const char* image_filename, const char* label_filename,
        const char* path) {
  // Open files
  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
  std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
  CHECK(image_file) << "Unable to open file " << image_filename;
  CHECK(label_file) << "Unable to open file " << label_filename;

  std::string out_filename(path);
  out_filename = out_filename + "/mnist.dat";
  LOG(ERROR) << "Output to " << out_filename;
  std::ofstream out_file(out_filename.c_str());

  // Read the magic and the meta data
  uint32_t magic;
  uint32_t num_items;
  uint32_t num_labels;
  uint32_t rows;
  uint32_t cols;

  image_file.read(reinterpret_cast<char*>(&magic), 4);
  magic = swap_endian(magic);
  CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
  label_file.read(reinterpret_cast<char*>(&magic), 4);
  magic = swap_endian(magic);
  CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
  image_file.read(reinterpret_cast<char*>(&num_items), 4);
  num_items = swap_endian(num_items);
  label_file.read(reinterpret_cast<char*>(&num_labels), 4);
  num_labels = swap_endian(num_labels);
  CHECK_EQ(num_items, num_labels);
  image_file.read(reinterpret_cast<char*>(&rows), 4);
  rows = swap_endian(rows);
  image_file.read(reinterpret_cast<char*>(&cols), 4);
  cols = swap_endian(cols);

  char label;
  char* pixels = new char[rows * cols];
  int count = 0;

  Datum datum;
  datum.set_channels(1);
  datum.set_height(rows);
  datum.set_width(cols);
  LOG(ERROR) << "A total of " << num_items << " items.";
  LOG(ERROR) << "Rows: " << rows << " Cols: " << cols;
  for (int item_id = 0; item_id < num_items; ++item_id) {
    image_file.read(pixels, rows * cols);
    label_file.read(&label, 1);
    float float_label = (float)label;
    datum.set_data(pixels, rows*cols);
    datum.set_float_label(float_label);
    datum.set_encoded(true);

    for(int i = 0; i < rows*cols; ++i) {
      out_file << static_cast<float>(static_cast<uint8_t>(pixels[i])) << "\t";
      //std::cout << pixels[i] << "\t";
      //printf("%d\t", pixels[i]);
      //if ((i + 1) % cols == 0) {
      //  out_file << "\n";
      //  printf("\n");
      //}
    }
    out_file << (float)label;
    out_file << std::endl;

    //// save to jpg files
    //char filename[500] = {};
    //int fn_length = sprintf(filename, "%s/%d.jpg", path, item_id);
    //
    ////std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
    //std::vector<uchar> vec_data(pixels, pixels + rows * cols);
    //cv::Mat mat = cv::imdecode(vec_data, CV_LOAD_IMAGE_GRAYSCALE);
    //if (!mat.data) {
    //  LOG(ERROR) << "Could not decode datum " << rows << "\t" << cols;
    //  for (int r = 0; r < rows; ++r) {
    //    for (int c = 0; c < cols; ++c) {
    //      printf("%d\t", pixels[r*cols+c]);
    //    }
    //    printf("\n");
    //  }
    //}

    ////cv::Mat mat = caffe::DecodeDatumToCVMat(datum, false);
    //imwrite(filename, mat);
    //LOG(ERROR) << filename << "\t" << datum.float_label();
  }
  out_file.flush();
  out_file.close();

  delete[] pixels;
}

int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
        "the lmdb/leveldb format used by Caffe to load data.\n"
        "Usage:\n"
        "    convert_mnist_data [FLAGS] input_image_file input_label_file "
        "output_db_file\n"
        "The MNIST dataset could be downloaded at\n"
        "    http://yann.lecun.com/exdb/mnist/\n"
        "You should gunzip them after downloading,"
        "or directly use data/mnist/get_mnist.sh\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  google::InitGoogleLogging(argv[0]);
  convert_dataset(argv[1], argv[2], FLAGS_path.c_str());

  return 0;
}
