//#include <Inventor/nodes/SoAlgebraicArrow.h>
#include "SoAlgebraicArrow.h"
#include <Inventor/nodes/SoFragmentShader.h>
#include <Inventor/nodes/SoAlgebraicCone.h>
#include <Inventor/actions/SoGLRenderAction.h>
#include <Inventor/actions/SoCallbackAction.h>
#include <Inventor/actions/SoWriteAction.h>
#include <Inventor/actions/SoGetPrimitiveCountAction.h>
#include <Inventor/actions/SoRayPickAction.h>
#include <Inventor/STL/algorithm>


SO_NODE_SOURCE(SoAlgebraicArrow);


SoAlgebraicArrow::SoAlgebraicArrow ()
  : SoAlgebraicShape()
{
  SO_NODE_CONSTRUCTOR(SoAlgebraicArrow);
  SO_NODE_ADD_FIELD(length, (2.f));
  SO_NODE_ADD_FIELD(coneRadius, (0.4f));
  SO_NODE_ADD_FIELD(cylinderRadius, (0.1f));
  SO_NODE_ADD_FIELD(arrowhead, (0.2f));

  // create ray/intersection shader object
  SoFragmentShader* frag = SoAlgebraicShape::fragmentShaderFromFile("$OIVHOME/examples/source/Inventor/Features/AlgebraicShape/CustomAlgebraicShape/shaders/SoAlgebraicArrow_RayIntersection_frag.glsl");

  // add shader parameters to tune arrow look
  SoShaderParameter1f* sparam = frag->addShaderParameter1f("coneRadius", coneRadius.getValue());
  sparam->value.connectFrom(&coneRadius);
  sparam = frag->addShaderParameter1f("cylinderRadius", cylinderRadius.getValue());
  sparam->value.connectFrom(&cylinderRadius);
  sparam = frag->addShaderParameter1f("arrowhead", arrowhead.getValue());
  sparam->value.connectFrom(&arrowhead);

  // fill ray intersection shader slot
  rayIntersection.setValue(frag);

  // set workspace
  workspace.setValue(SoAlgebraicShape::BOX);

  isBuiltIn = TRUE;
}

void
SoAlgebraicArrow::initClass ()
{
  SO__NODE_INIT_CLASS(SoAlgebraicArrow, "AlgebraicArrow", SoAlgebraicShape);
}

void
SoAlgebraicArrow::exitClass ()
{
  SO__NODE_EXIT_CLASS(SoAlgebraicArrow);
}

void
SoAlgebraicArrow::computeBBox ( SbBox3f &box, SbVec3f &center )
{
  float r; 
  if (arrowhead.getValue() >= 1.0)
    r = coneRadius.getValue(); 
  else if (arrowhead.getValue() == 0.0)
    r = cylinderRadius.getValue(); 
  else 
    r = std::max(coneRadius.getValue(), cylinderRadius.getValue()); 
  float h = length.getValue();
  box.setBounds(-r, -0.5f*h, -r, r, 0.5f*h, r);
  center.setValue(0.f, 0.f, 0.f);
}

void 
SoAlgebraicArrow::rayPick ( SoRayPickAction *action )
{
  // First see if the object is pickable
  if (! shouldRayPick(action))
    return;

  std::vector<bool> result; 
  result.resize(2);

  const SbVec3f cylinderFrameTranslation = SbVec3f(0.f, (- arrowhead.getValue() * 0.5f * length.getValue()), 0.f); 
  const SbVec3f coneFrameTranslation = SbVec3f(0.f, 0.5f * length.getValue(), 0.f);

  // Compute the picking ray in our current object space
  computeObjectSpaceRay(action);

  // Get ray
  SbLine ray = action->getLine();

  SbVec3f coneSolution, cylinderSolution; 
  result[1] = rayPickCone(changeFrame(ray, coneFrameTranslation), coneSolution); 
  result[0] = rayPickCylinder(changeFrame(ray, cylinderFrameTranslation), cylinderSolution); 

  if ( result[0] == false && result[1] == false )
    return;

  backToOriginalFrame(coneSolution, coneFrameTranslation);
  backToOriginalFrame(cylinderSolution, cylinderFrameTranslation);                                        

  float distCylinder = (result[0])? (cylinderSolution - ray.getPosition()).length() : 1e10f;
  float distCone = (result[1])? (coneSolution - ray.getPosition()).length() : 1e10f; 

  SbVec3f p = ( distCone < distCylinder )?  coneSolution : cylinderSolution; 
 
  std::cout << "Picked point: " << p << std::endl;

  action->addIntersection(p);
}

bool 
SoAlgebraicArrow::rayPickCylinder( SbLine ray, SbVec3f& p)
{
  if (arrowhead.getValue() >= 1.0 || cylinderRadius.getValue() == 0.0)
    return false; 

  SbVec3f rs = ray.getPosition();
  SbVec3f rd = ray.getDirection();
  
  // Get subvectors
  SbVec2f rsxz (rs[0], rs[2]);
  SbVec2f rdxz (rd[0], rd[2]);

  // Compute intersection using quadric form of the cylinder
  float r = cylinderRadius.getValue();
  float h = (1.f-arrowhead.getValue())*length.getValue();
  if (h == 0 || r == 0)
    return false; 

  // Solve quadratic system
  SbVec3f abc (rdxz.dot(rdxz), 2.f*rsxz.dot(rdxz), rsxz.dot(rsxz) - r*r );
  SbVec2f roots;

  // has roots ?
  if (! solveQuadric(abc, roots) )
    return false; 

  float t = 1e10;
  int   s = -1;

  // keep roots in range y in [-1; 1]
  for ( int i = 0; i < 2; ++i ) {
    if ( roots[i] > 0.0 && roots[i] < t && abs(ray.getPosition()[1] + roots[i]*ray.getDirection()[1]) <= 0.5f*h ) 
    {
      t = roots[i];
      s = i;
    }
  }


  SbVec2f fakesign (1.f, -1.f);
  for ( int i = 0; i < 2; ++i ) {
    SbVec3f plane (0.f, static_cast<float>(i), 0.f);
    float d = rd.dot(plane);
    if ( fabs(d) < 1e-5f )
      continue;
    float tt = -(rs.dot(plane) + 0.5f*h) / d;
    if ( tt > 0.f && tt < t && (ray.getPosition()[0]*ray.getPosition()[0] + ray.getPosition()[2]*ray.getPosition()[2]) <= r*r)
    {
      s = 2+i; 
      t = tt;
    }
  }

  switch (s) {
  case -1:
    return false;
  case 0:
  case 1:
  case 2:
  case 3:
    p = rs + t*rd;
    return true;
  }
  return false;
}

bool 
SoAlgebraicArrow::rayPickCone( SbLine ray , SbVec3f& p)
{
  if ( arrowhead.getValue() == 0.0 || coneRadius.getValue() == 0.0)
    return false; 

  SbVec3f rs = ray.getPosition();
  SbVec3f rd = ray.getDirection(); 

  // Get subvectors
  SbVec2f rsxz (rs[0], rs[2]);
  SbVec2f rdxz (rd[0], rd[2]);

  const float r = coneRadius.getValue();
  const float h = arrowhead.getValue()*length.getValue();
  const float cos2teta = h*h / (h*h + r*r);
  const SbVec3f D = rd;
  const SbVec3f V = SbVec3f(0, -1, 0);
  const SbVec3f CO = rs;
  // Solve quadratic system
  // See http://lousodrome.net/blog/light/2017/01/03/intersection-of-a-ray-and-a-cone/
  //
  // Assuming P is intersection point between a ray O + rd and the cone, P = O + t*rd
  // t is solution of a*t^2 + b*t + c, with a, b, c is:
  const SbVec3f abc(D.dot(V) * D.dot(V) - cos2teta,
                    2*(D.dot(V)*CO.dot(V) - D.dot(CO) * cos2teta),
                    CO.dot(V)*CO.dot(V) - CO.dot(CO)*cos2teta );

  // has roots ?
  SbVec2f roots;
  if (!solveQuadric(abc, roots))
    return false;

  float t = 1e10f;
  int   s = -1;
   
  // keep roots in range y in [-h; 0]
  for ( int i = 0; i < 2; ++i ) 
  {
    if ( roots[i] > 0.0 && roots[i] < t && abs(ray.getPosition()[1] + roots[i]*ray.getDirection()[1] + 0.5*h) <= 0.5f*h )
    {
      t = roots[i];
      s = i;
    }
  }

  SbVec3f plane (0.f, -1.f, 0.f);
  float d = rd.dot(plane);
  if ( fabs(d) > 1e-5f )
  {
    float tt = -(rs.dot(plane) + 0.5f*h) / d;
    if ( tt > 0.f && tt < t  && ((ray.getPosition()[0]*ray.getPosition()[0] + ray.getPosition()[2]*ray.getPosition()[2]) <= r*r && ray.getPosition()[1] == 0.5f*h ))
    {
      t = tt;
      s = 2; 
    }
  }
    
  switch (s) 
  {
  case -1: 
    return false;
  case 0:
  case 1:
  case 2:
    p = rs + t * rd;
    return true; 
  }
  return false;
}

SbLine 
SoAlgebraicArrow::changeFrame( SbLine ray,  SbVec3f translation) const
{
  SbLine newRay; 
  newRay.setPosDir(ray.getPosition() - translation, ray.getDirection());
  return newRay; 
}

void 
SoAlgebraicArrow::backToOriginalFrame (SbVec3f& point, const SbVec3f translation) const
{
  point = point + translation;
}
