import React, { useContext, useMemo, useState } from 'react';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import { layout } from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/layout';
import { isCommandKeyPressedOnMacOS, positionsToPathString, reduceShapeToString } from 'tools/helpers';
import styled from 'styled-components';
import { produce } from 'immer';
import LayerInnardComponent from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/LayerInnardComponent/LayerInnardComponent';
import FilterBadges, { FilterBadgeSize } from 'App/InspectionPanel/FilterBadges';
import SelectionContext from 'App/SelectionContext';
import InspectionPanelContext from 'App/InspectionPanel/InspectionPanelContext';
import ToolContext from 'App/ToolContext';
import LayerBoundingBox from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/LayerBoundingBox';
import WidgetContext from 'App/WidgetContext';
import { WidgetType } from 'types/inspection-types/WidgetType';
import { mid } from 'tools/colors';
import { greyStroke } from 'styles/colors';
import addWidget from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/add-widget';
import LayerLabelText from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/LayerLabelText';
import ModelContext from 'App/ModelContext';

const LayerLabel = styled.tspan`
    font-size: 20px;
    fill: #2c2c2c;
    text-anchor: start;
`;

const LayerNameLabel = styled(LayerLabel)`
    dominant-baseline: hanging;
`;

const LayerTypeLabel = styled(LayerLabel)`
    font-weight: bold;
    dominant-baseline: ideographic;
`;

const LayerSizeLabel = styled(LayerLabel)`
    dominant-baseline: ideographic;
`;

const InnardConnector = styled.path`
    fill: none;
    stroke: #a7a7a7;
    stroke-width: 2px;
`;

interface Props {
    modelGraphNode: ModelGraphNode;
    x: number;
    y: number;
    onSizeChange?: (newWidth: number, newHeight: number) => void;
}

const LayerComponent: React.FunctionComponent<Props> = ({ x, y, modelGraphNode, onSizeChange }: Props) => {
    const { model } = useContext(ModelContext);
    const { addWidget: addWidgetCb, getAssociatedWidgets } = React.useContext(WidgetContext);
    const { toggleSelection } = React.useContext(SelectionContext);
    const { descendLofa } = React.useContext(InspectionPanelContext);
    const { activeTool } = React.useContext(ToolContext);
    const [layerInnardDimensions, setLayerInnardDimensions] = React.useState<
        Record<string, { width: number; height: number }>
    >({});
    const [hovered, setHovered] = useState<boolean>(false);

    const associatedWidgets = getAssociatedWidgets(modelGraphNode.id).filter((w) =>
        [WidgetType.ANNOTATION].includes(w.widgetType),
    );

    const g = useMemo(
        () => layout(modelGraphNode.innardGraph, layerInnardDimensions, x, y, 40, 60),
        [modelGraphNode.innardGraph, layerInnardDimensions, x, y],
    );

    const width = g.graph().width ?? 0;
    const height = g.graph().height ?? 0;

    React.useEffect(() => {
        // Notify parent about size re-calculation
        if (onSizeChange) {
            onSizeChange(width, height);
        }
    }, [onSizeChange, width, height]);

    // Memoize those handlers, so they do not lead to infinite re-renders
    // when doing the iterative bottom-up size estimation.
    const layerInnardSizeChangeHandlers = useMemo(
        () =>
            modelGraphNode.innardGraph.nodes.map((n, idx) => (newWidth: number, newHeight: number) => {
                // Pass function to setState, forcing atomic operation to prevent race condition. See: https://stackoverflow.com/a/30341560
                setLayerInnardDimensions((prevDimensions) =>
                    produce(prevDimensions, (draftDimensions) => {
                        draftDimensions[n.id] = {
                            width: newWidth,
                            height: newHeight,
                        };
                    }),
                );
            }),
        [modelGraphNode, setLayerInnardDimensions],
    );

    const innardElements: JSX.Element[] = useMemo(
        () =>
            modelGraphNode.innardGraph.nodes.map((n, idx) => {
                const posX = g.node(n.id)?.x ?? 0;
                const posY = g.node(n.id)?.y ?? 0;

                return (
                    <LayerInnardComponent
                        key={n.id}
                        x={posX}
                        y={posY}
                        onSizeChange={layerInnardSizeChangeHandlers[idx]}
                        modelGraphNode={modelGraphNode}
                        layerGraphNode={n}
                    />
                );
            }),
        [modelGraphNode, layerInnardSizeChangeHandlers, g],
    );

    const linkElements: JSX.Element[] = g
        ? g.edges().map((e, idx) => {
              const edge = g.edge(e).points as { x: number; y: number }[];
              return <InnardConnector key={idx} d={positionsToPathString(edge)} />;
          })
        : [];

    const onClickHandler = (e: React.MouseEvent) => {
        if (e.button === 0) {
            if (e.ctrlKey || isCommandKeyPressedOnMacOS(e)) {
                toggleSelection(modelGraphNode);
            } else if (activeTool?.isApplicable(modelGraphNode.type)) {
                addWidget(model, modelGraphNode, activeTool, addWidgetCb);
            }
        }

        e.stopPropagation();
    };

    const onDoubleClickHandler = (e: React.MouseEvent) => {
        if (e.button === 0) {
            if (!e.ctrlKey && !isCommandKeyPressedOnMacOS(e) && !activeTool) {
                descendLofa(modelGraphNode.id);
            }
        }

        e.stopPropagation();
    };

    return (
        <g
            onClick={onClickHandler}
            onDoubleClick={onDoubleClickHandler}
            onMouseOver={() => setHovered(true)}
            onMouseLeave={() => setHovered(false)}
            transform={`translate(${x} ${y})`}
        >
            <LayerBoundingBox modelGraphNode={modelGraphNode} width={width} height={height} hovered={hovered} />

            <LayerLabelText x={0} y={0} transform="translate(10 10)" maxWidth={width - 20}>
                <LayerNameLabel x={0} y={0}>
                    {modelGraphNode.name}
                </LayerNameLabel>
            </LayerLabelText>
            <g style={{ color: mid(greyStroke) }}>
                {associatedWidgets.map((w, idx) =>
                    React.cloneElement(w.tool.icon, { x: width - 25 * (idx + 1), y: 10, key: w.widgetId }),
                )}
            </g>
            <LayerLabelText x={0} y={height} transform="translate(10 -10)" maxWidth={width - 20}>
                <LayerTypeLabel x={0} y={height}>
                    {modelGraphNode.clsName}
                </LayerTypeLabel>
                <LayerSizeLabel dx={10} y={height}>
                    {reduceShapeToString(modelGraphNode.outputShape)}
                </LayerSizeLabel>
            </LayerLabelText>

            {innardElements}
            {linkElements}
            <FilterBadges dagreArchitectureGraph={g} size={FilterBadgeSize.LARGE} />
        </g>
    );
};

export default LayerComponent;
