import React, { useContext, useState } from 'react';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import BackendQueryEngine from 'tools/BackendQueryEngine';
import ClassSelectionContext from 'App/ClassSelectionContext';
import { DataArray, isMatrix, isNumber } from 'types/inspection-types/DataArray';
import {
    generateScalarValuesTransformSpec,
    generateFilteredNodesTransformSpec,
} from 'App/InspectionPanel/L1LayerUnitComponent/transform-specification';
import L1LayerUnitContent from 'App/InspectionPanel/L1LayerUnitComponent/L1LayerUnitContent';
import { EntityType } from 'types/inspection-types/EntityType';
import { NeuronConnection } from 'types/inspection-types/NeuronConnection';
import { Matrix } from 'mathjs';
import NeuronClassPixelVis from 'App/InspectionPanel/L1LayerUnitComponent/NeuronClassPixelVis';
import WeightNetworkSettings from './Weights/WeightNetworkSettings';
import * as mathjs from 'mathjs';
import { Neuron } from 'types/inspection-types/Neuron';
import { getLastCheckpointStep } from 'tools/helpers';
import { InspectionLayerProps } from 'types/inspection-types/InspectionLayerProps';
import ModelContext from 'App/ModelContext';

interface ActivationData {
    neurons: Neuron[];
    neuronConnections: NeuronConnection[];
    currentActivationsScalar: DataArray;
    correspondingClasses: Array<number | string>;
}

interface Props extends InspectionLayerProps {
    modelGraphNode: ModelGraphNode;
}

const L1LayerUnitComponent: React.FunctionComponent<Props> = ({
    onReady,
    modelGraphNode: currentMGN,
    ...props
}: Props) => {
    const { model } = useContext(ModelContext);
    const { selectedClasses } = React.useContext(ClassSelectionContext);

    const parentMGNs = model.graph.getParents(currentMGN);
    const previousMGN = parentMGNs.length > 0 ? parentMGNs[0] : undefined;

    const [activationData, setActivationData] = useState<ActivationData | undefined>(undefined);
    const [numberOfConnections, setNumberOfConnections] = useState<number>(100);

    React.useEffect(() => {
        const step = getLastCheckpointStep(model);

        const scalarValuesTransformSpec = generateScalarValuesTransformSpec(selectedClasses);

        const currentScalarPromise = BackendQueryEngine.getMeanActivations(
            model.id,
            step,
            currentMGN.name,
            selectedClasses,
            scalarValuesTransformSpec,
        );

        const connectionsPromise = [EntityType.LAYER_DENSE, EntityType.LAYER_CONV2D].includes(currentMGN.type)
            ? BackendQueryEngine.getWeightsSubset(model.id, step, currentMGN.name, 'kernel', numberOfConnections)
            : Promise.resolve([] as DataArray);

        Promise.all([currentScalarPromise, connectionsPromise]).then(([currentScalar, connections]) => {
            if (currentScalar.length > 0) {
                // Transform weight data to a more readable datastructure
                const neuronConnections: NeuronConnection[] = previousMGN
                    ? connections.map((row) => {
                          const fromIndex = (row['index'] as Matrix).get([0]);
                          const toIndex = (row['index'] as Matrix).get([1]);

                          return {
                              fromIndex,
                              fromId: `${previousMGN.id}_${fromIndex}`,
                              toIndex,
                              toId: `${currentMGN.id}_${toIndex}`,
                              weight: isNumber(row['value'])
                                  ? row['value']
                                  : isMatrix(row['value'])
                                  ? mathjs.mean(row['value'].toArray() as number[])
                                  : 0,
                              data: row['value'],
                          };
                      })
                    : [];

                // Query nodes that belong to the top-n connections
                const filteredPreviousNodes = neuronConnections.map((connection) => connection.fromIndex);
                const filteredCurrentNodes = neuronConnections.map((connection) => connection.toIndex);

                const filteredPreviousNodesTransformSpec = generateFilteredNodesTransformSpec(
                    selectedClasses,
                    filteredPreviousNodes,
                );
                const filteredCurrentNodesTransformSpec = generateFilteredNodesTransformSpec(
                    selectedClasses,
                    filteredCurrentNodes,
                );

                const previousFilteredNodesPromise =
                    previousMGN && previousMGN !== currentMGN
                        ? BackendQueryEngine.getMeanActivations(
                              model.id,
                              step,
                              previousMGN.name,
                              selectedClasses,
                              filteredPreviousNodesTransformSpec,
                          )
                        : Promise.resolve([] as DataArray);
                const currentFilteredNodesPromise = BackendQueryEngine.getMeanActivations(
                    model.id,
                    step,
                    currentMGN.name,
                    selectedClasses,
                    filteredCurrentNodesTransformSpec,
                );

                Promise.all([previousFilteredNodesPromise, currentFilteredNodesPromise]).then(
                    ([filteredPrevious, filteredCurrent]) => {
                        const neurons: Neuron[] = previousMGN
                            ? filteredPrevious.map((row) => {
                                  const { id, meanSummedScalar, ...data } = row;
                                  return {
                                      id: `${previousMGN.id}_${id as number}`,
                                      type: EntityType.SINGLE_NEURON,
                                      name: `Neuron ${id as number}`,
                                      interestingness: [],
                                      parentModel: model,
                                      parentModelGraphNode: previousMGN,
                                      data: data,
                                      activation: meanSummedScalar as number,
                                      index: id as number,
                                      rank: 0,
                                  };
                              })
                            : [];
                        neurons.push(
                            ...filteredCurrent.map((row) => {
                                const { id, meanSummedScalar, ...data } = row;
                                return {
                                    id: `${currentMGN.id}_${id as number}`,
                                    type: EntityType.SINGLE_NEURON,
                                    name: `Neuron ${id as number}`,
                                    interestingness: [],
                                    parentModel: model,
                                    parentModelGraphNode: currentMGN,
                                    data: data,
                                    activation: meanSummedScalar as number,
                                    index: id as number,
                                    rank: 1,
                                };
                            }),
                        );

                        // Finally, set data as state
                        setActivationData({
                            neurons,
                            neuronConnections,
                            currentActivationsScalar: currentScalar,
                            correspondingClasses: selectedClasses,
                        });
                    },
                );
            } else {
                setActivationData(undefined);
            }
        });
    }, [model, currentMGN, numberOfConnections, selectedClasses, previousMGN]);

    // After rendering, notify parent that this layer now has its final size
    React.useEffect(() => {
        // This lofa's visual representation is considered complete, if all data is queried. => Notify parent component.
        activationData && onReady();
    }, [activationData, onReady]);

    return (
        <g>
            {activationData && (
                <>
                    <L1LayerUnitContent
                        classes={activationData.correspondingClasses}
                        neurons={activationData.neurons}
                        neuronConnections={activationData.neuronConnections}
                        currentMGN={currentMGN}
                        previousMGN={previousMGN}
                        {...props}
                    />
                    <WeightNetworkSettings
                        marginTop={40}
                        numberOfConnections={numberOfConnections}
                        setNumberOfConnections={setNumberOfConnections}
                    />
                    <NeuronClassPixelVis
                        visWidth={300}
                        visHeight={400}
                        marginTop={210}
                        classes={activationData.correspondingClasses}
                        data={activationData.currentActivationsScalar}
                    />
                </>
            )}
        </g>
    );
};
export default L1LayerUnitComponent;
