import React, { useContext, useMemo, useState } from 'react';
import { isCommandKeyPressedOnMacOS, positionsToPathString, reduceShapeToNumber } from 'tools/helpers';
import dagre, { Edge } from 'dagre';
import { DagreLayerNode } from 'types/dagre-nodes/DagreLayerNode';
import { produce } from 'immer';
import { layout } from 'App/InspectionPanel/L2ArchitectureComponent/layout';
import LayerComponent from './LayerComponent/LayerComponent';
import FilterBadges, { FilterBadgeSize } from 'App/InspectionPanel/FilterBadges';
import FilterContext from 'App/FilterContext';
import { EntityType, getEntityTypeColor } from 'types/inspection-types/EntityType';
import { Structure } from 'types/nn-types/ModelGraph';
import { InspectionLayerProps } from 'types/inspection-types/InspectionLayerProps';
import ConvexHullComponent from 'App/InspectionPanel/ConvexHullComponent';
import addWidget from 'App/InspectionPanel/L3TreeComponent/ModelComponent/add-widget';
import ToolContext from 'App/ToolContext';
import WidgetContext from 'App/WidgetContext';
import ModelContext from 'App/ModelContext';
import { useBrush, useLink } from 'tools/hooks/useLinkAndBrush';

type LayerDimensionRecord = Record<string, { width: number; height: number }>;

export interface Props extends InspectionLayerProps {}

const L2ArchitectureComponent: React.FunctionComponent<Props> = ({ onReady, ...props }: Props) => {
    const { model } = useContext(ModelContext);
    const { activeTool } = React.useContext(ToolContext);
    const { addWidget: addWidgetCb } = React.useContext(WidgetContext);
    const [layerDimensions, setLayerDimensions] = useState<LayerDimensionRecord>({});

    const g: dagre.graphlib.Graph<DagreLayerNode> = useMemo(
        () => layout(model, layerDimensions, 50, 100, 50),
        [model, layerDimensions],
    );

    // Memoize those handlers, so they do not lead to infinite re-renders
    // when doing the iterative bottom-up size estimation.
    const layerSizeChangeHandlers = useMemo(
        () =>
            model.graph.nodes.map((l, idx) => (newWidth: number, newHeight: number) => {
                // Pass function to setState, forcing atomic operation to prevent race condition. See: https://stackoverflow.com/a/30341560
                setLayerDimensions((prevState) =>
                    produce(prevState, (draftState: LayerDimensionRecord) => {
                        draftState[l.id] = {
                            width: newWidth,
                            height: newHeight,
                        };
                    }),
                );
            }),
        [model, setLayerDimensions],
    );

    const layerElements: JSX.Element[] = useMemo(
        () =>
            model.graph.nodes.map((l, idx) => {
                const layerNode: dagre.Node & DagreLayerNode = g.node(l.id);

                const layerNodeLeft = layerNode.x - layerNode.width / 2;
                const layerNodeTop = layerNode.y - layerNode.height / 2;

                return (
                    <LayerComponent
                        modelGraphNode={layerNode.modelGraphNode}
                        key={l.id}
                        x={layerNodeLeft}
                        y={layerNodeTop}
                        onSizeChange={layerSizeChangeHandlers[idx]}
                    />
                );
            }),
        [layerSizeChangeHandlers, g, model],
    );

    // After rendering, notify parent that this layer now has its final size
    React.useEffect(() => {
        // This lofa's visual representation is considered complete, if all modelGraphNodes's dimensions are known and
        // the layout algorithm was executed on them. => Notify parent component.
        if (Object.keys(layerDimensions).length === model.graph.nodes.length) {
            onReady();
        }
    }, [layerDimensions, onReady, model.graph.nodes.length]);

    const linkElements = g.edges().map((e, idx) => <LinkElement edge={e} graph={g} key={idx} />);
    const onClickHandler = (e: React.MouseEvent) => {
        if (e.button === 0) {
            if (!e.ctrlKey && !isCommandKeyPressedOnMacOS(e) && activeTool?.isApplicable(model.type)) {
                addWidget(model, activeTool, addWidgetCb);
            }
        }

        e.stopPropagation();
    };

    return (
        <g>
            <ConvexHullComponent
                entity={model}
                backgroundColor={model.preferences.baseColor}
                graph={g}
                onClick={onClickHandler}
            />
            {linkElements}
            {layerElements}
            <FilterBadges dagreArchitectureGraph={g} recursive={false} size={FilterBadgeSize.LARGE} />
        </g>
    );
};

interface LayerTreeEdgeProps {
    d: string;
    filteredStructure: Structure | undefined;
    strokeWidth: number;
}

const LayerTreeEdge: React.FunctionComponent<LayerTreeEdgeProps> = ({
    d,
    filteredStructure,
    strokeWidth,
}: LayerTreeEdgeProps) => {
    const [hovered, setHovered] = useState<boolean>(false);

    const [isLinked] = useLink<EntityType | undefined>('filter-hovered', filteredStructure);
    useBrush<EntityType | undefined>('filter-badge-hovered', hovered ? filteredStructure : undefined);

    const enlarge = hovered || isLinked;

    return (
        <path
            onMouseOver={() => setHovered(true)}
            onMouseOut={() => setHovered(false)}
            d={d}
            style={{
                fill: 'none',
                stroke: filteredStructure ? getEntityTypeColor(filteredStructure) : '#a7a7a7',
                strokeWidth: `${enlarge ? 1.5 * strokeWidth : strokeWidth}px`,
            }}
        />
    );
};

const LinkElement = ({ edge, graph }: { edge: Edge; graph: dagre.graphlib.Graph<DagreLayerNode> }) => {
    const { model } = useContext(ModelContext);
    const { getStructuresMatchedByFilters } = useContext(FilterContext);

    const graphEdge = graph.edge(edge);
    const maxCapacity = model.graph.nodes.reduce((acc, curr) => {
        return Math.max(acc, reduceShapeToNumber(curr.outputShape));
    }, 0);
    const modelStructures: Structure[] = getStructuresMatchedByFilters(model);

    let filteredStructure: Structure | undefined;
    if (modelStructures.length > 0) {
        // This approach assumes that a connection between two layers is only ever associated with a single structure.
        filteredStructure = modelStructures.find((modelStructure) => {
            const paths = model.graph.getPathsBelongingTo(modelStructure);

            return paths.some((path) => path.some((link) => link.source === edge.v && link.target === edge.w));
        });
    }

    return (
        <LayerTreeEdge
            d={positionsToPathString(graph.edge(edge).points)}
            filteredStructure={filteredStructure}
            strokeWidth={(graphEdge.capacity / maxCapacity) * 10 + 1}
        />
    );
};

export default L2ArchitectureComponent;
