// ====================================================
//
// Copyright (c) 2001 Sean Wilson. All Rights Reserved.
//
// ====================================================

import java.awt.*;

/** Represents canvas that contains a raytraced image. */
public class Raytracer extends Canvas
{
	/** Ray t parameter values below this value will not be considered in intersection tests. */
	public final float THRESHOLD = 0.001f;
	/** Refractive index of air. */
	public final float REFRACTIVE_INDEX_AIR = 1.0f;
	
	/** Scene to render. */
	private Scene scene;
	/** Origin all light rays start at. */
	private Vector3d viewPoint;
	/** Rendered image. */
	private Image image;

	public void update(Graphics g)
	{
		paint(g);
	}

	public void paint(Graphics g)
	{
		// If image created
		if (image != null)
		{
			// Paint image
			g.drawImage(image, 0, 0, this);
		}
	}

	/** Renders a scene.
	 *  @param g                 graphics context to update image to whilst rendering. If <code>null</code>, this will not happen.
	 *  @param initScene         scene to render.
	 *  @param width             width of rendered image.
	 *  @param height            height of rendered image.
	 *  @param antialiasing      wether to use antialiasing.
	 *  @param maxRecursiveDepth maximum depth rays will be traced to.
	 */
	public Image render(Graphics g, Scene initScene, int width, int height, boolean antialiasing, int maxRecursiveDepth)
	{
		// Store time rendering began
		long startTime = System.currentTimeMillis();
		// Store scene to render
		scene = initScene;
		// Create image to render scene to	
		image = createImage(width, height);
		// Get image graphics context
		Graphics imageGraphics = image.getGraphics();
		
		// Calculate view point position
		viewPoint = scene.eyeTransform.transform(new Vector3d(0, 0, -scene.viewPointDistance));

		// Calculate view plane scale so image has same view regardless of size
		float xScale = scene.viewWidth  / width;
		float yScale = scene.viewHeight / height;

		// For each pixel row
		for (int y = 0; y < height; ++y)
		{
			// For each pixel column
			for (int x = 0; x < width; ++x)
			{
				// Current pixel color
				Color pixelColor;
				
				// If using antialiasing
				if (antialiasing)
				{
					// Calculate points on 0.5 by 0.5 square
					float left  = (x - 0.25f) * xScale;
					float right = (x + 0.25f) * xScale;
					float above = (y - 0.25f) * yScale;
					float below = (y + 0.25f) * yScale;		
					
					// Get pixel colors from each corner of 0.5 by 0.5 square centered inside current pixel square
					Color pixelColor1 = getEyeRayColor(left,  above, maxRecursiveDepth);
					Color pixelColor2 = getEyeRayColor(right, above, maxRecursiveDepth);
					Color pixelColor3 = getEyeRayColor(left,  below, maxRecursiveDepth);
					Color pixelColor4 = getEyeRayColor(right, below, maxRecursiveDepth);

					// Calculate average pixel color
					pixelColor = new Color((pixelColor1.getRed  () + pixelColor2.getRed  () + pixelColor3.getRed  () + pixelColor4.getRed  ()) / 4,
					                       (pixelColor1.getGreen() + pixelColor2.getGreen() + pixelColor3.getGreen() + pixelColor4.getGreen()) / 4,
										   (pixelColor1.getBlue () + pixelColor2.getBlue () + pixelColor3.getBlue () + pixelColor4.getBlue ()) / 4);
				}
				else
				{
					// Get current pixel color
					pixelColor = getEyeRayColor(x * xScale, y * yScale, maxRecursiveDepth);
				}

				// Plot current pixel
				imageGraphics.setColor(pixelColor);
				imageGraphics.fillRect(x, y, 1, 1);
			}

			// If given graphics context to update to
			if (g != null)
			{
				// Update rendered image
				g.drawImage(image, 0, 0, this);
			}
		}

		// Write rendering time to console
		System.out.println("Scene rendered in " + (int)((System.currentTimeMillis() - startTime) / 1000f) + " seconds.");

		// Return rendered image
		return image;
	}

	/** Returns closest shape ray intesects with in scene. */
	private Intersection getIntersection(Ray ray)
	{
		// Closest intersection found so far
		Intersection closestIntersection = null;
		
		// For each shape in scene
		for (int i = 0; i < scene.shapeList.length; ++i)
		{
			// Get intersection between ray and current shape
			Intersection intersection = scene.shapeList[i].getIntersection(ray, THRESHOLD);

			// If ray intersected and intersection is closestest to ray's origin so far
			if (intersection != null && (closestIntersection == null || intersection.getT() < closestIntersection.getT()))
			{
				// Intersection is new closest intersection
				closestIntersection = intersection;
			}
		}

		// Return closest intersection found
		return closestIntersection;
	}

	/** Returns color created by summing together RGB components of two colors. */
	private static Color addColors(Color color1, Color color2)
	{
		return new Color(Math.min(color1.getRed  () + color2.getRed  (), 255),
		                 Math.min(color1.getGreen() + color2.getGreen(), 255),
						 Math.min(color1.getBlue () + color2.getBlue (), 255));
	}

	/** Returns color created by scaling RGB components of a color. */
	public Color scaleColor(Color color, float scale)
	{
		return new Color(Math.min((int)(color.getRed  () * scale), 255),
		                 Math.min((int)(color.getGreen() * scale), 255),
						 Math.min((int)(color.getBlue () * scale), 255));
	}

	/** Returns ray reflected off a surface.
	 *  @param position  position on surface.
	 *  @param direction direction ray was going in before reflection.
	 *  @param normal    surface normal.
	 *  @return          reflected ray.
	 */
	private static Ray getReflectedRay(Vector3d position, Vector3d direction, Vector3d normal)
	{
		return new Ray(position, direction.subtract(normal.multiply(2 * normal.dot(direction))).normalise());
	}

	/** Returns ray refracted off passing through a surface.
	 *  @param position  position on surface.
	 *  @param direction direction ray was going in before refraction.
	 *  @param normal    surface normal.
	 *  @param n1        refractive index of material ray was in before refraction.
	 *  @param n2        refractive index of material ray will be in after refraction.
	 *  @return          refracted ray.
	 */
	private static Ray getRefractedRay(Vector3d position, Vector3d direction, Vector3d normal, float n1, float n2)
	{
		// Calculate ray direction components in direction of normal
		float directionDotNormal = -normal.dot(direction);
		// Calculate ration of refractive indices
		float indexRatio =  n1 / n2;
		// Calculate angle between normal and ray
		float cosTheta = (float)Math.sqrt(1 - indexRatio * indexRatio * (1 - directionDotNormal * directionDotNormal));

		// Calculate refracted ray direction
		Vector3d refractedDirection = direction.multiply(indexRatio).add(normal.multiply(indexRatio * directionDotNormal - cosTheta));

		// Return refracted ray
		return new Ray(position, refractedDirection);
	}

	
	/** Returns light ray color.
	 *  @param ray    ray to trace.
	 *  @param depth  depth to trace ray to.
	 *  @param inside shape ray is current inside.
	 *  @return       ray color.
	 */
	private Color getRayColor(Ray ray, int depth, Shape3d inside)
	{
		// Get shape ray intersects with
		Intersection intersection = getIntersection(ray);

		// Ray color is black if it does not hit any shapes
		if (intersection == null)
			return Color.black;

		// Get shape ray intersected with
		Shape3d shape = intersection.getShape();
		// Get shape material
		Material material = shape.getMaterial();
		// Get shape color at intersection point
		Color materialColor = shape.getColor(intersection.getPosition());
		
		// Ray reflected off shape
		Ray reflectedRay = null;
		// Color of ray reflected off shape
		Color reflectedColor = Color.black;
		// Color of ray refracted through shape
		Color refractedColor = Color.black;

		// If should trace reflected and refracted rays
		if (depth > 0)
		{
			// If material is reflective
			if (material.isReflective())
			{
				// Get reflected ray
				reflectedRay = getReflectedRay(intersection.getPosition(), ray.getDirection(), intersection.getNormal());
				// Get reflected ray color
				reflectedColor = getRayColor(reflectedRay, depth - 1, inside);
			}

			// If material is transparent
			if (material.isTransparent())
			{
				// Refractive index of material ray is in
				float currentIndex;
				// Refractive index of material ray is enterering
				float newIndex;

				// If ray is leaving shape it is inside
				if (shape == inside)
				{
					// Ray passing from shape into air
					currentIndex = material.getRefractiveIndex();
					newIndex     = REFRACTIVE_INDEX_AIR;
				}
				else
				{
					// Ray passing from air into shape
					currentIndex = REFRACTIVE_INDEX_AIR;
					newIndex     = material.getRefractiveIndex();					
				}

				// Get refracted ray
				Ray refractedRay = getRefractedRay(intersection.getPosition(), ray.getDirection(), intersection.getNormal(), currentIndex, newIndex);
				// Get refracted ray color
				refractedColor = getRayColor(refractedRay, depth - 1, intersection.getShape());
			}
		}

		// Intensity of lighting at intersection point
		float lighting = 0;

		// For each light source
		for (int i = 0; i < scene.lightList.length; ++i)
		{
			// Get current light source
			Light light = scene.lightList[i];
			
			// Offset from intersection point to light
			Vector3d lightOffset = new Vector3d(intersection.getPosition(), light.getPosition());
			// Light source direction from intersection point
			Vector3d lightDirection = lightOffset.normalise();
			
			// Angle between intersected surface normal and light source
			float angle1 = intersection.getNormal().dot(lightDirection);

			// Light source contributes no light if surface normal is facing away from light source
			if (angle1 < 0)
				continue;

			// Get intesection of shadow ray from intersection to light source
			Intersection shadowIntersection = getIntersection(new Ray(intersection.getPosition(), lightDirection));
			
			// Light source contributes no light if shadow intersection is closer than light offset
			if (shadowIntersection != null &&  new Vector3d(shadowIntersection.getPosition(), light.getPosition()).getSquaredLength() < lightOffset.getSquaredLength())
				continue;
		
			// Calculate angle between reflected ray and light source, making sure angle is greater than 0
			float angle2 = (reflectedRay == null) ? 0 : Math.max(reflectedRay.getDirection().dot(lightDirection), 0);
			// Distance component of lighting equation
			float distance = (1 / (1 + lightOffset.getSquaredLength()));

			// Calculate terms from lighting equation
			float term1 = angle1 * material.getDiffuseReflectance();
			float term2 = (float)Math.pow(angle2, material.getShine()) * material.getSpecularReflectance();

			// Add light source intensity at intersection point to total
			lighting += (distance * (term1 + term2) * light.getIntensity());
		}

		// Calculate color intensity at intersection point
		float localIntensity = (scene.ambientLighting + lighting) * Math.max(1 - material.getReflectivity() - material.getTransparency(), 0);

		// Return ray color
		return combineColors(materialColor, localIntensity, reflectedColor, material.getReflectivity(), refractedColor, material.getTransparency());
	}

	/** Returns combination of three scaled colors.
	 *  @param color1 first color.
	 *  @param scale1 value to scale first color components by.
	 *  @param color2 second color.
	 *  @param scale2 value to scale second color components by.
	 *  @param color3 third color.
	 *  @param scale3 value to scale third color components by.
	 *  @return       color made by summing the scaled components of the three colors.
	 */	 
	private static Color combineColors(Color color1, float scale1, Color color2, float scale2, Color color3, float scale3)
	{
			return new Color(Math.min((int)(color1.getRed  () * scale1 + color2.getRed  () * scale2 + color3.getRed  () * scale3), 255),
			                 Math.min((int)(color1.getGreen() * scale1 + color2.getGreen() * scale2 + color3.getGreen() * scale3), 255),
							 Math.min((int)(color1.getBlue () * scale1 + color2.getBlue () * scale2 + color3.getBlue () * scale3), 255));
	}

	/** Returns color of light ray passing through view plane and view point.
	 *  @param planeX            view plane x co-ordinate.
	 *  @param planeY            view plane y co-ordinate.
	 *  @param maxRecursiveDepth maximum depth to trace reflected/refracted rays to.
	 *  @return                  ray color.
	 */
	private Color getEyeRayColor(float planeX, float planeY, int maxRecursiveDepth)
	{
		// Calculate plane position
		Vector3d planePosition = new Vector3d(planeX - scene.viewWidth / 2, planeY - scene.viewHeight / 2, 0);
		// Calculate ray direction from view point to plane position
		Vector3d direction = new Vector3d(viewPoint, scene.eyeTransform.transform(planePosition)).normalise();
		
		// Return ray color
		return getRayColor(new Ray(viewPoint, direction), maxRecursiveDepth, null);
	}
}