# AUTHOR NikitaKrutov
# VERSION 0.1.1
# This script can explode groups and put them back together. Rendering is possible, but the camera animation needs to be set up manually.


import lux
import luxmath
import math
import os
import sys


def get_user_input():
    """
    Display a GUI to get material names.
    """
    active_camera = lux.getCamera()
    all_cameras = lux.getCameras()[::-1]
    all_cameras_only  = [ac for ac in all_cameras if ac != "last_active"]
    
    dialog_fields = [
        ("dxt", lux.DIALOG_TEXT, "Translation in x", "250"),
        ("dyt", lux.DIALOG_TEXT, "Translation in y", "1000"),
        ("dzt", lux.DIALOG_TEXT, "Translation in z", "0"),
        (lux.DIALOG_LABEL, "Due to matrix weirdness it's best to avoid 45° and 90° angles:"),
        ("angle_x", lux.DIALOG_INTEGER, "Angle in x:", 0, (-180, 180)),
        ("angle_y", lux.DIALOG_INTEGER, "Angle in y:", 0, (-180, 180)),
        ("angle_z", lux.DIALOG_INTEGER, "Angle in z:", -30, (-180, 180)),
        ("frames", lux.DIALOG_INTEGER, "Number of Frames:", 30, (1, 10000)),
        (lux.DIALOG_LABEL, "Check Reverse to play invert the transformations, playing the animation backwards"),
        ("bool_reverse", lux.DIALOG_CHECK, "Reverse transformations", False),
        ("easing", lux.DIALOG_ITEM, "Easing", "ease out", ["linear", "ease in", "ease out", "ease in/out"]),
        (lux.DIALOG_LABEL, "Render will produce frames and an animation in the .bip's location"),
        ("bool_animation", lux.DIALOG_CHECK, "Render", False),
        ("camera_name", lux.DIALOG_ITEM, "Camera Name", active_camera, all_cameras_only),
        (lux.DIALOG_LABEL, "Checking Use Group Transformations will use the Scene's matrix, but may cause inaccuracies"),
        ("bool_use_matrix", lux.DIALOG_CHECK, "Use Group Transformations", False),
        (lux.DIALOG_LABEL, "Keep the translation global. Unchecking will likely not properly reverse"),
        ("bool_absolute", lux.DIALOG_CHECK, "Absolute Rotation", True),
        (lux.DIALOG_LABEL, "Checking RESET will only reset the Groups transformations"),
        ("bool_reset", lux.DIALOG_CHECK, "RESET GROUPS?", False)]
    
    user_input = lux.getInputDialog(title="Explosion Settings", 
                                    desc="Enter the total offset:",
                                    values=dialog_fields,
                                    id="material_name")
    
    if not user_input:
        return None, None, None, None, None, None, None, None, None, None, None, None, None, None
        sys.exit(0)
        
    dxt = float(user_input["dxt"])
    dyt = float(user_input["dyt"])
    dzt = float(user_input["dzt"])
    frames = int(user_input["frames"])
    angle_x, angle_y, angle_z = int(user_input["angle_x"]), int(user_input["angle_y"]), int(user_input["angle_z"])

    if bool(user_input["bool_reverse"]) == True:
        dxt, dyt, dzt = -dxt, -dyt, -dzt        
        angle_x, angle_y, angle_z = -angle_x, -angle_y, -angle_z

    overflow_angle = [0,0,0]

            
    print("#########################################")
    
    #angle_x, angle_y, angle_z = adjust_angles([angle_x, angle_y, angle_z]) # Fix 45° weirdness
              
    return dxt, dyt, dzt, angle_x, angle_y, angle_z, user_input["bool_absolute"], frames, user_input["bool_animation"], user_input["easing"][0], user_input["camera_name"][1], bool(user_input["bool_reset"]), bool(user_input["bool_reverse"]), bool(user_input["bool_use_matrix"])


dxt, dyt, dzt, angle_x, angle_y, angle_z, bool_absolute, frames, bool_animation, easing_type, camera_name, bool_reset, bool_reverse, bool_use_matrix = get_user_input()



def convert_angles(angle_x, angle_y, angle_z, angle_current, print_position, frame, bool_reverse):
    if any(round(abs(angle),8) >= math.pi/2 for angle in [angle_x, angle_y, angle_z]):
        # Don't adjust for angles <90°

        a1,a2,a3 = angle_current                       
        pi = math.pi
        frame_start = 0
        angle_current_adjusted = angle_current

        
        bool_convert = False

        
        #######################################
        ############PRINT POSITION#############
        #######################################
        print_position = (len(explosion)-1)-0
        
        # Check and update if they are -0
        a1 = 0 if round(a1,12) == -0.0 else a1
        a2 = 0 if round(a2,12) == -0.0 else a2
        a3 = 0 if round(a3,12) == -0.0 else a3

        # Fixes 0,120,-30 -> +/-180,60,150
        if round(abs(a1), 12) == round(pi, 12):
            angle_current_adjusted = [(pi - a1) % (2 * pi), (pi - a2) % (2 * pi), (-pi + a3) % (2 * pi)]
            if n == print_position: print("f1")

        # Fixes 100,100,-20 -> -80,80,160
        # Fixes -80, 100, -20 -> 100, 80, 160
        if (round(abs(a3), 12) > round(pi / 2, 12) and a1 < 0 and a3 > 0) or (round(abs(a3), 12) > round(pi / 2, 12) and abs(angle_z) < 90):                    
            angle_current_adjusted = [(pi + a1) % (2 * pi), (pi - a2) % (2 * pi), (-pi + a3) % (2 * pi)]
            if n == print_position: print("f2")

        # Fixes -100,-100,80 -> 80,-80,-100
        if round(abs(a3), 12) > round(pi / 2, 12) and a1 > 0 and a2 < 0:
            angle_current_adjusted = [(-pi + a1) % (2 * pi), (-pi - a2) % (2 * pi), (pi + a3) % (2 * pi)]
            if n == print_position: print("f3")

        # Fixes 100,-100,80 -> -80,-80,-100
        if round(abs(a3), 12) > round(pi / 2, 12) and a1 < 0 and a2 < 0:
            angle_current_adjusted = [(pi + a1) % (2 * pi), -(pi + a2) % (2 * pi), (pi + a3) % (2 * pi)]
            if n == print_position: print("f4")

         # Fixes -150,120,-60 -> 30,60,120
        if round(abs(a3), 12) > round(pi / 2, 12) and a1 > 0 and round(abs(a1), 12) < round(abs(a2), 12):                    
            angle_current_adjusted = [(pi + a1) % (2 * pi), (pi - a2) % (2 * pi), (-pi + a3) % (2 * pi)]
            if n == print_position: print("f5")

        # Fixes -150,-150,120 -> -30,-30,-60
        # Fixes 150,120,150 -> -30,60,-30
        # Fixes -150,-150,120 -> 30,-30,-60
        #if all(round(angle, 12) < 0 for angle in angle_current) and round(abs(a3), 12) <= round(pi / 2, 12):
        if bool_reverse == True:
            if frame == frame_start or euler_convert_c6[index] == True:
                if round(abs(a3), 12) <= round(pi / 2, 12) and all(round(angle, 12) != 0 for angle in angle_current):
                    if (a1 < 0 and a2 < 0 and a3 < 0) or (a1 < 0 and a3 < 0 and round(abs(a2),12)>round(abs(a1),12) and round(abs(a2),12)>round(abs(a3),12)) or (a2 < 0 and a3 < 0  and round(abs(a3),12)>round(abs(a1),12) and round(abs(a3),12)>round(abs(a2),12)):
                        if round(abs(a1),8) < round(abs(a2),8) or round(abs(a1),8) < round(abs(a3),8): # -150,-150,-50 (-50,-50,-16,67)
                            angle_current_adjusted = [(pi + a1) % (2 * pi), -(pi + a2) % (2 * pi), (pi + a3) % (2 * pi)]
                            if frame == frame_start: euler_convert_c6[index] = True
                            if n == print_position: print("f6")

               
        # Fixes 120, 120, 120 -> -60,60,-60
        # Fixes 120, -120, 120 -> -60,-60,-60
        # Fixes -120, 120, 120 -> 60,60,-60
        # Fixes -120, -120, 120 -> 60,-60,-60
        # Fixes 120, 120, -120 -> -60,60,60
        # Fixes 120, -120, -120 -> -60,-60,60
        # Fixes -120, 120, -120 -> 60,60,60
        # Fixes -120, -120, -120 -> 60,-60,60
        if abs(angle_y)>=90 and (abs(round(angle_y,6)) + abs(round(math.degrees(a2),6))) == 180:
            if frame == frame_start or bool_convert == True or bool_reverse == False:
                angle_current_adjusted = [(pi + a1) % (2 * pi), -(pi + a2) % (2 * pi), (pi + a3) % (2 * pi)]
                bool_convert == True
            
                if n == len(explosion)-1:
                    print("f8")
        
        
        # Fixes y = 90
        if bool_reverse == True:
            if abs(round(a1,6)) == 0 and abs(angle_x) > 0 and abs(a3) > 0: # Fixes -72,90,-18 - > 0, 90, 54
                        angle_current_adjusted = [-(pi - 2*a3) % (2 * pi), (pi/2) % (2 * pi), (pi - 2*a3 - pi/2) % (2 * pi)]
            if frame == frame_start or euler_convert[index] == True:
                
                if abs(angle_y)>=90 and abs(round(a2,6)) <= round(pi/2,6):
                    if (abs(round(angle_y*n/frames/parts*easing,6)) + abs(round(math.degrees(a2*n/frames/parts*easing),6))) == 180:
                    
                        if (a2 > 0 and a1 < 0 and a3 < 0) or (a2 < 0 and a1 > 0 and a3 > 0):
                            angle_current_adjusted = [(pi + a1) % (2 * pi), -(pi + a2) % (2 * pi), (pi + a3) % (2 * pi)]                                                
                            if frame == frame_start: euler_convert[index] = True
                            if n == print_position: print("f9",a1,a2,a3)
                    if abs(round(a2,6)) == abs(round(pi/2,6)) and a1 >= 0 and a3 > 0:
                        angle_current_adjusted = [-(pi - a3) % (2 * pi), (pi/2) % (2 * pi), (-pi/4) % (2 * pi)]                                                
                        if frame == frame_start: euler_convert[index] = True
                        if n == print_position: print("f9-2",a1,a2,a3)
                    if abs(a1) < pi/2 and abs(a2) < pi/2 and abs(a3) > pi/2: # Fixes -150,-150,-50 -> 30,-30,130 # Fixes -150,120,-60 -> 30,60,120
                        angle_current_adjusted = [(-pi + a1) % (2 * pi), -(pi + a2) % (2 * pi), (-pi + a3) % (2 * pi)]
                        if frame == frame_start: euler_convert[index] = True
                        if n == print_position: print("f7")
                                    

        # Fixes angles > 180°
        angle_current_adjusted = list(angle_current_adjusted)
        for i in range(len(angle_current_adjusted)):
            if angle_current_adjusted[i] > math.pi:
                #print(math.degrees(angle_current_adjusted[i]))
                angle_current_adjusted[i] = angle_current_adjusted[i] - 2 * math.pi
            elif angle_current_adjusted[i] < -math.pi:
                angle_current_adjusted[i] = 2 * math.pi - angle_current_adjusted[i]
                
    return angle_current_adjusted





scene_root = lux.getSceneTree()
model_nodes = scene_root.find(types=[lux.NODE_TYPE_MODEL])
#model_nodes = scene_root.find(types=[lux.NODE_TYPE_GROUP])


# This is the collection that  contains the groups. It will be created from the Groups in the Scene.
# Currently it collects any Group that starts with three numbers and sorts them.
# A collection and sorting by assembly names is possible as well as a combination of both or more detection types.

explosion = sorted([model for model in model_nodes if len(model.getName()) >= 3 and model.getName()[0].isdigit() and model.getName()[1].isdigit() and model.getName()[2].isdigit()], key=str)

#print(explosion)
print_position = (len(explosion)-1)-0

if bool_reset:
    for exp in explosion:
        rotation_matrix_inverse = exp.getTransform()
        rotation_matrix_inverse.invert()
        
        exp.applyTransform(rotation_matrix_inverse, absolute = False)

else:

    if bool_animation == True:
        file_name = lux.getSceneInfo()['name'].replace(".bip", "")  # Get the file name without the .bip part
        path = lux.getSceneInfo()['file'].replace(lux.getSceneInfo()['name'], "") # Get file path without the file name
        width = lux.getSceneInfo()['render_width']
        height = lux.getSceneInfo()['render_height']
        lux.setCamera(camera_name)

        if width % 2 != 0: width += 1       # Make sure that the video can be encoded
        if height % 2 != 0: height += 1


    compose_only = False    # Option for only compositing the video

    if easing_type > 0 and frames > 1:
        easing_values = []
        for i in range(frames):

            if easing_type == 1: easing = (1 - math.cos(math.pi * i / (frames - 1))) / 2 #ease in 0.5 * (1 - math.cos(math.pi * i / (frames - 1))) # math.sin((i / (frames - 1)) * (math.pi / 2))
            elif easing_type == 2: easing = (1 - math.cos(math.pi * (frames - 1 - i) / (frames - 1))) / 2 #ease out
            elif easing_type == 3: easing = math.sin(math.pi * i / (frames - 1)) #ease in/out
            easing_values.append(easing)

        sum_easing = sum(easing_values)
        normalized_easing_values = [easing / sum_easing * frames for easing in easing_values]
    easing = 1

    #for v in normalized_easing_values:
    #    print(v)

    ax = ay = az = ax_p = ay_p = az_p = 0
    parts = len(explosion)-1


    bool_convert = False
    euler_convert=[]
    euler_convert_c6=[]
    for exp in explosion:
        euler_convert.append(False)
        euler_convert_c6.append(False)




    current_rotation = []

    
    # Create starting Rotations
    for index, exp in enumerate(explosion):

        if bool_reverse == True:
            initial_rotation = [math.radians(-angle_x/parts*index), math.radians(-angle_y/parts*index), math.radians(-angle_z/parts*index)]
            current_rotation.append(initial_rotation)
            if not bool_use_matrix:
                rotation_matrix_inverse = exp.getTransform()
                rotation_matrix_inverse.invert()
                exp.applyTransform(rotation_matrix_inverse, absolute = False)

                Ri = luxmath.Matrix().makeRotate(initial_rotation)
                dx, dy, dz = -dxt/parts*index, -dyt/parts*index, -dzt/parts*index                
                T = luxmath.Matrix().makeIdentity().translate(luxmath.Vector(dx,dy,dz))
                exp.applyTransform(Ri.mul(T), absolute = False)                
                
        else:
            current_rotation.append([0, 0, 0])
    


    for frame in range(frames+1):
        print(frame)
        n = 0 # for counting the current part

        if bool_animation == True:

            lux.setAnimationFrame(frame)
            try:
                if compose_only == False:
                    output_file_path = os.path.join(rf"{path}{file_name}_Frames", f"{file_name}_{frame:04d}.png")   #set path
                    lux.renderImage(output_file_path, width, height)
            except KeyboardInterrupt:
                print("Interrupt received, exiting...")
                sys.exit(0)

        if easing_type > 0 and frames > 1:
            if frame <= frames:
                easing = normalized_easing_values[frame-1]

        #print(normalized_easing_values)
                

        # determine sign of rotations
        sign = 1
        angle = [a for a in [angle_x, angle_y, angle_z] if a != 0]

        if len(angle) > 0:
            for a in angle:
                sign*=a


        
        
        for index, exp in enumerate(explosion):
            
            if frame >= 1: # Make sure the first frame is rendered without transformations in its original state
                
                dx, dy, dz = dxt*n/frames/parts*easing, dyt*n/frames/parts*easing, dzt*n/frames/parts*easing
            
                T = luxmath.Matrix().makeIdentity().translate(luxmath.Vector(dx,dy,dz))
                    
                if angle_x == 0 and angle_y == 0 and angle_z == 0:
                    exp.applyTransform(T, absolute = False)
                    #print(str(frame) + ", " + str(frames))
                    
                    
                else:
                    ax = angle_x*n/frames/parts*easing
                    ay = angle_y*n/frames/parts*easing
                    az = angle_z*n/frames/parts*easing

                    angle = angle_transform = [math.radians(ax), math.radians(ay), math.radians(az)]

                    R_original = exp.getTransform()
                    R_transorm = luxmath.Matrix().makeRotate([math.radians(ax), math.radians(ay), math.radians(az)])

                    angle_current = tuple(angle for angle in R_original.getTransformation()[1].val())             
             

                    if bool_absolute == True:
                        
                        # Rotate to align local axis with global
                        
                        rotation_matrix_inverse = exp.getTransform()
                        rotation_matrix_inverse.invert()

                        rotation_matrix_reset = luxmath.Matrix().makeRotate(rotation_matrix_inverse.getTransformation()[1].val())
                        exp.applyTransform(rotation_matrix_reset, absolute = False)

                        # Do the transformations
                        exp.applyTransform(T, absolute = False)
                        
                        if bool_use_matrix:
                            angle_current_adjusted = convert_angles(angle_x, angle_y, angle_z, angle_current, print_position, frame, bool_reverse)
                        else:
                            angle_current_adjusted = current_rotation[index]
                                       
                        
                        # Fixes angles > 180°
                        angle_current_adjusted = list(angle_current_adjusted)
                        for i in range(len(angle_current_adjusted)):
                            if angle_current_adjusted[i] > math.pi:
                                #print(math.degrees(angle_current_adjusted[i]))
                                angle_current_adjusted[i] = angle_current_adjusted[i] - 2 * math.pi
                            elif angle_current_adjusted[i] < -math.pi:
                                angle_current_adjusted[i] = 2 * math.pi - angle_current_adjusted[i]                   
                                    

                        # Restore original Rotation and add Rotation
                        combined_angle = [angle_transform[0] + angle_current_adjusted[0], angle_transform[1] + angle_current_adjusted[1], angle_transform[2] + angle_current_adjusted[2]]
                        R = luxmath.Matrix().makeRotate(combined_angle)

                        current_rotation[index] = combined_angle


                        # Normalize Matrix
                        R_norm = R
                        R_norm.normalize()
                        exp.applyTransform(R, absolute = False)


                        # Matrix debugging
                        bool_debugging = False
                        
                        if bool_debugging:
                            a1, a2, a3 = angle_current
                            b1, b2, b3 = angle_transform
                            C = luxmath.Matrix().makeRotate([a1+b1, a2+b2, a3+b3])

                            
                            if n == print_position:
                            
                                print("rotation")
                                print(f"O: {[round(math.degrees(angle),2) for angle in R_original.getTransformation()[1].val()]}")
                                print(f"B: {[round(math.degrees(angle),2) for angle in R_transorm.getTransformation()[1].val()]}")
                                print(f"R: {[round(math.degrees(angle),2) for angle in R.getTransformation()[1].val()]}")
                                print(f"C: {[round(math.degrees(angle),2) for angle in C.getTransformation()[1].val()]}")
                        
                    else:
                        combined_transform = [angle_transform[0]+angle_current[0],angle_transform[1]+angle_current[1],angle_transform[2]+angle_current[2]]
                        if a > 0:
                            exp.applyTransform(T, absolute = False)
                            exp.applyTransform(combined_transform, absolute = False)
                        else:
                            exp.applyTransform(combined_transform, absolute = False)
                            exp.applyTransform(T, absolute = False)
                        
            n += 1
            
            #print(f"--------------{frame}, {n}")
        print(f"-##################-{frame}")

        

    if bool_animation == True:
        print (f"Frames rendered to \n{path}{file_name}_Frames")

        path2 = os.path.join(f"{path}{file_name}_Frames", f"{file_name}.mp4")
        
        print("Encoding video \n{}".format(path2))
        print(f"{file_name}_%d.png")
        lux.encodeVideo(folder = f"{path}{file_name}_Frames", frameFiles = f"{file_name}_%d.png", videoName = f"{file_name}.mp4",
                        fps = 30, firstFrame = 1, lastFrame = frames)  # Create video in the frames folder