from typing import List
import time
import ast
import streamlit as st
from shared.constants import (
    GPU_INFERENCE_ENABLED_KEY,
    AnimationStyleType,
    AnimationToolType,
    STEERABLE_MOTION_WORKFLOWS,
    ConfigManager,
)
import time
from ui_components.constants import DEFAULT_SHOT_MOTION_VALUES
from ui_components.methods.ml_methods import generate_sm_video
from ui_components.widgets.sm_animation_style_element import (
    animation_sidebar,
    individual_frame_settings_element,
    select_motion_lora_element,
    select_sd_model_element,
    video_motion_settings,
)
from ui_components.models import InternalFileObject, InternalShotObject
from ui_components.methods.animation_style_methods import (
    is_inference_enabled,
    toggle_generate_inference,
    transform_data,
    update_session_state_with_animation_details,
    update_session_state_with_dc_details,
)
from utils import st_memory
from utils.common_decorators import update_refresh_lock
from utils.state_refresh import refresh_app
from utils.data_repo.data_repo import DataRepo

DEFAULT_SM_MODEL = "dreamshaper_8.safetensors"


def sm_video_rendering_page(shot_uuid, img_list: List[InternalFileObject], column1, column2):
    data_repo = DataRepo()
    config_manager = ConfigManager()
    gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False)
    shot: InternalShotObject = data_repo.get_shot_from_uuid(shot_uuid)

    settings = {
        "animation_tool": AnimationToolType.ANIMATEDIFF.value,
    }
    shot_meta_data = {}

    with st.container():

        # ----------- HEADER OPTIONS -------------

        # ----------- INDIVIDUAL FRAME SETTINGS -----------
        (
            strength_of_frames,
            distances_to_next_frames,
            speeds_of_transitions,
            freedoms_between_frames,
            individual_prompts,
            individual_negative_prompts,
            motions_during_frames,
        ) = individual_frame_settings_element(shot_uuid, img_list)

        # ----------- SELECT SD MODEL -----------
        sd_model, model_files = select_sd_model_element(shot_uuid, DEFAULT_SM_MODEL)

        # ----------- SELECT MOTION LORA ------------
        if gpu_enabled:
            motion_lora_data = select_motion_lora_element(shot_uuid, model_files)
        else:
            motion_lora_data = []

        # ----------- OTHER SETTINGS ------------
        (
            strength_of_adherence,
            overall_positive_prompt,
            overall_negative_prompt,
            type_of_motion_context,
            allow_for_looping,
            high_detail_mode,
            stabilise_motion,
            styling_lora_data,
        ) = video_motion_settings(shot_uuid, img_list)

        type_of_frame_distribution = "dynamic"
        type_of_key_frame_influence = "dynamic"
        type_of_strength_distribution = "dynamic"
        linear_frame_distribution_value = 16
        linear_key_frame_influence_value = 1.0
        linear_cn_strength_value = 1.0
        relative_ipadapter_strength = 1.0
        relative_cn_strength = 0.0
        project_settings = data_repo.get_project_setting(shot.project.uuid)
        width = project_settings.width
        height = project_settings.height
        img_dimension = f"{width}x{height}"
        motion_scale = 1.3
        interpolation_style = "ease-in-out"
        buffer = 4

        (
            dynamic_strength_values,
            dynamic_key_frame_influence_values,
            dynamic_frame_distribution_values,
            context_length,
            context_stride,
            context_overlap,
            multipled_base_end_percent,
            multipled_base_adapter_strength,
            prompt_travel,
            negative_prompt_travel,
            motion_scales,
        ) = transform_data(
            strength_of_frames,
            freedoms_between_frames,
            speeds_of_transitions,
            distances_to_next_frames,
            type_of_motion_context,
            strength_of_adherence,
            individual_prompts,
            individual_negative_prompts,
            buffer,
            motions_during_frames,
        )

        settings.update(
            ckpt=sd_model,
            width=width,
            height=height,
            buffer=4,
            motion_scale=motion_scale,
            motion_scales=motion_scales,
            high_detail_mode=high_detail_mode,
            image_dimension=img_dimension,
            output_format="video/h264-mp4",
            prompt=overall_positive_prompt,
            allow_for_looping=allow_for_looping,
            negative_prompt=overall_negative_prompt,
            interpolation_type=interpolation_style,
            stmfnet_multiplier=2,
            relative_ipadapter_strength=relative_ipadapter_strength,
            relative_cn_strength=relative_cn_strength,
            type_of_strength_distribution=type_of_strength_distribution,
            linear_strength_value=str(linear_cn_strength_value),
            dynamic_strength_values=str(dynamic_strength_values),
            linear_frame_distribution_value=linear_frame_distribution_value,
            dynamic_frame_distribution_values=dynamic_frame_distribution_values,
            type_of_frame_distribution=type_of_frame_distribution,
            type_of_key_frame_influence=type_of_key_frame_influence,
            linear_key_frame_influence_value=float(linear_key_frame_influence_value),
            dynamic_key_frame_influence_values=dynamic_key_frame_influence_values,
            normalise_speed=True,
            ipadapter_noise=0.3,
            animation_style=AnimationStyleType.CREATIVE_INTERPOLATION.value,
            context_length=context_length,
            context_stride=context_stride,
            context_overlap=context_overlap,
            multipled_base_end_percent=multipled_base_end_percent,
            multipled_base_adapter_strength=multipled_base_adapter_strength,
            individual_prompts=prompt_travel,
            individual_negative_prompts=negative_prompt_travel,
            animation_stype=AnimationStyleType.CREATIVE_INTERPOLATION.value,
            max_frames=str(dynamic_frame_distribution_values[-1]),
            stabilise_motion=stabilise_motion,
            motion_lora_data=motion_lora_data,
            styling_lora_data=styling_lora_data,
            shot_data=shot_meta_data,
            pil_img_structure_control_image=st.session_state[
                f"structure_control_image_{shot.uuid}"
            ],  # this is a PIL object
            strength_of_structure_control_image=st.session_state[
                f"strength_of_structure_control_image_{shot.uuid}"
            ],
            filename_prefix="AD_",
        )

        st.markdown("***")
        st.markdown("##### Generation Settings")

        filtered_and_sorted_workflows = sorted(
            (workflow for workflow in STEERABLE_MOTION_WORKFLOWS if workflow["display"]),
            key=lambda x: x["order"],
        )

        generation_types = [workflow["name"] for workflow in filtered_and_sorted_workflows]

        footer1, footer2 = st.columns([1, 1])
        with footer1:
            number_of_generation_steps = st_memory.number_input(
                "Number of generation steps:",
                key=f"number_of_generation_steps_{shot.uuid}",
                min_value=5,
                max_value=30,
                step=5,
                value=20,
                help="You can dial this down to get a faster video. But beware, the quality will be lower.",
            )

            type_of_generation = st.radio(
                "Workflow variant:",
                options=generation_types,
                key="creative_interpolation_type",
                horizontal=True,
                index=st.session_state.get(f"type_of_generation_index_{shot.uuid}", 0),
                help="""
                
                    **Slurshy Realistiche**: good for simple realistic motion.

                    **Smooth n' Steady**: good for slow, smooth transitions. 
                    
                    **Chocky Realistiche**: good for realistic motion and chaotic transitions. 

                    **Liquidy Loop**: good for liquid-like motion with slick transitions. Also loops!
                    
                    **Fast With A Price**: runs fast but with a lot of detail loss.
                    
                    **Rad Attack**: good for realistic motion but with a lot of detail loss.
                    
                    """,
            )

            if (
                type_of_generation
                != generation_types[st.session_state.get(f"type_of_generation_index_{shot.uuid}", 0)]
            ):
                st.session_state[f"type_of_generation_index_{shot.uuid}"] = generation_types.index(
                    type_of_generation
                )
                refresh_app()

        generate_vid_inf_tag = "generate_vid"
        manual_save_inf_tag = "manual_save"

        st.write("")
        animate_col_1, _, _ = st.columns([2, 1, 1])
        with animate_col_1:
            variant_count = 1

            if is_inference_enabled(generate_vid_inf_tag) or is_inference_enabled(manual_save_inf_tag):

                update_refresh_lock(True)
                # last keyframe position * 16
                duration = float(dynamic_frame_distribution_values[-1] / 16)
                data_repo.update_shot(uuid=shot_uuid, duration=duration)

                # converting PIL imgs to InternalFileObject
                from ui_components.methods.common_methods import save_new_image

                key = "pil_img_structure_control_image"
                image = None
                if settings[key]:
                    image = save_new_image(settings[key], shot.project.uuid)
                    del settings[key]
                    new_key = key.replace("pil_img_", "") + "_uuid"
                    settings[new_key] = image.uuid

                # print("******************* ", st.session_state.get(f"{shot_uuid}_preview_mode", False))
                if st.session_state.get(f"{shot_uuid}_preview_mode", False):
                    start_frame, end_frame = st.session_state.get(f"frames_to_preview_{shot_uuid}", (1, 3))
                    img_list = img_list[start_frame - 1 : end_frame]
                    settings["inference_type"] = "preview"
                    trigger_shot_update = False

                else:
                    trigger_shot_update = True

                shot_data = update_session_state_with_animation_details(
                    shot_uuid,
                    img_list,
                    strength_of_frames,
                    distances_to_next_frames,
                    speeds_of_transitions,
                    freedoms_between_frames,
                    motions_during_frames,
                    individual_prompts,
                    individual_negative_prompts,
                    styling_lora_data,
                    DEFAULT_SM_MODEL,
                    high_detail_mode,
                    image.uuid if image else None,
                    settings["strength_of_structure_control_image"],
                    next(
                        (
                            index
                            for index, workflow in enumerate(filtered_and_sorted_workflows)
                            if workflow["name"] == type_of_generation
                        ),
                        0,
                    ),
                    stabilise_motion=stabilise_motion,
                    trigger_shot_update=trigger_shot_update,
                )

                settings.update(shot_data=shot_data)
                settings.update(number_of_generation_steps=number_of_generation_steps)
                settings.update(type_of_generation=type_of_generation)
                settings.update(filename_prefix="AD_")

                st.success(
                    "Generating clip - see status in the Generation Log in the sidebar. Press 'Refresh log' to update."
                )

                positive_prompt = ""
                append_to_prompt = ""
                for idx, img in enumerate(img_list):
                    if img.location:
                        b = img.inference_params
                        prompt = b.get("prompt", "") if b else ""
                        prompt += append_to_prompt
                        frame_prompt = f"{idx * linear_frame_distribution_value}_" + prompt
                        positive_prompt += ":" + frame_prompt if positive_prompt else frame_prompt
                    else:
                        st.error("Please generate primary images")
                        time.sleep(0.7)
                        refresh_app()

                if f"{shot_uuid}_backlog_enabled" not in st.session_state:
                    st.session_state[f"{shot_uuid}_backlog_enabled"] = False

                if is_inference_enabled(generate_vid_inf_tag):
                    generate_sm_video(
                        shot_uuid,
                        settings,
                        variant_count,
                        st.session_state[f"{shot_uuid}_backlog_enabled"],
                        img_list,
                    )

                updated_additional_params = {
                    f"{shot_uuid}_backlog_enabled": False,
                    f"{shot_uuid}_preview_mode": st.session_state[f"{shot_uuid}_preview_mode"],
                }

                position = (
                    generate_vid_inf_tag
                    if is_inference_enabled(generate_vid_inf_tag)
                    else manual_save_inf_tag
                )
                toggle_generate_inference(position, **updated_additional_params)
                update_refresh_lock(False)
                refresh_app()

            btn1, btn2, _ = st.columns([1, 1, 1])
            additional_params = {
                f"{shot_uuid}_backlog_enabled": False,
                f"{shot_uuid}_preview_mode": st.session_state[f"{shot_uuid}_preview_mode"],
            }

            with btn1:
                help = ""
                st.button(
                    "Add to queue",
                    key="generate_animation_clip",
                    disabled=False,
                    help=help,
                    on_click=lambda: toggle_generate_inference(generate_vid_inf_tag, **additional_params),
                    type="primary",
                    use_container_width=True,
                )

            backlog_update = {f"{shot_uuid}_backlog_enabled": True}
            with btn2:
                backlog_help = "This will add the new video generation in the backlog"
                st.button(
                    "Add to backlog",
                    key="generate_animation_clip_backlog",
                    disabled=False,
                    help=backlog_help,
                    on_click=lambda: toggle_generate_inference(generate_vid_inf_tag, **backlog_update),
                    type="secondary",
                )

            with column2:
                if st.button("Reset to default", use_container_width=True, key="reset_to_default"):
                    for idx, _ in enumerate(img_list):
                        for k, v in DEFAULT_SHOT_MOTION_VALUES.items():
                            st.session_state[f"{k}_{shot_uuid}_{img_list[idx].uuid}"] = v

                    st.success("All frames have been reset to default values.")
                    refresh_app()
                st.write("")

            if not st.session_state.get(f"{shot_uuid}_preview_mode", False):
                with column1:
                    if st.button(
                        "Save current settings",
                        key="save_current_settings",
                        use_container_width=True,
                        help="Settings will also be saved when you generate the animation.",
                        on_click=lambda: toggle_generate_inference(manual_save_inf_tag, **additional_params),
                    ):
                        refresh_app()

        # --------------- SIDEBAR ---------------------
        animation_sidebar(
            shot_uuid,
            img_list,
            type_of_frame_distribution,
            dynamic_frame_distribution_values,
            linear_frame_distribution_value,
            type_of_strength_distribution,
            dynamic_strength_values,
            linear_cn_strength_value,
            type_of_key_frame_influence,
            dynamic_key_frame_influence_values,
            linear_key_frame_influence_value,
            strength_of_frames,
            distances_to_next_frames,
            speeds_of_transitions,
            freedoms_between_frames,
            motions_during_frames,
            individual_prompts,
            individual_negative_prompts,
            DEFAULT_SM_MODEL,
        )


def two_img_realistic_interpolation_page(shot_uuid, img_list: List[InternalFileObject]):
    if not (img_list and len(img_list) >= 2):
        st.error("You need two images for this interpolation")
        return

    data_repo = DataRepo()
    shot = data_repo.get_shot_from_uuid(shot_uuid)

    settings = {}
    st.markdown("***")
    col1, col2, col3 = st.columns([1, 1, 1])
    with col1:
        st.image(img_list[0].location, use_column_width=True)

    with col3:
        st.image(img_list[1].location, use_column_width=True)

    with col2:
        if f"video_desc_{shot_uuid}" not in st.session_state:
            st.session_state[f"video_desc_{shot_uuid}"] = ""
        description_of_motion = st.text_area(
            "Describe the motion you want between the frames:",
            key=f"description_of_motion_{shot.uuid}",
            value=st.session_state[f"video_desc_{shot_uuid}"],
        )
        st.info("This is very important and will likely require some iteration.")
        st.info("NOTE: The model for this animation is 10.5 GB in size, which can take some time to download")

    variant_count = 1  # Assuming a default value for variant_count, adjust as necessary
    position = "dynamiccrafter"

    if (
        f"{position}_generate_inference" in st.session_state
        and st.session_state[f"{position}_generate_inference"]
    ):
        st.success(
            "Generating clip - see status in the Generation Log in the sidebar. Press 'Refresh log' to update."
        )
        # Assuming the logic to generate the clip based on two images, the described motion, and fixed duration
        duration = 4  # Fixed duration of 4 seconds
        data_repo.update_shot(uuid=shot.uuid, duration=duration)

        project_settings = data_repo.get_project_setting(shot.project.uuid)
        meta_data = update_session_state_with_dc_details(
            shot_uuid,
            img_list,
            description_of_motion,
        )
        settings.update(shot_data=meta_data)
        settings.update(
            duration=duration,
            animation_style=AnimationStyleType.DIRECT_MORPHING.value,
            output_format="video/h264-mp4",
            width=project_settings.width,
            height=project_settings.height,
            prompt=description_of_motion,
        )

        generate_sm_video(
            shot_uuid,
            settings,
            variant_count,
            st.session_state[f"{shot_uuid}_backlog_enabled"],
            img_list,
        )

        backlog_update = {f"{shot_uuid}_backlog_enabled": False}
        toggle_generate_inference(position, **backlog_update)
        refresh_app()

    # Buttons for adding to queue or backlog, assuming these are still relevant
    st.markdown("***")
    btn1, btn2, btn3 = st.columns([1, 1, 1])
    backlog_no_update = {f"{shot_uuid}_backlog_enabled": False}

    with btn1:
        st.button(
            "Add to queue",
            key="generate_animation_clip",
            disabled=False,
            help="Generate the interpolation clip based on the two images and described motion.",
            on_click=lambda: toggle_generate_inference(position, **backlog_no_update),
            type="primary",
            use_container_width=True,
        )

    backlog_update = {f"{shot_uuid}_backlog_enabled": True}
    with btn2:
        st.button(
            "Add to backlog",
            key="generate_animation_clip_backlog",
            disabled=False,
            help="Add the 2-Image Realistic Interpolation to the backlog.",
            on_click=lambda: toggle_generate_inference(position, **backlog_update),
            type="secondary",
        )
