import streamlit as st
from shared.logging.constants import LoggingType
from shared.logging.logging import AppLogger
from utils.data_repo.data_repo import DataRepo
from streamlit_option_menu import option_menu

from utils.state_refresh import refresh_app

logger = AppLogger()


# TODO: custom elements must be stateless and completely separate from our code logic
def radio(
    label,
    options,
    index=0,
    key=None,
    help=None,
    on_change=None,
    disabled=False,
    horizontal=False,
    label_visibility="visible",
    default_value=0,
):

    if key not in st.session_state:
        st.session_state[key] = default_value

    selection = st.radio(
        label=label,
        options=options,
        index=st.session_state[key],
        horizontal=horizontal,
        label_visibility=label_visibility,
        key=f"{key}_value",
    )

    if options.index(selection) != st.session_state[key]:
        st.session_state[key] = options.index(selection)
        refresh_app()

    return selection


def selectbox(label, options, index=0, key=None, help=None, on_change=None, disabled=False, format_func=str):

    if key not in st.session_state:
        st.session_state[key] = index

    selection = st.selectbox(
        label=label,
        options=options,
        index=st.session_state[key],
        format_func=format_func,
        help=help,
        on_change=on_change,
        disabled=disabled,
    )

    if options.index(selection) != st.session_state[key]:
        st.session_state[key] = options.index(selection)
        refresh_app()

    return selection


def number_input(
    label,
    min_value=None,
    max_value=None,
    step=None,
    format=None,
    key=None,
    help=None,
    on_change=None,
    args=None,
    kwargs=None,
    *,
    disabled=False,
    label_visibility="visible",
    value=1,
):
    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.number_input(
        label=label,
        min_value=min_value,
        max_value=max_value,
        value=st.session_state[key],
        step=step,
        format=format,
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def slider(
    label,
    min_value=None,
    max_value=None,
    value=None,
    step=None,
    format=None,
    key=None,
    help=None,
    on_change=None,
    args=None,
    kwargs=None,
    *,
    disabled=False,
    label_visibility="visible",
):

    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.slider(
        label=label,
        min_value=min_value,
        max_value=max_value,
        value=st.session_state[key],
        step=step,
        format=format,
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
        key=f"{key}_value",
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def select_slider(
    label,
    options=(),
    value=None,
    format_func=None,
    key=None,
    help=None,
    on_change=None,
    args=None,
    kwargs=None,
    *,
    disabled=False,
    label_visibility="visible",
    default_value=None,
    project_settings=None,
):
    if key not in st.session_state:
        if getattr(project_settings, key, default_value):
            st.session_state[key] = getattr(project_settings, key, default_value)
        else:
            st.session_state[key] = default_value

    selection = st.select_slider(
        label, options, st.session_state[key], format_func, help, on_change, disabled, label_visibility
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        if getattr(project_settings, key, default_value):
            data_repo = DataRepo()
            data_repo.update_project_setting(project_settings.project.uuid, key=value)
        refresh_app()

    return selection


def toggle(
    label, value=True, key=None, help=None, on_change=None, disabled=False, label_visibility="visible"
):

    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.toggle(
        label=label,
        value=st.session_state[key],
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
        key=f"{key}_value",
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def checkbox(
    label, value=True, key=None, help=None, on_change=None, disabled=False, label_visibility="visible"
):

    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.checkbox(
        label=label,
        value=st.session_state[key],
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
        key=f"{key}_value",
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def menu(
    menu_title,
    options,
    icons=None,
    menu_icon=None,
    default_index=0,
    key=None,
    help=None,
    on_change=None,
    disabled=False,
    orientation="horizontal",
    default_value=0,
    styles=None,
):

    if key not in st.session_state:
        st.session_state[key] = default_value

    # if {key}_manual_select doesn't exist, set it to None
    manual_select_key = f"{key}_manual_select"
    if manual_select_key not in st.session_state:
        st.session_state[manual_select_key] = None

    selection = option_menu(
        menu_title,
        options=options,
        icons=icons,
        menu_icon=menu_icon,
        orientation=orientation,
        default_index=int(st.session_state[key]),
        styles=styles,
        manual_select=st.session_state[manual_select_key],
    )

    # if {key}_manual_select is not None, set it to None
    if st.session_state[manual_select_key] is not None:
        st.session_state[manual_select_key] = None

    if options.index(selection) != st.session_state[key]:
        st.session_state[key] = options.index(selection)
        refresh_app()

    return selection


def text_area(
    label,
    value="",
    height=None,
    max_chars=None,
    key=None,
    help=None,
    on_change=None,
    args=None,
    kwargs=None,
    *,
    disabled=False,
    label_visibility="visible",
):

    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.text_area(
        label=label,
        value=st.session_state[key],
        height=height,
        max_chars=max_chars,
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def text_input(
    label,
    value="",
    max_chars=None,
    key=None,
    help=None,
    on_change=None,
    args=None,
    kwargs=None,
    *,
    disabled=False,
    label_visibility="visible",
):

    if key not in st.session_state:
        st.session_state[key] = value

    selection = st.text_input(
        label=label,
        value=st.session_state[key],
        max_chars=max_chars,
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
    )

    if selection != st.session_state[key]:
        st.session_state[key] = selection
        refresh_app()

    return selection


def multiselect(
    label,
    options,
    default,
    format_func=None,
    key=None,
    help=None,
    on_change=None,
    disabled=False,
    label_visibility="visible",
):
    # Generate a unique key if none is provided
    if key is None:
        key = f"multiselect_{label}"

    # Initialize session state if it doesn't exist
    if key not in st.session_state:
        st.session_state[key] = default

    # Ensure that the session state contains valid options
    valid_session_state = [v for v in st.session_state[key] if v in options]

    selection = st.multiselect(
        label=label,
        options=options,
        default=valid_session_state,
        format_func=format_func,
        help=help,
        on_change=on_change,
        disabled=disabled,
        label_visibility=label_visibility,
        key=f"{key}_widget",
    )

    # Update session state if selection has changed
    if set(selection) != set(st.session_state[key]):
        st.session_state[key] = selection
        if on_change:
            on_change()
        else:
            st.rerun()

    return selection
