import { graphlib, Node } from 'dagre';
import { DagreLayerNode } from 'types/dagre-nodes/DagreLayerNode';
import * as React from 'react';
import { useContext } from 'react';
import FilterContext from 'App/FilterContext';
import {
    ENTITY_TYPE_BADGE_SIZE,
    EntityType,
    getEntityCategory,
    getEntityTypeColor,
    getEntityTypeSymbol,
} from 'types/inspection-types/EntityType';
import tinyColor from 'tinycolor2';
import { greyBaseColor, halfTransparent, opaque } from 'styles/colors';
import { DagreLayerInnardsNode } from 'types/dagre-nodes/DagreLayerInnardsNode';
import { useHover } from 'react-use';
import { useBrush, useLink } from 'tools/hooks/useLinkAndBrush';

export enum FilterBadgeSize {
    SMALL = 1,
    LARGE = 2,
}

/**
 * Indicates the filters applied to entities of a model through badge-like glyphs on the lower edge of the respective
 * layers. The glyphs are colored according to the color maps associated with the entities in the [[SettingsPanel]].
 *
 * @param dagreArchitectureGraph The model graph that contains the entities to filter.
 * @param recursive Indicates whether to attach a badge for children of nodes in `dagreArchitectureGraph` that the
 *                  filter applies to. `true` attaches badges for filtered children, `false` only considers the
 *                  top-level nodes in `dagreArchitectureGraph`.
 * @param size The scaling of the filter badges. `SMALL` is the default, `LARGE` doubles the radius of the badges.
 */
const FilterBadges = ({
    dagreArchitectureGraph,
    recursive = true,
    size = FilterBadgeSize.SMALL,
}: {
    dagreArchitectureGraph: graphlib.Graph<DagreLayerNode> | graphlib.Graph<DagreLayerInnardsNode>;
    recursive?: boolean;
    size?: FilterBadgeSize;
}) => {
    const { getEntityTypesMatchedByFilters } = useContext(FilterContext);

    return (
        <g>
            {dagreArchitectureGraph.nodes().map((nodeID, nodeIndex) => {
                const node = dagreArchitectureGraph.node(nodeID);
                const matchingEntityTypes = getEntityTypesMatchedByFilters(
                    isDagreLayerNode(node) ? node.modelGraphNode : node.layerGraphNode,
                    recursive,
                );

                if (matchingEntityTypes.length === 0) return null;

                return matchingEntityTypes
                    .map((entityType, entityTypeIndex) => (
                        <FilterBadge
                            entityType={entityType}
                            index={entityTypeIndex}
                            key={`${nodeIndex}${entityTypeIndex}`}
                            node={node}
                            size={size}
                        />
                    ))
                    .flat();
            })}
        </g>
    );
};

/**
 * A colored circle that indicates an applied filter as a badge on a component of a neural network.
 *
 * @param entityType The type of the entity, i.e., the component of the neural network.
 * @param index The index of this badge on the component of the neural network.
 * @param node The node that represents this entity type in the neural network.
 * @param size The size of this badge.
 * @constructor
 */
const FilterBadge = ({
    entityType,
    index,
    node,
    size,
}: {
    entityType: EntityType;
    index: number;
    node: Node<DagreLayerNode> | Node<DagreLayerInnardsNode>;
    size: FilterBadgeSize;
}) => {
    const [isLinked] = useLink<EntityType | undefined>('filter-hovered', entityType);

    const fillColor = getEntityTypeColor(entityType);
    const getColorTransparency = () => (!!entityType || isLinked ? opaque : halfTransparent);
    // The input layer is chosen as one possible representative, other layers would work as well.
    const nodeHalfHeight = node.height / 2;
    const nodeHalfWidth = node.width / 2;
    const radius = (ENTITY_TYPE_BADGE_SIZE / 2) * size;
    let xCoordinate = node.x - nodeHalfWidth + (index % 2) * nodeHalfWidth;
    let yCoordinate = node.y + node.height / 2 + Math.floor(index / 2) * ENTITY_TYPE_BADGE_SIZE;

    let strokeColor = tinyColor(fillColor);
    while (!tinyColor.isReadable(strokeColor, greyBaseColor)) {
        strokeColor = strokeColor.darken(1);
    }

    if (
        !isDagreLayerNode(node) &&
        (getEntityCategory(entityType) === 'Algebraic Operation' ||
            getEntityCategory(entityType) === 'Activation Function')
    ) {
        /*
         * The node to attach the filter badge to is a circle, we display the badge on the lower
         * corner, regardless.
         *
         * This assumes that there are never two or more badges attached to a circular node on L2.
         */
        xCoordinate = node.x - nodeHalfWidth / 1.5;
        yCoordinate = node.y + nodeHalfHeight / 1.5;
    }

    const badge = (hovered: boolean) => {
        const badgeColor = getEntityTypeColor(entityType);
        const BadgeSymbol = getEntityTypeSymbol(entityType);

        return (
            <g transform={`translate(${xCoordinate} ${yCoordinate}) scale(${hovered ? 1.5 : 1})`}>
                <circle cx={radius} cy={radius} r={radius} fill={'rgba(255,255,255,0.5)'} />
                <BadgeSymbol
                    style={{ color: badgeColor }}
                    fillOpacity={getColorTransparency()}
                    size={ENTITY_TYPE_BADGE_SIZE * size}
                    stroke={`#${strokeColor.toHex()}`}
                    strokeOpacity={getColorTransparency()}
                />
            </g>
        );
    };

    const [hoverableCircle, hovered] = useHover(badge);
    useBrush<EntityType | undefined>('filter-badge-hovered', hovered ? entityType : undefined);

    return hoverableCircle;
};

const isDagreLayerNode = (node: DagreLayerNode | DagreLayerInnardsNode): node is DagreLayerNode =>
    (node as DagreLayerNode).modelGraphNode !== undefined;

export default FilterBadges;
