"""
In this example we will construct
discrete channel surfaces and their focal surfaces.
Let
M  = {[x] in RP^3 | <x, x>_{4, 1} = 0}
be the Moebius quadric and M^+
the set of points outside M
M^+ = {[x] in RP^3 | <x, x>_{4, 1} > 0}.

"""
from functools import lru_cache

import bpy

# [setup-1]
# Import necessary libraries
import numpy as np

import ddg
from ddg.geometry.euclidean_models import moebius_to_projective, projective_to_moebius
from ddg.geometry.intersection import intersect, join
from ddg.geometry.subspaces import (
    Point,
    orthonormalize_and_center_subspace,
    subspace_from_affine_points,
)
from ddg.visualization.blender.material import material

# Clear the existing objects in the Blender scene
ddg.visualization.blender.scene.clear()
ddg.visualization.blender.material.clear()


euc = ddg.geometry.euclidean(3)
mob = ddg.geometry.euclidean_models.MoebiusModel(3)
# [setup-1]


################################################
# Helper functions
################################################
def polarized(fct):
    """
    Create a new function returning the polarized
    objects of the previous function w.r.t. the Moebius quadric.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        return mob.absolute.polarize(fct(i))

    return new_fct


def m_points_to_spheres(fct):
    """
    Create a new function returning Moebius spheres as conic sections,
    while the given function returns Moebius spheres as points in M^+.
    """

    @lru_cache(maxsize=128)
    def point_to_circle(i):
        polar_fct = polarized(fct)
        return intersect(polar_fct(i), mob.absolute)

    return point_to_circle


def tangent_edges(fct):
    """Returns a function that returns the edge tangent lines
    for a given function returning Points as vertices of a discrete curve.
    The i'th edge tangent line is the join of fct(i) and fct(i+1).
    """

    @lru_cache(maxsize=128)
    def tangent_edge(i):
        tangent = join(fct(i), fct(i + 1))
        return orthonormalize_and_center_subspace(
            tangent, np.sum([fct(i).affine_point, fct(i + 1).affine_point], axis=0) / 2
        )

    return tangent_edge


# @lru_cache(maxsize=128)
def m_sphere_inversion(point, fix_sphere):
    """
    Applying a sphere inversion in `fix_sphere` to the point `point`.
    `point` is assumed to live in M, `fix_sphere`
    is assumed to be a point in M^+, i.e., outside M.
    Returned is the inversion of `point` that lives in M.
    """
    inverted_point = Point(
        ddg.math.inner_product.reflect(
            fix_sphere.point, point.point, mob.absolute.inner_product
        )
    )
    return inverted_point


def m_iterative_sphere_inversions(fct, starting_points):
    """
    For a given stating point, apply sphere inversion
    iteratively in a family of spheres given by the function `fct`.
    The spheres are assumed to be given as points in M^+.
    """

    @lru_cache(maxsize=128)
    def new_curve_fct(i):
        if i == 0:
            return starting_points
        else:
            return [m_sphere_inversion(p, fct(i - 1)) for p in new_curve_fct(i - 1)]

    return new_curve_fct


def m_points_of_anti_similitude_from_homogeneous_lifts(fct):
    """
    Returning the spheres of anti similitude for a given family of spheres.
    The spheres are assumed to be given as (special homogeneous) points in
    R^{4, 1}. The spheres of anti similitude are returned as points in M^+.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        return Point(fct(i) - fct(i + 1))

    return new_fct


def homogeneous_lift_of_spheres_fct(fct):
    """
    Wrapper to apply `homogeneous_lift_of_spheres` to
    a function returning Euclidean spheres.
    Returns a new function.
    """

    #   @lru_cache(maxsize=128)
    def new_fct(i):
        return homogeneous_lift_of_spheres(fct(i).center.affine_point, fct(i).radius)

    return new_fct


# @lru_cache(maxsize=128)
def homogeneous_lift_of_spheres(c, r):
    """
    Lift a Euclidean sphere (c, r) to the homogeneous coordinates
        1 / r * (c, (-1+ |c|^2 - r^2) / 2, (1 + |c|^2 - r^2) / 2)
    in R^{4, 1}.
    """
    lift = np.array(
        [*c, (-1 + np.dot(c, c) - r**2) / 2, (1 + np.dot(c, c) - r**2) / 2]
    )
    lift *= 1 / r
    return lift


def project_function_of_objects_down(fct):
    """
    Wrapper to apply `moebius_to_projective` to
    a function returning projective objects.
    Returns a new function.
    """

    @lru_cache(maxsize=128)
    def new_func(i):
        if isinstance(fct(i), list):
            return [moebius_to_projective(p) for p in fct(i)]
        else:
            return moebius_to_projective(fct(i))

    return new_func


def function_projective_to_affine(fct):
    """
    Wrapper to apply `.affine_point` to
    a function returning projective points.
    Returns a new function.
    """

    @lru_cache(maxsize=128)
    def new_fct(i):
        if isinstance(fct(i), list):
            return [p.affine_point for p in fct(i)]
        else:
            return fct(i).affine_point

    return new_fct


# [visualization-1]
################################################
# Visualization setup
################################################
orange = material("orange", (0.8, 0.1, 0.036), 0, 0)
gray = material("gray", (0.1, 0.1, 0.1), 0, 0, 1)
black = material("black", (0, 0, 0), 0, 0, 1)
lightgray = material("lightgray", (0.7, 0.7, 0.7), 0, 0, 1)
transparent_sphere = ddg.visualization.blender.material.material(
    "transparent_sphere", (0.018, 0.313, 0.656), 0.5, 0.5, alpha=0.8
)

bevel = 0.03
sampling = [0.1, 100, "c"]


def visualize_spheres(g, domain, **kwargs):
    return [
        ddg.to_blender_object_helper(
            g(i),
            sampling=sampling,
            name=f"sphere_{i}",
            **kwargs,
        )
        for i, in domain.traverser
    ]


def visualize_lines(g, domain, line_domain, **kwargs):
    return [
        ddg.to_blender_object_helper(
            g(i, j),
            domain=line_domain,
            sampling=sampling,
            curve_properties={"bevel_depth": bevel},
            name=f"line_{i}",
            **kwargs,
        )
        for i, j in domain.traverser
    ]


def visualize_circles(g, domain, **kwargs):
    return [
        ddg.to_blender_object_helper(
            g(i, j),
            sampling=sampling,
            curve_properties={"bevel_depth": bevel},
            name=f"circle_{i}",
            **kwargs,
        )
        for i, j in domain.traverser
    ]


def wire_bobj_to_curve(bobj, bevel_depth=0.025):
    # Choose the object mode context
    with ddg.visualization.blender.context.mode(bobj, mode="OBJECT"):
        # Save the material
        mat = bobj.active_material
        # Select the object
        bpy.ops.object.select_all(action="DESELECT")
        bobj.select_set(True)
        # Convert it to a curve
        bpy.ops.object.convert(target="CURVE")
        # Unselect the object
        bobj.select_set(False)
        # Set its bevel_depth and its material
        bobj.data.bevel_depth = bevel_depth
        ddg.visualization.blender.material.set_material(bobj, mat)
        return bobj


# [visualization-1]

################################################
# Example ("smooth" or "discrete")
################################################
def channel_surface_example(example_name):
    ###############################################
    # Initial data
    ###############################################
    if example_name == "smooth":
        # Samplings in two parameter directions
        sampling_curve = 50
        v_circle_sampling = 30

        # Stretched trefoil knot parameterization
        def trefoil_parameterization(t):
            pt = np.array(
                [
                    np.sin(t) + 2 * np.sin(2 * t),
                    np.cos(t) - 2 * np.cos(2 * t) - 2 * t,
                    -np.sin(3 * t),
                ]
            )
            return 3 * pt

        snet = ddg.SmoothNet(trefoil_parameterization, [[0, np.pi]])
        dnet = ddg.sample_smooth_net(snet, sampling=[sampling_curve, "t"])

        # Radii of the spheres associated to the vertices
        radii = np.linspace(1, 2, sampling_curve)

        # Start of discrete envelope. The first circle is the intersection of this plane
        # with the first sphere
        plane = ddg.geometry.subspaces.hyperplane_from_normal(
            (0.8, -0.4, -0.5), level=1.2
        )

        # Visualization domain of the focal surface
        domain_focal_surf = ddg.nets.DiscreteDomain([[0, 20], [-4, 3]])

        # Visualization
        dnet_bevel = 0.04
        camera_location = (0.2, 0.65, 6.2)
        camera_look_at_point = (5, -2, 0)

    # [example-1]
    elif example_name == "discrete":
        # Samplings in two parameter directions
        sampling_curve = 5
        v_circle_sampling = 20

        # Parameterization of a 2d parabola

        def parabola_parameterization(u):
            return 5 * np.array([u, -(u**2) / 2])

        snet_2d = ddg.nets.net.SmoothNet(
            parabola_parameterization, [[0, 1 / 2 * np.pi]]
        )
        snet = ddg.nets.utils.embed(snet_2d)
        dnet = ddg.sample_smooth_net(snet, sampling=[sampling_curve, "t"])

        # Radii of the spheres associated to the vertices
        radii = np.linspace(1, 2, sampling_curve)

        # Start of discrete envelope. The first circle is the intersection of this plane
        # with the first sphere
        plane = ddg.geometry.subspaces.Subspace(
            [-0.2, 1, 0, 1], [-0.2, 1, 1, 1], [-0.2, 0, 1, 1]
        )

        # Visualization domain of the focal surface
        domain_focal_surf = ddg.nets.DiscreteDomain([[0, 2], [-4, 3]])

        # Visualization
        dnet_bevel = 0.04
        camera_location = (-13, -5, 8)
        camera_look_at_point = (5, -2, 0)
    else:
        raise ValueError(
            "The only examples implemented are called 'smooth' and 'discrete'."
        )
    # [example-1]

    # [example-2]
    ###############################################
    # Spheres
    ###############################################

    # Euclidean ###################################
    def e_spheres_fct(i):
        return euc.sphere_from_affine_point_and_normals(dnet.fct(i), radii[i])

    # Moebius #####################################
    m_homogeneous_points_fct = homogeneous_lift_of_spheres_fct(
        e_spheres_fct
    )  # Homogeneous Points in R^{4, 1}
    # [example-2]

    # [example-3]
    ###############################################
    # Spheres of anti-similitude
    ###############################################

    # Moebius #####################################
    m_points_of_anti_similitude_fct = (
        m_points_of_anti_similitude_from_homogeneous_lifts(m_homogeneous_points_fct)
    )  # Points in M+
    m_spheres_of_anti_similitude_fct = m_points_to_spheres(
        m_points_of_anti_similitude_fct
    )  # Quadrics in M

    # Euclidean ####################################
    e_spheres_of_anti_similitude_fct = project_function_of_objects_down(
        m_spheres_of_anti_similitude_fct
    )  # Spheres in R3
    # [example-3]

    # [example-4]
    ###############################################
    # Envelope (as sphere inversions in spheres of anti-similitude)
    ###############################################

    # Intersecting the first sphere with the plane (given in the initial values
    # of the example) determines a starting circle. A sampling yields a list of points.
    # Iterative sphere inversions in the mid-spheres of anti-similitude
    # form the discrete channel surface.

    # Starting circle ##############################
    circle = ddg.geometry.intersection.intersect(
        plane, euc.sphere_to_quadric(e_spheres_fct(0))
    )  # Conic in R3
    circle_dnet = ddg.sample_smooth_net(
        ddg.to_smooth_net(circle), sampling=[v_circle_sampling, "t"]
    )  # DiscreteNet in R3
    cirlce_pts = [
        circle_dnet.fct(k) for k, in circle_dnet.domain.traverser
    ]  # List of affine coordinates
    m_starting_pts = [
        projective_to_moebius(subspace_from_affine_points(p)) for p in cirlce_pts
    ]  # List of points in M

    # Moebius #####################################
    m_surface_points_fct = m_iterative_sphere_inversions(
        m_points_of_anti_similitude_fct, m_starting_pts
    )  # Two-parameter family of points in M

    # Euclidean ####################################
    e_surface_points_fct = function_projective_to_affine(
        project_function_of_objects_down(m_surface_points_fct)
    )  # Two-parameter family of affine coordinates

    def channel_surface_parameterization(i, j):
        return e_surface_points_fct(i)[j]

    # [example-4]

    # [example-5]
    ###############################################
    # Circumcircles of the channel surface
    ###############################################

    # Channel surfaces are circular, for each face we determine the
    # circle through its four vertices. It suffices to determine the circle through
    # three of its vertices.
    def face_circles_data(i, j):
        """
        Returns the center, radius, and normal of the
        circumcircle of the face [(i, j), (i+1, j), (i+1, j+1), (i, j+1)].
        """
        points = [
            channel_surface_parameterization(i, j),
            channel_surface_parameterization(i + 1, j),
            channel_surface_parameterization(i, j + 1),
        ]
        c, r, n = ddg.math.euclidean.circle_through_three_points(*points)
        return c, r, n

    def face_circles(i, j):
        """
        Returns the circumcircle of the face [(i, j), (i+1, j), (i+1, j+1), (i, j+1)].
        """
        return euc.sphere_from_affine_point_and_normals(*face_circles_data(i, j))

    # [example-5]

    # [example-6]
    ##############################################
    # Normal line congruence
    ##############################################
    # From the same data we can compute face normal lines,
    # orthogonal to the faces and passing through the centers
    # of the circumcircles.
    def normal_line_congruence(i, j):
        """
        Returns the normal line to each face passing through the circumcircles center.
        """
        c, r, n = face_circles_data(i, j)
        line = ddg.geometry.subspaces.subspace_from_affine_points_and_directions(c, n)
        return line

    # [example-6]

    # [example-7]
    ##############################################
    # Focal surfaces
    ##############################################

    # Neighbouring normal lines intersect. The intersection points
    # determine two new surfaces, the focal surfaces.
    # One of these surfaces degenerates to a discrete curve. We compute the other.

    def focal_surf_i(i, j):
        """
        Returns the intersection points of successive normal lines in the fist
        parameter direction.
        """
        return ddg.geometry.intersection.intersect(
            normal_line_congruence(i, j), normal_line_congruence(i + 1, j)
        ).affine_point

    # [example-7]
    ################################################
    # Visualization
    ################################################
    # [visualization-2]
    col = ddg.visualization.blender.collection.collection(
        f"Channel_surface_{example_name}",
        children=[
            f"spheres_{example_name}",
            f"envelope_{example_name}",
            f"circumscribed_circles_{example_name}",
            f"normal_lines_{example_name}",
            f"focal_surfaces_{example_name}",
        ],
    )
    # [visualization-2]

    # [visualization-3]
    ###############################################
    # Channel surface
    ###############################################

    # Discrete curve
    ddg.to_blender_object_helper(
        dnet, curve_properties={"bevel_depth": dnet_bevel}, material=black
    )
    # Spheres
    visualize_spheres(
        e_spheres_fct,
        domain=dnet.domain,
        material=transparent_sphere,
        collection=col.children[0],
    )

    # Channel surface
    channel_surface_domain = ddg.nets.DiscreteDomain(
        [dnet.domain[0], [0, v_circle_sampling - 1, True]]
    )
    channel_bobj = ddg.to_blender_object_helper(
        ddg.nets.DiscreteNet(
            channel_surface_parameterization,
            domain=channel_surface_domain,
        ),
        material=lightgray,
        only_wire=True,
        collection=col.children[1],
    )

    wire_bobj_to_curve(channel_bobj)

    ###############################################
    # Circumscribed circles
    ###############################################
    circumcircles_domain = ddg.nets.DiscreteDomain(
        [[dnet.domain[0][0], dnet.domain[0][1] - 1], [0, v_circle_sampling - 2]]
    )
    visualize_circles(
        face_circles,
        domain=circumcircles_domain,
        material=gray,
        collection=col.children[2],
    )

    ###############################################
    # Normal line congruence
    ###############################################

    visualize_lines(
        normal_line_congruence,
        domain=circumcircles_domain,
        line_domain=[[-2, 2]],
        material=orange,
        collection=col.children[3],
    )

    ###############################################
    # Focal surface
    ###############################################

    visualize_lines(
        normal_line_congruence,
        domain=domain_focal_surf,
        line_domain=[[0, 100]],
        collection=col.children[4],
        material=orange,
    )

    focal_dnet = ddg.DiscreteNet(focal_surf_i, domain=domain_focal_surf)
    ddg.to_blender_object_helper(focal_dnet, material=gray, collection=col.children[4])
    focal_wire = ddg.to_blender_object_helper(
        focal_dnet, only_wire=True, material=gray, collection=col.children[4]
    )
    wire_bobj_to_curve(focal_wire, bevel_depth=0.04)
    # [visualization-3]

    #############################################
    # CAMERAS
    #############################################

    camera = ddg.visualization.blender.camera.camera(
        location=camera_location, collection=col
    )
    ddg.visualization.blender.camera.look_at_point(camera, camera_look_at_point)
    bpy.context.scene.camera = camera


channel_surface_example("discrete")
# channel_surface_example("smooth")
#############################################
# RENDERING AND LIGHT
#############################################

# Change to cycles rendering engine, increase max_samples size
# enable film transparency, and scale up the resolution for a good quality image
max_samples = 8
resolution = 500

ddg.visualization.blender.render.setup_cycles_renderer(
    device="CPU", max_samples=max_samples, time_limit=0
)

ddg.visualization.blender.render.set_world_background(color=(1, 1, 1, 1), strength=1)
bpy.context.scene.render.resolution_x = resolution
bpy.context.scene.render.resolution_y = resolution

# ddg.visualization.blender.render.set_film_transparency()
bpy.context.scene.view_settings.view_transform = "Standard"
light = ddg.visualization.blender.light.light(
    location=(0, 0, 100), type_="SUN", energy=0.5
)

from testing.tests.examples.blender.snapshot import opt_in  # noqa: E402

opt_in()
