import React, { useCallback, useEffect, useMemo, useState } from 'react';
import InterestingnessContext from './InterestingnessContext';
import { produce } from 'immer';
import _ from 'lodash';
import { Model } from 'types/nn-types/Model';
import { LayerGraphNode } from 'types/nn-types/LayerGraph';
import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import { InterestingnessType } from 'types/inspection-types/InterestingnessType';
import { Entity } from 'types/inspection-types/Entity';
import { max } from 'd3-array';
import { Group } from 'types/nn-types/Group';

function isNumber(maybeNumber: number | undefined): maybeNumber is number {
    return maybeNumber !== undefined;
}

function avg(interestingnessValues: (number | undefined)[]) {
    const values = interestingnessValues.filter(isNumber);
    return values.length > 0 ? _.sum(values) / values.length : undefined;
}

type EntityInterestingnessPair = [string, number | undefined];

const InterestingnessContextProvider = ({ children, groups }: { children: React.ReactNode; groups: Group[] }) => {
    const [activeInterestingnesses, setActiveInterestingnesses] = useState<InterestingnessType[]>([]);
    const [interestingnessPerEntityCache, setInterestingnessPerEntityCache] = useState<
        Record<string, number | undefined>
    >({});
    const [maxInnerInterestingnessCache, setMaxInnterInterestingnessCache] = useState<
        Record<string, number | undefined>
    >({});

    const toggleInterestingness = (interestingness: InterestingnessType) => {
        setActiveInterestingnesses((prevState) =>
            produce(prevState, (draftState) => {
                const index = draftState.findIndex((ai) => ai.equals(interestingness));

                if (index >= 0) {
                    draftState = draftState.filter((_, idx) => idx !== index);
                } else {
                    draftState = [...draftState, interestingness];
                }

                return draftState;
            }),
        );
    };

    const isInterestingnessActiveMemo = useCallback(
        (interestingnessType: InterestingnessType) =>
            activeInterestingnesses.some((ai) => ai.equals(interestingnessType)),
        [activeInterestingnesses],
    );

    const filterActiveInterestingessesMemo = useCallback(
        (entity: Entity) =>
            entity.interestingness.filter((d) => {
                const dType = InterestingnessType.fromDescriptor(d);
                return isInterestingnessActiveMemo(dType);
            }),
        [isInterestingnessActiveMemo],
    );

    const getValuesOfActiveInterestingnessesMemo = useCallback(
        (entity: Entity) => {
            // Get the active interestingnesses of the current entity.
            const activeDescriptors = filterActiveInterestingessesMemo(entity);

            // Get the values from the descriptors.
            return activeDescriptors.map((d) => Math.abs(d.value));
        },
        [filterActiveInterestingessesMemo],
    );

    const computeLayerInnardInterestingness = useCallback(
        (layerInnard: LayerGraphNode) => {
            // Get layer values of active layer innard interestingnesses.
            const values = getValuesOfActiveInterestingnessesMemo(layerInnard);
            // Compute average.
            return values.length > 0 ? _.sum(values) / values.length : undefined;
        },
        [getValuesOfActiveInterestingnessesMemo],
    );

    const computeLayerInterestingness = useCallback(
        (layer: ModelGraphNode, innerInterestingnesses: EntityInterestingnessPair[]) => {
            // Create a list of inner interestingnesses and the interestingness of this layer.
            const values = [
                ...innerInterestingnesses.map((ii) => ii[1]),
                ...getValuesOfActiveInterestingnessesMemo(layer),
            ];

            // Compute average.
            return avg(values);
        },
        [getValuesOfActiveInterestingnessesMemo],
    );

    const computeModelInterestingness = useCallback(
        (model: Model, innerInterestingnesses: EntityInterestingnessPair[]) => {
            // Create a list of inner interestingnesses and the interestingness of this model.
            const values = [
                ...innerInterestingnesses.map((ii) => ii[1]),
                ...getValuesOfActiveInterestingnessesMemo(model),
            ];

            // Compute average.
            return avg(values);
        },
        [getValuesOfActiveInterestingnessesMemo],
    );

    // Rebuild cache when active interestingnesses change.
    useEffect(() => {
        const interestingnessMap = new Map<string, number | undefined>();
        const maxInnerInterestingnessMap = new Map<string, number | undefined>();

        groups.forEach((group) => {
            const modelInterestingnesses: EntityInterestingnessPair[] = group.models.map((model) => {
                const layerInterestingnesses: EntityInterestingnessPair[] = model.graph.nodes.map((layer) => {
                    const layerInnardInterestingnesses: EntityInterestingnessPair[] = layer.innardGraph.nodes.map(
                        (layerInnard) => [layerInnard.id, computeLayerInnardInterestingness(layerInnard)],
                    );

                    // Save layer innard interestingnesses to map.
                    layerInnardInterestingnesses.forEach((lii) => interestingnessMap.set(lii[0], lii[1]));

                    // Save max inner interestingness for this layer.
                    maxInnerInterestingnessMap.set(
                        layer.id,
                        max(layerInnardInterestingnesses.map((lii) => lii[1]).filter(isNumber)),
                    );

                    // Calculate layer interestingness.
                    return [layer.id, computeLayerInterestingness(layer, layerInnardInterestingnesses)];
                });

                // Save layer interestingnesses to map.
                layerInterestingnesses.forEach((li) => interestingnessMap.set(li[0], li[1]));

                // Save max inner interestingness for this model.
                maxInnerInterestingnessMap.set(
                    model.id,
                    max(layerInterestingnesses.map((li) => li[1]).filter(isNumber)),
                );

                // Calculate model interestingness.
                return [model.id, computeModelInterestingness(model, layerInterestingnesses)];
            });

            // Save model interestingnesses to map.
            modelInterestingnesses.forEach((mi) => interestingnessMap.set(mi[0], mi[1]));

            // Save max inner interestingness for this group.
            maxInnerInterestingnessMap.set(group.id, max(modelInterestingnesses.map((mi) => mi[1]).filter(isNumber)));
        });

        // Finally, save the newly created interestingness catalogs to state variables.
        setInterestingnessPerEntityCache(() => Object.fromEntries(interestingnessMap));
        setMaxInnterInterestingnessCache(() => Object.fromEntries(maxInnerInterestingnessMap));
    }, [computeLayerInnardInterestingness, computeLayerInterestingness, computeModelInterestingness, groups]);

    const getInterestingness = (entity: Entity) => {
        if (interestingnessPerEntityCache.hasOwnProperty(entity.id)) {
            return interestingnessPerEntityCache[entity.id];
        }

        return undefined;
    };

    const getMaxInnerInterestingness = (entity: Entity) => {
        if (maxInnerInterestingnessCache.hasOwnProperty(entity.id)) {
            return maxInnerInterestingnessCache[entity.id];
        }

        return undefined;
    };

    // Memoize the callbacks, so they do not lead to unnecessary re-renders.
    const toggleInterestingnessMemo = useCallback(toggleInterestingness, []);
    const getInterestingnessMemo = useCallback(getInterestingness, [interestingnessPerEntityCache]);
    const getMaxInnerInterestingnessMemo = useCallback(getMaxInnerInterestingness, [maxInnerInterestingnessCache]);

    // Memoize the value object itself, so it doesn't lead to unnecessary re-renders.
    const providerValueMemo = useMemo(
        () => ({
            activeInterestingnesses,
            toggleInterestingness: toggleInterestingnessMemo,
            isInterestingnessActive: isInterestingnessActiveMemo,
            getInterestingness: getInterestingnessMemo,
            getMaxInnerInterestingness: getMaxInnerInterestingnessMemo,
        }),
        [
            activeInterestingnesses,
            getInterestingnessMemo,
            getMaxInnerInterestingnessMemo,
            isInterestingnessActiveMemo,
            toggleInterestingnessMemo,
        ],
    );

    return <InterestingnessContext.Provider value={providerValueMemo}>{children}</InterestingnessContext.Provider>;
};

export default InterestingnessContextProvider;
