import React, { useEffect, useState } from 'react';
import { Button, Spinner, Table } from 'react-bootstrap';
import { Model } from 'types/nn-types/Model';
import { Tool } from 'types/inspection-types/Tool';
import addWidgetL3 from 'App/InspectionPanel/L3TreeComponent/ModelComponent/add-widget';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import addWidgetL2 from 'App/InspectionPanel/L2ArchitectureComponent/LayerComponent/add-widget';
import tinyColor from 'tinycolor2';
import { PromisedWidgetDefinition } from 'types/inspection-types/WidgetDefinition';
import Queue from 'queue-promise';
import styled from 'styled-components';

const PaddedWrapper = styled.div`
    padding: 2rem;
`;

const toolRequiresBackend = [
    Tool.AUTOENCODER_SAMPLES,
    Tool.CLASSIFIER_CONFUSION_MATRIX,
    Tool.CLASSIFIER_CORRECTLY_CLASSIFIED,
    Tool.CLASSIFIER_SAMPLES,
    Tool.CLASSIFIER_WRONGLY_CLASSIFIED,
    Tool.DISTRIBUTION_FEATURE_HISTOGRAM,
    Tool.DISTRIBUTION_HISTOGRAM,
    Tool.DISTRIBUTION_MULTI_HISTOGRAM,
    Tool.MINIMIZING_SAMPLES,
    Tool.MAXIMIZING_SAMPLES,
    Tool.PROJECTION_2D_SCATTERPLOT,
];

const CacheDashboard = ({ models }: { models: Model[] }) => {
    const [shouldCacheAll, setShouldCacheAll] = useState<boolean>(false);
    const tools = Tool.getTools().sort((firstTool, secondTool) => firstTool.name.localeCompare(secondTool.name));

    const requestsQueue = new Queue({ concurrent: 3, interval: 1000 });

    return (
        <PaddedWrapper>
            <h1 style={{ marginBottom: '1rem' }}>Cache Dashboard</h1>
            <Button onClick={() => setShouldCacheAll(true)} size="sm" variant="primary">
                Cache All
            </Button>
            <Table bordered hover size="sm" striped style={{ marginTop: '1rem' }}>
                <thead>
                    <tr>
                        <th>Entity</th>
                        <th>Tool</th>
                        <th>Response Time (in seconds)</th>
                    </tr>
                </thead>
                <tbody>
                    {models.map((model) =>
                        tools.map((tool) => {
                            if (!tool.isApplicable(model.type) || !toolRequiresBackend.includes(tool)) return null;

                            return (
                                <CacheableEntity
                                    key={`${model.id}-${tool.id}`}
                                    model={model}
                                    requestsQueue={requestsQueue}
                                    shouldCache={shouldCacheAll}
                                    tool={tool}
                                />
                            );
                        }),
                    )}
                    {models.map((model) =>
                        model.graph.nodes.map((node) =>
                            tools.map((tool) => {
                                if (!tool.isApplicable(node.type) || !toolRequiresBackend.includes(tool)) return null;

                                return (
                                    <CacheableEntity
                                        key={`${model.id}-${tool.id}-${node.id}`}
                                        model={model}
                                        node={node}
                                        requestsQueue={requestsQueue}
                                        shouldCache={shouldCacheAll}
                                        tool={tool}
                                    />
                                );
                            }),
                        ),
                    )}
                </tbody>
            </Table>
        </PaddedWrapper>
    );
};

interface RequestResult {
    responseTime?: number;
    message?: string;
}

const CacheableEntity = ({
    model,
    node,
    requestsQueue,
    shouldCache,
    tool,
}: {
    model: Model;
    node?: ModelGraphNode;
    requestsQueue: Queue;
    shouldCache: boolean;
    tool: Tool;
}) => {
    const [isLoading, setIsLoading] = useState<boolean>(false);
    const [result, setResult] = useState<RequestResult | undefined>(undefined);

    const widgetResolver =
        (resolve: (value: void) => void, startTime: number) => (widgetDefinition: PromisedWidgetDefinition) => {
            Promise.all(widgetDefinition.entities).then(() => {
                setResult({
                    responseTime: (performance.now() - startTime) / 1000,
                });
                setIsLoading(false);
                resolve();
            });

            return '';
        };

    useEffect(() => {
        if (!shouldCache) return;

        requestsQueue.enqueue(
            () =>
                new Promise<void>((resolve, reject) => {
                    setIsLoading(true);

                    if (node) {
                        addWidgetL2(model, node, tool, widgetResolver(resolve, performance.now())).catch((reason) => {
                            setIsLoading(false);
                            setResult({
                                message: reason,
                            });
                            reject(reason);
                        });
                    } else {
                        addWidgetL3(model, tool, widgetResolver(resolve, performance.now())).catch((reason) => {
                            setIsLoading(false);
                            setResult({
                                message: reason,
                            });
                            reject(reason);
                        });
                    }
                }),
        );
    }, [model, node, requestsQueue, shouldCache, tool]);

    let entityName = model.name;
    if (node) {
        entityName = `${model.name}/${node.name}`;
    }

    let responseTimeElement: JSX.Element;
    if (isLoading) {
        responseTimeElement = <Spinner animation="border" />;
    } else if (result == undefined) {
        responseTimeElement = <i>N/A</i>;
    } else if (result.responseTime !== undefined) {
        let color = 'green';
        if (result.responseTime > 2) {
            color = 'red';
        } else if (result.responseTime > 0.5) {
            color = 'orange';
        }

        responseTimeElement = (
            <span style={{ backgroundColor: color, color: tinyColor(color).darken(25).toHexString() }}>
                {result.responseTime.toFixed(5)}
            </span>
        );
    } else {
        responseTimeElement = <i style={{ color: 'red' }}>Failed: {result.message}</i>;
    }

    return (
        <tr>
            <td>{entityName}</td>
            <td>{tool.name}</td>
            <td>{responseTimeElement}</td>
        </tr>
    );
};

export default CacheDashboard;
