import functools
import itertools

import bpy
import numpy as np

import ddg
from ddg.blender.animation import set_keyframe
from ddg.blender.collection import clear, collection
from ddg.blender.material import material
from ddg.blender.render import (
    set_film_transparency,
    set_render_output_images,
    set_render_stamp_note,
    setup_eevee_renderer,
)
from ddg.math.projective import homogenize

points_size = 0.018
chord_bevel_depth = 0.004


def billiard_map(conic, incidence_chord):
    """
    Returns a reflected point given incidence chord of a conic.

    Parameters
    ----------
    conic : ddg.geometry.Quadric
        A conic in projective space.
    incidence_chord : Tuple of two ddg.geometry.Point
        A chord defined by two points on conic.

    Returns
    -------
    ddg.geometry.Point
    """
    p0, p1 = incidence_chord
    chord_subspace = ddg.geometry.join(p0, p1)

    # Get tangent line to conic at p1.
    tangent_subspace = ddg.geometry.polarize(p1, conic)

    assert p1 in conic
    assert p1 in tangent_subspace

    # Reflect the chord in the tangent line.
    reflected_chord = (
        ddg.geometry.euclidean_models.ProjectiveModel.reflect_in_hyperplane(
            chord_subspace, tangent_subspace
        )
    )
    # Get other intersection point with conic.
    assert p0 != p1
    intersection_points = ddg.geometry.intersect(conic, reflected_chord)
    p, q = ddg.geometry.quadric_to_subspaces(intersection_points)
    if p == p1:
        reflected_point = q
    else:
        reflected_point = p

    return reflected_point


def chords(conic, chord_0):
    """A function that computes the i-th chord.

    Parameters
    ----------
    conic : ddg.geometry.Quadric
        The conic must have dimension == 1.
    chord_0 : pair of ddg.geometry.Point and ddg.geometry.Point
        These points must be distinct and contained in the conic.

    Returns
    -------
    Callable int -> pair of ddg.geometry.Point and ddg.geometry.Point
        A function that computes the i-th chord.
    """
    # The naive recursive definition of f has O(i) complexity. This means that
    # computing the 0-th to the i-th chord has O(i**2) complexity. This is
    # necessary in visualize_chords. Caching results in O(i) complexity for
    # computing the 0-th to the i-th chord.
    #
    # We could've written chords with the signature (conic, chord_0, i) -> chord_i
    # but conic and chord_0 aren't hashable, which breaks functools.cache.

    @functools.cache
    def f(i):
        if i < 0:
            raise ValueError
        elif i == 0:
            return chord_0
        else:
            previous_chord = f(i - 1)
            _, end = previous_chord
            return end, billiard_map(conic, previous_chord)

    return f


def visualize_chord_and_end_point(chord, i, chord_material, point_material, collection):
    """Visualise a chord and its end point.

    Parameters
    ----------
    chord : pair of ddg.geometry.Point and ddg.geometry.Point
        A chord.
    i : int
        The index of the chord. The Blender objects of the chord and its end
        point are named "chord i" and "point i" respecitvely.
    chord_material : bpy.types.Material
        The material used for the chords.
    point_material : bpy.types.Material
        The material used for the start and end points of the chords.
    collection : bpy.types.Collection
        The chords and their start and end points are linked to this collection.

    Returns
    -------
    None
    """
    if (f"chord {i}" in bpy.data.objects) and (f"point {i+1}" in bpy.data.objects):
        chord_bobj = bpy.data.objects[f"chord {i}"]
        chord_bobj.hide_render = False
        chord_bobj.hide_viewport = False

        point_bobj = bpy.data.objects[f"point {i+1}"]
        point_bobj.hide_render = False
        point_bobj.hide_viewport = False
    elif (f"chord {i}" not in bpy.data.objects) and (
        f"point {i+1}" not in bpy.data.objects
    ):
        q, p = chord
        chord_curve = ddg.arrays.line_segment_from_points(p, q)
        chord_bobj = ddg.blender.convert(
            chord_curve,
            f"chord {i}",
            chord_material,
            collection,
        )
        chord_bobj.data.bevel_depth = chord_bevel_depth
        ddg.blender.vertices(
            p,
            f"point {i+1}",
            radius=points_size,
            material=point_material,
            collection=collection,
        )
    else:
        raise Exception()


def visualize_chords(chords_function, i, chord_material, point_material, collection):
    """Visualise the 0-th to i-th chords of the respective billiard map iterations.

    Parameters
    ----------
    chords_function : Callable int -> pair of ddg.geometry.Point and ddg.geometry.Point
        Maps n to the n-th chord of the n-th billiard map iteration.
    i : int
        The last chord to be shown.
    chord_material : bpy.types.Material
        The material used for the chords.
    point_material : bpy.types.Material
        The material used for the start and end points of the chords.
    collection : bpy.types.Collection
        The chords and their start and end points are linked to this collection.

    Returns
    -------
    None
    """
    for j in range(i + 1):
        chord = chords_function(j)
        visualize_chord_and_end_point(
            chord, j, chord_material, point_material, collection
        )
        if j == 0 and "point 0" not in bpy.data.objects:
            p_0, _ = chord
            ddg.blender.vertices(
                p_0,
                "point 0",
                radius=points_size,
                material=point_material,
                collection=collection,
            )

    for j in itertools.count(i + 1):
        if f"chord {j}" in bpy.data.objects:
            chord_bobj = bpy.data.objects[f"chord {j}"]
            chord_bobj.hide_render = True
            chord_bobj.hide_viewport = True

            point_bobj = bpy.data.objects[f"point {j+1}"]
            point_bobj.hide_render = True
            point_bobj.hide_viewport = True
        else:
            break


def main():
    # Delete all objects and their corresponding data.
    clear()
    static_col = collection("static objects")
    animated_col = collection("animated objects")

    # Choose major and minor axes of the ellipse with a > b.
    a = 3.5
    b = 1.0
    # Create the ellipse and its parameterization.
    ellipse = ddg.geometry.Quadric(np.diag([1 / a, 1 / b, -1]))
    ellipse_snet = ddg.to_smooth_net(ellipse)
    ellipse_parameterization = ellipse_snet.fct

    # A chord is determined by two points on the ellipse.
    # We choose the initial chord using parameterization function of ellipse.
    # Set t0 and t1 accordingly.
    epsilon = 0.1
    t0 = np.pi / 2 - epsilon
    t1 = 3 * np.pi / 2 + epsilon

    s0 = ellipse_parameterization(t0)
    s1 = ellipse_parameterization(t1)

    chord_0 = (ddg.geometry.Point(homogenize(s0)), ddg.geometry.Point(homogenize(s1)))

    # -------------
    # Visualization
    # -------------

    # Set size for visualized points.
    # Set thickness of visualized curves.
    ellipse_thickness = 0.015

    # Setup camera.
    camera = ddg.blender.camera.camera(
        name="top_camera", location=(0.0, 0.0, 7.0), collection=static_col
    )
    bpy.context.scene.camera = camera

    # Setup light.
    ddg.blender.light.light(
        name="top_light", type_="SUN", energy=2.0, collection=static_col
    )

    # Setup basic materials.
    point_material = material("point_material", color=(1.0, 0.0, 0.0))
    chord_material = material("line_material", color=(0.0, 0.130, 0.717))
    ellipse_material = material("ellipse_material", color=(0.015, 0.015, 0.015))

    # Visualize ellipse.
    bobj = ddg.blender.convert(
        ellipse,
        "ellipse",
        ellipse_material,
        static_col,
    )
    bobj.data.bevel_depth = ellipse_thickness
    # Define and visualize its focal points.
    f1 = ddg.geometry.Point([np.sqrt(a - b), 0.0, 1.0])
    f2 = ddg.geometry.Point([-np.sqrt(a - b), 0.0, 1.0])
    ddg.blender.vertices(
        f1,
        "first_focal_point",
        radius=points_size,
        material=ellipse_material,
        collection=static_col,
    )
    ddg.blender.vertices(
        f2,
        "second_focal_point",
        radius=points_size,
        material=ellipse_material,
        collection=static_col,
    )

    # Animation settings
    chords_ = chords(ellipse, chord_0)

    def callback(i):
        visualize_chords(chords_, i, chord_material, point_material, animated_col)

    ddg.blender.props.add_props_with_callback(callback, ("i"), 0)

    # Keyframe and frame settings
    scene = bpy.context.scene
    N = 50  # The number of times to run billiard map.
    scene.frame_end = N
    set_keyframe(scene, 1, "i", 1)
    set_keyframe(scene, N, "i", N)

    # Setup render settings
    output_dir = "/"
    scene.render.fps = 4
    scene.view_settings.view_transform = "Standard"
    setup_eevee_renderer(scene=scene)
    set_film_transparency(scene=scene)
    set_render_output_images(output_dir, time=False)

    # Display render stamp of parameters.
    parameters_string = " N = " + str(N)
    parameters_string += "\n a = " + np.format_float_positional(a, precision=3)
    parameters_string += "\n b = " + np.format_float_positional(b, precision=3)
    parameters_string += "\n t₀ = " + np.format_float_positional(t0, precision=3)
    parameters_string += "\n t₁ = " + np.format_float_positional(t1, precision=3)
    set_render_stamp_note(note=parameters_string, scene=scene)

    # Do not display any other render stamp.
    scene.render.use_stamp_date = False
    scene.render.use_stamp_time = False
    scene.render.use_stamp_frame = False
    scene.render.use_stamp_scene = False
    scene.render.use_stamp_labels = False
    scene.render.use_stamp_camera = False
    scene.render.use_stamp_filename = False
    scene.render.use_stamp_render_time = False


if __name__ == "__main__":
    main()
