import {
    PromisedWidgetDataEntity,
    PromisedWidgetDefinition,
    WidgetDataEntity,
} from 'types/inspection-types/WidgetDefinition';
import { WidgetType } from 'types/inspection-types/WidgetType';
import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import { Model } from 'types/nn-types/Model';
import { Tool } from 'types/inspection-types/Tool';
import BackendQueryEngine from 'tools/BackendQueryEngine';
import {
    deriveChildInnspectorHeader,
    getLastCheckpointStep,
    isInterpretableAsImage,
    matrixToBase64Image,
} from 'tools/helpers';
import { TransformSpec } from 'types/inspection-types/TransformSpec';
import * as mathjs from 'mathjs';
import { IWidgetContext } from 'App/WidgetContext';
import { CheckpointInfo } from 'types/nn-types/CheckpointInfo';
import { DataArray, isMatrix, isNumber, isString } from 'types/inspection-types/DataArray';
import { Matrix } from 'mathjs';

const addWidget = (model: Model, activeTool: Tool, addWidgetCb: IWidgetContext['addWidget']): Promise<void> => {
    switch (activeTool.id) {
        case Tool.PERFORMANCE_METRICS.id:
            return addPerformanceWidget(model, activeTool, addWidgetCb);
        case Tool.CLASSIFIER_SAMPLES.id:
        case Tool.CLASSIFIER_CORRECTLY_CLASSIFIED.id:
        case Tool.CLASSIFIER_WRONGLY_CLASSIFIED.id:
            return addInputOutputClassifierWidget(model, activeTool, addWidgetCb);
        case Tool.CLASSIFIER_CONFUSION_MATRIX.id:
            return addConfusionMatrixWidget(model, activeTool, addWidgetCb);
        case Tool.AUTOENCODER_SAMPLES.id:
            return addInputOutputAutoencoderWidget(model, activeTool, addWidgetCb);
        case Tool.CHECKPOINT_SIZE.id:
            return addCheckpointSizeWidget(model, activeTool, addWidgetCb);
        case Tool.NOTE.id:
        case Tool.BRANCH_MODEL.id:
            return addAnnotationWidget(model, activeTool, addWidgetCb);
        case Tool.MODEL_INFO_LENS.id:
            return addModelInfoWidget(model, activeTool, addWidgetCb);
        default:
            return Promise.reject(`No implementation for ${activeTool.name}.`);
    }
};

const addAnnotationWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const entityName: string = model.info.label;

    let initialText: DataArray = [];

    if (tool.id === Tool.BRANCH_MODEL.id) {
        initialText = [{ text: '```python\n' + deriveChildInnspectorHeader(model) + '\n```' }];
    }

    const entities: WidgetDataEntity[] = [
        {
            entity: model,
            entityName: entityName,
            color: model.preferences.baseColor,
            data: initialText,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.ANNOTATION, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addPerformanceWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const entities: WidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: model.checkpointCatalog.checkpoints
                .filter((chkpt) => chkpt.performanceStatistics !== undefined)
                .map((chkpt) => {
                    return {
                        step: chkpt.step,
                        ...chkpt.performanceStatistics,
                    };
                }),
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.SINGLE_ENTITY_MULTI_TIME, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addCheckpointSizeWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const entities: WidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: model.checkpointCatalog.checkpoints.map((chkpt) => {
                return {
                    step: chkpt.step,
                    checkpointFileSize: chkpt.filesize.value,
                };
            }),
            unit: model.checkpointCatalog.checkpoints[0].filesize.unit,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.SINGLE_ENTITY_MULTI_TIME, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addModelInfoWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const maxStep = Math.max(...model.checkpointCatalog.checkpoints.map((c) => c.step));
    const lastChkpt: CheckpointInfo =
        model.checkpointCatalog.checkpoints.find((chkpt) => chkpt.step === maxStep) ??
        model.checkpointCatalog.checkpoints[0];

    const entities: WidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: [
                {
                    label: 'Layers',
                    value: model.stats.numLayers,
                },
                {
                    label: 'Parameters',
                    value: model.stats.numTrainableParameters,
                },
                {
                    label: 'Training steps',
                    value: maxStep,
                },
                {
                    label: 'Creation time',
                    value: model.info.timestamp,
                },
                ...Object.entries(lastChkpt.performanceStatistics ?? {}).map(([k, v]) => ({ label: k, value: v })),
            ],
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.VERBALIZATION, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addInputOutputClassifierWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const lastCheckpointStep = getLastCheckpointStep(model);

    const transfromSpec: TransformSpec = [];

    if (tool.id == Tool.CLASSIFIER_CORRECTLY_CLASSIFIED.id) {
        transfromSpec.push({
            type: 'filter',
            query: 'y_label == y_prediction',
        });
    } else if (tool.id == Tool.CLASSIFIER_WRONGLY_CLASSIFIED.id) {
        transfromSpec.push({
            type: 'filter',
            query: 'y_label != y_prediction',
        });
    }

    // Execute limit operation after filter operations, to actually retrieve the maximum amount of samples
    transfromSpec.push({
        type: 'head',
        count: 60,
    });

    const widgetDataPromise = BackendQueryEngine.getSamplesAndActivations(
        model.id,
        lastCheckpointStep,
        ['x', 'y_label_categorical', 'y_prediction_categorical', 'y_prediction', 'y_label'],
        transfromSpec,
    ).then((data) => {
        if (
            isMatrix(data[0]['x']) &&
            isMatrix(data[0]['y_prediction_categorical']) &&
            isMatrix(data[0]['y_label_categorical'])
        ) {
            if (isNumber(data[0]['y_prediction']) || isString(data[0]['y_prediction'])) {
                return data.map((row, idx) => ({
                    sampleId: idx,
                    xImage: matrixToBase64Image(row['x'] as mathjs.Matrix),
                    yPredictionCategorical: (row['y_prediction_categorical'] as Matrix).toArray() as number[],
                    yLabelCategorical: (row['y_label_categorical'] as Matrix).toArray() as number[],
                    yPrediction: row['y_prediction'] as number | string,
                    yLabel: row['y_label'] as number | string,
                }));
            }
        }

        console.error('This does not seem to be an image classifier model. Widget will be empty.');
        return [];
    });

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: widgetDataPromise,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.INPUT_OUTPUT_COMPARISON_CLASSIFIER, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addConfusionMatrixWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const lastCheckpointStep = getLastCheckpointStep(model);

    const widgetDataPromise = BackendQueryEngine.getSamplesAndActivations(
        model.id,
        lastCheckpointStep,
        ['x', 'y_label_categorical', 'y_prediction_categorical', 'y_prediction', 'y_label'],
        [
            {
                type: 'confusionMatrix',
                trueAttribute: 'y_label',
                predAttribute: 'y_prediction',
            },
            {
                type: 'rename',
                attribute: 'y_label',
                as: 'true',
            },
            {
                type: 'rename',
                attribute: 'y_prediction',
                as: 'pred',
            },
            {
                type: 'rename',
                attribute: 'rate',
                as: 'value',
            },
        ],
    );

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: widgetDataPromise,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.CONFUSION_MATRIX, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

const addInputOutputAutoencoderWidget = async (model: Model, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const lastCheckpointStep = getLastCheckpointStep(model);

    const transfromSpec: TransformSpec = [
        {
            type: 'head',
            count: 60,
        },
    ];

    const widgetDataPromise = BackendQueryEngine.getSamplesAndActivations(
        model.id,
        lastCheckpointStep,
        ['x', 'y_label_categorical', 'y_prediction_categorical', 'y_label'],
        transfromSpec,
    ).then((data) => {
        if (
            isMatrix(data[0]['x']) &&
            isMatrix(data[0]['y_prediction_categorical']) &&
            isMatrix(data[0]['y_label_categorical'])
        ) {
            if (isInterpretableAsImage(data[0]['x']) && isInterpretableAsImage(data[0]['y_prediction_categorical'])) {
                return data.map((row, idx) => ({
                    sampleId: idx,
                    yPrediction: matrixToBase64Image(row['y_prediction_categorical'] as mathjs.Matrix),
                    yTarget: matrixToBase64Image(row['x'] as mathjs.Matrix),
                    yLabel: row['y_label'] as number,
                }));
            }
        }

        console.error('This does not seem to be an image auto-encoder model. Widget will be empty.');
        return [];
    });

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: model,
            entityName: model.info.label,
            color: model.preferences.baseColor,
            data: widgetDataPromise,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.INPUT_OUTPUT_COMPARISON_AUTOENCODER, tool, model, entities),
        LevelOfAbstraction.SINGLE_MODEL,
    );
};

export default addWidget;
