import { Neuron } from 'types/inspection-types/Neuron';
import { NeuronConnection } from 'types/inspection-types/NeuronConnection';
import dagre from 'dagre';
import { isInterpretableAsImage, matrixToBase64Image, toSubscriptText } from 'tools/helpers';
import { Matrix } from 'mathjs';
import React from 'react';
import NeuronComponent from 'App/InspectionPanel/L1LayerUnitComponent/Neurons/NeuronComponent';
import ActivationBar from 'App/InspectionPanel/L1LayerUnitComponent/Neurons/ActivationBarComponent';
import { L1VisualNeuron } from 'types/inspection-types/L1VisualNeuron';
import { HeaderText } from 'App/InspectionPanel/L1LayerUnitComponent/HeaderText';
import { isMatrix } from 'types/inspection-types/DataArray';
import { Point2D } from 'types/Point2D';

/**
 * Compared to the layout functions of higher levels, this one does a fully manual layouting.
 * @param neurons
 * @param neuronConnections
 * @param nodeColorScale
 * @param maxActivation
 * @param summedActivations
 * @param classes
 * @param nodesep
 * @param ranksep
 * @param nodesize
 */
export function layout(
    neurons: Neuron[],
    neuronConnections: NeuronConnection[],
    nodeColorScale: (value: number) => string,
    maxActivation: number,
    summedActivations: number,
    classes: (number | string)[],
    nodesep = 50,
    ranksep = 500,
    nodesize = 30,
): dagre.graphlib.Graph<L1VisualNeuron> {
    // Create the final graph object. It will be filled later.
    const dagreGraph = new dagre.graphlib.Graph<L1VisualNeuron>();
    dagreGraph.setGraph({});

    // Get all ranks and assign nodes to each rank
    const rankIndices = [...new Set(neurons.map((n) => n.rank))].sort();
    const ranks = rankIndices.map((rank) =>
        neurons.filter((n) => n.rank === rank).sort((n1, n2) => n2.activation - n1.activation),
    );

    // Get the maximum layer size (needed to center the layers in y-direction)
    const maxNodesInRank = Math.max(...ranks.map((r) => r.length));

    // Use dict so referencing nodes in the following code is efficient. The linksOut and linksIn attributes are needed
    // locally for whiskers and links computation and will be removed later.
    const nodeDict: Record<
        string,
        L1VisualNeuron & {
            indexInRank: number;
            linksOut: NeuronConnection[];
            linksIn: NeuronConnection[];
        }
    > = {};

    // Layout nodes in each rank
    ranks.forEach((rank, rankIdx) => {
        const currentRankYOffset = (maxNodesInRank - rank.length) / 2;

        rank.forEach((neuron, neuronIdx) => {
            const layoutDirection = rankIdx === 0 ? 'left' : 'right';
            const isImageNeuron = isMatrix(neuron.data[0]) && isInterpretableAsImage(neuron.data[0]);
            const [cx, cy, width, height, visualElements] = getVisualNodeAppearance(
                neuron,
                nodesize,
                layoutDirection,
                nodeColorScale,
                maxActivation,
                summedActivations,
                classes,
                isImageNeuron,
            );

            // In the first row, generate the header elements and add them as dummy nodes to the graph.
            if (neuronIdx === 0) {
                const [cx, cy, width, height, visualElements] = getNodeHeader(
                    nodesize,
                    layoutDirection,
                    classes,
                    isImageNeuron,
                );

                dagreGraph.setNode(`rank-${rankIdx}_header-dummy-node`, {
                    x: ranksep * rankIdx + cx,
                    y: nodesep * (neuronIdx + currentRankYOffset) + cy,
                    width,
                    height,
                    visualElements,
                    ...neuron,
                    index: -1,
                });
            }

            nodeDict[neuron.id] = {
                x: ranksep * rankIdx + cx,
                y: nodesep * (neuronIdx + currentRankYOffset) + cy,
                indexInRank: neuronIdx,
                linksIn: [],
                linksOut: [],
                width,
                height,
                visualElements,
                ...neuron,
            };
        });
    });

    // Add information on incoming and outgoing links to each node
    neuronConnections.forEach((con) => {
        const sourceNode = nodeDict[con.fromId];
        const targetNode = nodeDict[con.toId];

        if (sourceNode && targetNode) {
            sourceNode.linksOut.push(con);
            targetNode.linksIn.push(con);
        }
    });

    // Sort the incoming and outgoing edges for each node by their position on the other side (so links don't cross)
    Object.values(nodeDict).forEach((n) => {
        n.linksOut = n.linksOut.sort((l1, l2) => nodeDict[l1.toId].y - nodeDict[l2.toId].y);
        n.linksIn = n.linksIn.sort((l1, l2) => nodeDict[l1.fromId].y - nodeDict[l2.fromId].y);
    });

    // Generate the whiskers and links for each connection
    const maxNeuronLinks = Math.max(
        ...Object.values(nodeDict).map((n) => Math.max(n.linksIn.length, n.linksOut.length)),
    );
    const deltaAngle = 90 / maxNeuronLinks;
    const r = nodesize / 2;

    neuronConnections.forEach((con) => {
        const sourceNode = nodeDict[con.fromId];
        const targetNode = nodeDict[con.toId];

        if (sourceNode && targetNode) {
            // Create 1:1 mapping between nodes and links
            const sourceNodeConIndex = sourceNode.linksOut.indexOf(con);
            const targetNodeConIndex = targetNode.linksIn.indexOf(con);

            // Define the center points for both source and target node
            const sourceNodeCenter = new Point2D(sourceNode.x + sourceNode.width / 2 - nodesize / 2, sourceNode.y);
            const targetNodeCenter = new Point2D(targetNode.x - targetNode.width / 2 + nodesize / 2, targetNode.y);

            // Construct the whiskers for the source node
            const p0Out = sourceNodeCenter
                .add(new Point2D(r, 0))
                .rotate((-deltaAngle * (sourceNode.linksOut.length - 1)) / 2, sourceNodeCenter)
                .rotate(deltaAngle * sourceNodeConIndex, sourceNodeCenter);
            const p0OutVec = p0Out.subtract(sourceNodeCenter);
            const p1Out = p0Out.add(new Point2D(r / 2, 0));
            const p2Out = sourceNodeCenter.add(p0OutVec.scale(2));

            // Construct the whiskers for the target node
            const p0In = targetNodeCenter
                .subtract(new Point2D(r, 0))
                .rotate((deltaAngle * (targetNode.linksIn.length - 1)) / 2, targetNodeCenter)
                .rotate(-deltaAngle * targetNodeConIndex, targetNodeCenter);
            const p0InVec = p0In.subtract(targetNodeCenter);
            const p1In = p0In.subtract(new Point2D(r / 2, 0));
            const p2In = targetNodeCenter.add(p0InVec.scale(2));

            const { fromId, toId, ...linkData } = con;
            dagreGraph.setEdge(fromId, toId, {
                // These two points form the actual edge, connecting the whiskers on both sides.
                points: [p2Out, p2In],
                // Set the whisker paths for the source and target node.
                whiskerPoints: {
                    out: [p0Out, p1Out, p2Out],
                    in: [p0In, p1In, p2In],
                },
                ...linkData,
            });
        }
    });

    // Convert node-dict to list and attach to graph. Thereby, remove attributes that were needed for computation.
    Object.values(nodeDict).forEach((node) => {
        // eslint-disable-next-line @typescript-eslint/no-unused-vars
        const { linksOut, linksIn, id, ...nodeData } = node;
        dagreGraph.setNode(id, nodeData);
    });

    // Set properties of the graph itself.
    const graphLeft = Math.min(
        ...dagreGraph.nodes().map((nId) => dagreGraph.node(nId).x - dagreGraph.node(nId).width / 2),
    );
    const graphRight = Math.max(
        ...dagreGraph.nodes().map((nId) => dagreGraph.node(nId).x + dagreGraph.node(nId).width / 2),
    );
    const graphTop = Math.min(
        ...dagreGraph.nodes().map((nId) => dagreGraph.node(nId).y - dagreGraph.node(nId).height / 2),
    );
    const graphBottom = Math.max(
        ...dagreGraph.nodes().map((nId) => dagreGraph.node(nId).y + dagreGraph.node(nId).height / 2),
    );

    dagreGraph.graph().width = graphRight - graphLeft;
    dagreGraph.graph().height = graphBottom - graphTop;

    dagreGraph.graph().marginx = graphLeft;
    dagreGraph.graph().marginy = graphTop;

    // console.log(graphLeft, graphRight, graphTop, graphBottom);
    // console.log('dagreGraph', dagreGraph);

    return dagreGraph;
}

/**
 * Creates dummy nodes with the column headers.
 * TODO: This is a dirty hack. Could be improved in the future - for now, it works.
 * @param neuronSize
 * @param layoutDirection
 * @param classes
 * @param isImageNeuron
 */
function getNodeHeader(
    neuronSize: number,
    layoutDirection: 'left' | 'right',
    classes: (number | string)[],
    isImageNeuron: boolean,
): [number, number, number, number, JSX.Element[]] {
    const dir = layoutDirection === 'right' ? 1 : -1;
    const padding = 10;

    let x = neuronSize / 2;

    const visualElements: JSX.Element[] = [];

    // Add the neuron circle itself.
    visualElements.push(
        <HeaderText key={`neuron-circle-header`} angle={-60} x={x} y={-10}>
            node ID
        </HeaderText>,
    );
    x += (neuronSize + padding * 2) * dir;

    // Add the activation bar.
    visualElements.push(
        <HeaderText key={`activation-bar-header`} angle={-60} x={x} y={-10}>
            rel. / abs. contr.
        </HeaderText>,
    );
    x += (neuronSize + padding * 2) * dir;

    // If this neuron contains image data, add images for all active classes.
    isImageNeuron &&
        ['mean', ...classes].forEach((c, idx) => {
            visualElements.push(
                <HeaderText key={`image-header_${c}`} angle={-60} x={x} y={-10}>
                    {idx > 0 ? `c${toSubscriptText(idx - 1)}: ${c}` : 'avg(c)'}
                </HeaderText>,
            );
            x += (neuronSize + padding) * dir;
        });

    const left = Math.min(0, x - (neuronSize / 2 + padding) * dir);
    const right = Math.max(neuronSize, x - (neuronSize / 2 + padding) * dir);
    const top = -80;
    const bottom = 0;

    const width = right - left;
    const height = bottom - top;

    const cx = left + width / 2;
    const cy = top + height / 2;

    const visualElementsTransformed = visualElements.map((e) => {
        return (
            <g key={e.key} transform={`translate(${-left}, ${-top})`}>
                {e}
            </g>
        );
    });

    return [cx, cy, width, height, visualElementsTransformed];
}

function getVisualNodeAppearance(
    neuron: Neuron,
    neuronSize: number,
    layoutDirection: 'left' | 'right',
    nodeColorScale: (value: number) => string,
    maxActivation: number,
    summedActivations: number,
    classes: (number | string)[],
    isImageNeuron: boolean,
): [number, number, number, number, JSX.Element[]] {
    const dir = layoutDirection === 'right' ? 1 : -1;
    const padding = 10;

    let x = 0;

    const visualElements: JSX.Element[] = [];

    // Add the neuron circle itself.
    visualElements.push(
        <NeuronComponent
            key={`neuron-circle_${neuron.id}`}
            neuron={neuron}
            nodeColorScale={nodeColorScale}
            x={x}
            y={0}
            width={neuronSize}
            height={neuronSize}
        />,
    );
    x += (neuronSize + padding * 2) * dir;

    // Add the activation bar.
    visualElements.push(
        <ActivationBar
            key={`activation-bar_${neuron.id}`}
            width={neuronSize}
            height={neuronSize}
            x={x}
            y={0}
            value={neuron.activation}
            maxValue={maxActivation}
            summedValues={summedActivations}
        />,
    );
    x += (neuronSize + padding * 2) * dir;

    // If this neuron contains image data, add images for all active classes.
    isImageNeuron &&
        ['mean', ...classes].forEach((c, idx) => {
            const data = neuron.data[c];

            visualElements.push(
                <image
                    key={`image_${neuron.id}_${c}`}
                    width={neuronSize}
                    height={neuronSize}
                    x={x}
                    y={0}
                    href={matrixToBase64Image(data as Matrix)}
                />,
            );
            x += (neuronSize + padding) * dir;
        });

    const left = Math.min(...visualElements.map((e) => e.props.x));
    const right = Math.max(...visualElements.map((e) => e.props.x + e.props.width));
    const top = Math.min(...visualElements.map((e) => e.props.y));
    const bottom = Math.max(...visualElements.map((e) => e.props.y + e.props.height));

    const width = right - left;
    const height = bottom - top;

    const cx = left + width / 2;
    const cy = top + height / 2;

    const visualElementsTransformed = visualElements.map((e) => {
        const x = e.props.x - left;
        const y = e.props.y - top;
        return React.cloneElement(e, { x, y });
    });

    return [cx, cy, width, height, visualElementsTransformed];
}
