#version 330 core

//!oiv_include <Inventor/oivShaderState.h>


//!oiv_include <Inventor/oivAlgebraicShape.h>

uniform float coneRadius;
uniform float cylinderRadius;
uniform float arrowhead;

// Ray intersects cylinder
bool
OivASRayIntersectionCylinder ( in OivASRay ray, inout OivASPoint p )
{
  vec2 roots = vec2(0.0);
  vec3 abc   = vec3( dot(ray.rd.xz, ray.rd.xz),
                     2.0*dot(ray.rs.xz, ray.rd.xz),
                     dot(ray.rs.xz, ray.rs.xz) - 1.0 );

  if (! OivASSolveQuadric(abc, roots) )
    return false;

  float t = length(ray.re - ray.rs);
  int   s = -1;

  // keep roots in range y in [-1; 1]
  for ( int i = 0; i < 2; ++i ) {
    if ( roots[i] > 0.0 && roots[i] < t && abs(ray.rs.y + roots[i]*ray.rd.y) <= 1.0 ) {
      t = roots[i];
      s = i;
    }
  }


  // test planes for caps
  roots = vec2(1.0, -1.0);
  for ( int i = 0; i < 2; ++i ) {
    float r = 0.0;
    if (! OivASRayPlaneIntersection(ray, vec4(0.0, roots[i], 0.0, 1.0), r) )
      continue;
    // avoid if point is outside cylinder
    vec2 rp = ray.rs.xz + r*ray.rd.xz;
    if ( (dot(rp, rp) <= 1.0) && r > 0.0 && r < t  ) {
      t = r;
      s = 2+i;
    }
  }

  switch (s) {
  case -1: return false;
  case 0:
  case 1:
    p.position = ray.rs + t*ray.rd;
    p.normal = normalize(vec3(p.position.x,
                              0.0,
                              p.position.z));
    return true;
  case 2:
  case 3:
    p.position = ray.rs + t*ray.rd;
    p.normal = normalize(vec3(0.0, p.position.y, 0.0));
    return true;
  }
  return false;
}

// Ray intersects cone
bool
OivASRayIntersectionCone ( in OivASRay ray, inout OivASPoint p )
{
  vec2 roots = vec2(-1.0);
  vec3 abc   = vec3( ray.rd.x*ray.rd.x - 0.25*ray.rd.y*ray.rd.y + ray.rd.z*ray.rd.z,
                     2.0*(ray.rs.x*ray.rd.x - 0.25*ray.rs.y*ray.rd.y + 0.25*ray.rd.y + ray.rs.z*ray.rd.z),
                     ray.rs.x*ray.rs.x - 0.25*(ray.rs.y - 1.0)*(ray.rs.y - 1.0) + ray.rs.z*ray.rs.z );

  if (! OivASSolveQuadric(abc, roots) )
    return false;

  float t = length(ray.re - ray.rs);
  int   s = -1;

  // keep roots in range y in [-1; 1]
  for ( int i = 0; i < 2; ++i ) {
    if ( roots[i] > 0.0 && roots[i] < t && abs(ray.rs.y + roots[i]*ray.rd.y) <= 1.0 ) {
      t = roots[i];
      s = i;
    }
  }

  // test with plane
  float r = 0.0;
  if ( OivASRayPlaneIntersection(ray, vec4(0.0, -1.0, 0.0, 1.0), r) ) {
    // avoid if point is outside cone
    vec2 rp = ray.rs.xz + r*ray.rd.xz;
    if ( (dot(rp, rp) <= 1.0) && r > 0.0 && r < t ) {
      t = r;
      s = 2;
    }
  }

  switch (s) {
  case -1: return false;
  case 0:
  case 1:
    p.position = ray.rs + t*ray.rd;
    p.normal = normalize(vec3(p.position.x*2.0,
                              0.5 - 0.5*p.position.y,
                              p.position.z*2.0));
    return true;
  case 2:
    p.position = ray.rs + t*ray.rd;
    p.normal = normalize(vec3(0.0, -1.0, 0.0));
    return true;
  }
  return false;
}

vec3
changeFramePos (in vec3 pos, vec3 scale, vec3 translation)
{
  return (pos - translation) / scale;
}

vec3
unchangeFramePos (in vec3 pos, vec3 scale, vec3 translation)
{
  return pos * scale + translation;
}

vec3
changeFrameDir (in vec3 dir, vec3 scale)
{
  return normalize(dir / scale);
}

OivASRay 
changeFrame(in OivASRay ray, vec3 scale, vec3 translation)
{
  OivASRay newRay;
  newRay.rs = changeFramePos(ray.rs, scale, translation);
  newRay.re = changeFramePos(ray.re, scale, translation);
  newRay.rd = normalize(newRay.re - newRay.rs);
  return newRay;
}

void 
backToOriginalFrame (inout OivASPoint point, vec3 scale, vec3 translation)
{
  point.position = unchangeFramePos(point.position, scale, translation);
  point.normal = changeFrameDir(point.normal, scale);
}

bool
OivASRayIntersection ( in OivASRay ray, inout OivASPoint p )
{
  bvec2 result;

  float maxRatio = max(coneRadius, cylinderRadius);
  float cylinderRatio = cylinderRadius/maxRatio;
  float coneRatio = coneRadius/maxRatio;

  vec3 cylinderFrameScale = vec3(cylinderRatio, 1.0-arrowhead, cylinderRatio);
  vec3 cylinderFrameTranslation = vec3(0.0, -arrowhead, 0.0);
  vec3 coneFrameScale = vec3(coneRatio, arrowhead, coneRatio);
  vec3 coneFrameTranslation = vec3(0.0, 1.0-arrowhead, 0.0);

  OivASPoint coneSolution = p;
  OivASPoint cylinderSolution = p;
  result = bvec2(OivASRayIntersectionCone(changeFrame(ray, coneFrameScale, coneFrameTranslation), coneSolution),
                 OivASRayIntersectionCylinder(changeFrame(ray, cylinderFrameScale, cylinderFrameTranslation), cylinderSolution));

  if ( !any(result) )
    return false;

  backToOriginalFrame(coneSolution, coneFrameScale, coneFrameTranslation);
  backToOriginalFrame(cylinderSolution, cylinderFrameScale, cylinderFrameTranslation);

  float distCone = (result.x)? length(ray.rs - coneSolution.position) : 1e10;
  float distCylinder = (result.y)? length(ray.rs - cylinderSolution.position) : 1e10;

  p = (distCone < distCylinder)? coneSolution : cylinderSolution;
  return true;
}







