import React, { useCallback, useMemo } from 'react';
import InspectionPanelContext from 'App/InspectionPanel/InspectionPanelContext';
import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import { produce } from 'immer';
import { isModel } from 'types/nn-types/Model';
import { isModelGraphNode } from 'types/nn-types/ModelGraph';
import { Entity } from 'types/inspection-types/Entity';
import { isLayerGraphNode } from 'types/nn-types/LayerGraph';

const deduceLevelOfAbstractionFromHierarchyPath = (selectedEntityIDs: string[]) => {
    switch (selectedEntityIDs.length) {
        case 0:
            return LevelOfAbstraction.MULTI_MODEL;
        case 1:
            return LevelOfAbstraction.SINGLE_MODEL;
        case 2:
            return LevelOfAbstraction.LAYERS_UNITS;
        // case 3:
        //     return LevelOfAbstraction.WEIGHTS_ACTIVATIONS;
        default:
            return LevelOfAbstraction.MULTI_MODEL;
    }
};

const InspectionPanelContextProvider = ({ children }: { children: React.ReactNode }) => {
    const [hierarchyPathEntityIDs, setHierarchyPathEntityIDs] = React.useState<string[]>([]);
    const lofa: LevelOfAbstraction = deduceLevelOfAbstractionFromHierarchyPath(hierarchyPathEntityIDs);

    const ascendLofa = (levels = 1) => {
        setHierarchyPathEntityIDs((prevState) =>
            produce(prevState, (draftState) => {
                draftState.splice(-levels);
            }),
        );
    };

    const descendLofa = (focusedEntityID: string) => {
        setHierarchyPathEntityIDs((prevState) =>
            produce(prevState, (draftState) => {
                draftState.push(focusedEntityID);
            }),
        );
    };

    /**
     * Sets the path of selected entities up to but excluding `entity`. If `entity` refers to a model, there are no
     * focused entities.
     *
     * @param entity The entity to focus.
     * @param on The level of abstraction to focus the entity on.
     */
    const focus = (entity: Entity) => {
        if (isModel(entity)) setHierarchyPathEntityIDs([]);
        if (isModelGraphNode(entity)) {
            setHierarchyPathEntityIDs([entity.parentModelId]);
        }
        if (isLayerGraphNode(entity)) {
            setHierarchyPathEntityIDs([entity.parentModelGraphNodeId, entity.parentModelId]);
        }
    };

    const getCurrentlySelectedEntityID = () => {
        return hierarchyPathEntityIDs.slice(-1)[0] ?? null;
    };

    // Memoize the callbacks, so they do not lead to unnecessary re-renders.
    const getCurrentlySelectedEntityIDMemo = useCallback(getCurrentlySelectedEntityID, [hierarchyPathEntityIDs]);
    const ascendLofaMemo = useCallback(ascendLofa, []);
    const descendLofaMemo = useCallback(descendLofa, []);
    const focusMemo = useCallback(focus, []);

    // Memoize the value object itself, so it doesn't lead to unnecessary re-renders.
    const providerValueMemo = useMemo(
        () => ({
            lofa,
            hierarchyPathEntityIDs,
            getCurrentlySelectedEntityID: getCurrentlySelectedEntityIDMemo,
            ascendLofa: ascendLofaMemo,
            descendLofa: descendLofaMemo,
            focus: focusMemo,
        }),
        [ascendLofaMemo, descendLofaMemo, focusMemo, getCurrentlySelectedEntityIDMemo, hierarchyPathEntityIDs, lofa],
    );

    return <InspectionPanelContext.Provider value={providerValueMemo}>{children}</InspectionPanelContext.Provider>;
};

export default InspectionPanelContextProvider;
