/*=======================================================================
** VSG_COPYRIGHT_TAG
**=======================================================================*/
/*=======================================================================
** Author      : VSG (MMM YYYY)
**=======================================================================*/

#pragma once

#include <cstring>
#include <limits>
#include <vector>

#include <Inventor/SbTime.h>
#include <Inventor/SbVec.h>
#include <LDM/nodes/SoDataSet.h>

#include <Inventor/SbDataType.h>
#include <Inventor/SbElapsedTime.h>
#include <Inventor/devices/SoCpuBufferObject.h>

#include <LDM/compressors/SoDataCompressor.h>
#include <LDM/tiles/SoCpuBufferCompressed.h>

#include "metrics.h"
#include "VolumeInfo.h"

/**
 * A structure that holds tile analysis data.
 * Processed by SoTileAnalyzer::analyse.
 */
struct SbTileAnalyzeData
{
  size_t uncompressedSize;
  size_t compressedSize;
  SbTime encodingTime;
  SbTime decodingTime;
  double originalMean;
  double originalVariance;
  double decodedMean;
  double decodedVariance;
  double covariance;
  double ssim;
  double maxError;
  double minError;
  double mae;
  double mse;
  double psnr;
};

/**
 * This tool will analyze the data in one tile.
 */
class SoTileAnalyzer
{
public:
  /**
   * Constructor.
   */
  SoTileAnalyzer( const VolumeInfo& info, const SbString& codec, int level, SbVec3i32 dimensions );

  /**
   * Destructor.
   */
  virtual ~SoTileAnalyzer();

  const SbVec3i32& getDimensions() const { return m_dimensions; }

  /**
   * Analyse tile's data and fill the analysis structure.
   * @param tileBuffer The tile data, as read from SoLDMReader.
   * @param data       A structure that will be filled with analysis data.
   */
  void analyze( SoBufferObject *tileBuffer, SbTileAnalyzeData &data ) const;

private:
  typedef void (SoTileAnalyzer::*AnalyzeFunction)( SoBufferObject* tileBuffer, SbTileAnalyzeData& data ) const;

  template <typename DataType>
  void analyze_template( SoBufferObject* tileBuffer, SbTileAnalyzeData& data ) const;

  const VolumeInfo& m_info;

  SbString m_codec;

  int m_level;

  /** The tile dimensions */
  SbVec3i32 m_dimensions;

  /** The analyzing function chosen accodring to the datatype */
  AnalyzeFunction m_analyzeFunction;
};

//*****************************************************************************
template <typename DataType>
void SoTileAnalyzer::analyze_template( SoBufferObject* tileBuffer, SbTileAnalyzeData& data ) const
{
  if ( tileBuffer == NULL )
  {
    SoDebugError::post("analyze", "NULL tile buffer specified");
    return;
  }

  SoCpuBufferCompressed* buffer = ( SoCpuBufferCompressed* ) tileBuffer;

  // Initialize compressor
  SoDataCompressor* codec = SoDataCompressor::getAppropriateCompressor( m_codec );

  if ( codec == NULL )
  {
    SoDebugError::post("SoTileAnalyzer::analyze_template", "no codec \"%s\" can be found.", m_codec.toLatin1() );
    return;
  }

  if ( m_level >= 0 )
  {
    codec->setCompressionLevel( static_cast<size_t>( m_level ) );
  }

  SbElapsedTime timer;
  size_t nbVoxels = m_dimensions[0];
  nbVoxels *= m_dimensions[1];
  nbVoxels *= m_dimensions[2];
  size_t byteCount = nbVoxels * sizeof( DataType );

  data.uncompressedSize = byteCount;
  std::vector<DataType> decodedData( nbVoxels );

  SbDataType datatype = m_info.datatype;
  SbVec2d range = m_info.range;
  SoDataCompressor::TileInfo tileInfo( m_dimensions, datatype, range );
  tileInfo.isRGBA = m_info.isRGBA;

  void* bufferMapping = buffer->map( SoBufferObject::READ_ONLY );
  DataType* voxelBuffer = static_cast<DataType*>( bufferMapping );

  // Compress
  timer.reset();
  data.compressedSize = codec->compress( bufferMapping, byteCount, tileInfo );
  data.encodingTime = timer.getElapsedTime();

  // Uncompress
  timer.reset();
  codec->uncompress( &decodedData[0], byteCount, tileInfo );
  data.decodingTime = timer.getElapsedTime();

  data.maxError = 0.0;
  data.minError = range[1];

  std::vector<double> originalSamples( nbVoxels );
  std::vector<double> originalSamplesSquare( nbVoxels );

  std::vector<double> decodedSamples( nbVoxels );
  std::vector<double> decodedSamplesSquare( nbVoxels );

  std::vector<double> covarianceSamples( nbVoxels );

  std::vector<double> absoluteError( nbVoxels );
  std::vector<double> squareError( nbVoxels );

  double error = 0.0;
  for ( size_t i = 0; i < nbVoxels; ++i )
  {
    originalSamples[i] = static_cast<double>( voxelBuffer[i] );
    originalSamplesSquare[i] = originalSamples[i]*originalSamples[i];

    decodedSamples[i] = static_cast<double>( decodedData[i] );
    decodedSamplesSquare[i] = decodedSamples[i]*decodedSamples[i];

    covarianceSamples[i] = originalSamples[i]*decodedSamples[i];

    error = voxelBuffer[i] - decodedData[i];
    squareError[i] = error*error;

    error = SbMathHelper::abs( error );
    absoluteError[i] = error;
    data.maxError = SbMathHelper::Max( data.maxError, error );
    data.minError = SbMathHelper::Min( data.minError, error );
  }

  double sumNormalize = static_cast<double>( nbVoxels );

  data.originalMean = kahanSum( originalSamples.begin(), originalSamples.end() ) / sumNormalize;
  data.originalVariance = kahanSum( originalSamplesSquare.begin(), originalSamplesSquare.end() ) / sumNormalize;
  data.originalVariance -= data.originalMean*data.originalMean;

  data.decodedMean = kahanSum( decodedSamples.begin(), decodedSamples.end() ) / sumNormalize;
  data.decodedVariance = kahanSum( decodedSamplesSquare.begin(), decodedSamplesSquare.end() ) / sumNormalize;
  data.decodedVariance -= data.decodedMean*data.decodedMean;

  data.covariance = kahanSum( covarianceSamples.begin(), covarianceSamples.end() ) / sumNormalize;
  data.covariance -= data.originalMean*data.decodedMean;

  data.ssim = ssim(
    data.originalMean, data.decodedMean,
    data.originalVariance, data.decodedVariance,
    data.covariance,
    range
  );

  data.mae = kahanSum( absoluteError.begin(), absoluteError.end() ) / sumNormalize;
  data.mse = kahanSum( squareError.begin(), squareError.end() ) / sumNormalize;

  if ( data.mse > 0.0 )
  {
    data.psnr = psnr( data.mse, range );
  }
  else
  {
    data.psnr = std::numeric_limits<double>::infinity();
  }

  buffer->unmap();
  delete codec;
}
