#define MAX_DISTANCE        (16.0f)
#define MIN_DELTA           (0.0001f)
#define MAX_REFLECTIONS     (8)

// Background
#define BG_DIRECTION        ((float3)(0.1, -0.1, 1.0))
#define BG_COLOR1           ((float3)(1.0, 1.0, 1.0))
#define BG_COLOR2           ((float3)(0.6, 0.8, 1.0) * 0.5f)

#define MOD(x, y)           ((x) - (y) * floor((x) / (y));

#ifndef M_PI
#define M_PI                (3.1415926535897932384626433832795f)
#endif

// ------------------------------------------------------------------------------------------------
float2 rand2n(float2 *seed)
{
    *seed += (float2)(-1.0f, 1.0f);
    // implementation based on: lumina.sourceforge.net/Tutorials/Noise.html
    float s = 0.0f;
    return (float2)(fract(sin(dot((*seed).xy, (float2)(12.9898f, 78.233f))) * 43758.5453f, &s),
        fract(cos(dot((*seed).xy, (float2)(4.898f, 7.23f))) * 23421.631f, &s));
}

// ------------------------------------------------------------------------------------------------
float3 ortho(float3 v)
{
    //  See : http://lolengine.net/blog/2013/09/21/picking-orthogonal-vector-combing-coconuts
    return fabs(v.x) > fabs(v.z) ? (float3)(-v.y, v.x, 0.0f) : (float3)(0.0, -v.z, v.y);
}

// ------------------------------------------------------------------------------------------------
float3 getSampleBiased(float3 dir, float power, float2 *seed)
{
    dir = normalize(dir);
    float3 o1 = normalize(ortho(dir));
    float3 o2 = normalize(cross(dir, o1));
    float2 r = rand2n(seed);
    r.x = r.x * 2.0f * M_PI;
    r.y = pow(r.y, 1.0f / (power + 1.0f));
    float oneminus = sqrt(1.0f - r.y*r.y);
    return cos(r.x)*oneminus*o1 + sin(r.x)*oneminus*o2 + r.y*dir;
}

// ------------------------------------------------------------------------------------------------
float3 getSample(float3 dir, float2 *seed)
{
    return getSampleBiased(dir, 0.0f, seed); // <- unbiased!
}

// ------------------------------------------------------------------------------------------------
float3 getCosineWeightedSample(float3 dir, float2 *seed)
{
    return getSampleBiased(dir, 1.0f, seed);
}

// ------------------------------------------------------------------------------------------------
float3 getConeSample(float3 dir, float extent, float2 *seed)
{
    // Formula 34 in GI Compendium
    dir = normalize(dir);
    float3 o1 = normalize(ortho(dir));
    float3 o2 = normalize(cross(dir, o1));
    float2 r = rand2n(seed);
    r.x = r.x * 2.0f * M_PI;
    r.y = 1.0 - r.y*extent;
    float oneminus = sqrt(1.0f - r.y*r.y);
    return cos(r.x)*oneminus*o1 + sin(r.x)*oneminus*o2 + r.y*dir;
}

// ------------------------------------------------------------------------------------------------
float3 vRotateX(float3 p, float angle)
{
    float c = cos(angle);
    float s = sin(angle);
    return (float3)(p.x, c*p.y + s*p.z, -s*p.y + c*p.z);
}

// ------------------------------------------------------------------------------------------------
float3 vRotateY(float3 p, float angle)
{
    float c = cos(angle);
    float s = sin(angle);
    return (float3)(c*p.x - s*p.z, p.y, s*p.x + c*p.z);
}

// ------------------------------------------------------------------------------------------------
float3 vRotateZ(float3 p, float angle)
{
    float c = cos(angle);
    float s = sin(angle);
    return (float3)(c*p.x + s*p.y, -s*p.x + c*p.y, p.z);
}

// ------------------------------------------------------------------------------------------------
float sphere(float3 p, float r)
{
    return length(p) - r;
}

// ------------------------------------------------------------------------------------------------
float plane(float3 p, float z)
{
    return p.z - z;
}

// ------------------------------------------------------------------------------------------------
float sdBox(float3 p, float3 b)
{
    float3 d = fabs(p) - b;
    return fmin(fmax(d.x, fmax(d.y, d.z)), 0.0f) +
        length(fmax(d, 0.0f));
}

// ------------------------------------------------------------------------------------------------
#define DE3_ITER            (32)
#define DE3_SCALE           (2.0f)
#define DE3_MIN_RADIUS      (0.5f)
#define DE3_FIXED_RADIUS    (1.0f)
#define DE3_FOLDING_LIMIT   (1.0f)
void sphereFold(float3 *z, float *dz)
{
    float r2 = dot(*z, *z);

    if (r2 < DE3_MIN_RADIUS)
    { 
        const float temp = DE3_FIXED_RADIUS / DE3_MIN_RADIUS;
        *z *= temp;
        *dz *= temp;
    }
    else if (r2 < DE3_FIXED_RADIUS)
    { 
        float temp = DE3_FIXED_RADIUS / r2;
        *z *= temp;
        *dz *= temp;
    }
}

void boxFold(float3 *z)
{
    const float3 foldingLimit = (float3)(DE3_FOLDING_LIMIT, DE3_FOLDING_LIMIT, DE3_FOLDING_LIMIT);
    *z = clamp(*z, -foldingLimit, foldingLimit) * 2.0f - *z;
}

float DE3(float3 z, float *orbit)
{
    float3 offset = z;
    float dr = 1.0f;

    float min_dist = -1e9;

    for (int n = 0; n < DE3_ITER; n++)
    {
        boxFold(&z);
        sphereFold(&z, &dr);

        z  = DE3_SCALE * z + offset;
        dr = dr * fabs(DE3_SCALE) + 1.0f;

        min_dist = max(min_dist, length(z));
    }

    *orbit = min_dist;

    float r = length(z);
    return r / fabs(dr);
}

// ------------------------------------------------------------------------------------------------
#define DE2_SCALE   (1.5f)
#define DE2_OFFSET  (2.0f)
#define DE2_ITER    (48)
float DE2(float3 z, float *orbit)
{
    float r;
    float d;

    float min_dist = 1e9f;

    z = vRotateZ(z, M_PI / 2.0f);

    for (int n = 0; n < DE2_ITER; n++)
    {
        z = vRotateX(z, 0.31);
        z = vRotateY(z, 0.31);
        if(z.x + z.y < 0.0f) { z.xy = -z.yx; } // fold 1
        if(z.x + z.z < 0.0f) { z.xz = -z.zx; } // fold 2
        if(z.y + z.z < 0.0f) { z.zy = -z.yz; } // fold 3  
        //z.xy = fabs(z.xy);
        z = z * DE2_SCALE - DE2_OFFSET * (DE2_SCALE - 1.0f);

        d = length(z);
        min_dist = min(min_dist, d);
    }

    *orbit = min_dist;
    return d * pow(DE2_SCALE, -(float)DE2_ITER);
}

// ------------------------------------------------------------------------------------------------
float sdLayer(float3 p, float lower, float upper)
{
    return max(-p.z + lower, p.z - upper);
}

// ------------------------------------------------------------------------------------------------
float getMap(float3 p, int *object, float *orbit)
{
    float distance = MAX_DISTANCE;
    float tempDist;

    //p = vRotateZ(p, -M_PI / 32.0f);
    //p = vRotateX(p, -M_PI / 16.0f);

    *orbit = 0.0f;
    distance = DE2(p, orbit);
    *object = 1;

    return distance;
}

// ------------------------------------------------------------------------------------------------
float3 getNormal(float3 p)
{
    float h = MIN_DELTA;
    int object = 0;
    float orbit = 0.0f;
    return normalize((float3)(
        getMap(p + (float3)(h, 0.0f, 0.0f), &object, &orbit) - getMap(p - (float3)(h, 0.0f, 0.0f), &object, &orbit),
        getMap(p + (float3)(0.0f, h, 0.0f), &object, &orbit) - getMap(p - (float3)(0.0f, h, 0.0f), &object, &orbit),
        getMap(p + (float3)(0.0f, 0.0f, h), &object, &orbit) - getMap(p - (float3)(0.0f, 0.0f, h), &object, &orbit)));
}

// ------------------------------------------------------------------------------------------------
float castRay(float3 origin, float3 direction, int *object, float *orbit)
{
    float rayDistance = 0.0f;
    float rayDelta = 0.0f;
    float3 rayPosition;
    *object = 0;

    rayPosition = origin;

    for (int i = 0; i < 4096; i++)
    {
        rayDelta = getMap(rayPosition, object, orbit);

        rayDistance += rayDelta;
        rayPosition = origin + direction * rayDistance;
        if (fabs(rayDelta) <= MIN_DELTA)
        {
            return rayDistance;
        }
        if (rayDistance >= MAX_DISTANCE)
        {
            *object = 0;
            return MAX_DISTANCE;
        }
    }

    *object = 0;
    return MAX_DISTANCE;
}

// ------------------------------------------------------------------------------------------------
float3 getBackground(float3 direction)
{
    float bgVal = dot(direction, normalize(BG_DIRECTION));

    if (bgVal >= 0.0)
    {
        bgVal = pow(1.0 - bgVal, 2.5);
        return mix(BG_COLOR2, BG_COLOR1, bgVal);
    }

    return BG_COLOR1;
}

// ------------------------------------------------------------------------------------------------
float3 pal(float t, float3 a, float3 b, float3 c, float3 d )
{
    return a + b*cos(6.28318f * (c * t + d));
}

// ------------------------------------------------------------------------------------------------
float3 getSurfaceColor(float3 position, float3 direction, int object, float orbit)
{
    float3 color = (float3)(0.0f, 0.0f, 0.0f);

    if (object == 0)
    {
        return getBackground(direction);
    }
    if (object == 1)
    {
        color = (float3)(1.0, 1.0, 1.0);
    }
    else if (object == 2)
    {
        color = (float3)(1.0, 0.2, 0.2);
    }
    else if (object == 3)
    {
        color = (float3)(0.2, 1.0, 0.2);
    }
    else if (object == 4)
    {
        color = (float3)(0.2, 0.2, 1.0);
    }

    /*
    color = pal(
        2.0f * orbit,
        (float3)(0.8,0.5,0.4),
        (float3)(0.2,0.4,0.2),
        (float3)(2.0,1.0,1.0),
        (float3)(0.0,0.25,0.25));
    */

        /*
        (float3)(0.5,0.5,0.5),
        (float3)(0.5,0.5,0.5),
        (float3)(1.0,1.0,1.0),
        (float3)(0.3,0.20,0.20));
        */

        /*
        (float3)(0.5,0.5,0.5),
        (float3)(0.5,0.5,0.5),
        (float3)(1.0,0.7,0.4),
        (float3)(0.0,0.15,0.20));
        */


    return color * 0.5f;
}

// ------------------------------------------------------------------------------------------------
float3 compositeSurfaceSamples(float3 *surfaceList, float3 *directList, int samples)
{
    float3 finalColor = (float3)(1.0f, 1.0f, 1.0f);

    for (int i = 0; i <= samples; i++)
    {
        float3 surfaceColor = surfaceList[samples - i];
        float3 directColor = directList[samples - i];
        finalColor = (finalColor + directColor) * surfaceColor;
    }

    return finalColor;
}

// ------------------------------------------------------------------------------------------------
float3 drawScene(float3 origin, float3 direction, float2 uv, int current_sample)
{
    float3 surfaceList[MAX_REFLECTIONS];
    float3 directList[MAX_REFLECTIONS];

    const float3 LightVec = normalize((float3)(1.0f, 2.0f, 3.0f));
    const float3 LightColor = (float3)(1.0f, 0.8f, 0.6f) * 2.0f;

    float3 next_origin = origin;
    float3 next_direction = direction;
    int next_object = 0;

    float3 position = (float3)(0.0f, 0.0f, 0.0f);
    float3 normal;
    float orbit = 0.0f;

    for (int i = 0; i < MAX_REFLECTIONS; i++)
    {
        float2 seed = uv * (float)(current_sample + i + 1) * 1.12256f;

        float distance = castRay(next_origin, next_direction, &next_object, &orbit);

        position = next_origin + next_direction * distance;
        normal = getNormal(position);

        float3 surfaceColor = getSurfaceColor(position, next_direction, next_object, orbit);
        surfaceList[i] = surfaceColor;

        if (next_object == 0)
        {
            directList[i] = (float3)(0.0f, 0.0f, 0.0f);
            return compositeSurfaceSamples(surfaceList, directList, i);
        }

        // Add in direct lighting.
        int shadowObject = 0;
        float3 shadowDirection = getConeSample(LightVec, 0.0002f, &seed);
        float3 shadowOrigin = position + normal * MIN_DELTA * 2.0f;
        float directNDotL = max(0.0f, dot(normal, LightVec));
        if (next_object != 0)
        {
            castRay(shadowOrigin, shadowDirection, &shadowObject, &orbit);
        }
        directList[i] = (1.0f - step(0.5f, (float)shadowObject)) * directNDotL * LightColor;
        //directList[i] = (float3)(0.0f, 0.0f, 0.0f);

        next_origin = shadowOrigin;
        next_direction = getCosineWeightedSample(normal, &seed);
        next_object = 0;
    }

    surfaceList[MAX_REFLECTIONS - 1] = (float3)(0.0f, 0.0f, 0.0f);
    return compositeSurfaceSamples(surfaceList, directList, MAX_REFLECTIONS - 1);
}

// ------------------------------------------------------------------------------------------------
__kernel void mainimage(
    __global float4 *tile,
    int width,
    int height,
    int tile_size,
    int offset_x,
    int offset_y,
    int sample)
{
    int gid_x = get_global_id(0);
    int gid_y = get_global_id(1);

    float fsample = (float)(sample + 1);

    float2 screen = (float2)(gid_x + offset_x, gid_y + offset_y);
    float2 seed = screen * fsample * 1.2256f;
    float2 offset = rand2n(&seed);
    float2 p = 2.0f * (screen + offset) / (float)height;
    p -= (float2)((float)width / (float)height, 1.0f);

    float4 final = (float4)(0.0f, 0.0f, 0.0f, 1.0f);

    float3 origin = (float3)(0.0f, 2.0f, 0.0f);
    float3 direction = (float3)(p.x, -1.0f, p.y);
    direction = normalize(direction);

    final.xyz = drawScene(origin, direction, screen, sample);
    final.xyz = pow(final.xyz, (float3)(1.0f / 2.2f));

    tile[gid_x + gid_y * tile_size] = mix(
        tile[gid_x + gid_y * tile_size],
        final,
        1.0f / fsample);
}

