"""
In this example we will construct
discrete channel surfaces and their focal surfaces.
Let
M  = {[x] in RP^3 | <x, x>_{4, 1} = 0}
be the Moebius quadric and M^+
the set of points outside M
M^+ = {[x] in RP^3 | <x, x>_{4, 1} > 0}.

"""
from functools import lru_cache

import bpy

# [setup-1]
# Import necessary libraries
import numpy as np

import ddg
from ddg.blender.material import material
from ddg.geometry.euclidean_models import moebius_to_projective, projective_to_moebius

# Clear the existing objects in the Blender scene
ddg.blender.scene.clear()
ddg.blender.material.clear()


euc = ddg.geometry.euclidean(3)
mob = ddg.geometry.euclidean_models.MoebiusModel(3)
# [setup-1]


################################################
# Helper functions
################################################
def polarized(fct):
    """
    Create a new function returning the polarized
    objects of the previous function w.r.t. the Moebius quadric.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        return mob.absolute.polarize(fct(i))

    return new_fct


def m_points_to_spheres(fct):
    """
    Create a new function returning Moebius spheres as conic sections,
    while the given function returns Moebius spheres as points in M^+.
    """

    @lru_cache(maxsize=128)
    def point_to_circle(i):
        polar_fct = polarized(fct)
        return ddg.geometry.intersect(polar_fct(i), mob.absolute)

    return point_to_circle


def tangent_edges(fct):
    """Returns a function that returns the edge tangent lines
    for a given function returning Points as vertices of a discrete curve.
    The i'th edge tangent line is the join of fct(i) and fct(i+1).
    """

    @lru_cache(maxsize=128)
    def tangent_edge(i):
        tangent = ddg.geometry.join(fct(i), fct(i + 1))
        return ddg.geometry.orthonormalize_and_center_subspace(
            tangent, np.sum([fct(i).affine_point, fct(i + 1).affine_point], axis=0) / 2
        )

    return tangent_edge


# @lru_cache(maxsize=128)
def m_sphere_inversion(point, fix_sphere):
    """
    Applying a sphere inversion in `fix_sphere` to the point `point`.
    `point` is assumed to live in M, `fix_sphere`
    is assumed to be a point in M^+, i.e., outside M.
    Returned is the inversion of `point` that lives in M.
    """
    inverted_point = ddg.geometry.Point(
        ddg.math.inner_product.reflect(
            fix_sphere.point, point.point, mob.absolute.inner_product
        )
    )
    return inverted_point


def m_iterative_sphere_inversions(fct, starting_points):
    """
    For a given stating point, apply sphere inversion
    iteratively in a family of spheres given by the function `fct`.
    The spheres are assumed to be given as points in M^+.
    """

    @lru_cache(maxsize=128)
    def new_curve_fct(i):
        if i == 0:
            return starting_points
        else:
            return [m_sphere_inversion(p, fct(i - 1)) for p in new_curve_fct(i - 1)]

    return new_curve_fct


def m_points_of_anti_similitude_from_homogeneous_lifts(fct):
    """
    Returning the spheres of anti similitude for a given family of spheres.
    The spheres are assumed to be given as (special homogeneous) points in
    R^{4, 1}. The spheres of anti similitude are returned as points in M^+.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        return ddg.geometry.Point(fct(i) - fct(i + 1))

    return new_fct


def homogeneous_lift_of_spheres_fct(fct):
    """
    Wrapper to apply `homogeneous_lift_of_spheres` to
    a function returning Euclidean spheres.
    Returns a new function.
    """

    #   @lru_cache(maxsize=128)
    def new_fct(i):
        return homogeneous_lift_of_spheres(fct(i).center.affine_point, fct(i).radius)

    return new_fct


# @lru_cache(maxsize=128)
def homogeneous_lift_of_spheres(c, r):
    """
    Lift a Euclidean sphere (c, r) to the homogeneous coordinates
        1 / r * (c, (-1+ |c|^2 - r^2) / 2, (1 + |c|^2 - r^2) / 2)
    in R^{4, 1}.
    """
    lift = np.array(
        [*c, (-1 + np.dot(c, c) - r**2) / 2, (1 + np.dot(c, c) - r**2) / 2]
    )
    lift *= 1 / r
    return lift


def project_function_of_objects_down(fct):
    """
    Wrapper to apply `moebius_to_projective` to
    a function returning projective objects.
    Returns a new function.
    """

    @lru_cache(maxsize=128)
    def new_func(i):
        if isinstance(fct(i), list):
            return [moebius_to_projective(p) for p in fct(i)]
        else:
            return moebius_to_projective(fct(i))

    return new_func


def function_projective_to_affine(fct):
    """
    Wrapper to apply `.affine_point` to
    a function returning projective points.
    Returns a new function.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        if isinstance(fct(i), list):
            return [p.affine_point for p in fct(i)]
        else:
            return fct(i).affine_point

    return new_fct


# [visualization-1]
################################################
# Visualization setup
################################################
orange = material("orange", (0.8, 0.1, 0.036), 0, 0)
gray = material("gray", (0.1, 0.1, 0.1), 0, 0, 1)
black = material("black", (0, 0, 0), 0, 0, 1)
lightgray = material("lightgray", (0.7, 0.7, 0.7), 0, 0, 1)
transparent_sphere = ddg.blender.material.material(
    "transparent_sphere", (0.018, 0.313, 0.656), 0.5, 0.5, alpha=0.8
)

bevel = 0.03
sampling = [0.1, 100, "c"]


# [visualization-1]

################################################
# Example ("smooth" or "discrete")
################################################
def channel_surface_example(example_name):
    ###############################################
    # Initial data
    ###############################################
    if example_name == "smooth":
        # Samplings in two parameter directions
        sampling_curve = 50
        v_circle_sampling = 30

        # Stretched trefoil knot parameterization
        def trefoil_parameterization(t):
            pt = np.array(
                [
                    np.sin(t) + 2 * np.sin(2 * t),
                    np.cos(t) - 2 * np.cos(2 * t) - 2 * t,
                    -np.sin(3 * t),
                ]
            )
            return 3 * pt

        snet = ddg.nets.SmoothNet(trefoil_parameterization, [[0, np.pi]])
        dnet = ddg.nets.sample_smooth_net(snet, sampling=[sampling_curve, "t"])

        # Radii of the spheres associated to the vertices
        radii = np.linspace(1, 2, sampling_curve)

        # Start of discrete envelope. The first circle is the intersection of this plane
        # with the first sphere
        plane = ddg.geometry.hyperplane_from_normal((0.8, -0.4, -0.5), level=1.2)

        # Visualization domain of the focal surface
        domain_focal_surf = ddg.nets.DiscreteDomain([[0, 20], [-4, 3]])

        # Visualization
        dnet_bevel = 0.04
        camera_location = (0.2, 0.65, 6.2)
        camera_look_at_point = (5, -2, 0)

    # [example-1]
    elif example_name == "discrete":
        # Samplings in two parameter directions
        sampling_curve = 5
        v_circle_sampling = 20

        # Parameterization of a 2d parabola

        def parabola_parameterization(u):
            return 5 * np.array([u, -(u**2) / 2])

        snet_2d = ddg.nets.SmoothNet(parabola_parameterization, [[0, 1 / 2 * np.pi]])
        snet = ddg.nets.embed(snet_2d)
        dnet = ddg.nets.sample_smooth_net(snet, sampling=[sampling_curve, "t"])

        # Radii of the spheres associated to the vertices
        radii = np.linspace(1, 2, sampling_curve)

        # Start of discrete envelope. The first circle is the intersection of this plane
        # with the first sphere
        plane = ddg.geometry.Subspace([-0.2, 1, 0, 1], [-0.2, 1, 1, 1], [-0.2, 0, 1, 1])

        # Visualization domain of the focal surface
        domain_focal_surf = ddg.nets.DiscreteDomain([[0, 2], [-4, 3]])

        # Visualization
        dnet_bevel = 0.04
        camera_location = (-13, -5, 8)
        camera_look_at_point = (5, -2, 0)
    else:
        raise ValueError(
            "The only examples implemented are called 'smooth' and 'discrete'."
        )
    # [example-1]

    # [example-2]
    ###############################################
    # Spheres
    ###############################################

    # Euclidean ###################################
    def e_spheres_fct(i):
        return euc.sphere_from_affine_point_and_normals(dnet.fct(i), radii[i])

    # Moebius #####################################
    m_homogeneous_points_fct = homogeneous_lift_of_spheres_fct(
        e_spheres_fct
    )  # Homogeneous Points in R^{4, 1}
    # [example-2]

    # [example-3]
    ###############################################
    # Spheres of anti-similitude
    ###############################################

    # Moebius #####################################
    m_points_of_anti_similitude_fct = (
        m_points_of_anti_similitude_from_homogeneous_lifts(m_homogeneous_points_fct)
    )  # Points in M+
    m_spheres_of_anti_similitude_fct = m_points_to_spheres(
        m_points_of_anti_similitude_fct
    )  # ddg.geometry.Quadrics in M

    # Euclidean ####################################
    e_spheres_of_anti_similitude_fct = project_function_of_objects_down(
        m_spheres_of_anti_similitude_fct
    )  # Spheres in R3
    # [example-3]

    # [example-4]
    ###############################################
    # Envelope (as sphere inversions in spheres of anti-similitude)
    ###############################################

    # Intersecting the first sphere with the plane (given in the initial values
    # of the example) determines a starting circle. A sampling yields a list of points.
    # Iterative sphere inversions in the mid-spheres of anti-similitude
    # form the discrete channel surface.

    # Starting circle ##############################
    circle = ddg.geometry.intersect(
        plane, euc.sphere_to_quadric(e_spheres_fct(0))
    )  # Conic in R3
    circle_dnet = ddg.nets.sample_smooth_net(
        ddg.to_smooth_net(circle), sampling=[v_circle_sampling, "t"]
    )  # DiscreteNet in R3
    cirlce_pts = [
        circle_dnet.fct(k) for k, in circle_dnet.domain.traverser
    ]  # List of affine coordinates
    m_starting_pts = [
        projective_to_moebius(ddg.geometry.subspace_from_affine_points(p))
        for p in cirlce_pts
    ]  # List of points in M

    # Moebius #####################################
    m_surface_points_fct = m_iterative_sphere_inversions(
        m_points_of_anti_similitude_fct, m_starting_pts
    )  # Two-parameter family of points in M

    # Euclidean ####################################
    e_surface_points_fct = function_projective_to_affine(
        project_function_of_objects_down(m_surface_points_fct)
    )  # Two-parameter family of affine coordinates

    def channel_surface_parameterization(i, j):
        return e_surface_points_fct(i)[j]

    # [example-4]

    # [example-5]
    ###############################################
    # Circumcircles of the channel surface
    ###############################################

    # Channel surfaces are circular, for each face we determine the
    # circle through its four vertices. It suffices to determine the circle through
    # three of its vertices.
    def face_circles_data(i, j):
        """
        Returns the center, radius, and normal of the
        circumcircle of the face [(i, j), (i+1, j), (i+1, j+1), (i, j+1)].
        """
        points = [
            channel_surface_parameterization(i, j),
            channel_surface_parameterization(i + 1, j),
            channel_surface_parameterization(i, j + 1),
        ]
        c, r, n = ddg.math.euclidean.circle_through_three_points(*points)
        return c, r, n

    def face_circles(i, j):
        """
        Returns the circumcircle of the face [(i, j), (i+1, j), (i+1, j+1), (i, j+1)].
        """
        return euc.sphere_from_affine_point_and_normals(*face_circles_data(i, j))

    # [example-5]

    # [example-6]
    ##############################################
    # Normal line congruence
    ##############################################
    # From the same data we can compute face normal lines,
    # orthogonal to the faces and passing through the centers
    # of the circumcircles.
    def normal_line_congruence(i, j):
        """
        Returns the normal line to each face passing through the circumcircles center.
        """
        circumcircles_center, _, normal = face_circles_data(i, j)
        return circumcircles_center, normal

    # [example-6]

    # [example-7]
    ##############################################
    # Focal surfaces
    ##############################################

    # Neighbouring normal lines intersect. The intersection points
    # determine two new surfaces, the focal surfaces.
    # One of these surfaces degenerates to a discrete curve. We compute the other.

    def focal_surf_i(i, j):
        """
        Returns the intersection points of successive normal lines in the fist
        parameter direction.
        """
        line = ddg.geometry.subspace_from_affine_points_and_directions(
            *normal_line_congruence(i, j)
        )
        line_ = ddg.geometry.subspace_from_affine_points_and_directions(
            *normal_line_congruence(i + 1, j)
        )
        return ddg.geometry.intersect(line, line_).affine_point

    # [example-7]
    ################################################
    # Visualization
    ################################################
    # [visualization-2]
    col = ddg.blender.collection.collection(
        f"Channel_surface_{example_name}",
        children=[
            f"spheres_{example_name}",
            f"envelope_{example_name}",
            f"circumscribed_circles_{example_name}",
            f"normal_lines_{example_name}",
            f"focal_surfaces_{example_name}",
        ],
    )
    # [visualization-2]

    # [visualization-3]
    ###############################################
    # Channel surface
    ###############################################

    # Discrete curve
    dnet_bobj = ddg.blender.convert(dnet, "centers_dnet", black)
    dnet_bobj.data.bevel_depth = dnet_bevel

    # Spheres
    for (i,) in dnet.domain.traverser:
        bobj = ddg.blender.convert(
            e_spheres_fct(i), f"sphere_{i})", transparent_sphere, col.children[0]
        )
        ddg.blender.mesh.shade_smooth(bobj)

    # Channel surface
    channel_surface_domain = ddg.nets.DiscreteDomain(
        [dnet.domain[0], [0, v_circle_sampling - 1, True]]
    )
    channel_surface_mesh = ddg.arrays.from_discrete_net(
        ddg.nets.DiscreteNet(channel_surface_parameterization, channel_surface_domain)
    )
    channel_surface_bobj = ddg.blender.edges(
        channel_surface_mesh,
        "channel_surface",
        material=lightgray,
        collection=col.children[1],
    )
    channel_surface_bobj.data.bevel_depth = bevel

    ###############################################
    # Circumscribed circles
    ###############################################
    circumcircles_domain = ddg.nets.DiscreteDomain(
        [[dnet.domain[0][0], dnet.domain[0][1] - 1], [0, v_circle_sampling - 2]]
    )
    for i, j in circumcircles_domain.traverser:
        bobj = ddg.blender.convert(
            face_circles(i, j), f"face_circles({i}, {j})", gray, col.children[2]
        )
        bobj.data.bevel_depth = bevel

    ###############################################
    # Normal line congruence
    ###############################################

    for i, j in circumcircles_domain.traverser:
        c, n = normal_line_congruence(i, j)
        line_segment = ddg.arrays.line_segment_from_point_and_direction(c, n, [-2, 2])
        line_bobj = ddg.blender.convert(
            line_segment, f"normal_line_congruence_{i}_{j}", orange, col.children[3]
        )
        line_bobj.data.bevel_depth = bevel

    ###############################################
    # Focal surface
    ###############################################

    for i, j in domain_focal_surf.traverser:
        c, n = normal_line_congruence(i, j)
        line_segment = ddg.arrays.line_segment_from_point_and_direction(c, n, [0, 100])
        line_bobj = ddg.blender.convert(
            line_segment,
            f"focal_normal_line_congruence_{i}_{j}",
            orange,
            col.children[3],
        )
        line_bobj.data.bevel_depth = bevel

    focal_dnet = ddg.nets.DiscreteNet(focal_surf_i, domain=domain_focal_surf)
    ddg.blender.convert(focal_dnet, "focal_surface", gray, col.children[4])

    focal_mesh = ddg.arrays.convert(focal_dnet)
    focal_surface_wire = ddg.blender.edges(
        focal_mesh, "focal_surface_wire", material=gray, collection=col.children[4]
    )
    focal_surface_wire.data.bevel_depth = bevel
    # [visualization-3]

    #############################################
    # CAMERAS
    #############################################

    camera = ddg.blender.camera.camera(location=camera_location, collection=col)
    ddg.blender.camera.look_at_point(camera, camera_look_at_point)
    bpy.context.scene.camera = camera


channel_surface_example("discrete")
# channel_surface_example("smooth")
#############################################
# RENDERING AND LIGHT
#############################################

# Change to cycles rendering engine, increase max_samples size
# enable film transparency, and scale up the resolution for a good quality image
samples = 8
resolution = 500

ddg.blender.render.setup_cycles_renderer(samples=samples)

ddg.blender.render.set_world_background(color=(1, 1, 1, 1), strength=1)
bpy.context.scene.render.resolution_x = resolution
bpy.context.scene.render.resolution_y = resolution

# ddg.blender.render.set_film_transparency()
bpy.context.scene.view_settings.view_transform = "Standard"
light = ddg.blender.light.light(location=(0, 0, 100), type_="SUN", energy=0.5)
