import { LayerGraph } from 'types/nn-types/LayerGraph';
import { EntityType } from 'types/inspection-types/EntityType';
import { Entity } from 'types/inspection-types/Entity';

export type FlatShape = (number | null)[];
export type Shape = ((number | null) | FlatShape)[];

export const isFlatShape = (s: number | Shape): s is FlatShape => {
    if (!Array.isArray(s)) {
        return false;
    }

    return s.every((v) => v === null || typeof v === 'number');
};

export type Structure =
    | EntityType.STRUCTURE_MULTI_BRANCH
    | EntityType.STRUCTURE_SKIP_CONNECTION
    | EntityType.STRUCTURE_STREAMLINE;

export interface ModelGraphNode extends Entity {
    parentModelId: string;
    kerasId: string;
    clsName: string;
    inputShape: Shape;
    outputShape: Shape;
    numParameter: number;
    config: Record<string, any>; // eslint-disable-line @typescript-eslint/no-explicit-any
    innardGraph: LayerGraph;
}

export interface ModelGraphLink {
    source: string;
    target: string;
}

export interface IModelGraph {
    directed: boolean;
    nodes: ModelGraphNode[];
    links: ModelGraphLink[];
}

export class ModelGraph implements IModelGraph {
    public directed = false;
    public nodes: ModelGraphNode[] = [];
    public links: ModelGraphLink[] = [];

    private static arePathsEqual(firstPath: ModelGraphLink[], secondPath: ModelGraphLink[]): boolean {
        if (firstPath.length !== secondPath.length) return false;

        const pathLength = firstPath.length;
        for (let linkIndex = 0; linkIndex < pathLength; linkIndex++) {
            if (firstPath[linkIndex] !== secondPath[linkIndex]) return false;
        }

        return true;
    }

    constructor(imodelGraph?: IModelGraph) {
        if (imodelGraph) {
            const { directed, nodes, links } = imodelGraph;
            this.directed = directed;
            this.nodes = nodes;
            this.links = links;
        }
    }

    /**
     * Returns all paths that belong to a certain type of structure.
     *
     * @param structure The type of structure to filter by.
     */
    getPathsBelongingTo(structure: Structure): ModelGraphLink[][] {
        const paths: ModelGraphLink[][] = [];

        if (structure === EntityType.STRUCTURE_STREAMLINE) {
            const allPaths = this.getPathsBetween(this.nodes[0], this.nodes[this.nodes.length - 1]).flat();
            const multiBranchPaths = this.getPathsBelongingTo(EntityType.STRUCTURE_MULTI_BRANCH).flat();
            const skipConnectionPaths = this.getPathsBelongingTo(EntityType.STRUCTURE_SKIP_CONNECTION).flat();

            const occupiedPaths = new Set([...multiBranchPaths, ...skipConnectionPaths]);

            paths.push(allPaths.filter((path) => !Array.from(occupiedPaths).includes(path)));
        } else {
            this.nodes
                .filter((node) => node.type === EntityType.LAYER_CONCATENATE)
                .forEach((node) => {
                    const structureOfNode = this.determineStructure(node);

                    if (structure === structureOfNode) {
                        const ancestors = this.getAncestors(node);
                        const nextCommonAncestor = this.getNextCommonAncestor(ancestors);

                        paths.push(...this.getPathsBetween(nextCommonAncestor as ModelGraphNode, node));
                    }
                });
        }

        return paths;
    }

    /**
     * Returns all paths that start at `start` and end at `end`.
     *
     * @param start The first node that all paths start with.
     * @param end The last node that all paths end with.
     * @param path The current state of the path being built, necessary for recursion. Leave empty for the initial call.
     */
    getPathsBetween(start: ModelGraphNode, end: ModelGraphNode, path: ModelGraphLink[] = []): ModelGraphLink[][] {
        const candidates = this.getChildren(start);

        if (start === end) return [path];

        return candidates
            .flatMap((candidate) =>
                this.getPathsBetween(candidate, end, [
                    ...path,
                    this.getLinkBetween(start, candidate) as ModelGraphLink,
                ]),
            )
            .filter((path) => path.length > 0);
    }

    /**
     * Returns the parents of `child` and `child` if it is the root of a graph.
     *
     * @param child The child to determine the parents of.
     */
    getParents(child: ModelGraphNode): ModelGraphNode[] {
        const parents: ModelGraphNode[] = [];

        this.links.forEach((link) => {
            if (child.id === link.target) {
                const sourceNode = this.getNodeByID(link.source);

                if (sourceNode) parents.push(sourceNode);
            }
        });

        return parents.length === 0 ? [] : parents;
    }

    /**
     * Returns the structures present in this model if any.
     */
    getStructures(): Structure[] {
        const structures: Structure[] = [];

        this.nodes.forEach((node) => {
            const structure = this.determineStructure(node);

            if (structure && !structures.includes(structure)) structures.push(structure);
        });

        return structures;
    }

    /**
     * Returns the structure present at a concatenation layer if any. Exists with `undefined` if `node` is not a concatenation layer.
     *
     * @param node A node that represents a concatenation layer.
     */
    private determineStructure(node: ModelGraphNode): Structure | undefined {
        if (node.type === EntityType.LAYER_CONCATENATE) {
            const ancestors = this.getAncestors(node);
            const nextCommonAncestor = this.getNextCommonAncestor(ancestors);
            const parents = this.getParents(node);

            if (node.inputShape.length === 2) {
                // A skip connection requires two branches incoming to a concatenation layer.
                if (
                    nextCommonAncestor &&
                    parents.some((parent) => parent.id === nextCommonAncestor.id) &&
                    !parents.every((parent) => parent.id === nextCommonAncestor.id)
                ) {
                    // The two branches have a common ancestor in the parent of one branch that is not a parent of the concatenation layer in the other branch.
                    return EntityType.STRUCTURE_SKIP_CONNECTION;
                } else {
                    return undefined;
                }
            } else {
                // A multi branch structure requires at least two branches incoming to a concatenation layer. The common ancestor of these branches is not a parent to the concatenation layer.
                if (nextCommonAncestor && parents.every((parent) => parent.id !== nextCommonAncestor.id))
                    return EntityType.STRUCTURE_MULTI_BRANCH;

                return undefined;
            }
        } else {
            // This will include false positives. However, we do not need a precise answer to determine structures.
            // The actual visualization of structures will use a more sound approach (refer to `getPathsBelongingTo`).
            return EntityType.STRUCTURE_STREAMLINE;
        }
    }

    /**
     * Returns all possible strains of ancestors from a layer of this graph to the input layer.
     *
     * @param node The node to build ancestor strains from.
     * @param strain The current state of the strain being built, necessary for recursion. Leave empty for the initial call.
     */
    private getAncestors(node: ModelGraphNode, strain: ModelGraphNode[] = []): ModelGraphNode[][] {
        const parents = this.getParents(node);

        if (parents.length === 1) {
            const parent = parents[0];

            if (node === parent) {
                // We have traversed the model back to the input layer.
                return [strain];
            } else {
                return this.getAncestors(parent, [...strain, parent]);
            }
        } else {
            // We assume that a network always starts with a single input layer, hence there is no need to check for the input layer.
            return parents.flatMap((parent) => this.getAncestors(parent, [...strain, parent]));
        }
    }

    /**
     * Returns the children of `parent` and an empty array for leaves of the graph.
     *
     * @param parent The parent to determine the children of.
     */
    private getChildren(parent: ModelGraphNode): ModelGraphNode[] {
        const children: ModelGraphNode[] = [];

        this.links.forEach((link) => {
            if (parent.id === link.source) {
                const targetNode = this.getNodeByID(link.target);

                if (targetNode) children.push(targetNode);
            }
        });

        return children;
    }

    /**
     * Returns a directed link between the adjacent nodes `source` and `target`, and `undefined` if there is no such link.
     *
     * @param source The node which the directed link originates from.
     * @param target The node which the directed link targets.
     */
    private getLinkBetween(source: ModelGraphNode, target: ModelGraphNode): ModelGraphLink | undefined {
        return this.links.find((link) => link.source === source.id && link.target === target.id);
    }

    /**
     * Returns the next common ancestor of two or more strains of ancestors and `undefined` if there is no such common ancestor.
     *
     * @param strains A sequence of ancestor from a layer to the input layer of this graph.
     */
    private getNextCommonAncestor(strains: ModelGraphNode[][]): ModelGraphNode | undefined {
        if (strains.length === 0) return undefined;

        return strains.reduce((union, strain) => [...union].filter((node) => strain.includes(node)))[0];
    }

    /**
     * Returns the node identified by `id` and `undefined` if there is no such node.
     *
     * @param id The ID of the node to find.
     */
    private getNodeByID(id: string): ModelGraphNode | undefined {
        return this.nodes.find((node) => node.id === id);
    }
}

// This is needed for runtime typechecks to enable iterating over interface keys (for type-guard)
class ModelGraphNodeImplementation implements ModelGraphNode {
    id = '';
    kerasId = '';
    name = '';
    clsName = '';
    inputShape: Shape = [];
    outputShape: Shape = [];
    numParameter = 0;
    config: Record<string, any> = {}; // eslint-disable-line @typescript-eslint/no-explicit-any
    innardGraph: LayerGraph = { nodes: [], links: [] };
    type: EntityType = EntityType.MISC;
    parentModelId = '';
    interestingness = [];
}

export const MODEL_GRAPH_NODE_INSTANCE = new ModelGraphNodeImplementation();

// Type-guard for the ModelGraphNode type
export function isModelGraphNode(entity: unknown): entity is ModelGraphNode {
    let result = true;

    Object.keys(MODEL_GRAPH_NODE_INSTANCE).forEach((key: string) => {
        result = result && (entity as any)[key] !== undefined; // eslint-disable-line @typescript-eslint/no-explicit-any
    });

    return result;
}
