// File : raytrace.cpp
// Info : This file is the main source file for the raytracer.
// Author : Joel Duggan (joelgduggan@gmail.com)

#include <iostream>
#include <iomanip>
#include <math.h>
#include <time.h>
#include <stdlib.h>
#include <float.h>

#include <SDL.h>

#include "FreeImage.h"
#include "tinyxml2.h"

#include "math.h"
#include "color.h"
#include "object.h"
#include "light.h"

using namespace std;

// V A R I A B L E S //////////////////////////////////////////////////////////////////////////////

SDL_Surface *Display = NULL;

int screenWidth;
int screenHeight;
double viewDistance;

Vector cameraLocation;
Vector cameraRotation;

int maxDepth;
int useExposure;
double exposure;
int useSRGB;
int antialias;
int numThreads;

bool saveToFile = false;

Object**    objectList = NULL;                   // array of objects
int         numObjects = 0;

Light**     lightList = NULL;                    // array of lights
int         numLights = 0;

Uint32* imageData = NULL;

// P R O T O T Y P E S ////////////////////////////////////////////////////////////////////////////

int    processLines(void* offset);
Uint32 renderImage(Uint32 interval, void *param);
TColor Trace_Ray(const Vector& start, const Vector& dir, Uint8 depth, int sender = -1);
TColor Shade_Ray(Uint16 object_index, const Vector& start, const Vector& intersection, Uint8 depth);
TColor Trace_Shadow_Ray(const Vector& start, const Vector& dir, Uint16 light_index, Uint16 sender);
void   Load_Config_File();
void   saveImage(int num);

// F U N C T I O N S //////////////////////////////////////////////////////////////////////////////

int main(int argc, char *argv[])
{

    // read in render setting and the scene (objects + lights)
    cout << "Raytracer (version 2.0), programmed by Joel Duggan\n\n";
    cout.flush();
    Load_Config_File();

    // setup sdl
    SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER);
    SDL_WM_SetCaption("Raytracer", 0);
    Display = SDL_SetVideoMode(screenWidth, screenHeight, 32, SDL_HWSURFACE | SDL_DOUBLEBUF);
    SDL_FillRect(Display, NULL, 0x0);

    imageData = new Uint32[screenWidth * screenHeight];

    // get the starting render time
    Uint32 start_time = SDL_GetTicks();

    cout << "\nRendering...";

    // start threads
    SDL_Thread** thread = new SDL_Thread*[numThreads];
    Uint16* offsets = new Uint16[numThreads];

    for (int i = 0; i < numThreads; i++) {
        offsets[i] = i;
        thread[i] = SDL_CreateThread(processLines, &offsets[i]);
    }

    // render progress every 500 milliseconds
    SDL_TimerID timer = SDL_AddTimer(500, renderImage, NULL);

    // wait for all threads to finish rendering
    for (int i = 0; i < numThreads; i++)
        SDL_WaitThread(thread[i], NULL);
    SDL_RemoveTimer(timer);

    delete[] thread;
    delete[] offsets;

    // render image
    renderImage(0, NULL);

    // print the render time
    cout << "\n\nRender time : " << (SDL_GetTicks() - start_time) / 1000.0 << " seconds\n\n";

    // save image
    if (saveToFile)
        saveImage(0);

    // wait for quit
    SDL_Event event;
    while (true) {

        if (SDL_PollEvent(&event) != 0)
            if (event.type == SDL_QUIT)
                break;
    }

    SDL_Quit();

    // delete all allocations

    delete[] imageData;

    for (int i = 0; i < numObjects; i++)
        delete objectList[i];
    delete[] objectList;

    for (int i = 0; i < numLights; i++)
        delete lightList[i];
    delete[] lightList;

    return 0;
}

Uint32 renderImage(Uint32 interval, void *param) {

    if (SDL_MUSTLOCK(Display))
        SDL_LockSurface(Display);

    memcpy(Display->pixels, imageData, screenWidth*screenHeight*4);

    if (SDL_MUSTLOCK(Display))
        SDL_UnlockSurface(Display);
    SDL_Flip(Display);

    cout << ".";

    return interval;
}

Vector getCameraRay(double x, double y) {

    Vector dir (x - (screenWidth / 2.0), (screenHeight / 2.0) - y, viewDistance);

    dir.Rotate_X(cameraRotation.x);
    dir.Rotate_Y(cameraRotation.y);
    dir.Rotate_Z(cameraRotation.z);
    dir.Normalize();

    return dir;
}

int processLines(void* offset) {

    Uint16 start = *(Uint16*)offset;
    TColor color;                            // color of a pixel

    for (Uint16 y = start; y < screenHeight; y += numThreads)
    {
        for (Uint16 x = 0; x < screenWidth; x++)
        {

            // check if we should antialias

            if (antialias == 0)
            {
                color = Trace_Ray(cameraLocation, getCameraRay(x + 0.5, y - 0.5), 1);        // trace the ray
            }
            else if (antialias == 1) // 2x2
            {
                color =  Trace_Ray(cameraLocation, getCameraRay(x + 0.25, y - 0.25), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.75, y - 0.25), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.25, y - 0.75), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.75, y - 0.75), 1);
                color /= 4.0;
            }
            else if (antialias == 2)  // 4x4
            {
                color =  Trace_Ray(cameraLocation, getCameraRay(x + 0.125, y - 0.125), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.375, y - 0.125), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.625, y - 0.125), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.875, y - 0.125), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.125, y - 0.375), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.375, y - 0.375), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.625, y - 0.375), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.875, y - 0.375), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.125, y - 0.625), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.375, y - 0.625), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.625, y - 0.625), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.875, y - 0.625), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.125, y - 0.875), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.375, y - 0.875), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.625, y - 0.875), 1);
                color += Trace_Ray(cameraLocation, getCameraRay(x + 0.875, y - 0.875), 1);
                color /= 16.0;
            }

            imageData[y * screenWidth + x] = SDL_MapRGB(Display->format, color.getR8(), color.getG8(), color.getB8());

        }    // end for each pixel
    }    // end for each row

    return 0;
}


//-------------------------------------------------------------------------------------------------
// This traces the sent ray and returns the color at its hit point.
//-------------------------------------------------------------------------------------------------
TColor Trace_Ray(const Vector& start, const Vector& dir, Uint8 depth, int sender)
{
    if (depth > maxDepth)
        return Black;

    int closestIndex = -1;
    double temp;
    double distance = DBL_MAX;

    // loop through each object and check for a collision
    for (int index = 0; index < numObjects; index++)
    {
        // check for closest intersection
        if (index != sender && objectList[index]->Ray_Intersect(start, dir, temp) && temp < distance)
        {
            distance = temp;
            closestIndex = index;
        }

    }    // end for each object

    // if an object was hit then get its color
    if (closestIndex != -1)
        return Shade_Ray(closestIndex, start, start + dir * distance, depth + 1);
    else
        return Black;

}    // end Trace_Ray


//-------------------------------------------------------------------------------------------------
// This computes the shading on a point on the object, and also does stuff like shadows and reflections.
//-------------------------------------------------------------------------------------------------
TColor Shade_Ray(Uint16 objectIndex, const Vector& start, const Vector& intersection, Uint8 depth)
{
    Vector dirToLight;                            // ray to light from intersection
    TColor temp_color;                            // temporary color
    double angle;                                // angle between viewer light ray and surface normal
    double specular;                            // amount of specular light reflected

    Vector normal = objectList[objectIndex]->Get_Normal(intersection);
    Vector v = Normalize(start - intersection);    // direction to viewer

    TColor color = objectList[objectIndex]->ambient;    // used to accumulate final color

    // loop through each light and do diffuse and specular light
    for (int i = 0; i < numLights; i++)
    {
        // get a ray to the light from the intersection point
        dirToLight = lightList[i]->Get_Light_Vector(intersection);
        angle       = dirToLight * normal;

        // make sure the light hits the surface
        if (angle > 0.0)
        {
            TColor lightAmount = Trace_Shadow_Ray(intersection, dirToLight, i, objectIndex);

            if (lightAmount != Black) {

                // calculate amount of diffuse light to add to color
                temp_color = objectList[objectIndex]->getDiffuseColorAt(intersection) * angle;

                // calculate amount of specular light to add to color
                specular = (normal * angle * 2.0 - dirToLight) * v;

                if (specular >= 0.0)
                    temp_color += objectList[objectIndex]->coefficient * pow(specular, objectList[objectIndex]->exponent);

                // add in diffuse and specular (multiplied by attenuation)
                temp_color *= lightAmount;
                color += (temp_color * lightList[i]->Get_Diffuse(intersection));
            }
        }
    }

    // do reflections
    if (objectList[objectIndex]->reflective > 0.0)
    {
        // calculate the reflection direction
        Vector r = Normalize(normal * 2.0 * (normal * v) - v);

        // add the reflection color
        color += Trace_Ray(intersection + r * MYEPSILON, r, depth + 1, objectIndex) * objectList[objectIndex]->reflective;
    }

    // do transparency
    if (objectList[objectIndex]->transparency < 1.0) {

        double t = objectList[objectIndex]->transparency;
        temp_color = Trace_Ray(intersection + (-normal) * MYEPSILON, -v, depth, objectIndex) * objectList[objectIndex]->diffuse;
        color += temp_color * (1.0 - t);
    }

    // return the final color
    if (useExposure)
        color.doExposure(exposure);
    if (useSRGB)
        color.doSRGB();
    color.makeValid();
    return color;

}    // end Shade_Ray


//-------------------------------------------------------------------------------------------------
// This traces the sent shadow ray and see if it hit's something. Return's black if completely in shadow, white if completely unobstructed, or any other color means that it passed through some
// semi-transparent objects.
//-------------------------------------------------------------------------------------------------
TColor Trace_Shadow_Ray(const Vector& start, const Vector& dirToLight, Uint16 light_index, Uint16 sender)
{
    double distance;
    TColor lightAmount(White);

    // loop through each object and check for a collision
    for (int i = 0; i < numObjects; i++)
    {
        if (i != sender && objectList[i]->Ray_Intersect(start, dirToLight, distance) && lightList[light_index]->Shadow(start, start + dirToLight * distance))
            lightAmount *= objectList[i]->diffuse * (1.0 - objectList[i]->transparency);

    }

    return lightAmount;
}


// functions for loading config xml file

Uint16 numChildren(tinyxml2::XMLElement* e, char* ofType = NULL) {

    Uint16 num = 0;
    e = e->FirstChildElement(ofType);

    while ( e != 0) {
        num++;
        e = e->NextSiblingElement();
    }
    return num;
}

double getFloat(tinyxml2::XMLElement* element, char* name) {

    return atof(element->FirstChildElement(name)->GetText());
}

Vector getVector(tinyxml2::XMLElement* element, char* name) {

    tinyxml2::XMLElement* e = element->FirstChildElement(name);
    if (e->Attribute("x") != NULL)
        return Vector(e->FloatAttribute("x"), e->FloatAttribute("y"), e->FloatAttribute("z"));
    else
        return Vector(e->FloatAttribute("a"), e->FloatAttribute("b"), e->FloatAttribute("c"));
}

TColor getColor(tinyxml2::XMLElement* element, char* name) {

    tinyxml2::XMLElement* e = element->FirstChildElement(name);
    return TColor(e->FloatAttribute("r"), e->FloatAttribute("g"), e->FloatAttribute("b"));
}

void Load_Config_File()
{
    cout << "Loading configuration file : raytracer config.xml...\n" << endl;
    tinyxml2::XMLDocument doc;
    doc.LoadFile("raytracer config.xml");

    // render settings
    tinyxml2::XMLElement* element = doc.FirstChildElement("render_settings");

    screenWidth        = atoi(element->FirstChildElement("screen_width")->GetText());
    screenHeight    = atoi(element->FirstChildElement("screen_height")->GetText());
    cout << "screen size       = " << screenWidth << " x " << screenHeight << endl;

    viewDistance    = (double)screenWidth / atof(element->FirstChildElement("focal_length")->GetText());
    cout << "focal length      = " << atof(element->FirstChildElement("focal_length")->GetText()) << endl;

    cameraLocation    = getVector(element, "camera_location");
    cout << "camera location   = (" << cameraLocation.x << ", " << cameraLocation.y << ", " << cameraLocation.z << ")" << endl;

    cameraRotation    = getVector(element, "camera_rotation");
    cout << "camera rotation   = (" << cameraRotation.x << ", " << cameraRotation.y << ", " << cameraRotation.z << ")" << endl;

    maxDepth        = atoi(element->FirstChildElement("max_depth")->GetText());
    cout << "maximum depth     = " << maxDepth << endl;

    if (element->FirstChildElement("exposure") != 0) {
        useExposure = 1;
        exposure = atof(element->FirstChildElement("exposure")->GetText());
        cout << "exposure amount   = " << exposure << endl;
    }
    else
        useExposure = 0;

    useSRGB            = atoi(element->FirstChildElement("srgb")->GetText());
    cout << "srgb              = " << useSRGB << endl;

    antialias        = atoi(element->FirstChildElement("antialias")->GetText());
    cout << "antialias         = " << (antialias == 0 ? "off" : antialias == 1 ? "2x2" : "4x4") << endl;

    numThreads        = atoi(element->FirstChildElement("numThreads")->GetText());
    cout << "number of threads = " << numThreads << endl;

    saveToFile        = atoi(element->FirstChildElement("save_to_file")->GetText()) == 0 ? false : true;
    cout << "save to file      = " << (saveToFile ? "yes" : "no") << endl;

    // objects
    element = doc.FirstChildElement("objects");
    numObjects = numChildren(element);
    objectList = new Object*[numObjects];

    cout << "found " << numObjects << " objects..." << endl;

    element = element->FirstChildElement();
    for (int i = 0; i < numObjects; i++ ) {

        const char* type = element->FirstChildElement("type")->GetText();

        if (strcmp(type, "sphere") == 0)
            objectList[i] = new Sphere(getVector(element, "position"), getFloat(element, "radius"));
        else if (strcmp(type, "plane") == 0) {
            objectList[i] = new Infinite_Plane(Plane(getVector(element, "normal"), getFloat(element, "distance")));

            // see if it should have a checkerboard pattern
            tinyxml2::XMLElement* checker = element->FirstChildElement("checkerboard");
            if (checker != 0 && atoi(checker->GetText()) != 0)
                ((Infinite_Plane*)objectList[i])->createCheckerboard(getColor(element, "diffuse_color2"), getFloat(element, "checkerboard_size"));
        }
        else if (strcmp(type, "triangle") == 0)
            objectList[i] = new Triangle(getVector(element, "v1"), getVector(element, "v2"), getVector(element, "v3"));
        else if (strcmp(type, "quad") == 0)
            objectList[i] = new Quad(getVector(element, "v1"), getVector(element, "v2"), getVector(element, "v3"), getVector(element, "v4"));

        objectList[i]->Set_Colors(getColor(element, "ambient_color"), getColor(element, "diffuse_color"), getFloat(element, "specular_exponent"),
                                  getFloat(element, "specular_coefficient"), getFloat(element, "reflectivity"), getFloat(element, "transparency") );

        element = element->NextSiblingElement();
    }

    // lights
    element = doc.FirstChildElement("lights");
    numLights = numChildren(element);
    lightList = new Light*[numLights];

    cout << "found " << numLights << " lights...\n" << endl;

    element = element->FirstChildElement();
    for (int i = 0; i < numLights; i++ ) {

        const char* type = element->FirstChildElement("type")->GetText();

        if (strcmp(type, "point") == 0)
            lightList[i] = new Point_Light(getVector(element, "position"), getFloat(element, "radius"), getVector(element, "attenuation"));
        else if (strcmp(type, "directional") == 0)
            lightList[i] = new Directional_Light(getVector(element, "direction"));

        lightList[i]->diffuse = getColor(element, "color");

        element = element->NextSiblingElement();
    }
}

void saveImage(int num) {

    FreeImage_Initialise();
    FIBITMAP* bitmap = FreeImage_Allocate(screenWidth, screenHeight, 24);
    RGBQUAD color;

    for (Uint16 y = 0; y < screenHeight; y++)
    {
        for (Uint16 x = 0; x < screenWidth; x++)
        {
            SDL_GetRGB(imageData[y * screenWidth + x], Display->format, &color.rgbRed, &color.rgbGreen, &color.rgbBlue);
            FreeImage_SetPixelColor(bitmap, x, (screenHeight - 1) - y, &color);
        }
    }

    char filename[100];

    sprintf(filename, "raytrace_%d.png", time(NULL));
    //sprintf(filename, "raytrace_%d.png", num);

    if (FreeImage_Save(FIF_PNG, bitmap, filename, 0))
        cout << "Image saved as " << filename << endl;
    else
        cout << "Error while trying to save .png image" << endl;

    FreeImage_DeInitialise();
}