import {
    PromisedWidgetDataEntity,
    PromisedWidgetDefinition,
    WidgetDataEntity,
} from 'types/inspection-types/WidgetDefinition';
import { WidgetType } from 'types/inspection-types/WidgetType';
import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import { Model } from 'types/nn-types/Model';
import { Tool } from 'types/inspection-types/Tool';
import BackendQueryEngine from 'tools/BackendQueryEngine';
import { isFlatShape, ModelGraphNode } from 'types/nn-types/ModelGraph';
import { TransformSpec } from 'types/inspection-types/TransformSpec';
import { getLastCheckpointStep } from 'tools/helpers';
import { IWidgetContext } from 'App/WidgetContext';
import _ from 'lodash';

const addWidget = (
    model: Model,
    modelGraphNode: ModelGraphNode,
    activeTool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
): Promise<void> => {
    switch (activeTool.id) {
        case Tool.NOTE.id:
            return addAnnotationWidget(model, modelGraphNode, activeTool, addWidgetCb);
        case Tool.PROJECTION_2D_SCATTERPLOT.id:
            return addScatterplotWidget(model, modelGraphNode, activeTool, addWidgetCb);
        case Tool.DISTRIBUTION_HISTOGRAM.id:
            return addHistogramWidget(model, modelGraphNode, activeTool, addWidgetCb);
        case Tool.DISTRIBUTION_MULTI_HISTOGRAM.id:
            return addMultiHistogramWidget(model, modelGraphNode, activeTool, addWidgetCb);
        case Tool.DISTRIBUTION_FEATURE_HISTOGRAM.id:
            return addFeatureHistogramWidget(model, modelGraphNode, activeTool, addWidgetCb);
        default:
            return Promise.reject(`No implementation for ${activeTool.name}.`);
    }
};

const addScatterplotWidget = async (
    model: Model,
    modelGraphNode: ModelGraphNode,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
) => {
    const lastCheckpointStep = getLastCheckpointStep(model);

    const layerName = modelGraphNode.name;

    const transformSpec: TransformSpec = [
        {
            type: 'project',
            algorithm: 'umap',
            attribute: layerName,
        },
        {
            type: 'rename',
            attribute: layerName,
            as: 'point',
        },
        {
            type: 'rename',
            attribute: 'y_label',
            as: 'label',
        },
    ];

    const dataPromise = BackendQueryEngine.getSamplesAndActivations(
        model.id,
        lastCheckpointStep,
        ['index', 'y_label', layerName],
        transformSpec,
    );

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: modelGraphNode,
            entityName: model.name + '/' + modelGraphNode.name,
            color: model.preferences.baseColor,
            data: dataPromise,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.SCATTERPLOT_2D, tool, modelGraphNode, entities),
        LevelOfAbstraction.LAYERS_UNITS,
    );
};

const addHistogramWidget = async (
    model: Model,
    modelGraphNode: ModelGraphNode,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
) => {
    const lastStep = getLastCheckpointStep(model);
    const layerName = modelGraphNode.name;

    const histo = BackendQueryEngine.getSamplesAndActivations(
        model.id,
        lastStep,
        [layerName],
        [
            {
                type: 'histogram',
                attribute: layerName,
                numbins: 20,
            },
        ],
    );

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: modelGraphNode,
            entityName: model.name + '/' + modelGraphNode.name,
            color: model.preferences.baseColor,
            data: histo,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.HISTOGRAM, tool, modelGraphNode, entities),
        LevelOfAbstraction.LAYERS_UNITS,
    );
};

const addMultiHistogramWidget = async (
    model: Model,
    modelGraphNode: ModelGraphNode,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
) => {
    const layerName = modelGraphNode.name;

    const chkptSteps = model.checkpointCatalog.checkpoints.map((chkpt) => chkpt.step);

    const histogramsPerStepPromises = chkptSteps.map(async (step) => {
        const histoAtStep = await BackendQueryEngine.getSamplesAndActivations(
            model.id,
            step,
            [layerName],
            [
                {
                    type: 'histogram',
                    attribute: layerName,
                    numbins: 20,
                },
            ],
        );

        histoAtStep.forEach((dataRow) => (dataRow['step'] = step));

        return histoAtStep;
    });

    const histograms = Promise.all(histogramsPerStepPromises).then((histograms) => _.flatten(histograms));

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: modelGraphNode,
            entityName: model.name + '/' + modelGraphNode.name,
            color: model.preferences.baseColor,
            data: histograms,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.MULTI_HISTOGRAM, tool, modelGraphNode, entities),
        LevelOfAbstraction.LAYERS_UNITS,
    );
};

const addFeatureHistogramWidget = async (
    model: Model,
    modelGraphNode: ModelGraphNode,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
): Promise<void> => {
    const layerName = modelGraphNode.name;
    const outputShape = modelGraphNode.outputShape;

    if (!isFlatShape(outputShape)) {
        return Promise.reject(
            `Output dimension 0 of ${layerName} is multi-dimensional. Cannot get histograms over its features.`,
        );
    }

    const outputDim = outputShape[0] ?? outputShape[1];

    if (outputDim === null) {
        return Promise.reject(
            `Output dimension 0 of ${layerName} is not determined. Cannot get histograms over its features.`,
        );
    }

    if (outputDim > 10) {
        return Promise.reject(
            `Output dimension 0 of ${layerName} is greater than 10 (= ${outputDim}). Refusing to get histograms over its features due to runtime complexity.`,
        );
    }

    const lastChkptStep = getLastCheckpointStep(model);

    const histogramsPerFeaturePromises = _.range(outputDim).map(async (outputDim) => {
        const histoForFeature = await BackendQueryEngine.getSamplesAndActivations(
            model.id,
            lastChkptStep,
            [layerName],
            [
                {
                    type: 'firstDimToColumns',
                    attribute: layerName,
                    columnPrefix: 'feature_',
                },
                {
                    type: 'delete',
                    attribute: layerName,
                },
                {
                    type: 'histogram',
                    attribute: `feature_${outputDim}`,
                    numbins: 20,
                },
            ],
        );

        histoForFeature.forEach((dataRow) => (dataRow['feature'] = `feature_${outputDim}`));

        return histoForFeature;
    });

    const histogramsPerStep = await Promise.all(histogramsPerFeaturePromises);
    const histograms = _.flatten(histogramsPerStep);

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: modelGraphNode,
            entityName: model.name + '/' + modelGraphNode.name,
            color: model.preferences.baseColor,
            data: Promise.all(histograms),
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.FEATURE_HISTOGRAM, tool, modelGraphNode, entities),
        LevelOfAbstraction.LAYERS_UNITS,
    );
};

const addAnnotationWidget = async (
    model: Model,
    modelGraphNode: ModelGraphNode,
    tool: Tool,
    addWidgetCb: IWidgetContext['addWidget'],
) => {
    const entityName: string = modelGraphNode.name;

    const entities: WidgetDataEntity[] = [
        {
            entity: modelGraphNode,
            entityName: entityName,
            color: model.preferences.baseColor,
            data: [],
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.ANNOTATION, tool, modelGraphNode, entities),
        LevelOfAbstraction.LAYERS_UNITS,
    );
};

export default addWidget;
