import { isModel, Model } from 'types/nn-types/Model';
import { isLayerGraphNode, LayerGraphNode } from 'types/nn-types/LayerGraph';
import { isModelGraphNode, ModelGraphNode } from 'types/nn-types/ModelGraph';
import { EntityType } from 'types/inspection-types/EntityType';

export default function getEnclosedFilterTypes(
    entity: Model | ModelGraphNode | LayerGraphNode,
    recursive = true,
): EntityType[] {
    if (isModel(entity) && recursive) {
        // Get types hidden in inner model graph structures
        const nestedTypes = entity.graph.nodes.reduce((acc, modelGraphNode) => {
            return [...acc, ...getEnclosedFilterTypes(modelGraphNode)];
        }, [] as EntityType[]);

        // Ensure uniqueness of values by converting array to set and vice-versa
        return [...new Set(nestedTypes), ...new Set(entity.graph.getStructures())];
    } else if (isModelGraphNode(entity)) {
        const modelGraphNode: ModelGraphNode = entity;

        if (!recursive) return [modelGraphNode.type];

        // Get types hidden in layer innards
        const nestedTypes = modelGraphNode.innardGraph.nodes.reduce((acc, layerGraphNode) => {
            return [...acc, ...getEnclosedFilterTypes(layerGraphNode)];
        }, [] as EntityType[]);

        // Ensure uniqueness of values by converting array to set and vice-versa
        return [...new Set([modelGraphNode.type, ...nestedTypes])];
    } else if (isLayerGraphNode(entity)) {
        return [entity.type];
    }

    return [];
}
