#include "Gauge.h"

#ifdef _WIN32
#include <windows.h>
#include <psapi.h>
#endif
#include <sstream>
#include <iomanip>
#include <cmath>
#include <limits>

static size_t
usedMemory() 
{
#ifdef _WIN32
  PROCESS_MEMORY_COUNTERS_EX pmc;
  GetProcessMemoryInfo(GetCurrentProcess(), reinterpret_cast<PROCESS_MEMORY_COUNTERS *>(&pmc), sizeof(pmc));
  return pmc.PrivateUsage;
#else
  return 0;
#endif
}


std::string
Gauge::memToString(double mem) 
{
  std::stringstream ss;

  ss << std::setprecision(3);

  if (std::abs(mem) > 1e9) {
    ss << mem / 1e9 << " GB";
  }
  else if (std::abs(mem) > 1e6) {
    ss << mem / 1e6 << " MB";
  }
  else if (std::abs(mem) > 1e3) {
    ss << mem / 1e3 << " kB";
  }
  else {
    ss << mem << " B";
  }

  return ss.str();
}

std::string
Gauge::timeToString(double eltime)
{
  std::stringstream ss;
  ss << std::setprecision(3);
  ss << eltime << " sec";
  return ss.str();
}

Gauge::Gauge() 
{
  reset("");
}


Gauge::~Gauge() 
{
}

int
Gauge::sample(const std::string& tag, bool display) 
{
  stage_names_.push_back(tag);
  stage_mems_.push_back(usedMemory());
  stage_times_.push_back(m_elapsedTime.getElapsed());
  int memUsed;
  if (stage_mems_[stage_mems_.size() - 1] > stage_mems_[stage_mems_.size() - 2])
    memUsed  = int(stage_mems_[stage_mems_.size() - 1] - stage_mems_[stage_mems_.size() - 2]);
  else
    memUsed = - int(stage_mems_[stage_mems_.size() - 2] - stage_mems_[stage_mems_.size() - 1]);
  if (display)
    std::cout << stage_names_.back() << " in " << timeToString(stage_times_.back() - stage_times_[stage_times_.size() - 2])
    << ", mem " << memToString(memUsed) << std::endl;
  return memUsed;
}


void
Gauge::reset(const std::string& tag)
{
  stage_names_.clear();
  stage_mems_.clear();
  stage_times_.clear();
  m_elapsedTime.reset();

  stage_names_.push_back(tag);
  stage_mems_.push_back(usedMemory());
  stage_times_.push_back(m_elapsedTime.getElapsed());
}


std::string
Gauge::report() const
{
  std::stringstream ss;

  if (!stage_mems_.empty())
  {
    size_t num_sample = stage_mems_.size();

    ss << stage_names_[0] << std::endl;
    for (size_t i = 1; i < num_sample; ++i) {
      ss << stage_names_[i] ;
      ss << ", time " << timeToString(stage_times_[i] - stage_times_[i - 1]);
      if (stage_mems_[i] > stage_mems_[i - 1])
        ss << ", mem " << memToString((double)stage_mems_[i] - stage_mems_[i - 1]);
      else
        ss << ", mem " << memToString((double)stage_mems_[i-1] - stage_mems_[i]);
      ss << std::endl;
    }

  }

  return ss.str();
}

std::string
Gauge::reportStat(std::ostringstream& result, const char* sep) const
{
  std::stringstream ss;

  if (!stage_mems_.empty())
  {
    size_t num_sample = stage_mems_.size();

    int first = 0;
    while (stage_names_[first] != "extraction") ++first;


    // first extract not benched to ignore initial memory alloc
    double mint = std::numeric_limits<double>::infinity();
    double sumt = 0;
    int summ = 0;
    int numExtract = 0;
    for (size_t i = first+1; i < num_sample; ++i) {
      if (stage_names_[i] == "extraction")
      {
        double time = stage_times_[i] - stage_times_[i - 1];
        if (time < mint)
          mint = time;
        sumt += time;
        if (stage_mems_[i] > stage_mems_[i - 1])
          summ += int(stage_mems_[i] - stage_mems_[i - 1]);
        else
          summ -= int(stage_mems_[i-1] - stage_mems_[i]);
        ++numExtract;
      }
    }
    double firstTime = stage_times_[first] - stage_times_[first - 1];
    int firstMem = int(stage_mems_[first] - stage_mems_[first - 1]);
    ss << "Time to first extract: " << timeToString(firstTime) << std::endl;
    ss << "Average Time to re-extract: " << timeToString(sumt / (double)numExtract) << std::endl;
    ss << "Min Time to re-extract:" << timeToString(mint) << std::endl;

    ss << "First extract mem allocation: " << memToString(firstMem) << std::endl;
    ss << "Average mem to re-extract: " << memToString(summ / (double)numExtract) << std::endl;

    result << sep << std::setw(6) << firstTime << sep << std::setw(6) << sumt / (double)numExtract << sep << std::setw(6) << mint;
    result << sep << std::setw(6) << static_cast<double>(firstMem) / 1e6;
    result << sep << std::setw(6) << static_cast<double>(summ / (double)numExtract) / 1e6;
  }

  return ss.str();
}
