import { ModelGraphNode } from 'types/nn-types/ModelGraph';
import { reduceShapeToNumber, reduceShapeToString } from 'tools/helpers';
import { EntityType } from 'types/inspection-types/EntityType';
import { LayerGraphNode } from 'types/nn-types/LayerGraph';

export function constructLayerInnardGraph(modelGraphNode: ModelGraphNode) {
    modelGraphNode.innardGraph = {
        nodes: [],
        links: [],
    };

    switch (modelGraphNode.clsName.toLowerCase()) {
        case 'dense':
            modelGraphNode.type = EntityType.LAYER_DENSE;
            constructDense(modelGraphNode);
            break;
        case 'conv2d':
            modelGraphNode.type = EntityType.LAYER_CONV2D;
            constructConv2D(modelGraphNode);
            break;
        case 'conv2dtranspose':
            modelGraphNode.type = EntityType.LAYER_CONV2DTRANSPOSE;
            constructConv2DTranspose(modelGraphNode);
            break;
        case 'input':
            modelGraphNode.type = EntityType.LAYER_INPUT;
            constructInput(modelGraphNode);
            break;
        case 'dropout':
            modelGraphNode.type = EntityType.LAYER_DROPOUT;
            constructDropout(modelGraphNode);
            break;
        case 'flatten':
            modelGraphNode.type = EntityType.LAYER_FLATTEN;
            constructFlatten(modelGraphNode);
            break;
        case 'reshape':
            modelGraphNode.type = EntityType.LAYER_RESHAPE;
            constructReshape(modelGraphNode);
            break;
        case 'concatenate':
            modelGraphNode.type = EntityType.LAYER_CONCATENATE;
            constructConcatenate(modelGraphNode);
            break;
        case 'maxpooling2d':
            modelGraphNode.type = EntityType.LAYER_MAXPOOLING2D;
            constructMaxPooling2D(modelGraphNode);
            break;
        case 'averagepooling2d':
            modelGraphNode.type = EntityType.LAYER_AVGPOOLING2D;
            constructAvgPooling2D(modelGraphNode);
            break;
        case 'relu':
        case 'leakyrelu':
        case 'softmax':
        case 'linear':
        case 'sigmoid':
        case 'tanh':
            modelGraphNode.type = EntityType.LAYER_ACTIVATION_FN;
            constructActivationFnLayer(modelGraphNode);
            break;
        default:
            modelGraphNode.type = EntityType.LAYER_MISC;
            constructMisc(modelGraphNode);
            break;
    }

    for (let i = 1; i < modelGraphNode.innardGraph.nodes.length; i++) {
        modelGraphNode.innardGraph.links.push({
            source: modelGraphNode.innardGraph.nodes[i - 1].id,
            target: modelGraphNode.innardGraph.nodes[i].id,
        });
    }
}

function constructDense(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Matmul',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.ALGEBRAIC_OP_MATMUL,
            description: [],
        }),
    );
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Kernel',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.VARIABLE_DENSE_KERNEL,
            description: [
                `In shape: ${reduceShapeToString(n.inputShape)}`,
                `Out shape: ${reduceShapeToString(n.outputShape)}`,
                `Initializer: ${n.config.kernel_initializer.class_name}`,
            ],
        }),
    );

    appendBias(n);
    appendActivation(n);
}

function constructConv2D(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Conv2D',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.ALGEBRAIC_OP_CONV2D,
            // description: [`Dim: ${reduceShapeToNumber(n.inputShape)} x ${reduceShapeToNumber(n.outputShape)}`],
            description: [],
        }),
    );
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Kernel',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.VARIABLE_CONV2D_KERNEL,
            description: [
                `Filters: ${n.config.filters}`,
                `Filter shape: ${n.config.kernel_size}`,
                `Padding: ${n.config.padding}`,
                `Strides: ${n.config.strides}`,
                `Initializer: ${n.config.kernel_initializer.class_name}`,
            ],
        }),
    );

    appendBias(n);
    appendActivation(n);
}

function constructConv2DTranspose(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Conv2DTranspose',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.ALGEBRAIC_OP_CONV2DTRANSPOSE,
            description: [`Dim: ${reduceShapeToNumber(n.inputShape)} x ${reduceShapeToNumber(n.outputShape)}`],
        }),
    );
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Kernel',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.VARIABLE_CONV2D_KERNEL,
            description: [
                `Filters: ${n.config.filters}`,
                `Filter shape: ${n.config.kernel_size}`,
                `Padding: ${n.config.padding}`,
                `Strides: ${n.config.strides}`,
                `Initializer: ${n.config.kernel_initializer.class_name}`,
            ],
        }),
    );

    appendBias(n);
    appendActivation(n);
}

function appendBias(n: ModelGraphNode) {
    if (n.config.use_bias) {
        n.innardGraph.nodes.push(
            LayerGraphNode.fromRecord({
                name: 'Add',
                parentModelId: n.parentModelId,
                parentModelGraphNodeId: n.id,
                type: EntityType.ALGEBRAIC_OP_ADD,
                description: [],
            }),
        );

        n.innardGraph.nodes.push(
            LayerGraphNode.fromRecord({
                name: 'Bias',
                parentModelId: n.parentModelId,
                parentModelGraphNodeId: n.id,
                type:
                    n.clsName.toLowerCase() === 'dense'
                        ? EntityType.VARIABLE_DENSE_BIAS
                        : EntityType.VARIABLE_CONV2D_BIAS,
                description: [`Initializer: ${n.config.bias_initializer.class_name}`],
            }),
        );
    }
}

function appendActivation(n: ModelGraphNode) {
    let type: EntityType;

    // If the activation function occurs as part of another layer type (e.g., Dense),
    // this information is stored in n.config.activation.
    // For layers that are the activation function themselves, we have to parse it from
    // the layer name.
    const activationFnName = n.config.activation ?? n.clsName.toLowerCase();

    switch (activationFnName) {
        case 'relu':
            type = EntityType.ACTIVATION_FN_RELU;
            break;
        case 'leaky_relu':
        case 'leakyrelu':
            type = EntityType.ACTIVATION_FN_LRELU;
            break;
        case 'softmax':
            type = EntityType.ACTIVATION_FN_SOFTMAX;
            break;
        case 'linear':
            type = EntityType.ACTIVATION_FN_LINEAR;
            break;
        case 'sigmoid':
            type = EntityType.ACTIVATION_FN_SIGMOID;
            break;
        case 'tanh':
            type = EntityType.ACTIVATION_FN_TANH;
            break;
        default:
            type = EntityType.ACTIVATION_FN_UNKNOWN;
            break;
    }

    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Activation',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: type,
            description: ['Activation Fn', `f(x): ${activationFnName}`],
        }),
    );
}

function constructInput(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Input',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_INPUT,
            description: [`Shape: ${reduceShapeToString(n.config.batch_input_shape)}`],
        }),
    );
}

function constructDropout(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Dropout',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_DROPOUT,
            description: [`Rate: ${n.config.rate}`],
        }),
    );
}

function constructFlatten(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Flatten',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_FLATTEN,
            description: [`From: ${reduceShapeToString(n.inputShape)}`, `To: ${reduceShapeToString(n.outputShape)}`],
        }),
    );
}

function constructReshape(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Reshape',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_RESHAPE,
            description: [`From: ${reduceShapeToString(n.inputShape)}`, `To: ${reduceShapeToString(n.outputShape)}`],
        }),
    );
}

function constructConcatenate(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Concat',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_CONCAT,
            description: [`Axis: ${n.config.axis}`],
        }),
    );
}

function constructMaxPooling2D(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'MaxPool2D',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_MAXPOOL2D,
            description: [
                `Pool size: ${n.config.pool_size}`,
                `Strides: ${n.config.strides}`,
                `Padding: ${n.config.padding}`,
            ],
        }),
    );
}

function constructAvgPooling2D(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'AvgPool2D',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.TENSOR_OP_AVGPOOL2D,
            description: [
                `Pool size: ${n.config.pool_size}`,
                `Strides: ${n.config.strides}`,
                `Padding: ${n.config.padding}`,
            ],
        }),
    );
}

function constructMisc(n: ModelGraphNode) {
    n.innardGraph.nodes.push(
        LayerGraphNode.fromRecord({
            name: 'Misc',
            parentModelId: n.parentModelId,
            parentModelGraphNodeId: n.id,
            type: EntityType.MISC,
            description: ['NOT IMPLEMENTED', n.clsName],
        }),
    );
}

function constructActivationFnLayer(n: ModelGraphNode) {
    appendActivation(n);
}
