import { PromisedWidgetDataEntity, PromisedWidgetDefinition } from 'types/inspection-types/WidgetDefinition';
import { WidgetType } from 'types/inspection-types/WidgetType';
import { LevelOfAbstraction } from 'types/inspection-types/LevelOfAbstraction';
import { Tool } from 'types/inspection-types/Tool';
import BackendQueryEngine from 'tools/BackendQueryEngine';
import { TransformSpec } from 'types/inspection-types/TransformSpec';
import { getLastCheckpointStep, matrixToBase64Image } from 'tools/helpers';
import { IWidgetContext } from 'App/WidgetContext';
import { Neuron } from 'types/inspection-types/Neuron';
import { isMatrix } from 'types/inspection-types/DataArray';
import * as mathjs from 'mathjs';

const addWidget = (neuron: Neuron, activeTool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    switch (activeTool.id) {
        case Tool.MINIMIZING_SAMPLES.id:
        case Tool.MAXIMIZING_SAMPLES.id:
            addMaxMinSamplesWidget(neuron, activeTool, addWidgetCb);
            break;
        default:
            console.warn('No implementation for', activeTool.name);
            break;
    }
};

const addMaxMinSamplesWidget = (neuron: Neuron, tool: Tool, addWidgetCb: IWidgetContext['addWidget']) => {
    const lastCheckpointStep = getLastCheckpointStep(neuron.parentModel);

    const layerName = neuron.parentModelGraphNode.name;

    const transformSpec: TransformSpec = [
        // TODO: Filter for active classes. This will be made possible with #75.
        //       https://gitlab.dbvis.de/innspector/development-v2/-/issues/75
        // {
        //     type: 'filter',
        //     query: 'y_label in [1, 2]',
        // },
        {
            type: 'rename',
            attribute: layerName,
            as: 'activations',
        },
        {
            type: 'reduceToSingleDim',
            attribute: 'activations',
            operation: 'mean',
            firstOrLast: 'last',
        },
        {
            type: 'take',
            attribute: 'activations',
            indices: [neuron.index],
        },
        {
            type: 'squeeze',
            attribute: 'activations',
        },
        {
            type: 'sort',
            attribute: 'activations',
            ascending: tool.id === Tool.MINIMIZING_SAMPLES.id,
        },
        {
            type: 'head',
            count: 70,
        },
    ];

    const dataPromise = BackendQueryEngine.getSamplesAndActivations(
        neuron.parentModel.id,
        lastCheckpointStep,
        ['x', 'y_prediction', 'y_label', layerName],
        transformSpec,
    ).then((data) => {
        if (isMatrix(data[0]['x'])) {
            return data.map((row, idx) => ({
                sampleId: idx,
                xImage: matrixToBase64Image(row['x'] as mathjs.Matrix),
                yPrediction: row['y_prediction'],
                yLabel: row['y_label'],
            }));
        }

        console.error('The input to this model does not seem to be images. The widget will be empty.');
        return [];
    });

    const entities: PromisedWidgetDataEntity[] = [
        {
            entity: neuron,
            entityName: neuron.parentModel.name + '/' + neuron.parentModelGraphNode.name + '/' + neuron.name,
            color: neuron.parentModel.preferences.baseColor,
            data: dataPromise,
        },
    ];

    addWidgetCb(
        new PromisedWidgetDefinition(WidgetType.IMAGE_DATA_SAMPLES, tool, neuron, entities),
        LevelOfAbstraction.WEIGHTS_NEURONS,
    );
};

export default addWidget;
