import * as React from 'react';
import { GraphEdge, Node } from 'dagre';
import { useTooltipInPortal } from '@visx/tooltip';
import styled from 'styled-components';
import Flatten from 'flatten-js';
import { DagreModelNode } from 'types/dagre-nodes/DagreModelNode';
import { area, curveBasis } from 'd3-shape';

interface Props {
    srcNode: Node<DagreModelNode>;
    targetNode: Node<DagreModelNode>;
    edge: GraphEdge;
    maxParameters: number;
    maxEdgeThickness?: number;
}

const TooltipTable = styled.div`
    display: table;
    border: none;
`;

const TooltipTableRow = styled.div`
    display: table-row;
`;

const TooltipTableCell = styled.div`
    display: table-cell;
    padding: 5px;
`;

const EdgeAreaPath = styled.path`
    fill: #eeeeee;
    stroke: #9c9c9c;
    cursor: default;
`;

const EdgeComponent: React.FunctionComponent<Props> = ({
    srcNode,
    targetNode,
    edge,
    maxParameters,
    maxEdgeThickness = 50,
}: Props) => {
    const [tooltipPosition, setTooltipPosition] = React.useState<{ x: number; y: number } | null>(null);

    const { TooltipInPortal } = useTooltipInPortal({
        // use TooltipWithBounds
        detectBounds: true,
        // when tooltip containers are scrolled, this will correctly update the Tooltip position
        scroll: true,
    });

    const numberToSignedText = (v: number) => {
        let sign = '';
        if (v < 0) {
            sign = '-';
        } else if (v > 0) {
            sign = '+';
        } else {
            sign = '~';
        }

        return `${sign}${Math.abs(v)}`;
    };

    const areaGen = area<{ p1: Flatten.Vector; p2: Flatten.Vector }>()
        .x0((d) => d.p1.x)
        .x1((d) => d.p2.x)
        .y0((d) => d.p1.y)
        .y1((d) => d.p2.y)
        .curve(curveBasis);

    const layerChange = targetNode.model.stats.numLayers - srcNode.model.stats.numLayers;
    const paramChange = targetNode.model.stats.numTrainableParameters - srcNode.model.stats.numTrainableParameters;

    // Make edge thickness on in- and output node dependant on number of parameters
    const offsetSourceNode1 = maxEdgeThickness * (srcNode.model.stats.numTrainableParameters / maxParameters);
    const offsetTargetNode2 = maxEdgeThickness * (targetNode.model.stats.numTrainableParameters / maxParameters);

    // Get original edge points
    const [n1, m, n2] = edge.points;

    // Now construct guide points that are offset from m by offsetSourceNode1 and offsetSourceNode2
    // For documentation of this magic, see: misc/l3-edge-construction.svg
    const srcNodeCenterPoint = new Flatten.Vector(srcNode.x, srcNode.y);
    const srcNodeEdgePoint = new Flatten.Vector(n1.x, n1.y);
    const intersectionPoint = new Flatten.Vector(m.x, m.y);
    const targetNodeEdgePoint = new Flatten.Vector(n2.x, n2.y);
    const targetNodeCenterPoint = new Flatten.Vector(targetNode.x, targetNode.y);

    const g1 = intersectionPoint.subtract(srcNodeEdgePoint);
    const g2 = intersectionPoint.subtract(targetNodeEdgePoint);

    const alpha = g1.angleTo(g2);

    const lengthU = offsetSourceNode1 / Math.sin(alpha);
    const lengthV = offsetTargetNode2 / Math.sin(alpha);

    let parallelogramPoint1 = intersectionPoint
        .add(g2.normalize().multiply(lengthU))
        .add(g1.normalize().multiply(lengthV));

    let parallelogramPoint2 = intersectionPoint
        .subtract(g2.normalize().multiply(lengthU))
        .subtract(g1.normalize().multiply(lengthV));

    const srcNodeOffsetCenterPoint1 = srcNodeCenterPoint.add(g1.rotate90CCW().normalize().multiply(offsetSourceNode1));
    const srcNodeOffsetCenterPoint2 = srcNodeCenterPoint.add(g1.rotate90CW().normalize().multiply(offsetSourceNode1));

    const srcNodeOffsetEdgePoint1 = srcNodeEdgePoint.add(g1.rotate90CCW().normalize().multiply(offsetSourceNode1));
    const srcNodeOffsetEdgePoint2 = srcNodeEdgePoint.add(g1.rotate90CW().normalize().multiply(offsetSourceNode1));

    const targetNodeOffsetEdgePoint1 = targetNodeEdgePoint.add(g2.rotate90CW().normalize().multiply(offsetTargetNode2));
    const targetNodeOffsetEdgePoint2 = targetNodeEdgePoint.add(
        g2.rotate90CCW().normalize().multiply(offsetTargetNode2),
    );

    const targetNodeOffsetCenterPoint1 = targetNodeCenterPoint.add(
        g2.rotate90CW().normalize().multiply(offsetTargetNode2),
    );
    const targetNodeOffsetCenterPoint2 = targetNodeCenterPoint.add(
        g2.rotate90CCW().normalize().multiply(offsetTargetNode2),
    );

    // If parallelogram points are lying before source borders, push them back to offset edge points of source
    let edgePointVec = srcNodeOffsetEdgePoint1.subtract(srcNodeOffsetCenterPoint1);
    let parallelogramPointVec = parallelogramPoint1.subtract(srcNodeOffsetCenterPoint1);
    if (parallelogramPointVec.length < edgePointVec.length) {
        parallelogramPoint1 = srcNodeOffsetEdgePoint1;
    }

    edgePointVec = srcNodeOffsetEdgePoint2.subtract(srcNodeOffsetCenterPoint2);
    parallelogramPointVec = parallelogramPoint2.subtract(srcNodeOffsetCenterPoint2);
    if (parallelogramPointVec.length < edgePointVec.length) {
        parallelogramPoint2 = srcNodeOffsetEdgePoint2;
    }

    // If parallelogram points are lying behind target borders, push them back to offset edge points of target
    edgePointVec = targetNodeOffsetEdgePoint1.subtract(srcNodeOffsetCenterPoint1);
    parallelogramPointVec = parallelogramPoint1.subtract(srcNodeOffsetCenterPoint1);
    if (parallelogramPointVec.length > edgePointVec.length) {
        parallelogramPoint1 = targetNodeOffsetEdgePoint1;
    }

    edgePointVec = targetNodeOffsetEdgePoint2.subtract(srcNodeOffsetCenterPoint2);
    parallelogramPointVec = parallelogramPoint2.subtract(srcNodeOffsetCenterPoint2);
    if (parallelogramPointVec.length > edgePointVec.length) {
        parallelogramPoint2 = targetNodeOffsetEdgePoint2;
    }

    const path1 = [
        srcNodeOffsetCenterPoint1,
        srcNodeOffsetEdgePoint1,
        parallelogramPoint1,
        targetNodeOffsetEdgePoint1,
        targetNodeOffsetCenterPoint1,
    ];
    const path2 = [
        srcNodeOffsetCenterPoint2,
        srcNodeOffsetEdgePoint2,
        parallelogramPoint2,
        targetNodeOffsetEdgePoint2,
        targetNodeOffsetCenterPoint2,
    ];

    const edgeBoundaries = path1.map((p1, idx) => ({ p1, p2: path2[idx] }));

    const handleMouseMove = (e: React.MouseEvent<SVGElement>) => {
        const { clientX: x, clientY: y } = e;
        setTooltipPosition({ x, y });
    };

    const handleMouseOut = (e: React.MouseEvent<SVGElement>) => {
        setTooltipPosition(null);
    };

    return (
        <>
            <g className="model-graph-edge">
                <EdgeAreaPath
                    d={areaGen(edgeBoundaries) ?? ''}
                    onMouseMove={handleMouseMove}
                    onMouseOut={handleMouseOut}
                />
                {/*{path1.map((p, key) => <circle key={key} cx={p.x} cy={p.y} r={3}/>)}*/}
                {/*{path2.map((p, key) => <circle key={key} cx={p.x} cy={p.y} r={3}/>)}*/}
            </g>

            {tooltipPosition && (
                <TooltipInPortal
                    key={Math.random()} // set this to random so it correctly updates with parent bounds
                    left={tooltipPosition.x + 2}
                    top={tooltipPosition.y + 2}
                >
                    <TooltipTable>
                        <TooltipTableRow>
                            <TooltipTableCell>&#120491; layers:</TooltipTableCell>
                            <TooltipTableCell>
                                <strong>{numberToSignedText(layerChange)}</strong>
                            </TooltipTableCell>
                        </TooltipTableRow>
                        <TooltipTableRow>
                            <TooltipTableCell>&#120491; parameters:</TooltipTableCell>
                            <TooltipTableCell>
                                <strong>{numberToSignedText(paramChange)}</strong>
                            </TooltipTableCell>
                        </TooltipTableRow>
                    </TooltipTable>
                </TooltipInPortal>
            )}
        </>
    );
};
export default EdgeComponent;
